<img src="../assets/header_notebook.png" />
<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>ESA - Black Sea Deoxygenation Emulator</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

In [None]:
# ----------
# Librairies
# ----------
import os
import sys
import cv2
import dawgz
import wandb
import xarray
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Pytorch
import torch
import torch.nn as nn
import torch.optim as optim

# Dawgz (jobs //)
from dawgz import job, schedule

# -------------------
# Librairies (Custom)
# -------------------
# Adding path to source folder to load custom modules
sys.path.append('/src')
sys.path.append('/src/debs/')
sys.path.insert(1, '/src/debs/')
sys.path.insert(1, '/scripts/')

# Moving to the .py directory
%cd src/debs/

## Loading libraries
from metrics     import *
from dataset     import *
from dataloader  import *
from tools       import *
from losses      import *

# -------
# Jupyter
# -------
%matplotlib inline
plt.rcParams.update({'font.size': 13})

# Making sure modules are reloaded when modified
%reload_ext autoreload
%autoreload 2

<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>Scripts</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

In [None]:
# Generating the preprocessed data (normalized)
%run __generate_n.py

In [None]:
# Generating the preprocessed data (standardized)
%run __generate_s.py

In [None]:
# Analyzing the data distributions
%run __distributions.py

In [5]:
# Training a neural network (using a given configuration)
%run __training.py --config local

-------------------------------------------------------
                                                       
                    ESA - PROJECT                      
                                                       
          BLACK SEA DEOXYGENATION EMULATOR             
                                                       
-------------------------------------------------------
                                                       
- Project : ESA - Black Sea Deoxygenation Emulator - Test 2 (Local)
- Month (Starting) : 6
- Month (Ending) : 9
- Year (Starting) : 1980
- Year (Ending) : 1980
- Hypoxia Treshold : 63
- Depth : 150
- Inputs : ['temperature', 'salinity', 'chlorophyll', 'kshort', 'klong']
- Window (Inputs) : 14
- Window (Output) : 1
- Window (Transformation) : 1
- Architecture : UNET
- Scaling : 8
- Kernel Size : 3
- Datasets Size : [0.6, 0.3]
- Loss Weights : [1, 1]
- Learning Rate : 0.001
- Batch Size : 16
- Epochs : 10

-----------------
Emulator Training
------

<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>Data</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

In [None]:
# Loading the different inputs
BSD_dataset = BlackSea_Dataset(year_start  = 2020,
                               year_end    = 2021,
                               month_start = 1,
                               month_end   = 12)

data_temperature   = BSD_dataset.get_data(variable = "temperature")
data_salinity      = BSD_dataset.get_data(variable = "salinity")
data_chlorophyll   = BSD_dataset.get_data(variable = "chlorophyll")
data_kshort        = BSD_dataset.get_data(variable = "kshort")
data_klong         = BSD_dataset.get_data(variable = "klong")
data_oxygen        = BSD_dataset.get_data(variable = "oxygen")
mask               = BSD_dataset.get_mask(False)
maskCS             = BSD_dataset.get_mask(True)

def generate_animation(data : np.array, mask : np.array, title : str, limits : list = [0, 1]):
    """Generate an animation of the data, i.e. used to inspect the data"""
    import os
    import cv2
    import imageio
    import numpy as np
    import matplotlib.pyplot as plt
    from PIL import Image
    from matplotlib.patches import Rectangle
    from datetime import datetime, timedelta

    # Converting to booleans for easier plot
    mask = mask > 0.5

    # Displaying information over terminal
    print(f"Data shape: {data.shape}")

    def index_to_date(i):
        """Used to convert an index to a date, i.e. useful for the title of the plot"""

        # Define the start date
        start_date = datetime(2020, 1, 1)

        # Calculate the offset from the start date
        delta = timedelta(days=i)

        # Calculate the corresponding date
        result_date = start_date + delta

        # Format the date as a string in the "YYYY-MM-DD" format
        return result_date.strftime('%Y-%m-%d')

    # Creation of a folder if it does not exist
    if not os.path.exists(f'../../analysis/images/{title}'):
        os.mkdir(f'../../analysis/images/{title}/')

    # Displaying information over terminal
    print(f"Generating images for {title}...")

    # Creating of the plots
    for i in range(data.shape[0]):
        plt.figure(figsize = (12, 8))
        plt.imshow(data[i], cmap='viridis', vmin = limits[0], vmax = limits[1])
        plt.colorbar(fraction = 0.021)
        plt.imshow(mask, cmap='gray', alpha=0.025)
        date_string = index_to_date(i)
        ax = plt.gca()
        box = Rectangle((0.85, 0.9), 0.46, 0.16, edgecolor='black', facecolor='black', transform=ax.transAxes)
        plt.annotate(date_string, xy=(0.98, 0.96), xycoords='axes fraction',
                    horizontalalignment='right', verticalalignment='top',
                    fontsize=12, color='white')
        ax.add_patch(box)
        plt.grid(True, alpha=0.1)
        plt.savefig(f'../../analysis/images/{title}/{title}_{i}.png')
        plt.close()

    # Displaying information over terminal
    print("Loading images...")

    # Creation of the paths and opening the plots
    paths = [f"../../analysis/images/{title}/{title}_{i}.png" for i in range(data.shape[0])]
    image_array = []
    for my_file in paths:
        image = Image.open(my_file)
        image_array.append(image)

    # Displaying information over terminal
    print("Generating the video...")

    # Generating the video
    with imageio.get_writer(f'../../analysis/images/{title}.gif', mode='I') as writer:
        for filename in paths:
            image = imageio.imread(filename)
            writer.append_data(image)

