<img src="../assets/header_notebook.png" />
<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>ESA - Black Sea Deoxygenation Emulator</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

In [None]:
# ----------
# Librairies
# ----------
import os
import sys
import cv2
import xarray
import random
import dawgz
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Pytorch
import torch
import torch.nn as nn
import torch.optim as optim

# Dawgz (jobs //)
from dawgz import job, schedule

# -------------------
# Librairies (Custom)
# -------------------
# Adding path to source folder to load custom modules
sys.path.insert(1, '../src/debs/')
sys.path.insert(1, '../scripts/')

# Loading libraries
from tools                import *
from metrics              import *
from dataset              import BlackSea_Dataset
from dataloader           import BlackSea_Dataloader

# -------
# Jupyter
# -------
%matplotlib inline
plt.rcParams.update({'font.size': 13})

# Making sure modules are reloaded when modified
%reload_ext autoreload
%autoreload 2

# Moving to the .py directory
%cd ../src/debs/

<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>Scripts</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

In [None]:
# Analyzing the data (1):
%run script_distribution.py --start_year        0 \
                            --end_year          0 \
                            --start_month       1 \
                            --end_month         2 \
                            --dawgz         False

In [None]:
# Analyzing the data (2):
%run script_evolution.py --start_year        0 \
                         --end_year          0 \
                         --start_month       1 \
                         --end_month         2 \
                         --dawgz         False

In [None]:
# Training a neural network:
%run script_training.py  --start_year                 0 \
                         --end_year                   0 \
                         --start_month                0 \
                         --end_month                  1 \
                         --inputs           temperature \
                         --problem           regression \
                         --windows_input              1 \
                         --windows_output             1 \
                         --architecture            FCNN \
                         --batch_size                64 \
                         --epochs                    10 \
                         --kernel_size                3 \
                         --dawgz                  False

<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>Playground</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

In [None]:
# -----------------
#    Parameters
# -----------------
#
# Time window
month_starting = 1
month_ending   = 1
year_starting  = 0
year_ending    = 0

# Maximum depth observed for oxygen, what is left is masked [m] (Note: To observe only the continental shelf set it to ~200m).
depth_max_oxygen = 200

# ------------------
#  Loading the data
# ------------------
# Loading the different datasets
Dataset_phy = BlackSea_Dataset(year_start  = year_starting,
                               year_end    = year_ending,
                               month_start = month_starting,
                               month_end   = month_ending,
                               variable    = "grid_T")

Dataset_bio = BlackSea_Dataset(year_start  = year_starting,
                               year_end    = year_ending,
                               month_start = month_starting,
                               month_end   = month_ending,
                               variable    = "ptrc_T")

# Loading the different field values
data_temperature   = Dataset_phy.get_data(variable = "temperature", type = "surface", depth = None)
data_salinity      = Dataset_phy.get_data(variable = "salinity",    type = "surface", depth = None)
data_chlorophyll   = Dataset_bio.get_data(variable = "chlorophyll", type = "surface", depth = None)
data_kshort        = Dataset_bio.get_data(variable = "k_short",     type = "surface", depth = None)
data_klong         = Dataset_bio.get_data(variable = "k_long",      type = "surface", depth = None)
data_oxygen        = Dataset_bio.get_data(variable = "oxygen",      type = "bottom" , depth = depth_max_oxygen)

# Loading the black sea mask
bs_mask             = Dataset_phy.get_mask(depth = None)
bs_mask_with_depth  = Dataset_phy.get_mask(depth = depth_max_oxygen)

# --------------------
#  Preparing the data
# --------------------
# Loading the dataloader
BSD_loader = BlackSea_Dataloader(x = [data_temperature],
                                 y = data_oxygen,
                           bs_mask = bs_mask,
                bs_mask_with_depth = bs_mask_with_depth,
                              mode = "regression",
                        window_inp = 1,
                        window_out = 2)

# Retrieving the datasets
ds_validation = BSD_loader.get_dataloader("validation")
ds_train      = BSD_loader.get_dataloader("train")
ds_test       = BSD_loader.get_dataloader("test")

In [None]:
# -------------------------------------------------------
#
#        |
#       / \
#      / _ \                  ESA - PROJECT
#     |.o '.|
#     |'._.'|          BLACK SEA DEOXYGENATION EMULATOR
#     |     |
#   ,'|  |  |`.             BY VICTOR MANGELEER
#  /  |  |  |  \
#  |,-'--|--'-.|                2023-2024
#
#
# -------------------------------------------------------
#
# Documentation
# -------------
# A script to train a neural network to become a oxygen concentration forecaster in the Black Sea.
#
#   Dawgz = False : compute the distributions over a given time period given by the user as arguments
#
#   Dawgz = True  : compute the distributions over all the possible time periods
#
import time
import wandb
import argparse
import matplotlib.pyplot as plt

