Utilities

In [None]:
# %pip install xarray[complete] netcdf4 h5netcdf
# %pip install matplotlib
# %pip install numpy
# %pip install pandas
# %pip install scipy
# %pip install dask
# %pip install tensorflow --user
# %pip install scikit-learn

In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras

from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Add, Conv2DTranspose, Concatenate, concatenate, AveragePooling2D, UpSampling2D
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

Dataset load and train/test splitting

In [None]:
# LOAD DATASETS

# dataset = np.load('dati/mist/datasets/dataset.npy')
# date = np.load('dati/mist/datasets/date.npy')

dataset_d = np.load('dati/mist/datasets/dataset_d.npy')
dataset_n = np.load('dati/mist/datasets/dataset_n.npy')
date_d = np.load('dati/mist/datasets/date_d.npy')
date_n = np.load('dati/mist/datasets/date_n.npy')

# baseline = np.load('dati/mist/datasets/baseline.npy')
baseline_d = np.load('dati/mist/datasets/baseline_d.npy')
baseline_n = np.load('dati/mist/datasets/baseline_n.npy')

italy_mask = np.load('dati/mist/datasets/italy_mask.npy')
data_min = np.load('dati/mist/datasets/data_min.npy')
data_max = np.load('dati/mist/datasets/data_max.npy')

In [None]:
# Split the day and the night dataset into training and testing sets

train_indices_d, temp_indices_d = train_test_split(np.arange(dataset_d.shape[0]), test_size=0.25, random_state=42)
val_indices_d, test_indices_d = train_test_split(temp_indices_d, test_size=0.4, random_state=42)

x_train_d = dataset_d[train_indices_d]
x_val_d = dataset_d[val_indices_d]
x_test_d = dataset_d[test_indices_d]

dates_train_d = date_d[train_indices_d]
dates_val_d = date_d[val_indices_d]
dates_test_d = date_d[test_indices_d]


train_indices_n, temp_indices_n = train_test_split(np.arange(dataset_n.shape[0]), test_size=0.25, random_state=42)
val_indices_n, test_indices_n = train_test_split(temp_indices_n, test_size=0.4, random_state=42)

x_train_n = dataset_n[train_indices_n]
x_val_n = dataset_n[val_indices_n]
x_test_n = dataset_n[test_indices_n]

dates_train_n = date_n[train_indices_n]
dates_val_n = date_n[val_indices_n]
dates_test_n = date_n[test_indices_n]

Loss, Metrics and Hyperparameters

In [None]:
def customLoss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)

    clear_mask = y_true[:,:,:,1:2]      # 0 for land/clouds, 1 for clear sea
    y_true = y_true[:,:,:,0:1]          # The true SST values. Obfuscated areas are already converted to 0
    
    # Calculate the squared error only over clear sea
    squared_error = tf.square(y_true - y_pred)
    clear_masked_error = squared_error * clear_mask

    # Calculate the mean of the masked errors
    clear_loss = tf.reduce_mean(clear_masked_error)     # The final loss

    return clear_loss

In [None]:
def ClearMetric(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)

    clear_mask = y_true[:,:,:,1:2]  # 0 for land/clouds, 1 for clear sea
    y_true = y_true[:,:,:,0:1]  # The true SST values. Obfuscated areas are already converted to 0

    # Calculate the squared error only over clear sea
    squared_error = tf.square(y_true - y_pred)
    clear_masked_error = squared_error * clear_mask
    # Calculate the mean of the masked errors
    clr_metric = tf.reduce_sum(clear_masked_error) / tf.reduce_sum(clear_mask)

    #loss = clear_loss
    return clr_metric

In [None]:
def ArtificialMetric(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)    # Was getting an error because of the different types: y_true in the metrics is float64 instead of the normal float32

    artificial_mask = y_true[:,:,:,2:3]  # 1 for artificial clouds, 0 for the rest
    y_true = y_true[:,:,:,0:1]  # The true SST values. Obfuscated areas are already converted to 0

    # Calculate the squared error only over artificially clouded areas
    squared_error = tf.square(y_true - y_pred)
    artificial_masked_error = squared_error * artificial_mask
    # Calculate the mean of the masked errors
    art_metric = tf.reduce_sum(artificial_masked_error) / tf.reduce_sum(artificial_mask)

    return art_metric

In [None]:
# Hyperparameters

epochs=100
batch_size=32

lr = 1e-4

loss = customLoss
metrics = [ClearMetric, ArtificialMetric]
early_stop = EarlyStopping(monitor='val_loss', patience=10, verbose=1)

