<img src="../assets/notebook.png" />
<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>Generative Models For Spatial-Based Bottom <br><br> Hypoxia Forecasting In The Black Sea</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

Do not forget to activate the **environnement esa_diffusion**

In [None]:
# ---------- Librairies ----------
import os
import sys
import dawgz
import wandb
import xarray
import random
import numpy             as np
import pandas            as pd
import seaborn           as sns
import matplotlib.pyplot as plt

# Specific
from tqdm import tqdm
from einops import rearrange, reduce, repeat

# Pytorch
import torch
import torch.nn as nn
import torch.optim as optim

# Dawgz (jobs //)
from dawgz import job, schedule

# ---------- Librairies (Custom) ----------
#
# Adding path to source folder to load custom modules
sys.path.append('/src')
sys.path.append('/src/debs/')
sys.path.insert(1, '/src/debs/')
sys.path.insert(1, '/scripts/')

# Moving to the .py directory
%cd src/debs/

## Loading libraries
from dataset     import *
from dataloader  import *
from training    import *
from metrics     import *
from tools       import *
from unet        import *

# ---------- Jupyter ----------
%matplotlib inline
plt.rcParams.update({'font.size': 13})

# Making sure modules are reloaded when modified
%reload_ext autoreload
%autoreload 2

<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>Scripts</center>
    </b>
</p>
<hr style="color:#5A7D9F;">


The configurations are available in the **configs** folder.

In [None]:
# Training a Neural Network
%run __training.py --config local

<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>Playground</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

In [None]:
# Custom Paremeters
kwargs = {
    "Project"                : "ESA - Notebook (Diffusion)",
    "Mode"                   : "disabled",
    "Window (Inputs)"        : 1,
    "Window (Outputs)"       : 10,
    "Diffusion Steps"        : 5,
    "Diffusion Scheduler"    : 0.0125,
    "Diffusion Variance"     : 0.005,
    "Frequencies"            : 32,
    "Scaling"                : 1,
    "Learning Rate"          : 0.0001,
    "Batch Size"             : 16,
    "Epochs"                 : 1,
    'Number of Workers'      : 2,
    'Results (Epoch)'        : 1,
    'Results (Trajectories)' : 10,
    'Model Saving'           : False,
}

In [None]:
# -------------—---------
#     Initialization
# -------------—---------
#
# Information over terminal (1)
project_title(kwargs)

# Checking if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Fixing random seed for reproducibility
np.random.seed(2701)
torch.manual_seed(2701)

# Loading configuration
project        = kwargs['Project']
mode           = kwargs['Mode']
window_input   = kwargs['Window (Inputs)']
window_output  = kwargs['Window (Outputs)']
diff_steps     = kwargs['Diffusion Steps']
diff_scheduler = kwargs['Diffusion Scheduler']
diff_variance  = kwargs['Diffusion Variance']
scaling        = kwargs['Scaling']
frequencies    = kwargs['Frequencies']
learning_rate  = kwargs['Learning Rate']
batch_size     = kwargs['Batch Size']
nb_epochs      = kwargs['Epochs']
num_workers    = kwargs['Number of Workers']
model_saving   = kwargs['Model Saving']
results_epoch  = kwargs['Results (Epoch)']
results_traj   = kwargs['Results (Trajectories)']

# -------------—---------
#          Data
# -----------------------
#
# Loading preprocessed datasets
dataset_train      = BlackSea_Dataset("Test")
# dataset_validation = BlackSea_Dataset("Test")

# Loading other information
black_sea_mesh       = dataset_train.get_mesh()
black_sea_mask       = dataset_train.get_mask(continental_shelf = False)
black_sea_mask_cs    = dataset_train.get_mask(continental_shelf = True)
black_sea_bathymetry = dataset_train.get_depth(unit = "meter")

# Used to detect the presence of hypoxia events
hypoxia_treshold_standardized = dataset_train.get_treshold(standardized = True)