# Generating the animations
generate_animation(data_oxygen,      maskCS,      "oxygen", limits = [0, 1])
generate_animation(data_temperature,   mask, "temperature", limits = [0, 1])
generate_animation(data_salinity,      mask,    "salinity", limits = [0, 1])
generate_animation(data_chlorophyll,   mask, "chlorophyll", limits = [0, 0.25])
generate_animation(data_kshort,        mask,      "kshort", limits = [0, 0.1])
generate_animation(data_klong,         mask,       "klong", limits = [0, 0.1])


In [None]:
# Loading the different inputs
BSD_dataset = BlackSea_Dataset(year_start  = 2010,
                               year_end    = 2020,
                               month_start = 1,
                               month_end   = 12)

data_oxygen        = BSD_dataset.get_data(variable = "oxygen")
mask               = BSD_dataset.get_mask(False)
maskCS             = BSD_dataset.get_mask(True)

# Extracting the training set
training_set = data_oxygen[: 365 * 6]

# Computing the mean
mean_6years = np.mean(training_set, axis = 0)

# Stores the mean each year
mean_each_year = [np.mean(data_oxygen[i * 365 : (i + 1) * 365], axis = 0) for i in range(6)]

def save_plot(data : np.array, mask : np.array, title : str):
    """Save the plot of the mean"""
    plt.figure(figsize = (12, 8))
    plt.imshow(data, cmap='viridis', vmin = 0, vmax = 1)
    plt.colorbar(fraction = 0.021)
    plt.imshow(mask, cmap='gray', alpha = 0.025)
    date_string = f"Mean : {title}"
    ax = plt.gca()
    box = Rectangle((0.80, 0.90), 0.56, 0.16, edgecolor='black', facecolor='black', transform=ax.transAxes)
    plt.annotate(date_string, xy=(0.98, 0.965), xycoords='axes fraction',
                horizontalalignment='right', verticalalignment='top',
                fontsize=12, color='white')
    ax.add_patch(box)
    plt.grid(True, alpha=0.1)
    plt.savefig(f'../../analysis/means/means_{title}.png')

# Saving the plot for all the years
save_plot(mean_6years, maskCS, "Training")

# Saving the plot for each year
for i, n in enumerate(["2010", "2011", "2012", "2013", "2014", "2015"]):
    save_plot(mean_each_year[i], maskCS, n)

# Normalized deoxygenation treshold
hypox_tresh = xarray.open_dataset(BSD_dataset.paths[0])["HYPON"].data.item()

# Computing the average state of the region (Hypoxia or not)
mean_6years_state = mean_6years < hypox_tresh

# Stores the mean each year
mean_each_year_state = [mean < hypox_tresh for mean in mean_each_year]

