Imports

In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras

from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Add, Conv2DTranspose, Concatenate, concatenate, AveragePooling2D, UpSampling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping

from sklearn.model_selection import train_test_split

In [None]:
# LOAD DATASETS

date_d = np.load('dati/mist/datasets/date_d.npy')
date_n = np.load('dati/mist/datasets/date_n.npy')
dataset_d = np.load('dati/mist/datasets/dataset_d.npy')
dataset_n = np.load('dati/mist/datasets/dataset_n.npy')

# abs_dataset_d = np.load('dati/mist/datasets/abs_dataset_d.npy')
# abs_dataset_n = np.load('dati/mist/datasets/abs_dataset_n.npy')
# training_baseline_d = np.load('dati/mist/datasets/training_baseline_d.npy')
# training_baseline_n = np.load('dati/mist/datasets/training_baseline_n.npy')

italy_mask = np.load('dati/mist/datasets/italy_mask.npy')
data_min = np.load('dati/mist/datasets/data_min.npy')
data_max = np.load('dati/mist/datasets/data_max.npy')

abs_new2_baseline_d = np.load('dati/mist/datasets/abs_new2_baseline_d.npy')
abs_new2_baseline_n = np.load('dati/mist/datasets/abs_new2_baseline_n.npy')
abs_new2_baseline_d = 2 * ((abs_new2_baseline_d - data_min) / (data_max - data_min)) - 1
abs_new2_baseline_n = 2 * ((abs_new2_baseline_n - data_min) / (data_max - data_min)) - 1

In [None]:
# Split the day and the night dataset into training and testing sets

train_indices_d, temp_indices_d = train_test_split(np.arange(dataset_d.shape[0]), test_size=0.25, random_state=42)
val_indices_d, test_indices_d = train_test_split(temp_indices_d, test_size=0.4, random_state=42)

x_train_d = dataset_d[train_indices_d]
x_val_d = dataset_d[val_indices_d]
x_test_d = dataset_d[test_indices_d]

dates_train_d = date_d[train_indices_d]
dates_val_d = date_d[val_indices_d]
dates_test_d = date_d[test_indices_d]


train_indices_n, temp_indices_n = train_test_split(np.arange(dataset_n.shape[0]), test_size=0.25, random_state=42)
val_indices_n, test_indices_n = train_test_split(temp_indices_n, test_size=0.4, random_state=42)

x_train_n = dataset_n[train_indices_n]
x_val_n = dataset_n[val_indices_n]
x_test_n = dataset_n[test_indices_n]

dates_train_n = date_n[train_indices_n]
dates_val_n = date_n[val_indices_n]
dates_test_n = date_n[test_indices_n]

Model and Generator Functions

In [None]:
def customLoss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)

    real_mask = y_true[:,:,:,1:2]        # 0 for land/clouds, 1 for sea
    y_true = y_true[:,:,:,0:1]          # The true SST values. Obfuscated areas are already converted to 0
    
    # Calculate the squared error only over clear sea
    squared_error = tf.square(y_true - y_pred)
    masked_error = squared_error * real_mask

    # Calculate the mean of the masked errors
    clear_loss = tf.reduce_mean(masked_error)     # The final loss

    return clear_loss

In [None]:
def ClearMetric(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)

    art_mask = y_true[:,:,:,2:3]   # 0 for land/clouds + artificials, 1 for clear, untouched sea
    y_true = y_true[:,:,:,0:1]  # The true SST values. Obfuscated areas are already converted to 0

    # Calculate the squared error only over clear sea
    squared_error = tf.square(y_true - y_pred)
    clear_masked_error = squared_error * art_mask
    # Calculate the mean of the masked errors
    clr_metric = tf.reduce_sum(clear_masked_error) / tf.reduce_sum(art_mask)

    #loss = clear_loss
    return clr_metric

