<img src="../assets/header_notebook.png" />
<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>ESA - Black Sea Deoxygenation Emulator</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

In [None]:
# ----------
# Librairies
# ----------
import os
import sys
import xarray
import random
import dawgz
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Pytorch
import torch
import torch.nn as nn
import torch.optim as optim

# Plots
from matplotlib.animation import FuncAnimation

# Dawgz (jobs //)
from dawgz import job, after, ensure, schedule

# -------------------
# Librairies (Custom)
# -------------------
# Adding path to source folder to load custom modules
sys.path.insert(1, '../src/')
sys.path.insert(1, '../scripts/')

# Loading libraries
from dataset              import BlackSea_Dataset
from dataset_evolution    import BlackSea_Dataset_Evolution
from dataset_distribution import BlackSea_Dataset_Distribution

# -------
# Jupyter
# -------
%matplotlib inline
plt.rcParams.update({'font.size': 13})

# Making sure modules are reloaded when modified
%reload_ext autoreload
%autoreload 2

# Moving to the .py directory
%cd ../src/

<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>Scripts</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

In [None]:
# Analyzing the data (1):
%run script_distribution.py --start_year        0 \
                            --end_year          0 \
                            --start_month       1 \
                            --end_month         2 \
                            --dawgz         False

In [None]:
# Analyzing the data (2):
%run script_evolution.py --start_year        0 \
                         --end_year          0 \
                         --start_month       1 \
                         --end_month         2 \
                         --dawgz         False

<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>Playground</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

In [None]:
# -----------------
#    Parameters
# -----------------
#
# Dataset time window
month_starting = 1
month_ending   = 2
year_starting  = 0
year_ending    = 0

# Maximum depth observed for oxygen, what is left is masked [m]. To observe only the continental shelf set it to ~120m
depth_max_oxygen = None

# ------------------
#  Loading the data
# ------------------
# Dataset handlers !
Dataset_physical = BlackSea_Dataset(year_start = year_starting, year_end = year_ending, month_start = month_starting,  month_end = month_ending, variable = "grid_T")
Dataset_bio      = BlackSea_Dataset(year_start = year_starting, year_end = year_ending, month_start = month_starting,  month_end = month_ending, variable = "ptrc_T")

# Loading the different field values
data_temperature   = Dataset_physical.get_temperature()
data_salinity      = Dataset_physical.get_salinity()
data_oxygen        = Dataset_bio.get_oxygen_bottom(depth = depth_max_oxygen)
data_chlorophyll   = Dataset_bio.get_chlorophyll()
data_kshort        = Dataset_bio.get_light_attenuation_coefficient_short_waves()
data_klong         = Dataset_bio.get_light_attenuation_coefficient_long_waves()

# Loading the black sea mask
BS_mask = Dataset_physical.get_blacksea_mask()