def save_plot_state(data : np.array, mask : np.array, title : str):
    """Save the plot of the mean"""
    plt.figure(figsize = (12, 8))
    plt.imshow(data, cmap='viridis', vmin = 0, vmax = 1)
    cbar = plt.colorbar(fraction=0.021)
    cbar.set_ticks([0, 1])
    cbar.set_ticklabels(['Oxygenated', 'Hypoxia'])
    plt.imshow(mask, cmap='gray', alpha = 0.25)
    date_string = f"Mean : {title}"
    ax = plt.gca()
    box = Rectangle((0.80, 0.90), 0.56, 0.16, edgecolor='black', facecolor='black', transform=ax.transAxes)
    plt.annotate(date_string, xy=(0.98, 0.965), xycoords='axes fraction',
                horizontalalignment='right', verticalalignment='top',
                fontsize=12, color='white')
    ax.add_patch(box)
    plt.grid(True, alpha=0.1)
    plt.savefig(f'../../analysis/means/means_{title}.png')

# Saving the plot for all the years
save_plot_state(mean_6years_state, maskCS, "Training (H)")

# Saving the plot for each year
for i, n in enumerate(["2010 (H)", "2011 (H)", "2012 (H)", "2013 (H)", "2014 (H)", "2015 (H)"]):
    save_plot_state(mean_each_year_state[i], maskCS, n)


In [None]:
# Sending everything to wandb
wandb.init(project = "ESA - Repport")

# Sending the animations
wandb.log({"Oxygen":              wandb.Video("../../analysis/images/oxygen.gif", fps = 1),
           "Temperature":         wandb.Video("../../analysis/images/temperature.gif", fps = 1),
           "Salinity":            wandb.Video("../../analysis/images/salinity.gif", fps = 1),
           "Chlorophyll":         wandb.Video("../../analysis/images/chlorophyll.gif", fps = 1),
           "Reflectance (Short)": wandb.Video("../../analysis/images/kshort.gif", fps = 1),
           "Reflectance (Long)" : wandb.Video("../../analysis/images/klong.gif", fps = 1)})

# Sending the mean concentrations
wandb.log({"Concentration (Mean, 2010-2015)": wandb.Image(f"../../analysis/means/means_Training.png"),
           "Concentration (Mean, 2010)":      wandb.Image(f"../../analysis/means/means_2010.png"),
           "Concentration (Mean, 2011)":      wandb.Image(f"../../analysis/means/means_2011.png"),
           "Concentration (Mean, 2012)":      wandb.Image(f"../../analysis/means/means_2012.png"),
           "Concentration (Mean, 2013)":      wandb.Image(f"../../analysis/means/means_2013.png"),
           "Concentration (Mean, 2014)":      wandb.Image(f"../../analysis/means/means_2014.png"),
           "Concentration (Mean, 2015)":      wandb.Image(f"../../analysis/means/means_2015.png")})

# Sending the mean states
wandb.log({"State (Mean, 2010-2015)": wandb.Image(f"../../analysis/means/means_Training_(H).png"),
           "State (Mean, 2010)":      wandb.Image(f"../../analysis/means/means_2010_(H).png"),
           "State (Mean, 2011)":      wandb.Image(f"../../analysis/means/means_2011_(H).png"),
           "State (Mean, 2012)":      wandb.Image(f"../../analysis/means/means_2012_(H).png"),
           "State (Mean, 2013)":      wandb.Image(f"../../analysis/means/means_2013_(H).png"),
           "State (Mean, 2014)":      wandb.Image(f"../../analysis/means/means_2014_(H).png"),
           "State (Mean, 2015)":      wandb.Image(f"../../analysis/means/means_2015_(H).png")})

# Closing the wandb session
wandb.finish()


<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>Playground</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

In [None]:
# -----------------
#    Parameters
# -----------------
#
# Time window
month_starting = 6
month_ending   = 8
year_starting  = 1980
year_ending    = 1980

# Window size
windows_inputs = 1

# ------------------
#  Loading the data
# ------------------
# Loading the different datasets
BSD_dataset = BlackSea_Dataset(year_start  = year_starting,
                               year_end    = year_ending,
                               month_start = month_starting,
                               month_end   = month_ending)

# Loading the days ID (used to give temporal information to the model)
days_ID = BSD_dataset.get_days()

# Loading the different inputs
data_temperature   = BSD_dataset.get_data(variable = "temperature")
data_salinity      = BSD_dataset.get_data(variable = "salinity")
data_chlorophyll   = BSD_dataset.get_data(variable = "chlorophyll")
data_kshort        = BSD_dataset.get_data(variable = "kshort")
data_klong         = BSD_dataset.get_data(variable = "klong")

# Loading the output
data_oxygen = BSD_dataset.get_data(variable = "oxygen")