In [None]:
def ArtificialMetric(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)    # Was getting an error because of the different types: y_true in the metrics is float64 instead of the normal float32

    real_mask = y_true[:,:,:,1:2]  # 0 for land/clouds, 1 for clear sea
    art_mask = y_true[:,:,:,2:3]  # 0 for land/clouds + artificials, 1 for clear sea
    added_mask = real_mask - art_mask  # 1 only for hidden sea, 0 for land/clouds and visible sea
    y_true = y_true[:,:,:,0:1]  # The true SST values. Obfuscated areas are already converted to 0

    # Calculate the squared error only over artificially clouded areas
    squared_error = tf.square(y_true - y_pred)
    artificial_masked_error = squared_error * added_mask
    # Calculate the mean of the masked errors
    art_metric = tf.reduce_sum(artificial_masked_error) / tf.reduce_sum(added_mask)

    return art_metric

In [None]:
def RMSEMetric(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)

    real_mask = y_true[:,:,:,1:2]  # 0 for land/clouds, 1 for clear sea
    art_mask = y_true[:,:,:,2:3]  # 0 for land/clouds + artificials, 1 for clear sea
    added_mask = real_mask - art_mask  # 1 only for hidden sea, 0 for land/clouds and visible sea
    y_true = y_true[:,:,:,0:1]  # The true SST values. Obfuscated areas are already converted to 0

    # Denormalize the predictions and the true values
    y_pred_denorm = ((y_pred + 1) / 2) * (data_max - data_min) + data_min
    y_true_denorm = ((y_true + 1) / 2) * (data_max - data_min) + data_min

    # Calculate the squared error only over hidden sea
    squared_error = tf.square(y_true_denorm - y_pred_denorm)
    hidden_masked_error = squared_error * added_mask
    # Calculate the mean of the masked errors
    mse_metric = tf.reduce_sum(hidden_masked_error) / tf.reduce_sum(added_mask)

    # Calculate the square root of the mean squared error to get the RMSE
    rmse_loss = tf.sqrt(mse_metric)

    return rmse_loss

In [None]:
# Hyperparameters

epochs=100
batch_size=32

lr = 1e-4

loss = customLoss
metrics = [ClearMetric, ArtificialMetric, RMSEMetric]
early_stop = EarlyStopping(monitor='val_loss', patience=epochs, verbose=1, restore_best_weights=True)

#steps_per_epoch = min(100, len(x_train_d) // batch_size)
# validation_steps = 20
# testing_steps = 20
steps_per_epoch = len(x_train_d) // batch_size
validation_steps = len(x_val_d) // batch_size
testing_steps = len(x_test_d) // batch_size

input_shape = (256, 256, 4)