In [None]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class BlackSea_Dataloader(Dataset):
    r"""A simple dataloader for Black Sea dataset"""

    def __init__(self, x: list, y: np.array, bs_mask, mode: str, resolution: int,  window: int = 1, datasets_size = [0.5, 0.3], seed: int = 42):

        # Concatenation of the inputs and output (t, x, y) into (t, variable, x, y), i.e. variable = 0 is the target
        x = np.stack([y] + x, axis = 1)

        # Masking the land (NaNs) with a fixed value (-10)
        x[:, :, bs_mask == 0] = -1

        # Removing dimensions to be multiple of 2 for ease of use (original shape:(258, 578), target:(256, 576) obtained by removing last 2-rows/bottoms)
        x = x[:, :, :-2, :-2]

        # Current shape of the input (1)
        t, v, x_res, y_res = x.shape

        # Security
        assert mode in ["spatial", "temporal"], f"ERROR (BlackSea_Dataloader) Mode must be either 'spatial', 'temporal' ({mode})"
        assert resolution % 2 == 0,             f"ERROR (BlackSea_Dataloader) Resolution must be a multiple of 2 ({resolution})"
        assert resolution < x_res/2,            f"ERROR (BlackSea_Dataloader) Resolution must be smaller than half the input resolution ({resolution} < {x_res/2})"
        assert window <= int(t/3 - 1),          f"ERROR (BlackSea_Dataloader) Window must be smaller than a third of the input time resolution ({window} < {int(t/3 - 1)})"

        # Concatenation of the inputs (t, x, y) into (t, variable, x, y)
        x = np.stack(x, axis = 0)

        # Removing dimensions to be multiple of 2 for ease of use (original shape:(258, 578), target:(256, 576) obtained by removing last 2-rows/bottoms)
        x = x[:, :, :-2, :-2]

        # Current shape of the input (2)
        t, v, x_res, y_res = x.shape

        # Number of patches along the x-, y- dimensions and total number of possible patches
        nb_patches_x, nb_patches_y, total_patches = int(x_res/resolution), int(y_res/resolution), int(x_res/resolution) * int(y_res/resolution)

        # Extracting patches from the input of of a given resolution
        x = [x[:, :, i * resolution : (i + 1) * resolution, j * resolution : (j + 1) * resolution] for i in range(nb_patches_x) for j in range(nb_patches_y)]

        # Concatenation of the inputs (t, variable, x, y) into (t, variables, number of patches, resolution, resolution)
        x = np.stack(x, axis = 2)

        # Separation of the x and y data (to avoid a mess with timeseries)
        y = x[:, 0,  :, :, :]
        x = x[:, 1:, :, :, :]

        # Input - Creation of the time series, i.e. (index, variable(s)_{t, t + window}, number of patches, resolution, resolution)
        x = np.stack([x[i : i + window, :, :, :, :] for i in range(t - window)], axis = 0).reshape(t - window, (v - 1) * window, total_patches, resolution, resolution)

        # Output - Creation of the output pair (only the last value of time series is important), i.e. (index, variable(s)_{t + window}, number of patches, resolution, resolution)
        y = np.stack([y[i + window,  :, :, :]    for i in range(t - window)], axis = 0)

        # Used to merge timeseries and patches dimensions,
        def merge_timeseries_and_patches(data: np.array):

            # If dimensions is 5, the input has been given
            if len(data.shape) == 5:

                # Swaping axes (needed to concatenate along patch dimensino during reshaping), i.e. (t, v, p, r, r) to (t, p, v, r, r)
                data = np.swapaxes(data, 1, 2)

                # Retrieving all the dimensions for reshaping
                t, p, v, res_x, res_y = data.shape

                # Merging timeseries and patches dimensions
                return data.reshape(t * p, v, res_x, res_y)

            # If dimension is 4, the output has been given
            else:

                # Retrieving all the dimensions for reshaping
                t, p, res_x, res_y = data.shape

                # Merging timeseries and patches dimensions
                data = data.reshape(t * p, res_x, res_y)

                # Adding a dummy dimensions to simulate channels
                return np.expand_dims(data, axis = 1)

        # ------------------------------------------------
        #                   TEMPORAL MODE
        # ------------------------------------------------
        if mode == "temporal":

            # Computing size of the training, validation and test sets
            training_size, validation_size = int(t * datasets_size[0]), int(t * datasets_size[1])

            # Splitting the dataset into training, validation and test sets while not taking overlapping timeseries
            x_train, x_validation, x_test = x[:training_size - window, :, :, :, :], x[training_size:training_size + validation_size - window, :, :, :, :], x[training_size + validation_size:, :, :, :, :]
            y_train, y_validation, y_test = y[:training_size - window,    :, :, :], y[training_size:training_size + validation_size - window,    :, :, :], y[training_size + validation_size:,    :, :, :]

            # It is merging time !
            x_train      = merge_timeseries_and_patches(x_train)
            x_validation = merge_timeseries_and_patches(x_validation)
            x_test       = merge_timeseries_and_patches(x_test)
            y_train      = merge_timeseries_and_patches(y_train)
            y_validation = merge_timeseries_and_patches(y_validation)
            y_test       = merge_timeseries_and_patches(y_test)

        # ------------------------------------------------
        #                   SPATIAL MODE
        # ------------------------------------------------
        if mode == "spatial":

            # Computing size of the training, validation and test sets
            training_size, validation_size = int(total_patches * datasets_size[0]), int(total_patches * datasets_size[1])

            # Used to randomly permute the patches !
            rand_patches = np.random.permutation(total_patches)

            # Randomly shuffling along the patches axis
            x = x[:, :, rand_patches, :, :]
            y = y[:,    rand_patches, :, :]

            # Splitting the dataset into training, validation and test sets
            x_train, x_validation, x_test = x[:, :, :training_size, :, :], x[:, :, training_size:training_size + validation_size, :, :], x[:, :, training_size + validation_size:, :, :]
            y_train, y_validation, y_test = y[:,    :training_size, :, :], y[:,    training_size:training_size + validation_size, :, :], y[:,    training_size + validation_size:, :, :]

            # It is merging time !
            x_train      = merge_timeseries_and_patches(x_train)
            x_validation = merge_timeseries_and_patches(x_validation)
            x_test       = merge_timeseries_and_patches(x_test)
            y_train      = merge_timeseries_and_patches(y_train)
            y_validation = merge_timeseries_and_patches(y_validation)
            y_test       = merge_timeseries_and_patches(y_test)

        # ------------------------------------------------
        #                      PROCESSING
        # ------------------------------------------------
        # Used to replace the NaNs, standardize the values in the ocean and add the black sea mask on the channel dimension
        def process_data(data: np.array, black_sea_mask: np.array):

            # Creation of the mask
            mask = data[:, :] != -1

            # Standardizing the values in the ocean (i.e. not land)
            data[mask] = (data[mask] - np.mean(data[mask])) / np.std(data[mask])

            # Replacing NaNs by -1
            data[np.isnan(data)] = -1

            return data

        # Replace Nans (land) by -1
        self.x_train      = process_data(x_train, bs_mask)
        self.x_validation = process_data(x_validation, bs_mask)
        self.x_test       = process_data(x_test, bs_mask)
        self.y_train      = process_data(y_train, bs_mask)
        self.y_validation = process_data(y_validation, bs_mask)
        self.y_test       = process_data(y_test, bs_mask)

    def __getitem__(self, index, train = True):
        return (self.x_train[index, :, :, :], self.y_train[index, :, :, :]) if train else (self.x_validation[index, :, :, :], self.y_validation[index, :, :, :])

    def __len__(self, train = True):
        return self.x_train.shape[0] if train else self.x_validation.shape[0]