# Loading spatial information
bathy = BSD_dataset.get_depth(unit = "meter")
mesh  = BSD_dataset.get_mesh(x = 256, y = 576)

# Hypoxia treshold
hypox_tresh = xarray.open_dataset(BSD_dataset.paths[0])["HYPON"].data.item()

# Loading the black sea masks
bs_mask             = BSD_dataset.get_mask(continental_shelf = False)
bs_mask_with_depth  = BSD_dataset.get_mask(continental_shelf = True)
bs_mask_complete    = get_complete_mask(data_oxygen, hypox_tresh, bs_mask_with_depth)

# Creation of the dataloader
BSD_loader = BlackSea_Dataloader(x = [data_temperature, data_salinity, data_chlorophyll, data_kshort, data_klong],
                                 y = data_oxygen,
                                 t = days_ID,
                              mesh = mesh,
                              mask = bs_mask,
                        bathymetry = bathy,
                        window_inp = windows_inputs)

# Retrieving the datasets
ds_train      = BSD_loader.get_dataloader("train")
ds_validation = BSD_loader.get_dataloader("validation")
ds_test       = BSD_loader.get_dataloader("test")

In [None]:
# -------------------------------------------------------
#
#        |
#       / \
#      / _ \                  ESA - PROJECT
#     |.o '.|
#     |'._.'|          BLACK SEA DEOXYGENATION EMULATOR
#     |     |
#   ,'|  |  |`.             BY VICTOR MANGELEER
#  /  |  |  |  \
#  |,-'--|--'-.|                2023-2024
#
#
# -------------------------------------------------------
#
# Documentation
# -------------
# A neural network definition to be used as temporal encoder
#
# Pytorch
import torch.nn as nn