# Creation of the dataloaders
dataloader_train = BlackSea_Dataloader(dataset_train,
                                       window_input,
                                       window_output,
                                       frequencies,
                                       batch_size,
                                       num_workers,
                                       black_sea_mesh,
                                       black_sea_mask,
                                       black_sea_mask_cs,
                                       black_sea_bathymetry,
                                       random = True).get_dataloader()

dataloader_valid = BlackSea_Dataloader_Diffusion(dataset_train,
                                                 window_input,
                                                 window_output,
                                                 frequencies,
                                                 12,
                                                 num_workers,
                                                 black_sea_mesh,
                                                 black_sea_mask,
                                                 black_sea_mask_cs,
                                                 black_sea_bathymetry,
                                                 random = False).get_dataloader()

In [None]:
# Architecture

def time_encoding(time: torch.Tensor, frequencies:int = 128):
    r"""Encoding the time using the "Attention is all you need" paper encoding scheme"""

    # Security
    with torch.no_grad():

        # Encoding functions
        sinusoidal   = lambda time, frequency_index, frequencies: torch.sin(time / (10000 ** (frequency_index / frequencies)))
        cosinusoidal = lambda time, frequency_index, frequencies: torch.cos(time / (10000 ** (frequency_index / frequencies)))

        # Storing the encoding
        encoded_time = torch.zeros(time.shape[0], time.shape[1], frequencies * 2)

        # Mapping time to its encoding
        for b_index, b in enumerate(time):
            for t_index, t in enumerate(b):

                # Stores the current encoding
                encoding = list()

                # Computing the encoding, i.e. alternating between sinusoidal and cosinusoidal encoding
                for i in range(frequencies):
                    encoding += [sinusoidal(t, i, frequencies), cosinusoidal(t, i, frequencies)]

                # Conversion to torch tensor and storing the encoding
                encoded_time[b_index, t_index, :] =  torch.FloatTensor(encoding).clone()

        return encoded_time

class LayerNormalization(nn.Module):
    r"""Custom Layer Normalization module"""

    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.dim = dim

    def forward(self, x: torch.Tensor):
        var, mean = torch.var_mean(x, dim = self.dim, keepdim = True)
        return (x - mean)/torch.sqrt(var + self.eps)

class TimeResidual_Block(nn.Module):
    r"""A time residual block for UNET"""

    def __init__(self, input_channels: int, frequencies: int):
        super(TimeResidual_Block, self).__init__()

        # Initializations
        self.frequencies   = frequencies
        self.activation    = nn.SiLU()
        self.normalization = LayerNormalization(dim = 1)
        self.variance      = torch.sqrt(torch.tensor(2))

        # Temporal Projection on the channels
        self.time_projection = nn.Linear(in_features = self.frequencies * 2, out_features = input_channels, bias = False)

        # Convolutions
        self.conv1 = nn.Conv2d(in_channels  = input_channels,
                               out_channels = input_channels,
                               kernel_size  = 3,
                               stride       = 1,
                               padding      = 1)

        self.conv2 = nn.Conv2d(in_channels  = input_channels,
                               out_channels = input_channels,
                               kernel_size  = 3,
                               stride       = 1,
                               padding      = 1)

    def forward(self, x, time):

        # -------------------
        #        Time
        # -------------------
        # 1. Initial information
        b, c, x_res, y_res = x.shape

        # 3. Temporal Projection
        encoded_time = self.time_projection(time)
        encoded_time = self.activation(encoded_time)

        # 4. Reshaping the time encoding
        encoded_time = encoded_time[:, :, None, None]

        # 5. Creating the grids
        encoded_time = encoded_time.expand(-1, -1, x_res, y_res)

        # -------------------
        #        Spatial
        # -------------------
        # 1. Adding temporal information (broadcasting)
        x_residual = x + encoded_time

        # 2. Normalization
        x_residual = self.normalization(x_residual)

        # 3. Convolution (1)
        x_residual = self.conv1(x_residual)

        # 4. Activation
        x_residual = self.activation(x_residual)

        # 5. Convolution (2)
        x_residual = self.conv2(x_residual)

        # 6. Adding the residual
        x = x + x_residual

        # 7. Keeping unit variance
        return x / self.variance

    def count_parameters(self,):
        r"""Determines the number of trainable parameters in the model"""
        return int(sum(p.numel() for p in self.parameters() if p.requires_grad))