# Pytorch
import torch
import torch.nn as nn
import torch.optim as optim

# Custom libraries
from dataset              import BlackSea_Dataset
from dataloader           import BlackSea_Dataloader
from metrics              import BlackSea_Metrics
from neural_networks      import FCNN
from tools                import progressBar, to_device

# Dawgz library (used to parallelized the jobs)
from dawgz import job, schedule

# Combinatorics
from itertools import combinations, product

# ---------------------------------------------------------------------
#
#                                  DAWGZ
#
# ---------------------------------------------------------------------
#
# -------------
# Possibilities
# -------------
# Creation of all the inputs combinations
input_list = ["temperature"]

# Generate all combinations
all_combinations = []
for r in range(1, len(input_list) + 1):
    all_combinations.extend(combinations(input_list, r))

# Convert combinations to lists
all_combinations = [list(combination) for combination in all_combinations]

# Storing all the information
arguments = {
    'month_start'     : [0],
    'month_end'       : [1],
    'year_start'      : [0],
    'year_end'        : [0],
    'Inputs'          : all_combinations,
    'Problem'         : ["regression", "classification"],
    'Window (Inputs)' : [1],
    'Window (Output)' : [1],
    'Depth'           : [200],
    'Architecture'    : ["FCNN"],
    'Learning Rate'   : [0.001],
    'Kernel Size'     : [3],
    'Batch Size'      : [64],
    'Epochs'          : [3]
}

# Generate all combinations
param_combinations = list(product(*arguments.values()))

# Create a list of dictionaries
param_dicts = [dict(zip(arguments.keys(), combo)) for combo in param_combinations]