class FCNN(nn.Sequential):
    r"""A fully convolutional neural network"""

    def __init__(self, inputs: int, kernel_size : int = 3, scaling : int = 1):
        super(FCNN, self).__init__()

        # Initialization (predicting mean and standard deviation)
        self.n_in    = inputs
        self.n_out   = 2
        self.padding = kernel_size // 2

        # ------ Architecture ------
        #
        # Main Layers
        self.conv_init           = nn.Conv2d(self.n_in    , 256 * scaling, kernel_size, padding = self.padding)
        self.conv_intermediate_1 = nn.Conv2d(256 * scaling, 128 * scaling, kernel_size, padding = self.padding)
        self.conv_intermediate_2 = nn.Conv2d(128 * scaling,  64 * scaling, kernel_size, padding = self.padding)
        self.conv_intermediate_3 = nn.Conv2d( 64 * scaling,  32 * scaling, kernel_size, padding = self.padding)
        self.conv_final          = nn.Conv2d( 32 * scaling,    self.n_out, kernel_size, padding = self.padding)

        # Activation function
        self.activation = nn.GELU()

        # Normalization
        self.normalization_init           = nn.BatchNorm2d(self.conv_init.out_channels)
        self.normalization_intermediate_1 = nn.BatchNorm2d(self.conv_intermediate_1.out_channels)
        self.normalization_intermediate_2 = nn.BatchNorm2d(self.conv_intermediate_2.out_channels)
        self.normalization_intermediate_3 = nn.BatchNorm2d(self.conv_intermediate_3.out_channels)

    def forward(self, x):

        # Forward pass
        x = self.normalization_init(self.activation(self.conv_init(x)))
        x = self.normalization_intermediate_1(self.activation(self.conv_intermediate_1(x)))
        x = self.normalization_intermediate_2(self.activation(self.conv_intermediate_2(x)))
        x = self.normalization_intermediate_3(self.activation(self.conv_intermediate_3(x)))
        x = self.conv_final(x)

        # Retrieiving dimensions (Ease of comprehension)
        b, c, x_res, y_res = x.shape

        # Reshaping the output, i.e. (samples, days, values, x, y)
        return x.reshape(b, self.n_out // 2, 2, x_res, y_res)

    def count_parameters(self,):
        r"""Determines the number of trainable parameters in the model"""
        return int(sum(p.numel() for p in self.parameters() if p.requires_grad))


In [None]:
# Neural Network stuff
neural_network = FCNN(inputs  = 184,
                      kernel_size = 5,
                      scaling = 3)

optimizer = optim.Adam(neural_network.parameters(), lr = 0.001)

# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

neural_network.to(device)

In [None]:
# ------------------------------------------------
show = True

for epoch in range(15):

    metrics_tool = BlackSea_Metrics(mode = "regression",
                                    mask = bs_mask_with_depth,
                           mask_complete = bs_mask_complete,
                                treshold = norm_oxy,
                       number_of_samples = BSD_loader.get_number_of_samples("validation"))

    for x, y in ds_train:

        x, y = x.to(device), y.to(device)
        prediction = neural_network(x)
        loss_training = compute_loss(y_pred = prediction, y_true = y, mask = bs_mask_with_depth, problem = "regression", device = "cpu", kwargs = {})
        print(f"E{epoch} - Loss (Training):", loss_training.item())
        optimizer.zero_grad()
        loss_training.backward()
        optimizer.step()

        # Cleaning
        del x, y, prediction, loss_training
        torch.cuda.empty_cache()
        break


    with torch.no_grad():

        # Stores all the predictions for the metrics (plots)
        prediction_all = None

        for x, y in ds_validation:

            # Making prediction
            x, y = x.to(device), y.to(device)
            prediction = neural_network(x)
            loss_validation = compute_loss(y_pred = prediction, y_true = y, mask = bs_mask_with_depth, problem = "regression", device = "cpu", kwargs = {})
            print(f"E{epoch} - Loss (Validation):", loss_validation.item())
            x, y, prediction = x.to("cpu"), y.to("cpu"), prediction.to("cpu")


            # Plotting mean against ground truth in a subplot
            if show:

                # Highlighting hypoxic areas
                y_hyp = ( y          < norm_oxy ) * 1.0
                p_hyp = ( prediction < norm_oxy ) * 1.0

                # Hiding non-obserable areas
                p_hyp[:,:,:, bs_mask_with_depth == 0] = torch.nan
                y_hyp[:,:,:, bs_mask_with_depth == 0] = torch.nan


                plt.figure(figsize = (20, 20))
                plt.subplot(1, 3, 1)
                plt.imshow(y_hyp[0, 0, 0])
                plt.subplot(1, 3, 2)
                plt.imshow(p_hyp[0, 0, 0])
                plt.subplot(1, 3, 3)
                plt.imshow(y_hyp[0, 0, 0] - p_hyp[0, 0, 0])
                plt.setp(plt.gcf().get_axes(), xticks = [], yticks = [])
                plt.subplot(1, 3, 1).set_title("Ground Truth", fontsize = 6)
                plt.subplot(1, 3, 2).set_title("Prediction", fontsize = 6)
                plt.subplot(1, 3, 3).set_title("Difference", fontsize = 6)
                plt.show()

                prediction[:,:,:, bs_mask_with_depth == 0] = torch.nan
                y[:,:,:, bs_mask_with_depth == 0]          = torch.nan

                plt.figure(figsize = (20, 20))
                plt.subplot(1, 3, 1)
                plt.imshow(y[0, 0, 0])
                plt.subplot(1, 3, 2)
                plt.imshow(prediction[0, 0, 0])
                plt.subplot(1, 3, 3)
                plt.imshow(torch.exp(prediction[0, 0, 1]/2))
                plt.setp(plt.gcf().get_axes(), xticks = [], yticks = [])
                plt.subplot(1, 3, 1).set_title("Ground Truth", fontsize = 6)
                plt.subplot(1, 3, 2).set_title("Prediction (Mean)", fontsize = 6)
                plt.subplot(1, 3, 3).set_title("Prediction (Std)", fontsize = 6)
                plt.show()


            # Concatenating all the predictions
            prediction_all = torch.cat((prediction_all, prediction), dim = 0) if prediction_all is not None else prediction

            del x, y, prediction
            torch.cuda.empty_cache()

            break

    """
    # Sampling random data for comparison
    # Metrics
    y_vall_all = torch.from_numpy(BSD_loader.y_validation)
    #metrics_tool.compute_metrics(y_pred = prediction_all, y_true = y_vall_all)
    metrics_tool.compute_plots_comparison_regression(y_pred = prediction_all, y_true = y_vall_all)


    #metrics_tool.compute_plots(  y_pred = prediction_all, y_true = y_vall_all)

    # Getting the results
    if show:
        results, results_name = metrics_tool.get_results()
        for r, n in zip(results[0], results_name):
            print(n, " : ", r)
        print("\n")
    """