class TimeResidual_UNET(nn.Module):
    r"""A time residual UNET for time series forecasting"""

    def __init__(self, input_channels: int, output_channels: int, frequencies: int, scaling: int = 1):
        super(TimeResidual_UNET, self).__init__()

        # Initializations
        self.frequencies      = frequencies
        self.input_channels   = input_channels
        self.output_channels  = output_channels

        # 1. Input (lifting)
        self.input_conv = nn.Conv2d(in_channels = self.input_channels, out_channels = 32 * scaling, kernel_size = 3, stride = 1, padding = 1)

        # 2. Downsampling
        #
        # Time Residual Blocks (1)
        self.downsample_11_residuals = TimeResidual_Block(input_channels  = 32 * scaling,     frequencies = self.frequencies)
        self.downsample_12_residuals = TimeResidual_Block(input_channels  = 32 * scaling,     frequencies = self.frequencies)
        self.downsample_21_residuals = TimeResidual_Block(input_channels  = 32 * scaling * 2, frequencies = self.frequencies)
        self.downsample_22_residuals = TimeResidual_Block(input_channels  = 32 * scaling * 2, frequencies = self.frequencies)
        self.downsample_31_residuals = TimeResidual_Block(input_channels  = 32 * scaling * 4, frequencies = self.frequencies)
        self.downsample_32_residuals = TimeResidual_Block(input_channels  = 32 * scaling * 4, frequencies = self.frequencies)
        self.downsample_41_residuals = TimeResidual_Block(input_channels  = 32 * scaling * 8, frequencies = self.frequencies)
        self.downsample_42_residuals = TimeResidual_Block(input_channels  = 32 * scaling * 8, frequencies = self.frequencies)

        # Convolutions (downsampling)
        self.downsample_1_conv = nn.Conv2d(in_channels = 32 * scaling,     out_channels = 32 * scaling * 2, kernel_size = 2, stride = 2)
        self.downsample_2_conv = nn.Conv2d(in_channels = 32 * scaling * 2, out_channels = 32 * scaling * 4, kernel_size = 2, stride = 2)
        self.downsample_3_conv = nn.Conv2d(in_channels = 32 * scaling * 4, out_channels = 32 * scaling * 8, kernel_size = 2, stride = 2)

        # 3. Upsampling
        #
        # Used for upsampling instead of transposed convolutions
        self.upsample = nn.Upsample(scale_factor = (2, 2))

        # Convolutions (projection)
        self.projection_1 = nn.Conv2d(in_channels = 32 * scaling * (8 + 4), out_channels = 32 * scaling * 4, kernel_size = 3, padding = 1)
        self.projection_2 = nn.Conv2d(in_channels = 32 * scaling * (4 + 2), out_channels = 32 * scaling * 2, kernel_size = 3, padding = 1)
        self.projection_3 = nn.Conv2d(in_channels = 32 * scaling * (2 + 1), out_channels = 32 * scaling    , kernel_size = 3, padding = 1)

        # Time Residual Blocks (2)
        self.upsample_11_residuals = TimeResidual_Block(input_channels = 32 * scaling * 4, frequencies = self.frequencies)
        self.upsample_12_residuals = TimeResidual_Block(input_channels = 32 * scaling * 4, frequencies = self.frequencies)
        self.upsample_21_residuals = TimeResidual_Block(input_channels = 32 * scaling * 2, frequencies = self.frequencies)
        self.upsample_22_residuals = TimeResidual_Block(input_channels = 32 * scaling * 2, frequencies = self.frequencies)
        self.upsample_31_residuals = TimeResidual_Block(input_channels = 32 * scaling    , frequencies = self.frequencies)
        self.upsample_32_residuals = TimeResidual_Block(input_channels = 32 * scaling    , frequencies = self.frequencies)

        # 4. Output (We use a linear to mix accross channels, a convolution mix spatially and introduce bias at the corners)
        self.output_linear = nn.Linear(in_features = 32 * scaling, out_features = self.output_channels, bias = False)

        # Normalization
        self.normalization = LayerNormalization(dim = 1)

    def forward(self, x, time):

        # 1. Lifting
        x = self.input_conv(x)

        # 2. Downsampling
        x = self.downsample_11_residuals(x, time)
        x = self.downsample_12_residuals(x, time)
        x1 = self.downsample_1_conv(x)
        x1 = self.downsample_21_residuals(x1, time)
        x1 = self.downsample_22_residuals(x1, time)

        x2 = self.downsample_2_conv(x1)
        x2 = self.downsample_31_residuals(x2, time)
        x2 = self.downsample_32_residuals(x2, time)

        x3 = self.downsample_3_conv(x2)
        x3 = self.downsample_41_residuals(x3, time)
        x3 = self.downsample_42_residuals(x3, time)
        x3 = self.normalization(x3)

        # 3. Upsampling
        x3 = self.upsample(x3)

        x2 = torch.cat([x3, x2], dim = 1)
        x2 = self.projection_1(x2)
        x2 = self.upsample_11_residuals(x2, time)
        x2 = self.upsample_12_residuals(x2, time)
        x2 = self.normalization(x2)
        x2 = self.upsample(x2)

        x1 = torch.cat([x2, x1], dim = 1)
        x1 = self.projection_2(x1)
        x1 = self.upsample_21_residuals(x1, time)
        x1 = self.upsample_22_residuals(x1, time)
        x1 = self.normalization(x1)
        x1 = self.upsample(x1)

        x = torch.cat([x1, x], dim = 1)
        x = self.projection_3(x)
        x = self.upsample_31_residuals(x, time)
        x = self.upsample_32_residuals(x, time)

        # 4. Output
        x = self.output_linear(torch.permute(x, (0, 2, 3, 1)))

        # 5. Adding separate channels for mean and log(var)
        return torch.permute(x, (0, 3, 1, 2))

    def count_parameters(self,):
        r"""Determines the number of trainable parameters in the model"""
        return int(sum(p.numel() for p in self.parameters() if p.requires_grad))