steps_per_epoch = min(100, len(x_train_d) // batch_size)
validation_steps = 20
testing_steps = 20

input_shape = (256, 256, 4)

Generator

In [None]:
# Baseline generator function

def baseline_generator(batch_size, data_d, data_n, dates_d, dates_n, dayChance=0.5):
    while True:
        batch_x = np.zeros((batch_size, 256, 256, 4))
        batch_y = np.zeros((batch_size, 256, 256, 3))

        #Randomly choose between day and night dataset
        (dataset, date, baseline) = (data_d, dates_d, baseline_d) if np.random.rand() < dayChance else (data_n, dates_n, baseline_n)

        for b in range(batch_size):
            # Choose a random index as the current day, and 3 random indices
            i, r1, r2, r3= np.random.randint(0, dataset.shape[0], 4)

            # Extract the image and mask from the current day, and the masks from the other days
            image_current = np.nan_to_num(dataset[i], nan=0)
            mask_current = np.isnan(dataset[i])
            mask_r1 = np.isnan(dataset[r1])
            mask_r2 = np.isnan(dataset[r2])
            mask_r3 = np.isnan(dataset[r3])

            # Perform OR operation between masks
            mask_or_r1 = np.logical_or(mask_current, mask_r1)
            mask_or_r2 = np.logical_or(mask_current, mask_r2)
            mask_or_r3 = np.logical_or(mask_current, mask_r3)
            #choose the middle mask
            masks = [mask_or_r1, mask_or_r2, mask_or_r3]
            masks.sort(key=np.sum)
            artificial_mask = masks[1] # The mask with the medium amount of coverage

            # Apply the amplified mask to the current day's image
            image_masked = np.where(artificial_mask, 0, image_current)
            
            # Convert the current date to a datetime object using pandas
            date_series = pd.to_datetime(date[i], unit='D', origin='unix')
            day_of_year = date_series.dayofyear

            # Fix masks before they are used in the loss and metric functions
            artificial_mask = np.logical_xor(artificial_mask, mask_current)  # 1 for artificially obfuscated, 0 for the rest
            mask_current = np.logical_not(mask_current) # 1 for clear sea, 0 for land/clouds
            
            # Create batch_x and batch_y
            batch_x[b, ..., 0] = image_masked               #artificially cloudy image
            batch_x[b, ..., 1] = mask_current               #real mask
            batch_x[b, ..., 2] = italy_mask                 #land-sea mask
            batch_x[b, ..., 3] = baseline[day_of_year - 1]  #baseline values for the current day (day_of_year starts from 1)

            batch_y[b, ..., 0] = image_current              #real image
            batch_y[b, ..., 1] = mask_current               #real mask
            batch_y[b, ..., 2] = artificial_mask            #artificial mask used for the input
        
        yield batch_x, batch_y

In [None]:
# Create the generators

train_gen = baseline_generator(batch_size, x_train_d, x_train_n, dates_train_d, dates_train_n)
val_gen = baseline_generator(batch_size, x_val_d, x_val_n, dates_val_d, dates_val_n)
test_gen = baseline_generator(batch_size, x_test_d, x_test_n, dates_test_d, dates_test_n)

# Test generator that returns dates
#test_gen_dates = gen_qual_dates(batch_size, x_test_d, x_test_n, q_test_d, q_test_n, dates_test_d, dates_test_n)

In [None]:
# Test the generator

x,y = next(train_gen)
r = np.random.randint(0, batch_size)    # Choose a random image from the batch

# Plot the image
plt.figure(figsize=(8, 8))

# Plot the x data
plt.subplot(2, 2, 1)
plt.imshow(x[r, :, :, 0], cmap='jet')
plt.title("x_0 (model input)")
plt.colorbar()

# Plot the y data
plt.subplot(2, 2, 2)
plt.imshow(y[r, :, :, 0], cmap='jet')
plt.title("y_0 (ground truth)")
plt.colorbar()

plt.show()

# Information about the data
#print(np.isnan(x).any())
print("x.shape:", x.shape)
print("y.shape:", y.shape)
print("min of all x:", np.min(x[..., 0]))
print("max of all x:", np.max(x[..., 0]))
print("min of this x:", np.min(x[r, :, :, 0]))
print("max of this x:", np.max(x[r, :, :, 0]))

Model and Training

In [None]:
# Rest of the model's hyperparameters

#input_shape = (256, 256, 4)

image_size = 256  # We'll resize input images to this size
patch_size = 8  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 128    
num_heads = 6   # was 6
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 6  # was 6
mlp_head_units = [
    256,
    128,
]  # Size of the dense layers of the final classifier

In [None]:
# Visual Transformer model 

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        input_shape = tf.shape(images)
        batch_size = input_shape[0]
        height = input_shape[1]
        width = input_shape[2]
        channels = input_shape[3]
        num_patches_h = height // self.patch_size
        num_patches_w = width // self.patch_size
        patches = tf.image.extract_patches(images, sizes=[1,8,8,1], strides=[1,8,8,1], rates=[1,1,1,1], padding='VALID')
        patches = tf.reshape(
            patches,
            (
                batch_size,
                num_patches_h * num_patches_w,
                self.patch_size * self.patch_size * channels,
            ),
        )
        return patches

    def get_config(self):
        config = super().get_config()
        config.update({"patch_size": self.patch_size})
        return config

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.expand_dims(
            tf.experimental.numpy.arange(start=0, stop=self.num_patches, step=1), axis=0
        )
        projected_patches = self.projection(patch)
        encoded = projected_patches + self.position_embedding(positions)
        return encoded

    def get_config(self):
        config = super().get_config()
        config.update({"num_patches": self.num_patches})
        return config

# only used in postprocessing
def ResidualBlock(width):
            def apply(x):
                input_width = x.shape[3]
                if input_width == width:
                    residual = x
                else:
                    residual = layers.Conv2D(width, kernel_size=1)(x)
                #x = layers.BatchNormalization(center=False, scale=False)(x)
                x = layers.LayerNormalization(axis=-1,center=True, scale=True)(x)
                x = layers.Conv2D(
                    width, kernel_size=3, padding="same", activation=keras.activations.swish
                )(x)
                x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
                x = layers.Add()([x, residual])
                return x

            return apply



def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=keras.activations.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


def create_vit():
    inputs = keras.Input(shape=input_shape)
    # Create patches
    patches = Patches(patch_size)(inputs)
    # Encode patches
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    out = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    # Spatial reshape of the patches
    out = layers.Reshape((32, 32, 128))(out)
    # Residual block
    out = ResidualBlock(256)(out)
    # Upsample
    out = layers.Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same', activation='relu')(out)
    out = ResidualBlock(128)(out)
    # Upsample while reducing channel size
    out = layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu')(out)
    out = ResidualBlock(64)(out)
    out = layers.Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', activation='tanh')(out) #tanh

    model = keras.Model(inputs=inputs, outputs=out)
    return model


In [None]:
# Define the model, create it and print the summary

model = create_vit()
#model.summary()

In [None]:
# Compile model with custom loss function
opt = Adam(learning_rate=lr)
model.compile(optimizer=opt, loss=loss, metrics=metrics)

In [None]:
# LOAD WEIGHTS
# model.load_weights('weights/visualTransformer.weights.h5')   # does not work on local machine
model.load_weights('weights/visualTransformer.h5')

# SAVE WEIGHTS
#model.save_weights('weights/visualTransformer.h5')  # execute remotely after training

In [None]:
# Train model
#history = model.fit(train_gen, epochs=epochs, steps_per_epoch=steps_per_epoch, validation_data=val_gen, validation_steps=validation_steps, verbose=1, callbacks=[early_stop])

Experiment on Errors

In [None]:
#Loop error calculation over tot batches

# Initialize lists to store the average errors, maximum errors, and variances
avg_errors_list = []
avg_max_errors_list = []
var_max_errors_list = []

# Generate and evaluate tot batches
tot = 100
for _ in range(tot):
    # Generate a batch
    x_true, y_true = next(test_gen)
    predictions = model.predict(x_true)

    # Denormalize
    predictions_denorm = ((predictions[..., 0] + 1) / 2) * (data_max - data_min) + data_min
    true_values_denorm = ((y_true[..., 0] + 1) / 2) * (data_max - data_min) + data_min

    # Calculate the errors
    clearMask = y_true[..., 1]
    errors = np.where(clearMask, np.abs(predictions_denorm - true_values_denorm), np.nan)

    # Calculate the average and maximum error for each image in the batch
    avg_errors = np.nanmean(errors, axis=(1, 2))
    max_errors = np.nanmax(errors, axis=(1, 2))

    # Add the average error, average maximum error, and variance of maximum errors to the lists
    avg_errors_list.append(np.mean(avg_errors))
    avg_max_errors_list.append(np.mean(max_errors))
    var_max_errors_list.append(np.var(max_errors))

# Print the average, average maximum, and variance of maximums calculated over tot batches
print(f"Average error over {tot} batches:", np.mean(avg_errors_list))
print(f"Average maximum error over {tot} batches:", np.mean(avg_max_errors_list))
print(f"Variance of maximum errors over {tot} batches:", np.mean(var_max_errors_list))

Model Evaluation

In [None]:
# Utility to show loss and metrics

# Evaluate the model on the test data
x_true, y_true = next(test_gen)
results = model.evaluate(x_true, y_true)

In [None]:
# Evaluate the model

# Generate predictions. This generates a batch of data.
x_true, y_true = next(test_gen)
print("x_true shape:", x_true.shape)
print("y_true shape:", y_true.shape)
print("is there a Nan in x?", np.isnan(x_true).any())
print("is there a Nan in y?", np.isnan(y_true).any())

predictions = model.predict(x_true)

print("--------------------")

print("y's min:", np.min(y_true[:, :, :, 0]))
print("y's max:", np.max(y_true[:, :, :, 0]))
print("x's min:", np.min(predictions[:, :, :, 0]))
print("x's max:", np.max(predictions[:, :, :, 0]))

print("--------------------")

evalx = model.evaluate(x_true, y_true)
print("evalx: ", evalx)
xloss = customLoss(y_true, predictions)
print("xloss: ", xloss)

print("--------------------")

#get the coordinates of min and max values in a single prediction. This is to check if the model is predicting the same values as the true ones
coordxmin = np.argmin(predictions[0, :, :, 0])
coordxmax = np.argmax(predictions[0, :, :, 0])
print("first x's min:", coordxmin%256, coordxmin//256, np.nanmin(predictions[0, :, :, 0]))
print("first x's max:", coordxmax%256, coordxmax//256, np.nanmax(predictions[0, :, :, 0]))
print("predictions in coordxmin:", predictions[0, coordxmin//256, coordxmin%256, 0])
print("predictions in coordxmax:", predictions[0, coordxmax//256, coordxmax%256, 0])

print("--------------------")

# Plot the predictions and true values

for i in range(10):
    plt.figure(figsize=(20, 8))

    # Plot the true value
    plt.subplot(1, 3, 1)
    mask_overlay = np.where(y_true[i, :, :, 1], y_true[i, :, :, 0], np.nan)
    plt.imshow(mask_overlay, cmap='jet', vmin=-1, vmax=1)
    plt.title("y_0 (Ground Truth)")
    plt.colorbar()

    # Plot the prediction with the land mask
    plt.subplot(1, 3, 2)
    masked_prediction = np.where(italy_mask, predictions[i, :, :, 0], np.nan)
    plt.imshow(masked_prediction, cmap='jet', vmin=-1, vmax=1)
    plt.title("Prediction with Land Mask")
    plt.colorbar()

    # # Plot the predicted 'pure' value 
    # plt.subplot(1, 3, 3)
    # plt.imshow(predictions[i], cmap='jet', vmin=-1, vmax=1)
    # plt.title("Unmasked prediction (DEBUG)")
    # plt.colorbar()

    plt.show()

Baseline comparison

In [None]:
# Calculate the MSE for the predictions and the baseline

batch_x, batch_y = next(test_gen)
predictions = model.predict(batch_x)

filter_mask = batch_y[..., 1].astype(bool)    # We filter out the land and cloud data


# Calculate the MSE for the predictions and the baseline, and the average MSEs in that batch. Only consider the ocean data.
mse_predictions = [mean_squared_error(batch_y[i, :, :, 0][filter_mask[i]], predictions[i, :, :, 0][filter_mask[i]]) for i in range(batch_size)]
mse_baseline = [mean_squared_error(batch_y[i, :, :, 0][filter_mask[i]], batch_x[i, :, :, 3][filter_mask[i]]) for i in range(batch_size)]
print('MSE for predictions:', mse_predictions)
print('MSE for baseline:', mse_baseline)

avg_mse_predictions = np.mean(mse_predictions)
avg_mse_baseline = np.mean(mse_baseline)
print('Average MSE for predictions:', avg_mse_predictions)
print('Average MSE for baseline:', avg_mse_baseline)

# Plot the MSEs
indices = np.arange(batch_size + 1)   # Add 1 to the batch size to include the average MSEs

fig, ax = plt.subplots()
ax.bar(indices[:-1] - 0.2, mse_predictions, width=0.4, label='Predictions')
ax.bar(indices[:-1] + 0.2, mse_baseline, width=0.4, label='Baseline')
ax.bar(batch_size - 0.2, avg_mse_predictions, width=0.4, color='blue', label='Avg of Predictions')
ax.bar(batch_size + 0.2, avg_mse_baseline, width=0.4, color='red', label='Avg of Baseline')

ax.set_xlabel('Batch Index')
ax.set_ylabel('MSE')
ax.set_title('MSE for Predictions and Baseline')
ax.legend()

plt.show()
#fig.savefig("mseGraph.png")