In [None]:
# Generator function
def generator(batch_size, data_d, data_n, dates_d, dates_n, dayChance=0.5):
    i = 0   # Counter for the dataset. We will use the whole dataset, one batch at a time
    while True:
        batch_x = np.zeros((batch_size, 256, 256, 4))
        batch_y = np.zeros((batch_size, 256, 256, 3))
        batch_dates = []

        #Randomly choose between day and night dataset
        isDay = np.random.rand() < dayChance
        (dataset, date, baseline) = (data_d, dates_d, abs_new2_baseline_d) if isDay else (data_n, dates_n, abs_new2_baseline_n)

        for b in range(batch_size):
            # Choose a random index as the current day, and 3 random indices
            r1, r2, r3= np.random.randint(0, dataset.shape[0], 3)

            # Extract the image and mask from the current day, and the masks from the other days
            image_current = np.nan_to_num(dataset[i], nan=0)
            mask_current = np.isnan(dataset[i])
            mask_r1 = np.isnan(dataset[r1])
            mask_r2 = np.isnan(dataset[r2])
            mask_r3 = np.isnan(dataset[r3])

            # Perform OR operation between masks
            mask_or_r1 = np.logical_or(mask_current, mask_r1)
            mask_or_r2 = np.logical_or(mask_current, mask_r2)
            mask_or_r3 = np.logical_or(mask_current, mask_r3)
            #choose the middle mask
            masks = [mask_or_r1, mask_or_r2, mask_or_r3]
            masks.sort(key=np.sum)
            artificial_mask = masks[1] # The mask with the medium amount of coverage

            # Apply the amplified mask to the current day's image
            image_masked = np.where(artificial_mask, 0, image_current)

            # Convert the current date to a datetime object using pandas
            date_series = pd.to_datetime(date[i], unit='D', origin='unix')
            day_of_year = date_series.dayofyear
            #Before they are added to the batch_dates list, the dates are converted to strings (no time)
            date_series = date_series.strftime('%Y-%m-%d')
            batch_dates.append(date_series) # Append the current date to the batch_dates list

            #avg temp of the current day
            image_masked_nan = np.where(artificial_mask, np.nan, image_current)
            if(np.isnan(image_masked_nan).all()):
                tuned_baseline = baseline[day_of_year - 1]
            else:
                avg_temp_real = np.nanmean(image_masked_nan)
                avg_temp_baseline = np.nanmean(np.where(artificial_mask, np.nan, baseline[day_of_year - 1]))
                tuned_baseline = baseline[day_of_year - 1] + avg_temp_real - avg_temp_baseline  # Adjust the baseline to match the average temperature of the current day
            tuned_baseline = np.where(italy_mask, tuned_baseline, 0)    # Apply the land-sea mask

            # Fix masks before they are used in the loss and metric functions
            artificial_mask = np.logical_not(artificial_mask)  # 1 for clear sea, 0 for land/clouds and artificial clouds
            mask_current = np.logical_not(mask_current) # 1 for clear sea, 0 for land/clouds

            # Increment the index
            i += 1
            if i >= dataset.shape[0] - 1:
                i = 0
            
            # Create batch_x and batch_y
            batch_x[b, ..., 0] = image_masked               #artificially cloudy image
            batch_x[b, ..., 1] = artificial_mask            #artificial mask
            batch_x[b, ..., 2] = italy_mask                 #land-sea mask
            batch_x[b, ..., 3] = tuned_baseline             #tuned baseline

            batch_y[b, ..., 0] = image_current              #real image
            batch_y[b, ..., 1] = mask_current               #real mask
            batch_y[b, ..., 2] = artificial_mask            #artificial mask used for the input
        
        dn_flag = 'day' if isDay else 'night'
        yield batch_x, batch_y, batch_dates, dn_flag

In [None]:
# Create the generators

train_gen = generator(batch_size, x_train_d, x_train_n, dates_train_d, dates_train_n)
val_gen = generator(batch_size, x_val_d, x_val_n, dates_val_d, dates_val_n)
test_gen = generator(batch_size, x_test_d, x_test_n, dates_test_d, dates_test_n)

In [None]:
# Test the generator

x,y,date,dn_flag = next(train_gen)
r = np.random.randint(0, batch_size)    # Choose a random image from the batch

# Plot the image
plt.figure(figsize=(8, 8))
vmin = min(np.min(x[r, :, :, 0]), np.min(y[r, :, :, 0]))
vmax = max(np.max(x[r, :, :, 0]), np.max(y[r, :, :, 0]))

# Plot the x data
plt.subplot(2, 2, 1)
# masked_x = np.where(x[r, :, :, 1], x[r, :, :, 0], np.nan)
# plt.imshow(masked_x, cmap='viridis', vmin=vmin, vmax=vmax)
plt.imshow(x[r, :, :, 0], cmap='viridis', vmin=vmin, vmax=vmax)
plt.title("x_0 (model input)")
plt.colorbar()

# Plot the y data
plt.subplot(2, 2, 2)
# masked_y = np.where(y[r, :, :, 1], y[r, :, :, 0], np.nan)
# plt.imshow(masked_y, cmap='viridis', vmin=vmin, vmax=vmax)
plt.imshow(y[r, :, :, 0], cmap='viridis', vmin=vmin, vmax=vmax)
plt.title("y_0 (ground truth)")
plt.colorbar()

plt.show()


# Information about the data
print(f"Showing day {date[r]}, from the '{dn_flag}' dataset")
#print(np.isnan(x).any())
print("\nx.shape:", x.shape)
print("y.shape:", y.shape)
# print("min of all x:", np.min(x[..., 0]))
# print("max of all x:", np.max(x[..., 0]))
print("min in x:", np.min(x[r, :, :, 0]))
print("max in x:", np.max(x[r, :, :, 0]))