class Diffusion_UNET(nn.Module):
    r"""A diffusion UNET for time series forecasting"""

    def __init__(self, window_input: int,
                      window_output:int,
                    diffusion_steps: int,
                diffusion_scheduler: float = 0.01511,
                 diffusion_variance: float = 0.01,
                            scaling: int = 1,
                        frequencies: int = 32,
                             device: str = 'cpu'):
        super(Diffusion_UNET, self).__init__()

        # Initialization
        self.diffusion_steps     = diffusion_steps
        self.diffusion_scheduler = diffusion_scheduler
        self.diffusion_variance  = diffusion_variance
        self.frequencies         = frequencies
        self.device              = device

        # ---- Pre-Calculation ----
        #
        # Diffusion steps and their encoding
        steps_t            = torch.arange(1, self.diffusion_steps + 1, dtype = torch.float32)
        self.encoded_steps = time_encoding(steps_t[:, None], self.frequencies)[:, 0]

        # Constants
        betas                      = torch.ones((self.diffusion_steps, 1), dtype = torch.float32) * diffusion_scheduler
        alphas                     = torch.pow(1 - diffusion_scheduler, steps_t)[:, None]
        self.sqrt_alphas           = torch.sqrt(alphas)
        self.sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
        self.latent_constant_zt    = 1 / (torch.sqrt(1 - betas))
        self.latent_constant_noise = betas / (torch.sqrt(1 - alphas) * torch.sqrt(1 - betas))

        # Pushing to the device
        self.latent_constant_zt    = self.latent_constant_zt.to(self.device)
        self.latent_constant_noise = self.latent_constant_noise.to(self.device)

        # ---- Model ----
        #
        # Number of inputs (mesh (2), bathymetry (1), time (3), conditioning (4 * win) and the number of forecasted days (wout))
        nb_inputs = 3 + 3 + 4 * window_input + window_output

        # Model
        self.model = TimeResidual_UNET(input_channels = nb_inputs, output_channels = window_output, frequencies = self.frequencies, scaling = scaling)

    def parrelize(self, number_gpus: int):
        self.model = torch.nn.parallel.DataParallel(self.model, device_ids= list(range(number_gpus)), dim= 0)

    def count_parameters(self,):
        r"""Determines the number of trainable parameters in the model"""
        return self.model.count_parameters()

    def predict(self, z, diffusion_steps):
        return self.model(z, self.encoded_steps[diffusion_steps[:,0]].to(self.device))

    def generate_latent(self, x, noise, diffusion_steps):
        """Used to generate the latent variable z_t given the input x"""

        # Extracting constants
        sqrt_alphas           = self.sqrt_alphas[diffusion_steps[:, 0]][:, :, None, None].expand(-1, -1, *x.shape[2:])
        sqrt_one_minus_alphas = self.sqrt_one_minus_alphas[diffusion_steps[:, 0]][:, :, None, None].expand(-1, -1, *x.shape[2:])

        # Computing the latent variable z_t
        return sqrt_alphas * x + sqrt_one_minus_alphas * noise

    def generate_samples(self, x, conditioning, number_trajectories: int = 3):
        """Given an input and its conditioning, generate multiple samples"""

        # Libraries
        from tqdm import tqdm

        # Making sure the model is in evaluation mode
        with torch.no_grad():

            # Stores the generated samples
            x_generated = list()

            # Generating multiple trajectories
            for n in tqdm(range(number_trajectories)):

                # Sampling the initial noise
                zt = torch.normal(0, 1, x.shape, device = self.device)

                # Reverse process
                for t in range(self.diffusion_steps - 1, 0, -1):

                    # Genereting encoded timesteps
                    diffusion_steps = torch.ones(x.shape[0], 1, dtype=torch.int64) * t

                    # Removing the noise
                    zt_hat = self.latent_constant_zt[t] * zt - self.latent_constant_noise[t] * self.predict(torch.cat([conditioning, zt], dim = 1), diffusion_steps)

                    # Generating noise
                    noise = torch.normal(0, 1, zt_hat.shape).to(self.device)

                    # Adding a bit of noise for stochasticity but not on the last step
                    zt = zt_hat + self.diffusion_variance * noise if t > 1 else zt_hat

                # Adding the final sample
                x_generated.append(zt)

            # Returning the generated samples
            return torch.stack(x_generated, dim = 2)