In [None]:
def main(**kwargs):

    # ------------------------------------------
    #               Initialization
    # ------------------------------------------
    #
    # ------- Arguments -------
    start_month     = kwargs['month_start']
    end_month       = kwargs['month_end']
    start_year      = kwargs['year_start']
    end_year        = kwargs['year_end']
    inputs          = kwargs['Inputs']
    problem         = kwargs['Problem']
    windows_inputs  = kwargs['Window (Inputs)']
    windows_outputs = kwargs['Window (Output)']
    depth           = kwargs['Depth']
    architecture    = kwargs['Architecture']
    learning_rate   = kwargs['Learning Rate']
    kernel_size     = kwargs['Kernel Size']
    batch_size      = kwargs['Batch Size']
    nb_epochs       = kwargs['Epochs']

    # ------- Data -------
    Dataset_phy = BlackSea_Dataset(year_start = start_year, year_end = end_year, month_start = start_month,  month_end = end_month, variable = "grid_T")
    Dataset_bio = BlackSea_Dataset(year_start = start_year, year_end = end_year, month_start = start_month,  month_end = end_month, variable = "ptrc_T")

    # Loading the inputs
    input_datasets = list()
    for inp in inputs:
        if inp in ["temperature", "salinity"]:
            input_datasets.append(Dataset_phy.get_data(variable = inp, type = "surface", depth = None))
        if inp in ["chlorophyll", "kshort", "klong"]:
            input_datasets.append(Dataset_bio.get_data(variable = inp, type = "surface", depth = None))

    # Loading the output
    data_oxygen = Dataset_bio.get_data(variable = "oxygen", type = "bottom", depth = depth)

    # Loading the black sea mask
    bs_mask             = Dataset_phy.get_mask(depth = None)
    bs_mask_with_depth  = Dataset_phy.get_mask(depth = depth)

    # ------- Preprocessing -------
    BSD_loader = BlackSea_Dataloader(x = input_datasets,
                                     y = data_oxygen,
                               bs_mask = bs_mask,
                    bs_mask_with_depth = bs_mask_with_depth,
                                  mode = problem,
                            window_inp = windows_inputs,
                            window_out = windows_outputs,
                      hypoxia_treshold = 63,
                         datasets_size = [0.6, 0.3],
                                  seed = 2701)

    # Retreiving the individual dataloader
    dataset_train      = BSD_loader.get_dataloader("train",      batch_size = batch_size)
    dataset_validation = BSD_loader.get_dataloader("validation", batch_size = batch_size)
    dataset_test       = BSD_loader.get_dataloader("test",       batch_size = batch_size)

    # Normalized oxygen treshold
    norm_oxy = BSD_loader.get_normalized_deoxygenation_treshold()

    # ------------------------------------------
    #                   Training
    # ------------------------------------------
    #
    # ------- WandB -------
    wandb.init(project = "esa-blacksea-deoxygenation-emulator-V3", config = kwargs)

    # Check if GPU is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialization of neural network and pushing it to device (GPU)
    neural_net = FCNN(inputs = len(input_datasets), outputs =  windows_outputs, problem = problem, kernel_size = kernel_size)
    neural_net.to(device)

    # Initialization of the optimizer and the loss function
    optimizer  = optim.Adam(neural_net.parameters(), lr = learning_rate)
    criterion  = nn.MSELoss() if problem == "regression" else nn.BCELoss()

    # Used to compute time left
    epoch_time = 0.0

    # Starting training !
    for epoch in range(nb_epochs):

        # Information over terminal (1)
        print("\n") if epoch == 0 else print("")
        print("Epoch : ", epoch + 1, "/", nb_epochs, "\n")

        # Used to approximate time left for current epoch and in total
        start      = time.time()

        # Used to store instantaneous loss and compute the average per batch (AOB) training loss
        training_loss = 0.0
        batch_steps   = 0

        # Used to compute our metrics
        metrics_tool = BlackSea_Metrics(mode = problem,
                                        mask = bs_mask_with_depth,
                                        treshold = norm_oxy,
                                        number_of_batches = len(dataset_validation))

        # ----- TRAINING -----
        for x, y in dataset_train:

            # Moving data to the correct device
            x, y = to_device(x, device), to_device(y, device)

            # Forward pass, i.e. prediction of the neural network
            pred = neural_net.forward(x)

            # Determine the indices of the valid samples, i.e. inside the observed region (-1 is the masked region)
            indices = torch.where(y != -1)

            # Computing the loss
            loss = criterion(pred[indices], y[indices])

            # Information over terminal (2)
            print("Loss (T) = ", loss.detach().item())

            # Sending to wandDB
            wandb.log({"Loss (T)": loss.detach().item()})

            # Accumulating the loss
            training_loss += loss.detach().item()

            # Reseting the gradients
            optimizer.zero_grad()

            # Backward pass
            loss.backward()

            # Optimizing the parameters
            optimizer.step()

            # Updating epoch information
            batch_steps += 1

            break

        # Information over terminal (3)
        print("Loss (Training, Averaged over batch): ", training_loss / batch_steps)

        # Sending the loss to wandDB
        wandb.log({"Loss (T, AOB): ": training_loss / batch_steps})

        # ----- VALIDATION -----
        with torch.no_grad():

            # Used to store instantaneous loss and compute the average per batch (AOB) training loss
            validation_loss = 0.0
            batch_steps = 0

            for x, y in dataset_validation:

                # Moving data to the correct device
                x, y = to_device(x, device), to_device(y, device)

                # Forward pass, i.e. prediction of the neural network
                pred = neural_net.forward(x)

                # Determine the indices of the valid samples, i.e. inside the observed region (-1 is the masked region)
                indices = torch.where(y != -1)

                # Computing the loss
                loss = criterion(pred[indices], y[indices])

                # Information over terminal (4)
                print("Loss (V) = ", loss.detach().item())

                # Sending the loss to wandDB the loss
                wandb.log({"Loss (V)": loss.detach().item()})

                # Accumulating the loss
                validation_loss += loss.detach().item()

                # Used to compute the metrics
                metrics_tool.compute_metrics(y_pred = pred.cpu(), y_true = y.cpu())

                # Visual inspection (Only on the first batch)
                metrics_tool.compute_plots(y_pred = pred.cpu(), y_true = y.cpu()) if batch_steps == 0 else None

                # Updating epoch information
                batch_steps += 1

                break

            # Information over terminal (5)
            print("Loss (Validation, Averaged over batch): ", validation_loss / batch_steps)

            # Sending more information to wandDB
            wandb.log({"Loss (V, AOB): ": validation_loss / batch_steps})
            wandb.log({"Epochs : ": nb_epochs - epoch})

            # ---------- WandB (Metrics & Plots) ----------
            #
            # Getting results of each metric (averaged over each batch)
            results = metrics_tool.get_results()
            results_name = metrics_tool.get_names_metrics()

            # Sending these results to wandDB
            for d, day_results in enumerate(results):
                for i, result in enumerate(day_results):

                    # Current name of metric with corresponding day
                    m_name = results_name[i] + " D(" + str(d) + ")"

                    # Logging
                    wandb.log({m_name : result})

            # Getting the plots
            plots = metrics_tool.get_plots()

            # Sending the plots to wandDB
            for p_info in plots:

                    # Ease of comprehension
                    p_fig = p_info[0]
                    p_nam = p_info[1]

                    # Logging
                    wandb.log({p_nam : wandb.Image(p_fig)})

        # Updating timing
        epoch_time = time.time() - start

    # Finishing the run
    wandb.finish()

