In [1]:
import numpy as np

In [2]:
import keras.layers
import tensorflow as tf
import keras.backend as K
import keras.optimizers

In [3]:
def loss_function_NN(data:tf.Tensor, header:list, classification:tf.Tensor) -> tf.Tensor:
    '''
    The function to calculate and tell the model how bad it performs prediction.

    Args: 
        data: a tensor(tensorflow will automatically convert numpy array to tensorflow array),
        recording the data of a batch. 1st index represents a data point, 2nd index represents
        the values of each column at that grid point

        header: a list storing the name of variables telling how variable temperature_gradient,
        velocity_magnitude_gradient, z_velocity_gradient are correlated the 2nd index of the data.

        classification: a 1D tensor, storing classification results between 0-1 for each grid points for this batch.

    Returns: 
        loss: a SCALAR(single-value) tensor representing how bad a model predicts in a batch. A point with
        high gradient and low classification value, or low gradient and high classification value will contribute
        to higher loss. The loss wil also be high if the classificaition is close to 0.5, to encourage certain classification results.
    '''
    # Extract the indices of the gradients from the header
    temp_grad_idx = header.index('temperature_gradient')
    vel_mag_grad_idx = header.index('velocity_magnitude_gradient')
    z_vel_grad_idx = header.index('z_velocity_gradient')

    # Extract the gradient values from the data tensor
    temperature_gradient = data[:, temp_grad_idx]
    velocity_magnitude_gradient = data[:, vel_mag_grad_idx]
    z_velocity_gradient = data[:, z_vel_grad_idx]

    # Calculate the primary gradient loss
    gradient_sum = temperature_gradient + velocity_magnitude_gradient + z_velocity_gradient
    loss_high_class_low_grad = classification * (1 - gradient_sum)
    loss_low_class_high_grad = (1 - classification) * gradient_sum

    primary_loss = tf.reduce_mean(loss_high_class_low_grad + loss_low_class_high_grad)
    
    # Add regularization loss to encourage certain properties in classification
    regularization_loss = tf.reduce_mean(tf.square(classification - 0.5))
    
    # Total loss
    loss = primary_loss + regularization_loss

    return loss

In [4]:
class CustomMasking(keras.layers.Layer):
    '''
    Add a mask(Boolean mark) for points with NaN values.
    '''
    def __init__(self, mask_value=np.nan, **kwargs):
        super(CustomMasking, self).__init__(**kwargs)
        self.mask_value = mask_value

    def build(self, input_shape):
        super(CustomMasking, self).build(input_shape)

    def call(self, inputs):
        mask = K.not_equal(inputs, self.mask_value)
        return tf.convert_to_tensor(K.switch(mask, inputs, K.constant(np.nan, dtype=inputs.dtype)))

    def compute_output_shape(self, input_shape):
        return input_shape

In [5]:
class NaNHandlingLayer(keras.layers.Layer):
    '''
    Let points with NaN values have NaN classification result.
    '''
    def call(self, inputs):
        non_nan_mask = tf.math.logical_not(tf.math.is_nan(inputs))
        outputs = tf.convert_to_tensor(tf.where(non_nan_mask, inputs, tf.constant(np.nan, dtype=inputs.dtype)))
        return outputs

In [6]:
header = ['temperature_gradient','velocity_magnitude_gradient','z_velocity_gradient']

In [23]:
input = np.array([
    [0.9, 0.5, 0.4],
    [0.7, 0.3, 0.5],
    [np.nan, np.nan, np.nan],
    [0.1, 0.4, 0.5],
    [0.2, 0.6, 0.1],
    [0.9, 0.8, 0.5],
    [np.nan, np.nan, np.nan],
    [0.3, 0.5, 0.1]
])

In [24]:
class CustomModel(tf.keras.Model):
    def __init__(self, header):
        super(CustomModel, self).__init__()
        self.header = header
        self.masking = CustomMasking(mask_value=np.nan)
        self.nan_handling = NaNHandlingLayer()
        self.dense1 = keras.layers.Dense(len(header), activation='relu')
        self.bn1 = keras.layers.BatchNormalization()
        self.dense2 = keras.layers.Dense(len(header), activation='relu')
        self.bn2 = keras.layers.BatchNormalization()
        self.dense3 = keras.layers.Dense(len(header), activation='relu')
        self.output_layer = keras.layers.Dense(1, activation='sigmoid')

    def call(self, inputs):
        x = self.masking(inputs)
        x = self.nan_handling(x)
        x = self.dense1(x)
        x = self.bn1(x)
        x = self.dense2(x)
        x = self.bn2(x)
        x = self.dense3(x)
        return self.output_layer(x)

    def train_step(self, data):
        if isinstance(data, tuple):
            x = data[0]
        else:
            x = data
        
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = loss_function_NN(x, self.header, y_pred)
        
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        return {"loss": loss}

# Create and compile the model
model = CustomModel(header)
model.compile(optimizer=tf.keras.optimizers.Adam())

class LossHistory(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))

# Train the model
loss_hist = LossHistory()
input_tensor = tf.convert_to_tensor(input, dtype='float32')
history = model.fit(input_tensor, batch_size=1, epochs=2, callbacks=[loss_hist])

Epoch 1/2
Epoch 2/2