# ---------------
#
# Initialization
neural_net = Diffusion_UNET(window_input, window_output, diff_steps, diff_scheduler, diff_variance, scaling, frequencies, device).to(device)
num_gpus   = torch.cuda.device_count()
neural_net.parrelize(num_gpus)

"""
# Computing Metrics
for c, _, x in dataloader_valid:

    # Pushing to device
    x, c = x.to(device), c.to(device)

    # Generating conditionnal samples
    x = neural_net.generate_samples(x = x, conditioning = c, number_trajectories = 10)

    # Computing Metrics

    break
"""

In [None]:
# -------------—--------------------
#     Neural Network & Training
# -------------—--------------------
#
# Initialization
neural_net = Diffusion_UNET(window_input, window_output, diff_steps, diff_scheduler, diff_variance, scaling, frequencies, device).to(device)

# Training Parameters
optimizer  = optim.Adam(neural_net.parameters(), lr = learning_rate)
scheduler  = LinearLR(optimizer, start_factor = 0.95, total_iters = nb_epochs)

# Information about the model
num_gpus  = torch.cuda.device_count()
nn_params = neural_net.count_parameters()

# Deploying the model on multiple GPUs
neural_net.parrelize(num_gpus)

# Displaying information over the terminal
print("Total number of parameters: ", nn_params/1e6, "M")
print("Available GPUs: ", num_gpus)