# ----------------------------------------------------------------

BS_dataset = BlackSea_Dataloader(x = [data_temperature, data_chlorophyll], y = data_oxygen, bs_mask = BS_mask, mode = "spatial", resolution = 64, window = 10)


In [None]:
class FullyCNN(nn.Sequential):
    def __init__(self, inputs, targets):

        # Dimension of input and output data
        n_in  = inputs
        n_out = targets
        padding_5 = 2
        padding_3 = 1

        #-----------------------------------------------------------------------------------
        #                                   Architecture
        #-----------------------------------------------------------------------------------
        block1 = self._make_subblock(nn.Conv2d(n_in, 256, 5, padding = padding_5))
        block2 = self._make_subblock(nn.Conv2d(256,  128, 5, padding = padding_5))
        block3 = self._make_subblock(nn.Conv2d(128,   32, 3, padding = padding_3))
        block4 = self._make_subblock(nn.Conv2d(32,    32, 3, padding = padding_3))
        block5 = self._make_subblock(nn.Conv2d(32,    32, 3, padding = padding_3))
        block6 = self._make_subblock(nn.Conv2d(32,    32, 3, padding = padding_3))
        block7 = self._make_subblock(nn.Conv2d(32,    32, 3, padding = padding_3))
        conv8  =                     nn.Conv2d(32, n_out, 3, padding = padding_3)

        # Combining everything together
        super().__init__(*block1, *block2, *block3, *block4, *block5, *block6, *block7, conv8)

    def _make_subblock(self, conv):
        return [conv, nn.ReLU(), nn.BatchNorm2d(conv.out_channels)]

    #-----------------------------------------------------------------------------------
    #                                     Forward
    #-----------------------------------------------------------------------------------
    def forward(self, x):

        # Final prediction
        return super().forward(x)

    def count_parameters(self,):
        print("Model parameters  =", sum(p.numel() for p in self.parameters() if p.requires_grad))

In [None]:
wind = 1
# ----------------------
#
# ----------------------
bs_load = BlackSea_Dataloader(x = [data_temperature, data_chlorophyll, data_salinity, data_kshort, data_klong], y = data_oxygen, bs_mask = BS_mask, mode = "temporal", resolution = 64, window = wind)
bs_loader = DataLoader(bs_load, batch_size = 64)

## NN
FCNN = FullyCNN(wind * 5, 1)
criterion = nn.MSELoss()
optimizer = optim.Adam(FCNN.parameters(), lr = 0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'min', factor = 0.5, patience = 5, threshold = 1e-2)

# Used to compute the average loss value over the all epoch
train_losses = []
validation_losses = []

# Going through epochs
for epoch in range(15):
    for x, y in bs_loader:

        pred = FCNN.forward(x)

        # Computing the loss
        loss = criterion(pred, y)

        # Adding the loss
        train_losses.append(loss.detach().item())

        # Reseting the gradients
        optimizer.zero_grad()

        # Backward pass
        loss.backward()

        # Optimizing the parameters
        optimizer.step()

        print(F"Train{epoch}:", train_losses[-1])


plt.figure()
plt.plot(train_losses)
plt.plot(validation_losses)