In [None]:
# U-Net model with residual blocks

def ResidualBlock(depth):
    def apply(x):
        input_depth = x.shape[3]    # Get the number of channels from the channels dimension
        if input_depth == depth:    # It's already the desired channel number
            residual = x
        else:                       # Adjust the number of channels with a 1x1 convolution
            residual = Conv2D(depth, kernel_size=1)(x)

        x = BatchNormalization(center=False, scale=False)(x)    
        x = Conv2D(depth, kernel_size=3, padding="same", activation='swish')(x) 
        x = Conv2D(depth, kernel_size=3, padding="same")(x)
        x = Add()([x, residual])
        return x
    
    return apply


def DownBlock(depth, block_depth):
    def apply(x):
        x, skips = x
        for _ in range(block_depth):
            x = ResidualBlock(depth)(x)
            skips.append(x)
        x = AveragePooling2D(pool_size=2)(x)    #downsampling
        return x

    return apply


def UpBlock(depth, block_depth):
    def apply(x):
        x, skips = x
        x = UpSampling2D(size=2, interpolation="bilinear")(x)   #upsampling
        for _ in range(block_depth):
            x = Concatenate()([x, skips.pop()])
            x = ResidualBlock(depth)(x)
        return x

    return apply


def get_Unet(image_size, depths, block_depth):
    input_images = Input(shape=image_size)  #input layer
    
    x = Conv2D(depths[0], kernel_size=1)(input_images)  #reduce the number of channels

    skips = []  #store the skip connections
    
    for depth in depths[:-1]:   #downsampling layers
        x = DownBlock(depth, block_depth)([x, skips])

    for _ in range(block_depth):    #middle layer
        x = ResidualBlock(depths[-1])(x)

    for depth in reversed(depths[:-1]):   #upsampling layers
        x = UpBlock(depth, block_depth)([x, skips])

    x = Conv2D(1, kernel_size=1, kernel_initializer="zeros", name = "output_noise")(x)  #output layer, no activation function
    
    return Model(input_images, outputs=x, name="UNetInpainter")

In [None]:
# Define the model, create it and print the summary
depths = [32, 64, 128, 256, 512]
block_depth = 2

model = get_Unet(input_shape, depths, block_depth)
#model.summary()

In [None]:
# Compile model with custom loss function
opt = Adam(learning_rate=lr)
model.compile(optimizer=opt, loss=loss, metrics=metrics)

In [None]:
# LOAD WEIGHTS
model.load_weights('weights/standard.h5')

1 - Standard

In [None]:
# Plot the sampls where the max error is above a certain degree, and point it in the plot