In [None]:
# WandB (1) - Initialization of the run
wandb.init(project = project, mode = mode, config = kwargs)

# WandB (2) - Logging info
wandb.config.update({"Number of Parameters": nn_params, "Number of GPUs": num_gpus})

# ------- Training Loop -------
for epoch in range(nb_epochs):

    # Stores the mean loss
    mean_loss = list()

    for conditioning, _, x in dataloader_train:

        # ------ Preprocessing -----
        #
        # Sampling uniformly diffusion steps
        diffusion_steps = torch.randint(0, diff_steps, (x.shape[0], 1))

        # Sampling noise
        noise = torch.normal(0, 1, x.shape)

        # Generating latent representations of the data
        z_t = neural_net.generate_latent(x, noise, diffusion_steps)

        # Pushing to device
        z_t, conditioning, noise, diffusion_steps =  z_t.to(device), conditioning.to(device),  noise.to(device), diffusion_steps

        # Adding the conditioning
        z_t = torch.cat([conditioning, z_t], dim = 1)

        # ----- Training -----
        #
        # Predicting the noise
        noise_pred = neural_net.predict(z_t, diffusion_steps)

        # Computing the loss (MSE between noise levels)
        loss = torch.pow(noise_pred[:, :, black_sea_mask_cs[0] == 1] - noise[:, :, black_sea_mask_cs[0] == 1], 2).nanmean()

        # Appending the loss
        mean_loss.append(loss.item())

        # WandB (4) - Logging the loss
        wandb.log({"Training/Loss (Instantaneous)": loss.item()})

        # Optimizing
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Displaying the loss
        print("Loss: ", loss.item())

        break

    # WandB (3) - Logging the loss and the epoch
    wandb.log({"Training/Loss (Averaged Over Batch)": np.mean(mean_loss), "Epoch (Left)": nb_epochs - epoch})

    # Updating the learning rate
    scheduler.step()

    # Computing Metrics
    for c, _, x in dataloader_valid:

        # Pushing to device
        x, c = x.to(device), c.to(device)

        # Generating conditionnal samples
        forecast = neural_net.generate_samples(x = x, conditioning = c, number_trajectories = 24)

        # Computing Metrics
        metrics(x.cpu(), forecast.cpu(), black_sea_mask_cs, hypoxia_treshold_standardized)

        # Only on the first batch (= one year of data)
        break