##############################################################################################################
main(**param_dicts[1])

<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>Testing</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

In [None]:
#       TESTING
# --------------------
#
# Number of input variables
var_inputs = 2

# Window for the input
win_in = 2

# Window for the oxygen
win_out = 3

# Number of days
days = 14

# --------------------
#      GENERATING
# --------------------
# Generating fake data
fake_data_physical_variables = generateFakeDataset(resolution = 128, number_of_variables = var_inputs, number_of_samples=days)
fake_data_physical_oxygen    = generateFakeDataset(resolution = 128, number_of_variables = 1, number_of_samples=days, oxygen = True)

# Creation of the dataloaders
BSD_loader_fake_spatial = BlackSea_Dataloader(x = fake_data_physical_variables,
                                              y = fake_data_physical_oxygen[0],
                                           mask = np.ones(shape = (258, 258)),
                                           mode = "spatial",
                                     resolution = 128,
                                        window = win_in,
                                    window_out = win_out)

BSD_loader_fake_temporal = BlackSea_Dataloader(x = fake_data_physical_variables,
                                               y = fake_data_physical_oxygen[0],
                                            mask = np.ones(shape = (258, 258)),
                                            mode = "temporal",
                                      resolution = 128,
                                          window = win_in,
                                      window_out = win_out)


# OK
"""
for i in range(fake_data_physical_variables[0].shape[0]):

    plt.figure(figsize=(5, 5))
    plt.imshow(fake_data_physical_variables[1][i, :, :])
"""

# --------------------------------------------------------------------------------
#                                       SPATIAL
# --------------------------------------------------------------------------------
"""
for x, y in BSD_loader_fake_spatial.get_dataloader("train"):

    # Initial shapes
    print("Input shape: ", len(fake_data_physical_variables), fake_data_physical_variables[0].shape, "\nOutput shape: ", fake_data_physical_oxygen[0].shape)

    # Shapes (dataloader)
    print("Input shape: ", x.shape, "\nOutput shape: ", y.shape)

    # The total number of samples is
    # [Number of timesteps - number of input days (window) - number of output days (window_outgen) ] * number of regions
    #
    # Tests
    #
    # Number of variables
    assert x.shape[1] == var_inputs * win_in

    # Number of outputs
    assert y.shape[1] == win_out

    # Number of samples (must be divided by 2 for the validation and test)
    assert x.shape[0] == (fake_data_physical_variables[0].shape[0] - win_in - win_out) * int(256/128)

    # Checking that I have all the timesteps
    #
    # Looping over all the time steps
    for i in range(x.shape[0]):

        # Showing as a subplots the input and output pairs
        plt.figure(figsize=(14, 14))

        for j in range(var_inputs * win_in):
            plt.subplot(1, var_inputs * win_in + win_out, j+1)

            # Removing labels and tickz
            plt.xticks([])
            plt.yticks([])
            plt.grid(False)
            plt.imshow(x[i, j, :, :])

        for j in range(win_out):
            plt.subplot(1, var_inputs * win_in + win_out, var_inputs * win_in + j+1)
            # Removing labels and tickz
            plt.xticks([])
            plt.yticks([])
            plt.grid(False)

            plt.imshow(y[i, j, :, :])

    # Idea check the validity using a trehsold to create a 1 0 matrix for comparison (otherwise it will  bug since you have normalized the data)

"""
# --------------------------------------------------------------------------------
#                                       SPATIAL
# --------------------------------------------------------------------------------
for x, y in BSD_loader_fake_temporal.get_dataloader("train"):

    # Initial shapes
    print("Input shape: ", len(fake_data_physical_variables), fake_data_physical_variables[0].shape, "\nOutput shape: ", fake_data_physical_oxygen[0].shape)

    # Shapes (dataloader)
    print("Input shape: ", x.shape, "\nOutput shape: ", y.shape)

    # Checking that I have all the regions
    #
    # Looping over all the time steps
    for i in range(x.shape[0]):

        # Showing as a subplots the input and output pairs
        plt.figure(figsize=(14, 14))

        for j in range(var_inputs * win_in):
            plt.subplot(1, var_inputs * win_in + win_out, j+1)

            # Removing labels and tickz
            plt.xticks([])
            plt.yticks([])
            plt.grid(False)
            plt.imshow(x[i, j, :, :])

        for j in range(win_out):
            plt.subplot(1, var_inputs * win_in + win_out, var_inputs * win_in + j+1)
            # Removing labels and tickz
            plt.xticks([])
            plt.yticks([])
            plt.grid(False)

            plt.imshow(y[i, j, :, :])