# Generate and evaluate tot batches
tot = 1 #10
for _ in range(tot):
    # Generate a batch
    x_true, y_true, batch_dates, dn_flag = next(test_gen)
    predictions = model.predict(x_true, verbose=0)

    x_true1 = x_true.copy()
    #x_true1[..., 3] = predictions[..., 0]  # Replace the baseline with the predictions

    avg_temp_real = np.nanmean(x_true[..., 0])
    avg_temp_pred = np.nanmean(np.where(x_true[..., 1], predictions[..., 0], np.nan))    # Inverse of the generator because the mask is inverted
    tuned_pred = predictions[..., 0] + avg_temp_real - avg_temp_pred
    tuned_pred = np.where(italy_mask, tuned_pred, 0)    # Apply the land-sea mask

    x_true1[..., 3] = tuned_pred   # Use the prediction as the baseline
    predictions1 = model.predict(x_true1, verbose=0)

    # Denormalize
    predictions_denorm = ((predictions[..., 0] + 1) / 2) * (data_max - data_min) + data_min
    true_values_denorm = ((y_true[..., 0] + 1) / 2) * (data_max - data_min) + data_min
    predictions1_denorm = ((predictions1[..., 0] + 1) / 2) * (data_max - data_min) + data_min

    nanbaseline = np.where(italy_mask, x_true[..., 3], np.nan)
    baseline_denorm = ((nanbaseline + 1) / 2) * (data_max - data_min) + data_min

    # Get the masks and calculate the errors
    realMask = y_true[..., 1]                                   #real mask
    hiddenMask = np.not_equal(y_true[..., 1], x_true[..., 1])   #hidden sea
    clearMask = x_true[..., 1]                                  #clear sea

    errors = np.where(realMask, np.abs(predictions_denorm - true_values_denorm), np.nan)
    errors1 = np.where(realMask, np.abs(predictions1_denorm - true_values_denorm), np.nan)
    clear_errors_batch = np.where(clearMask, np.abs(predictions_denorm - true_values_denorm), np.nan)
    hidden_errors_batch = np.where(hiddenMask, np.abs(predictions_denorm - true_values_denorm), np.nan)

    # Print the samples where the max error is above a certain degree, and point it in the plot
    treshold = 7
    #iterate over errors and find the ones above the treshold
    for i in range(len(errors)):
        print(f"Day {batch_dates[i]} from the {dn_flag} dataset")
        maxerr = np.nanmax(errors[i])
        maxerr_coords = np.unravel_index(np.nanargmax(errors[i]), errors[i].shape)
        if maxerr > treshold:
            print(f"Error of {np.nanmax(errors[i])} found in coordinates {maxerr_coords}")

        #plot the image
        plt.figure(figsize=(15, 10))

        # vmin = min(np.nanmin(true_values_denorm[i]), np.nanmin(baseline_denorm))
        # vmax = max(np.nanmax(true_values_denorm[i]), np.nanmax(baseline_denorm))
        vmin = min(np.nanmin(true_values_denorm[i]), np.nanmin(predictions_denorm[i]), np.nanmin(baseline_denorm[i]))
        vmax = max(np.nanmax(true_values_denorm[i]), np.nanmax(predictions_denorm[i]), np.nanmax(baseline_denorm[i]))

        plt.subplot(2, 3, 1)
        mask_overlay = np.where(y_true[i, :, :, 1], true_values_denorm[i], np.nan)
        plt.imshow(mask_overlay, cmap='jet', vmin=vmin, vmax=vmax)
        plt.title("denormalized y_0 (ground truth)")
        plt.colorbar()

        plt.subplot(2, 3, 2)
        masked_input = np.where(x_true[i, :, :, 1], true_values_denorm[i], np.nan)
        plt.imshow(masked_input, cmap='jet', vmin=vmin, vmax=vmax)
        plt.title("denormalized x_0 (model input)")
        plt.colorbar()

        plt.subplot(2, 3, 3)
        masked_prediction = np.where(italy_mask, predictions_denorm[i], np.nan)
        #plt.scatter(maxerr_coords[1], maxerr_coords[0], c='magenta', s=10)  #mark the maximum error
        plt.imshow(masked_prediction, cmap='jet', vmin=vmin, vmax=vmax)
        plt.title("denormalized prediction")
        plt.colorbar()

        # plt.subplot(2, 3, 4)
        # plt.imshow(baseline_denorm[i], cmap='viridis', vmin=vmin, vmax=vmax)    #, cmap='Purples')
        # plt.title("denormalized (tuned) baseline")
        # #plt.colorbar()

        plt.subplot(2, 3, 4)
        masked_error = np.where(y_true[i, :, :, 1], errors[i], np.nan)
        #plt.scatter(maxerr_coords[1], maxerr_coords[0], c='magenta', s=10)  #mark the maximum error
        plt.imshow(errors[i], cmap='hot_r')
        plt.title("error")
        plt.colorbar()

        plt.subplot(2, 3, 5)
        masked_prediction1 = np.where(italy_mask, predictions1_denorm[i], np.nan)
        plt.imshow(masked_prediction1, cmap='jet', vmin=vmin, vmax=vmax)
        plt.title("denormalized prediction1")
        plt.colorbar()
        
        plt.subplot(2, 3, 6)
        masked_error1 = np.where(y_true[i, :, :, 1], errors1[1], np.nan)
        plt.imshow(masked_error1, cmap='hot_r')
        plt.title("prediction difference")
        plt.colorbar()

        plt.show()