In [None]:
def metrics(ground_truth: torch.Tensor, forecasts: torch.Tensor, mask: torch.Tensor, treshold: float):
    """Used to compute different visualizations"""

    def plot_forecasts(ground_truth: torch.Tensor, forecasts: torch.Tensor, mask: torch.Tensor, day: int = 0):
        """Plotting the forecasts against the ground truth"""

        # Plotting the forecasts
        fig, ax = plt.subplots(12, 4, figsize = (30, 50))

        # List of months
        months = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]

        # Looping over each first day of the month
        for i in range(12):

            # Extracting the region of interest
            gt = ground_truth[i, day, 25:125, 70:270]

            # Masking the ground truth
            gt[mask[0, 25:125, 70:270] == 0] = np.nan

            # Extracting the minimum and maximum values
            vmin, vmax = np.nanmin(gt), np.nanmax(gt)

            # Plotting the ground truth
            ax[i, 0].imshow(gt, label = "Ground Truth")

            # Removing the tickz
            ax[i, 0].set_xticks([])
            ax[i, 0].set_yticks([])

            # Adding the title
            ax[i, 0].set_ylabel(months[i])

            # Plotting the forecasts
            for j in range(3):

                # Extracting the forecast
                fc = forecasts[i, day, j, 25:125, 70:270]

                # Masking the forecast
                fc[mask[0, 25:125, 70:270] == 0] = np.nan

                # Plotting the forecast
                ax[i, j + 1].imshow(fc, label = f"Forecast {j + 1}", vmin = vmin, vmax = vmax)
                ax[i, j + 1].set_xticks([])
                ax[i, j + 1].set_yticks([])

        # Tight layout
        plt.tight_layout()
        return fig

    def plot_hypoxia(ground_truth: torch.Tensor, forecasts: torch.Tensor, mask: torch.Tensor, day: int = 0):
        """Plotting the forecasts against the ground truth"""

        # Plotting the forecasts
        fig, ax = plt.subplots(12, 4, figsize = (30, 50))

        # List of months
        months = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]

        # Looping over each first day of the month
        for i in range(12):

            # Extracting the region of interest
            gt = ground_truth[i, day, 25:125, 70:270]

            # Masking the ground truth
            gt[mask[0, 25:125, 70:270] == 0] = -1

            # Extracting the minimum and maximum values
            vmin, vmax = np.nanmin(gt), np.nanmax(gt)

            # Plotting the ground truth
            ax[i, 0].imshow(gt, label = "Ground Truth")

            # Removing the tickz
            ax[i, 0].set_xticks([])
            ax[i, 0].set_yticks([])

            # Adding the title
            ax[i, 0].set_ylabel(months[i])

            # Plotting the forecasts
            for j in range(3):

                # Extracting the forecast
                fc = forecasts[i, day, j, 25:125, 70:270]

                # Masking the forecast
                fc[mask[0, 25:125, 70:270] == 0] = -1

                # Plotting the forecast
                ax[i, j + 1].imshow(fc, label = f"Forecast {j + 1}", vmin = vmin, vmax = vmax)
                ax[i, j + 1].set_xticks([])
                ax[i, j + 1].set_yticks([])

        # Tight layout
        plt.tight_layout()
        return fig

    def plot_probability_maps(ground_truth: torch.Tensor, forecasts: torch.Tensor, mask: torch.Tensor, day: int = 0):
        """Plotting the probability maps"""

        # Plotting the forecasts
        fig, ax = plt.subplots(12, 2, figsize = (10, 30))

        # Computing probability map
        prob_map = torch.mean(forecasts, dim = 2)

        # List of months
        months = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]

        # Looping over each first day of the month
        for i in range(12):

            # Extracting the region of interest
            gt = ground_truth[i, day, 25:125, 70:270]
            pm = prob_map[i, day, 25:125, 70:270]

            # Masking the ground truth
            gt[mask[0, 25:125, 70:270] == 0] = np.nan
            pm[mask[0, 25:125, 70:270] == 0] = np.nan

            # Extracting the minimum and maximum values
            vmin, vmax = np.nanmin(gt), np.nanmax(gt)

            # Plotting the ground truth
            ax[i, 0].imshow(gt, label = "Ground Truth")

            # Adding the colorbar
            fig.colorbar(ax[i, 1].imshow(pm, vmin = 0, vmax = 1, cmap="inferno"), ax = ax[i, 1])

            # Removing the tickz
            ax[i, 0].set_xticks([])
            ax[i, 0].set_yticks([])
            ax[i, 1].set_xticks([])
            ax[i, 1].set_yticks([])

            # Adding the title
            ax[i, 0].set_ylabel(months[i])

        plt.tight_layout()
        return fig

    def compute_precision_recall(ground_truth: torch.Tensor, forecasts: torch.Tensor, mask: torch.Tensor):
        """Computes the recall and precision of the forecasts"""

        # Store results
        results = {
            'precision': [],
            'recall': [],
            'accuracy': []
        }

        # Define thresholds
        thresholds = [i / 10 for i in range(0, 11)]

        # Computing probability map
        prob_map = torch.mean(forecasts, dim = 2)

        # Extracting only relevant information
        ground_truth = ground_truth[:, :, mask[0] == 1]
        prob_map     = prob_map[:, :, mask[0] == 1]

        # Computing metrics
        for threshold in thresholds:

            # Binarize predictions
            binary_prediction = (prob_map >= threshold).float()

            # Calculate TP, FP, TN, FN
            TP = (binary_prediction * ground_truth).sum(dim=(0, 2))
            FP = (binary_prediction * (1 - ground_truth)).sum(dim=(0, 2))
            TN = ((1 - binary_prediction) * (1 - ground_truth)).sum(dim=(0, 2))
            FN = ((1 - binary_prediction) * ground_truth).sum(dim=(0, 2))

            # Compute precision, recall, and accuracy
            precision =       TP / (TP + FP + 1e-8)
            recall =          TP / (TP + FN + 1e-8)
            accuracy = (TP + TN) / (TP + TN + FP + FN + 1e-8)

            # Adding results to the dictionary
            results['precision'].append(precision)
            results['recall'].append(recall)
            results['accuracy'].append(accuracy)

        # Convert lists to tensors for better handling
        results['precision'] = torch.stack(results['precision'], dim=0)
        results['recall'] = torch.stack(results['recall'], dim=0)
        results['accuracy'] = torch.stack(results['accuracy'], dim=0)

        # Computing Recall vs Precision curve
        rec_pre_0 = plt.figure(figsize=(7, 7))
        plt.plot(results['recall'][:, 0], results['precision'][:, 0], marker='*')
        plt.xlabel('Recall [-]')
        plt.ylabel('Precision [-]')
        plt.xlim(0, 1)
        plt.ylim(0, 1)
        plt.grid()

        rec_pre_1 = plt.figure(figsize=(7, 7))
        plt.plot(results['recall'][:, -1], results['precision'][:,-1], marker='o')
        plt.xlabel('Recall [-]')
        plt.ylabel('Precision [-]')
        plt.xlim(0, 1)
        plt.ylim(0, 1)
        plt.grid()

        # Sending results to WandB
        wandb.log({"Precision-Recall Curve (First Day)": wandb.Image(rec_pre_0),
                   "Precision-Recall Curve (Last Day)":  wandb.Image(rec_pre_1)})

        # Sending singular values and averages
        for i, t in enumerate(thresholds):
            wandb.log({f"Metrics/Precision (First Day, T = {t}))" : results['precision'][i, 0],
                       f"Metrics/Precision (Last Day, T = {t}))"  : results['precision'][i, -1],
                       f"Metrics/Recall (First Day, T = {t}))"    : results['recall'][i, 0],
                       f"Metrics/Recall (Last Day, T = {t}))"     : results['recall'][i, -1],
                       f"Metrics/Accuracy (First Day, T = {t}))"  : results['accuracy'][i, 0],
                       f"Metrics/Accuracy (Last Day, T = {t}))"   : results['accuracy'][i, -1]})

    # Plotting the forecasts
    wandb.log({"Forecast Visualization (First Day)": wandb.Image(plot_forecasts(ground_truth, forecasts, mask, day = 0))})
    wandb.log({"Forecast Visualization (Last Day)":  wandb.Image(plot_forecasts(ground_truth, forecasts, mask, day = -1))})

    # Extracting hypoxia regions
    ground_truth = (ground_truth < treshold) * 1.0
    forecasts    = (forecasts < treshold) * 1.0

    # Plotting the hypoxia regions
    wandb.log({"Hypoxia Visualization (First Day)": wandb.Image(plot_hypoxia(ground_truth, forecasts, mask, day = 0))})
    wandb.log({"Hypoxia Visualization (Last Day)": wandb.Image(plot_hypoxia(ground_truth, forecasts, mask, day = -1))})

    # Plotting the probability maps
    wandb.log({"Probability Visualization (First Day)": wandb.Image(plot_probability_maps(ground_truth, forecasts, mask, day = 0))})
    wandb.log({"Probability Visualization (Last Day)": wandb.Image(plot_probability_maps(ground_truth, forecasts, mask, day = -1))})

    # Computing global metrics
    compute_precision_recall(ground_truth, forecasts, mask)


metrics(x.cpu(), forecast.cpu() , black_sea_mask_cs, 1)