In [None]:
# COMPUTE ERRORS
print("Evaluating the errors, please wait...")

# Initialize lists to store the errors and the maximum errors
all_errors = []
max_errors = []         #for each element
clear_errors = []
clear_max_errors = []   #for each element always visible
hidden_errors = []
hidden_max_errors = []  #for each element hidden artificially

# Generate and evaluate tot batches
tot = 100
for _ in range(tot):
    # Generate a batch
    x_true, y_true = next(test_gen)
    predictions = model.predict(x_true, verbose=0)

    # Denormalize
    predictions_denorm = ((predictions[..., 0] + 1) / 2) * (data_max - data_min) + data_min
    true_values_denorm = ((y_true[..., 0] + 1) / 2) * (data_max - data_min) + data_min

    # Get the masks and calculate the errors
    realMask = y_true[..., 1]                                   #real mask
    hiddenMask = np.not_equal(y_true[..., 1], x_true[..., 1])   #hidden sea
    clearMask = x_true[..., 1]                                  #clear sea

    errors = np.where(realMask, np.abs(predictions_denorm - true_values_denorm), np.nan)
    clear_errors_batch = np.where(clearMask, np.abs(predictions_denorm - true_values_denorm), np.nan)
    hidden_errors_batch = np.where(hiddenMask, np.abs(predictions_denorm - true_values_denorm), np.nan)

    # Flatten the errors and add to the list
    all_errors.extend(errors.flatten())
    max_errors.append(np.nanmax(errors))
    clear_errors.extend(clear_errors_batch.flatten())
    clear_max_errors.append(np.nanmax(clear_errors_batch))
    hidden_errors.extend(hidden_errors_batch.flatten())
    hidden_max_errors.append(np.nanmax(hidden_errors_batch))

# Convert to numpy array for easier calculations
all_errors = np.array(all_errors)
max_errors = np.array(max_errors)
clear_errors = np.array(clear_errors)
clear_max_errors = np.array(clear_max_errors)
hidden_errors = np.array(hidden_errors)
hidden_max_errors = np.array(hidden_max_errors)

# Calculate the metrics over all errors
avg_error = np.nanmean(all_errors)
max_error = np.nanmax(all_errors)
avg_max_error = np.nanmean(max_errors)
var_error = np.nanvar(all_errors)
rmse_error = np.sqrt(np.nanmean(all_errors**2))
# Calculate the metrics over clear errors
avg_clear_error = np.nanmean(clear_errors)
max_clear_error = np.nanmax(clear_errors)
avg_max_clear_error = np.nanmean(clear_max_errors)
var_clear_error = np.nanvar(clear_errors)
rmse_clear_error = np.sqrt(np.nanmean(clear_errors**2))
# Calculate the metrics over hidden errors
avg_hidden_error = np.nanmean(hidden_errors)
max_hidden_error = np.nanmax(hidden_errors)
avg_max_hidden_error = np.nanmean(hidden_max_errors)
var_hidden_error = np.nanvar(hidden_errors)
rmse_hidden_error = np.sqrt(np.nanmean(hidden_errors**2))

# Print the metrics calculated over all elements
print(f"\nAverage error over all elements:", avg_error)
print(f"Average maximum error:", avg_max_error, ", and maximum error:", max_error)
print(f"Variance of errors over all elements:", var_error)
print(f"RMSE over all elements:", rmse_error)

# Print the metrics calculated over clear elements
print(f"\nAverage error over clear elements:", avg_clear_error)
print(f"Average maximum error over clear elements:", avg_max_clear_error, ", and maximum error over clear elements:", max_clear_error)
print(f"Variance of errors over clear elements:", var_clear_error)
print(f"RMSE over clear elements:", rmse_clear_error)

# Print the metrics calculated over hidden elements
print(f"\nAverage error over hidden elements:", avg_hidden_error)
print(f"Average maximum error over hidden elements:", avg_max_hidden_error, ", and maximum error over hidden elements:", max_hidden_error)
print(f"Variance of errors over hidden elements:", var_hidden_error)
print(f"RMSE over hidden elements:", rmse_hidden_error)