<a href="https://colab.research.google.com/github/ElDavido98/Neural-Networks/blob/main/Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Final Project**


---


### **Contents**
*   Preliminar Part
*   Data Processing
*   Utils
*   Neural Networks
*   Metrics
*   Forecasting
*   Pre
*   Training
*   Evaluation


---


### **N.B.**
Using the basic version of Google Colab, it is not possible to run this notebook because the RAM memory is not sufficient. This problem persists even when decreasing the complexity of the models and decreasing the amount of data used for training.

## Preliminar Part

In [None]:
!pip install torch==2.2.0
!pip install numpy==1.26.4
!pip install timm==0.9.16
!pip install netCDF4==1.6.5
!pip install scikit-learn==1.4.2
!pip install matplotlib==3.8.4

In [None]:
import torch.optim as optim
import torch
import statistics
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import netCDF4
import sklearn.preprocessing
from timm.layers import PatchEmbed, DropPath

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Utils
This section contains functions used by various classes.

In [None]:
path = "/content/drive/MyDrive/Neural_Networks/climate-learn"

single_folder = ["toa_incident_solar_radiation_5.625deg", "2m_temperature_5.625deg", "10m_u_component_of_wind_5.625deg",
                 "10m_v_component_of_wind_5.625deg"]
atmospheric_folder = ["geopotential_5.625deg", "u_component_of_wind_5.625deg", "v_component_of_wind_5.625deg",
                      "temperature_5.625deg", "specific_humidity_5.625deg",
                      "relative_humidity_5.625deg"]

static_variable = "constants_5.625deg"
single_variable = ["toa_incident_solar_radiation_", "2m_temperature_", "10m_u_component_of_wind_",
                   "10m_v_component_of_wind_"]
atmospheric_variable = ["geopotential_", "u_component_of_wind_", "v_component_of_wind_", "temperature_",
                        "specific_humidity_", "relative_humidity_"]

resolution = "_5.625deg"

abbr = ["lsm", "orography", "lat2d", "tisr", "t2m", "u10", "v10", "z", "u", "v", "t", "q", "r"]

levels = [50, 250, 500, 600, 700, 850, 925]
lev_indexes = [0, 4, 7, 8, 9, 10, 11]

low_bound_year_train, max_bound_year_train = 1980, 1981  # Original values from paper: 1979, 2016
low_bound_year_val_test, max_bound_year_val_test = 1986, 1987   # First part for validation, second part for test

low_year_train, max_year_train = 0, (max_bound_year_train - low_bound_year_train - 1)
low_hour_train, max_hour_train = 0, 8759
low_year_val, max_year_val = 0, (max_bound_year_val_test - low_bound_year_val_test - 1)
low_hour_val, max_hour_val = 0, 4379
low_year_test, max_year_test = 0, (max_bound_year_val_test - low_bound_year_val_test - 1)
low_hour_test, max_hour_test = 0, 4379

latitude_coordinates = [-87.1875, -81.5625, -75.9375, -70.3125, -64.6875, -59.0625, -53.4375, -47.8125, -42.1875,
                        -36.5625, -30.9375, -25.3125, -19.6875, -14.0625, -8.4375, -2.8125, 2.8125, 8.4375, 14.0625,
                        19.6875, 25.3125, 30.9375, 36.5625, 42.1875, 47.8125, 53.4375, 59.0625, 64.6875, 70.3125,
                        75.9375, 81.5625, 87.1875]


def return_to_image(x, patch_size, out_channels, img_size):
    p = patch_size
    c = out_channels
    h = img_size[0] // p
    w = img_size[1] // p
    assert h * w == x.shape[1]
    x = torch.reshape(x, shape=(x.shape[0], h, w, p, p, c)).to(device)
    x = torch.einsum("nhwpqc->nchpwq", x).to(device)
    imgs = torch.reshape(x, shape=(x.shape[0], c, (h * p), (w * p))).to(device)
    return imgs


def make_layer(block, in_channels, out_channels, num_blocks, change=0):
    layers = []
    for i in range(num_blocks):
        layers.append(block(in_channels, out_channels))
        if change:
            in_channels = out_channels
    return nn.Sequential(*layers).to(device)


def plot(name, linreg_baseline_rmse, linreg_baseline_acc, resnet_rmse, resnet_acc, unet_rmse, unet_acc, vit_rmse, vit_acc):
    lead_time = [6, 24, 72, 120, 240]
    # Create the figure and subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
    # Plot RMSE on the first subplot
    ax1.plot(lead_time, linreg_baseline_rmse, 'o-', label='Linear Regression RMSE')
    ax1.plot(lead_time, resnet_rmse, 'o-', label='ResNet RMSE')
    ax1.plot(lead_time, unet_rmse, 'o-', label='UNet RMSE')
    ax1.plot(lead_time, vit_rmse, 'o-', label='ViT RMSE')
    ax1.set_xlabel('Leadtime [hours]')
    ax1.set_ylabel('RMSE')
    ax1.set_title(name)
    ax1.legend()
    # Add grid
    ax1.grid(True)
    # Plot ACC on the second subplot
    ax2.plot(lead_time, linreg_baseline_acc, 'o-', label='Linear Regression ACC')
    ax2.plot(lead_time, resnet_acc, 'o-', label='ResNet ACC')
    ax2.plot(lead_time, unet_acc, 'o-', label='UNet ACC')
    ax2.plot(lead_time, vit_acc, 'o-', label='ViT ACC')
    ax2.set_xlabel('Leadtime [hours]')
    ax2.set_ylabel('ACC')
    ax2.set_title(name)
    ax2.legend()
    # Add grid
    ax2.grid(True)
    # Adjust spacing and layout
    plt.tight_layout()
    # Add legend
    plt.legend()
    # Show the plot
    plt.show()


def create_list(stats):
    l1, l2, l3 = [], [], []
    lists = [l1, l2, l3]
    for i in range(len(lists)):  # Variable
        for j in range(len(stats)):  # Net
            lists[i].append(stats[j][i])
    return lists


def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='█'):
    """
    Call in a loop to create terminal progress bar
    @params:
        iteration   - Required  : current iteration (Int)
        total       - Required  : total iterations (Int)
        prefix      - Optional  : prefix string (Str)
        suffix      - Optional  : suffix string (Str)
        decimals    - Optional  : positive number of decimals in percent complete (Int)
        length      - Optional  : character length of bar (Int)
        fill        - Optional  : bar fill character (Str)
        printEnd    - Optional  : end character (e.g. "\r", "\r\n") (Str)
    """
    percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
    filledLength = int(length * iteration // total)
    bar = fill * filledLength + '-' * (length - 1 - filledLength)
    print(f'\r{prefix} |{bar}| {percent}% {suffix}', end=" ")
    # Print New Line on Complete
    if iteration == total:
        print()


def printProgressAction(action, iteration):
    print(f'\r{action} {iteration}', end=" ")


class PeriodicPadding2D(nn.Module):
    def __init__(self, pad_width):
        super().__init__()
        self.pad_width = pad_width

    def forward(self, inputs):
        if self.pad_width == 0:
            return inputs
        inputs_padded = torch.cat((inputs[:, :, :, -self.pad_width:],
                                   inputs,
                                   inputs[:, :, :, :self.pad_width],), dim=-1, ).to(device)
        inputs_padded = nn.functional.pad(inputs_padded, (0, 0, self.pad_width, self.pad_width), ).to(device)
        return inputs_padded


def EarlyStopping(curr_monitor, old_monitor, count, patience=5, min_delta=0):
    stop = False
    if (curr_monitor - old_monitor) <= min_delta:
        count += 1
        if count > patience:
            stop = True
            return count, stop
    count = 0
    return count, stop


def lr_schedulers(Net_optimizer):
    Net_linearLR = optim.lr_scheduler.LinearLR(Net_optimizer, total_iters=5)
    Net_cos_annLR = optim.lr_scheduler.CosineAnnealingLR(Net_optimizer, T_max=45, eta_min=3.75e-4)
    return Net_linearLR, Net_cos_annLR


def check(Net, data, six_hours_ago, twelve_hours_ago, target, constants, latitude_weights):
    target = target[:, [4, 9, 33], :, :]
    pred = Net(data, six_hours_ago, twelve_hours_ago, target, constants)
    loss = loss_function(pred, target, latitude_weights)
    return loss


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, lead_time_6h: list, lead_time_24h: list, lead_time_72h: list, lead_time_120h: list,
                 lead_time_240h: list):
        super(CustomDataset, self).__init__()
        # Lead Time 6
        self.data_lt_6 = lead_time_6h[0]
        self.six_hours_ago_lt_6 = lead_time_6h[1]
        self.twelve_hours_ago_lt_6 = lead_time_6h[2]
        self.target_6 = lead_time_6h[3]
        # Lead Time 24
        self.data_lt_24 = lead_time_24h[0]
        self.six_hours_ago_lt_24 = lead_time_24h[1]
        self.twelve_hours_ago_lt_24 = lead_time_24h[2]
        self.target_24 = lead_time_24h[3]
        # Lead Time 72
        self.data_lt_72 = lead_time_72h[0]
        self.six_hours_ago_lt_72 = lead_time_72h[1]
        self.twelve_hours_ago_lt_72 = lead_time_72h[2]
        self.target_72 = lead_time_72h[3]
        # Lead Time 120
        self.data_lt_120 = lead_time_120h[0]
        self.six_hours_ago_lt_120 = lead_time_120h[1]
        self.twelve_hours_ago_lt_120 = lead_time_120h[2]
        self.target_120 = lead_time_120h[3]
        # Lead Time 240
        self.data_lt_240 = lead_time_240h[0]
        self.six_hours_ago_lt_240 = lead_time_240h[1]
        self.twelve_hours_ago_lt_240 = lead_time_240h[2]
        self.target_240 = lead_time_240h[3]

    def __len__(self):
        return len(self.data_lt_240)

    def __getitem__(self, idx):
        l_t_6 = [self.data_lt_6[idx], self.six_hours_ago_lt_6[idx], self.twelve_hours_ago_lt_6[idx],
                 self.target_6[idx]]
        l_t_24 = [self.data_lt_24[idx], self.six_hours_ago_lt_24[idx], self.twelve_hours_ago_lt_24[idx],
                  self.target_24[idx]]
        l_t_72 = [self.data_lt_72[idx], self.six_hours_ago_lt_72[idx], self.twelve_hours_ago_lt_72[idx],
                  self.target_72[idx]]
        l_t_120 = [self.data_lt_120[idx], self.six_hours_ago_lt_120[idx], self.twelve_hours_ago_lt_120[idx],
                   self.target_120[idx]]
        l_t_240 = [self.data_lt_240[idx], self.six_hours_ago_lt_240[idx], self.twelve_hours_ago_lt_240[idx],
                   self.target_240[idx]]
        return l_t_6, l_t_24, l_t_72, l_t_120, l_t_240


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.in_channels = in_channels
        self.periodic_zeros_padding = PeriodicPadding2D(1)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=0).to(device)
        self.leaky_relu = nn.LeakyReLU(0.3).to(device)
        self.bn1 = nn.BatchNorm2d(out_channels).to(device)
        self.dropout = nn.Dropout(0.1).to(device)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1).to(device)
        self.bn2 = nn.BatchNorm2d(out_channels).to(device)
        self.shortcut = nn.Identity().to(device)
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1)).to(device),
                nn.BatchNorm2d(out_channels).to(device)
            ).to(device)

    def forward(self, x):
        residual = self.shortcut(x)
        pad_x = self.periodic_zeros_padding(x)
        x = self.conv1(pad_x)
        x = self.leaky_relu(x)
        x = self.bn1(x)
        x = self.dropout(x)
        x = self.conv2(x)
        x = self.leaky_relu(x)
        x = self.bn2(x)
        x = self.dropout(x)
        x = x + residual
        return x

## **Data Processing**
'define_sets(task, val=0, test=0)' loads '.nc' files and turns them into lists; according to the values of 'task', 'val' and 'test' are defined the constant set, the training set or the validation/test set. The PreProcessing class initialises and uses the function that normalises network inputs.

In [None]:
class PreProcessing:
    def __init__(self):
        self.scaler = sklearn.preprocessing.StandardScaler()

    def process(self, dataset):                                            # dataset has shape (batch_size, 141, 32, 64)
        processed_data = []
        for i in range(len(dataset)):                                      # len(dataset) == batch_size
            processed_batch = []
            for j in range(len(dataset[i])):                               # len(dataset[i]) == 141
                processed_batch.append(self.scaler.fit_transform(X=dataset[i][j]))
            processed_data.append(np.array(processed_batch, dtype=np.float32))
        processed_dataset = torch.FloatTensor(np.array(processed_data, dtype=np.float32)).to(device)
        return processed_dataset


def define_sets(task, val=0, test=0):              # Loads .nc files and turns them into lists
    # Constants
    if task == 'const':
        print("Constants Set Definition")
        lsm, orography, lat2d = [], [], []
        nc = netCDF4.Dataset(f"{path}/{static_variable}.nc")
        for i in range(0, 3):
            data = nc[abbr[i]]
            data_np = data[:]
            if i == 0:
                lsm.append(data_np)
            if i == 1:
                orography.append(data_np)
            if i == 2:
                lat2d.append(data_np)
        constants_set = lsm + orography + lat2d
        return constants_set
    # Define train_set
    if task == 'train':
        print("Train Set Definition")
        train_tisr, train_t2m, train_u10, train_v10 = [], [], [], []
        train_z, train_u, train_v, train_t, train_q, train_r, = [], [], [], [], [], []
        j, l = 0, (max_bound_year_train - low_bound_year_train)
        printProgressBar(j, l, prefix='Progress:', suffix='Complete', length=50)
        for year in range(low_bound_year_train, max_bound_year_train):
            for i in range(3, len(abbr)):
                data, data_np = [], []
                if 2 < i < 7:
                    nc = netCDF4.Dataset(f"{path}/{single_folder[i - 3]}/{single_variable[i - 3]}{year}{resolution}.nc")
                    data = nc[abbr[i]]
                    data_np = data[:]
                    # Remove the last 24 hours if this year has 366 days
                    if data_np.shape[0] == 8784:
                        data_np = data_np[:8760]
                if i == 3:
                    train_tisr.append(data_np)
                if i == 4:
                    train_t2m.append(data_np)
                if i == 5:
                    train_u10.append(data_np)
                if i == 6:
                    train_v10.append(data_np)

                level = []
                if 6 < i < 13:
                    nc = netCDF4.Dataset(f"{path}/{atmospheric_folder[i - 7]}/"
                                         f"{atmospheric_variable[i - 7]}{year}{resolution}.nc")
                    data = nc[abbr[i]]
                    data_np = data[:]
                    # Remove the last 24 hours if this year has 366 days
                    if data_np.shape[0] == 8784:
                        data_np = data_np[:8760]
                    level = []
                    for lev in lev_indexes:
                        level.append(data_np[:, lev])
                if i == 7:
                    train_z.append(level)
                if i == 8:
                    train_u.append(level)
                if i == 9:
                    train_v.append(level)
                if i == 10:
                    train_t.append(level)
                if i == 11:
                    train_q.append(level)
                if i == 12:
                    train_r.append(level)
            j += 1
            printProgressBar(j, l, prefix='Progress:', suffix='Complete', length=50)
        train_list = []
        for i in range(max_year_train + 1):
            for j in range(8760):
                tr_list = []
                tr_list.append(train_tisr[i][j]), tr_list.append(train_t2m[i][j]), \
                tr_list.append(train_u10[i][j]), tr_list.append(train_v10[i][j])
                for lev in range(len(levels)):
                    tr_list.append(train_z[i][lev][j]), tr_list.append(train_u[i][lev][j]), \
                    tr_list.append(train_v[i][lev][j]), tr_list.append(train_t[i][lev][j]), \
                    tr_list.append(train_q[i][lev][j]), tr_list.append(train_r[i][lev][j])
                train_list.append(np.array(tr_list))
                #train_list.append(tr_list)
        train_set = train_list
        return train_set

    # Define validation_set and test_set
    if task == 'val_test':
        print("Validation and Test Sets Definition")
        val_tisr, val_t2m, val_u10, val_v10 = [], [], [], []
        val_z, val_u, val_v, val_t, val_q, val_r, = [], [], [], [], [], []
        test_tisr, test_t2m, test_u10, test_v10 = [], [], [], []
        test_z, test_u, test_v, test_t, test_q, test_r, = [], [], [], [], [], []
        j, l = 0, (max_bound_year_val_test - low_bound_year_val_test)
        printProgressBar(j, l, prefix='Progress:', suffix='Complete', length=50)
        for year in range(low_bound_year_val_test, max_bound_year_val_test):
            for i in range(3, len(abbr)):
                data, data_np = [], []
                if 2 < i < 7:
                    nc = netCDF4.Dataset(f"{path}/{single_folder[i - 3]}/{single_variable[i - 3]}{year}{resolution}.nc")
                    data = nc[abbr[i]]
                    data_np = data[:]
                    # Remove the last 24 hours if this year has 366 days
                    if data_np.shape[0] == 8784:
                        data_np = data_np[:8760]
                if i == 3:
                  if val:
                    val_tisr.append(data_np[0:4380])
                  if test:
                    test_tisr.append(data_np[4380:8760])
                if i == 4:
                  if val:
                    val_t2m.append(data_np[0:4380])
                  if test:
                    test_t2m.append(data_np[4380:8760])
                if i == 5:
                  if val:
                    val_u10.append(data_np[0:4380])
                  if test:
                    test_u10.append(data_np[4380:8760])
                if i == 6:
                  if val:
                    val_v10.append(data_np[0:4380])
                  if test:
                    test_v10.append(data_np[4380:8760])

                level_val, level_test = [], []
                if 6 < i < 13:
                    nc = netCDF4.Dataset(f"{path}/{atmospheric_folder[i - 7]}/"
                                         f"{atmospheric_variable[i - 7]}{year}{resolution}.nc")
                    data = nc[abbr[i]]
                    data_np = data[0:4380]
                    # Remove the last 24 hours if this year has 366 days
                    if data_np.shape[0] == 8784:
                        data_np = data_np[:8760]
                    level_val, level_test = [], []
                    for lev in lev_indexes:
                      if val:
                        level_val.append(data_np[0:4380, lev])
                      if test:
                        level_test.append(data_np[4380:8760, lev])
                if i == 7:
                  if val:
                    val_z.append(level_val)
                  if test:
                    test_z.append(level_test)
                if i == 8:
                  if val:
                    val_u.append(level_val)
                  if test:
                    test_u.append(level_test)
                if i == 9:
                  if val:
                    val_v.append(level_val)
                  if test:
                    test_v.append(level_test)
                if i == 10:
                  if val:
                    val_t.append(level_val)
                  if test:
                    test_t.append(level_test)
                if i == 11:
                  if val:
                    val_q.append(level_val)
                  if test:
                    test_q.append(level_test)
                if i == 12:
                  if val:
                    val_r.append(level_val)
                  if test:
                    test_r.append(level_test)
            j += 1
            printProgressBar(j, l, prefix='Progress:', suffix='Complete', length=50)
        validation_list, test_list = [], []
        for i in range(max_year_val + 1):
            for j in range(4380):
                val_list, tst_list = [], []
                if val:
                  val_list.append(val_tisr[i][j]), val_list.append(val_t2m[i][j]), \
                  val_list.append(val_u10[i][j]), val_list.append(val_v10[i][j])
                if test:
                  tst_list.append(test_tisr[i][j]), tst_list.append(test_t2m[i][j]), \
                  tst_list.append(test_u10[i][j]), tst_list.append(test_v10[i][j])
                for lev in range(len(levels)):
                    if val:
                      val_list.append(val_z[i][lev][j]), val_list.append(val_u[i][lev][j]), \
                      val_list.append(val_v[i][lev][j]), val_list.append(val_t[i][lev][j]), \
                      val_list.append(val_q[i][lev][j]), val_list.append(val_r[i][lev][j])
                    if test:
                      tst_list.append(test_z[i][lev][j]), tst_list.append(test_u[i][lev][j]), \
                      tst_list.append(test_v[i][lev][j]), tst_list.append(test_t[i][lev][j]), \
                      tst_list.append(test_q[i][lev][j]), tst_list.append(test_r[i][lev][j])
                if val:
                    validation_list.append(np.array(val_list))
                    #validation_list.append(val_list)
                if test:
                    test_list.append(np.array(tst_list))
                    #test_list.append(tst_list)
        validation_set = validation_list
        test_set = test_list
        return validation_set, test_set

In [None]:
constants_set = define_sets('const')

In [None]:
train_set = define_sets('train')

In [None]:
validation_set, test_set = define_sets('val_test', val=1, test=1)

Validation and Test Sets Definition
Progress: |██████████████████████████████████████████████████| 100.0% Complete 


## **Neural Networks**
Here, the three neural networks, the baseline and the generic class ‘Network’ are defined; the latter class creates a generic network with all the methods required for training.

ResNet

In [None]:
class ResNet(nn.Module):
    def __init__(self, in_channels, out_channels, processor, hidden_channels=128, num_blocks=28):
        super(ResNet, self).__init__()
        self.periodic_zeros_padding = PeriodicPadding2D(3)
        self.image_projection = nn.Conv2d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=0).to(device)
        self.res_net_blocks = make_layer(ResidualBlock, hidden_channels, hidden_channels, num_blocks=num_blocks, change=1)
        self.norm = nn.BatchNorm2d(hidden_channels).to(device)
        self.out = nn.Conv2d(hidden_channels, out_channels, kernel_size=7, stride=1, padding=3).to(device)
        self.leaky_relu = nn.LeakyReLU(0.3).to(device)
        self.processor = processor
        self.set_climatology = []

    def __call__(self, data, six_hours_ago, twelve_hours_ago, target, constants):
        current_data = np.concatenate((constants, data), axis=1)
        pred = self.forward(np.concatenate((current_data, six_hours_ago, twelve_hours_ago), axis=1))
        self.set_climatology.append(target)
        return pred

    def forward(self, x):
        x = self.processor.process(x)
        x = self.image_projection(self.periodic_zeros_padding(x))
        x = self.res_net_blocks(x)
        x = self.out(self.leaky_relu(self.norm(x)))
        return x

UNet

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, processor, hidden_channels=64, channel_multiplications=(1, 2, 2),
                 blocks=2, use_attention=False):
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.processor = processor
        self.set_climatology = []
        self.periodic_zeros_padding = PeriodicPadding2D(3)
        self.conv1 = nn.Conv2d(self.in_channels, self.hidden_channels, kernel_size=7, stride=1, padding=0).to(device)
        out_channels = in_channels = self.hidden_channels
        self.n_resolutions = len(channel_multiplications)
        self.blocks = blocks
        # Downward path
        self.down_blocks = []
        for i in range(self.n_resolutions):     # 3
            out_channels = in_channels * channel_multiplications[i]
            for _ in range(self.blocks):     # 2
                self.down_blocks.append(ResidualBlock(in_channels, out_channels))
                in_channels = out_channels
            # Down sample at all resolutions except the last
            # Scale down the feature map by 0.5 times
            if i < self.n_resolutions - 1:
                self.down_blocks.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1).to(device))
        self.down_blocks = nn.ModuleList(self.down_blocks).to(device)
        # Bottleneck
        self.bottleneck = nn.Sequential(
            ResidualBlock(out_channels, out_channels),
            ResidualBlock(out_channels, out_channels)
        ).to(device)
        # Upward path
        self.up_blocks = []
        for i in reversed(range(self.n_resolutions)):       # 3
            out_channels = in_channels
            for _ in range(self.blocks):     # 2
                self.up_blocks.append(ResidualBlock(in_channels+out_channels, out_channels))
            out_channels = in_channels // channel_multiplications[i]
            self.up_blocks.append(ResidualBlock(in_channels + out_channels, out_channels))
            in_channels = out_channels
            # Up sample at all resolutions except last
            if i > 0:
                self.up_blocks.append(nn.ConvTranspose2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1).to(device))
        self.up_blocks = nn.ModuleList(self.up_blocks).to(device)
        self.norm = nn.BatchNorm2d(self.hidden_channels).to(device)
        self.leaky_relu = nn.LeakyReLU(0.3).to(device)
        self.out = nn.Conv2d(in_channels, self.out_channels, kernel_size=7, padding=0).to(device)

    def __call__(self, data, six_hours_ago, twelve_hours_ago, target, constants):
        current_data = np.concatenate((constants, data), axis=1)
        pred = self.forward(np.concatenate((current_data, six_hours_ago, twelve_hours_ago), axis=1))
        self.set_climatology.append(target)
        return pred

    def forward(self, x):
        x = self.processor.process(x)
        x = self.conv1(self.periodic_zeros_padding(x))
        skips = [x]
        # Downward path
        for down_block in self.down_blocks:
            x = down_block(x)
            skips.append(x)
        # Bottleneck
        x = self.bottleneck(x)
        # Upward path
        for up_block in self.up_blocks:
            if isinstance(up_block, nn.ConvTranspose2d):
                x = up_block(x)
            else:
                skip_connection = skips.pop()
                x = torch.cat((x, skip_connection), dim=1).to(device)
                x = up_block(x)
        x = self.periodic_zeros_padding(x)
        x = self.out(self.leaky_relu(self.norm(x)))
        return x

ViT

In [None]:
class ViT(nn.Module):
    def __init__(self, in_channels, out_channels, img_size, processor, patch_size=2, embedding_dim=128, depth=8,
                 num_heads=4, hidden_dimension=128, mlp_ratio=4, prediction_depth=2, drop_path_rate=0.1,
                 dropout_rate=0.1, learn_positional_embedding=False):
        super(ViT, self).__init__()
        self.img_size = img_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.patch_size = patch_size
        self.processor = processor
        self.set_climatology = []
        # Patch Embedding
        self.patch_embedding = PatchEmbed(self.img_size, patch_size, self.in_channels, embedding_dim).to(device)
        self.num_patches = self.patch_embedding.num_patches
        # Position Embedding
        self.pos_embed = nn.Parameter(self.trigonometric_pos(embedding_dim),
                                      requires_grad=learn_positional_embedding).to(device)
        self.pos_drop = nn.Dropout(dropout_rate).to(device)
        self.vit_block = []
        for _ in range(depth):
            self.vit_block.append(ViTBlock(embedding_dim, num_heads, mlp_ratio, dropout_rate, drop_path_rate))
        self.vit_block = nn.ModuleList(self.vit_block).to(device)
        # Prediction head (MLP Head)
        prediction_layers = []
        for _ in range(prediction_depth):
            prediction_layers.append(nn.Linear(hidden_dimension, hidden_dimension).to(device))
            prediction_layers.append(nn.LeakyReLU().to(device))
        prediction_layers.append(nn.Linear(hidden_dimension, out_channels * patch_size ** 2).to(device))
        self.prediction_head = nn.Sequential(*prediction_layers).to(device)

    def __call__(self, data, six_hours_ago, twelve_hours_ago, target, constants):
        current_data = np.concatenate((constants, data), axis=1)
        pred = self.forward(np.concatenate((current_data, six_hours_ago, twelve_hours_ago), axis=1))
        self.set_climatology.append(target)
        return pred

    def forward(self, x):
        x = self.processor.process(x)
        # Patch embedding
        x = self.patch_embedding(x)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        # Encoder
        for vit_block in self.vit_block:
            x = vit_block.forward(x)
        # Prediction head
        x = self.prediction_head(x)
        x = return_to_image(x, self.patch_size, self.out_channels, self.img_size)
        return x

    def trigonometric_pos(self, embedding_dim, n=10000):
        pos_embed = torch.zeros(1, self.num_patches, embedding_dim).to(device)
        for i in range(1):
            for j in range(self.num_patches):
                for k in range(int(embedding_dim / 2)):
                    denominator = np.power(n, 2 * i / embedding_dim)
                    pos_embed[i, k, 2 * i] = np.sin(k / denominator)
                    pos_embed[i, k, 2 * i + 1] = np.cos(k / denominator)
        return pos_embed


class ViTBlock(nn.Module):
    def __init__(self, embedding_dim, num_heads, mlp_ratio, dropout_rate, drop_path_rate):
        super(ViTBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embedding_dim).to(device)

        self.query = nn.Linear(embedding_dim, embedding_dim).to(device)
        self.key = nn.Linear(embedding_dim, embedding_dim).to(device)
        self.value = nn.Linear(embedding_dim, embedding_dim).to(device)

        self.mha = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, dropout=dropout_rate).to(device)
        if drop_path_rate > 0:
            self.drop_path1 = DropPath(drop_path_rate)
        else:
            self.drop_path1 = nn.Identity().to(device)

        self.norm2 = nn.LayerNorm(embedding_dim).to(device)
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, int(embedding_dim * mlp_ratio)).to(device),
            nn.LeakyReLU().to(device),
            nn.Dropout(dropout_rate).to(device),
            nn.Linear(int(embedding_dim * mlp_ratio), embedding_dim).to(device),
            nn.LeakyReLU().to(device),
            nn.Dropout(dropout_rate).to(device)
        ).to(device)
        if drop_path_rate > 0:
            self.drop_path2 = DropPath(drop_path_rate)
        else:
            self.drop_path2 = nn.Identity().to(device)

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        x = self.norm1(x)
        y = self.mha(query=self.query(x), key=self.key(x), value=self.value(x))
        x = x + self.drop_path1(y[0])
        x = x + self.drop_path2(self.mlp(self.norm2(x)))

        return x

Linear Regression (Baseline)

In [None]:
class LinearRegression(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.linear = nn.Linear(in_channels, out_channels)

    def forward(self, x):
        batch_size = x.shape[0]
        height = x.shape[3]
        width = x.shape[4]
        x = x.flatten(1)
        x = self.linear(x)
        x = x.reshape(batch_size, -1, height, width)
        return x

Generic class 'Network'

In [None]:
class Network:
    def __init__(self, Net, lat_weights, name, lead_time):
        self.Net = Net
        self.name = name
        self.lead_time = lead_time
        self.Net_optimizer = optim.AdamW(self.Net.parameters(), lr=5e-4, weight_decay=1e-5)
        self.Net_linearLR, self.Net_cos_annLR = lr_schedulers(self.Net_optimizer)
        self.low_year_train, self.max_year_train = low_year_train, max_year_train
        self.low_hour_train, self.max_hour_train = low_hour_train, max_hour_train
        self.low_year_val, self.max_year_val = low_year_val, max_year_val
        self.low_hour_val, self.max_hour_val = low_hour_val, max_hour_val
        self.lat_weights = lat_weights
        self.done = False
        self.count = 0
        self.loss = None
        self.validation_losses = []
        self.previous_validation_losses = self.validation_losses
        self.pred, self.targ = [], []

    def pre_steps(self):
        if not self.done:
            self.count = 0
            self.loss = None
            self.validation_losses = []
            self.previous_validation_losses = self.validation_losses

    def train_step(self, input_list, constants):
        # input_list = [data, six_hours_ago, twelve_hours_ago, target]
        if not self.done:
            self.loss = check(self.Net, input_list[0], input_list[1], input_list[2], input_list[3], constants,
                              self.lat_weights)

    def val_step(self, input_list, constants):
        # input_list = [data, six_hours_ago, twelve_hours_ago, target]
        if not self.done:
            loss = check(self.Net, input_list[0], input_list[1], input_list[2], input_list[3], constants,
                         self.lat_weights)
            self.validation_losses.append(loss.item())

    def post_steps(self, epoch):
        # Optimization
        if not self.done:
            loss = self.loss
            self.Net_optimizer.zero_grad()
            loss.mean().backward()
            self.Net_optimizer.step()
            # Learning Rate update
            # Linear warmup schedule if 'epoch < 5' - Cosine-annealing warmup schedule 'else'
            if epoch < 5:
                self.Net_linearLR.step()
                self.Net_optimizer.defaults['lr'] = self.Net_linearLR.get_last_lr()
            else:
                self.Net_cos_annLR.step()
                self.Net_optimizer.defaults['lr'] = self.Net_cos_annLR.get_last_lr()

    def stopping(self):
        if self.previous_validation_losses is None:
            self.count, self.done = 0, False
        else:
            self.count, self.done = EarlyStopping(statistics.mean(self.validation_losses),
                                                  statistics.mean(self.previous_validation_losses), self.count)
        if self.done:
            print("{}_{} is done".format(self.name, self.lead_time))

    def eval_step(self, data, six_hours_ago, twelve_hours_ago, target, constants):
        p = self.Net.Net(data, six_hours_ago, twelve_hours_ago, target, constants)
        self.pred.append(p.detach().numpy()), self.targ.append(target.detach().numpy())

## Metrics
Contains the loss function and metrics used for evaluation.

In [None]:
def latitude_weighting_function(latitude_coordinates):
    # latitude_coordinates is an array of 'H' elements
    num = np.cos(np.deg2rad(latitude_coordinates))
    den = sum(num) / len(num)
    latitude_weights = num / den
    latitude_weights = torch.from_numpy(latitude_weights).view(1, 1, -1, 1).to(device)
    return latitude_weights.cpu()


def loss_function(prediction, target, latitude_weights):                            # LW_MSE
    error = latitude_weights * torch.square(prediction - target).to(device)
    result = torch.mean(torch.mean(error, dim=[0, 2, 3]).to(device), dim=0).to(device)
    return result


def LW_RMSE(prediction, target, latitude_weights):
    diff = [x - y for x, y in zip(prediction, target)]
    error = latitude_weights * np.square(diff)
    channel_rmse = error.mean([3, 4]).sqrt().mean(1)
    result = channel_rmse.mean(0)
    return result.cpu().numpy()


def LW_ACC(prediction, target, latitude_weights, climatology):
    climatology = np.asarray(climatology).mean(0)
    prediction = prediction - climatology
    target = target - climatology
    channel_acc = []
    for i in range(prediction.shape[1]):
        pred_prime = prediction[:, i] - prediction[:, i].mean()
        target_prime = target[:, i] - target[:, i].mean()
        numer = (latitude_weights * pred_prime * target_prime).sum()
        denom_1 = (latitude_weights * np.square(pred_prime)).sum()
        denom_2 = (latitude_weights * np.square(target_prime)).sum()
        channel_acc.append(numer / np.sqrt(denom_1 * denom_2))
    channel_acc = torch.stack(channel_acc).to(device)
    result = channel_acc
    return result.cpu().numpy()


def compute_eval(prediction, target, latitude_weights, set_climatology):
    rmse = LW_RMSE(prediction, target, latitude_weights)
    acc = LW_ACC(prediction, target, latitude_weights, set_climatology)
    return rmse, acc

## Forecasting
This class manages the 15 networks and the baseline and their training. Networks are of three types: 5 ResNets, 5 UNets and 5 ViTs; their high number is due to the fact that we are considering 5 lead times for forecasting: [6, 24, 72, 120, 240] hours. The baseline is a Linear Regression model.

In [None]:
class Forecasting(nn.Module):
    def __init__(self, constants_set, train_data, validation_data, batch_size=128,
                 res_params=list[128, 28], u_params=list[64, 2], vit_params=list[8, 4, 2, 128]):
        super(Forecasting, self).__init__()
        self.device = device
        self.num_channels = 141
        self.out_channels = 3
        self.batch_size = None
        self.img_size = (self.height, self.width) = (32, 64)
        self.val_dim = len(validation_data)
        self.processor = PreProcessing()
        self.latitude_coordinates = latitude_coordinates
        self.latitude_weights = latitude_weighting_function(self.latitude_coordinates)
        self.constants = constants_set
        self.validation_data = validation_data
        train_years = max_year_train + 1 - low_year_train
        train_6 = [train_data[12:((8760*train_years) - 6)], train_data[6:((8760*train_years) - 12)],
                   train_data[0:((8760*train_years) - 18)], train_data[18:(8760*train_years)]]
        train_24 = [train_data[12:((8760*train_years) - 24)], train_data[6:((8760*train_years) - 30)],
                    train_data[0:((8760*train_years) - 36)], train_data[36:(8760*train_years)]]
        train_72 = [train_data[12:((8760*train_years) - 72)], train_data[6:((8760*train_years) - 78)],
                    train_data[0:((8760*train_years) - 84)], train_data[84:(8760*train_years)]]
        train_120 = [train_data[12:((8760*train_years) - 120)], train_data[6:((8760*train_years) - 126)],
                     train_data[0:((8760*train_years) - 132)], train_data[132:(8760*train_years)]]
        train_240 = [train_data[12:((8760*train_years) - 240)], train_data[6:((8760*train_years) - 246)],
                     train_data[0:((8760*train_years) - 252)], train_data[252:(8760*train_years)]]
        train_set = CustomDataset(train_6, train_24, train_72, train_120, train_240)
        self.train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
        val_6 = [validation_data[12:4374], validation_data[6:4368], validation_data[0:4362], validation_data[18:4380]]
        val_24 = [validation_data[12:4356], validation_data[6:4350], validation_data[0:4344], validation_data[36:4380]]
        val_72 = [validation_data[12:4308], validation_data[6:4302], validation_data[0:4296], validation_data[84:4380]]
        val_120 = [validation_data[12:4260], validation_data[6:4254], validation_data[0:4248],
                   validation_data[132:4380]]
        val_240 = [validation_data[12:4140], validation_data[6:4134], validation_data[0:4128],
                   validation_data[252:4380]]
        validation_set = CustomDataset(val_6, val_24, val_72, val_120, val_240)
        self.validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=True)

        # Baselines
        self.LinReg_Baseline_6 = Network(
            LinearRegression(self.num_channels,
                             self.out_channels),
            self.latitude_weights,
            'LinearRegression',
            6)
        self.LinReg_Baseline_24 = Network(
            LinearRegression(self.num_channels,
                             self.out_channels),
            self.latitude_weights,
            'LinearRegression',
            24)
        self.LinReg_Baseline_72 = Network(
            LinearRegression(self.num_channels,
                             self.out_channels),
            self.latitude_weights,
            'LinearRegression',
            72)
        self.LinReg_Baseline_120 = Network(
            LinearRegression(self.num_channels,
                             self.out_channels),
            self.latitude_weights,
            'LinearRegression',
            120)
        self.LinReg_Baseline_240 = Network(
            LinearRegression(self.num_channels,
                             self.out_channels),
            self.latitude_weights,
            'LinearRegression',
            240)
        # ResNets
        self.ResNet_6 = Network(
            ResNet(self.num_channels,
                   self.out_channels,
                   self.processor,
                   hidden_channels=res_params[0],
                   num_blocks=res_params[1]),
            self.latitude_weights,
            'ResNet',
            6)
        self.ResNet_24 = Network(
            ResNet(self.num_channels,
                   self.out_channels,
                   self.processor,
                   hidden_channels=res_params[0],
                   num_blocks=res_params[1]),
            self.latitude_weights,
            'ResNet',
            24)
        self.ResNet_72 = Network(
            ResNet(self.num_channels,
                   self.out_channels,
                   self.processor,
                   hidden_channels=res_params[0],
                   num_blocks=res_params[1]),
            self.latitude_weights,
            'ResNet',
            72)
        self.ResNet_120 = Network(
            ResNet(self.num_channels,
                   self.out_channels,
                   self.processor,
                   hidden_channels=res_params[0],
                   num_blocks=res_params[1]),
            self.latitude_weights,
            'ResNet',
            120)
        self.ResNet_240 = Network(
            ResNet(self.num_channels,
                   self.out_channels,
                   self.processor,
                   hidden_channels=res_params[0],
                   num_blocks=res_params[1]),
            self.latitude_weights,
            'ResNet',
            240)
        # UNets
        self.UNet_6 = Network(
            UNet(self.num_channels,
                 self.out_channels,
                 self.processor,
                 hidden_channels=u_params[0],
                 blocks=u_params[1]),
            self.latitude_weights,
            'UNet',
            6)
        self.UNet_24 = Network(
            UNet(self.num_channels,
                 self.out_channels,
                 self.processor,
                 hidden_channels=u_params[0],
                 blocks=u_params[1]),
            self.latitude_weights,
            'UNet',
            24)
        self.UNet_72 = Network(
            UNet(self.num_channels,
                 self.out_channels,
                 self.processor,
                 hidden_channels=u_params[0],
                 blocks=u_params[1]),
            self.latitude_weights,
            'UNet',
            72)
        self.UNet_120 = Network(
            UNet(self.num_channels,
                 self.out_channels,
                 self.processor,
                 hidden_channels=u_params[0],
                 blocks=u_params[1]),
            self.latitude_weights,
            'UNet',
            120)
        self.UNet_240 = Network(
            UNet(self.num_channels,
                 self.out_channels,
                 self.processor,
                 hidden_channels=u_params[0],
                 blocks=u_params[1]),
            self.latitude_weights,
            'UNet',
            240)
        # ViTs
        self.ViT_6 = Network(
            ViT(self.num_channels,
                self.out_channels,
                self.img_size,
                self.processor,
                embedding_dim=vit_params[3],
                depth=vit_params[0],
                num_heads=vit_params[1],
                prediction_depth=vit_params[2],
                hidden_dimension=vit_params[3]),
            self.latitude_weights,
            'ViT', 6)
        self.ViT_24 = Network(
            ViT(self.num_channels,
                self.out_channels,
                self.img_size,
                self.processor,
                embedding_dim=vit_params[3],
                depth=vit_params[0],
                num_heads=vit_params[1],
                prediction_depth=vit_params[2],
                hidden_dimension=vit_params[3]),
            self.latitude_weights,
            'ViT', 24)
        self.ViT_72 = Network(
            ViT(self.num_channels,
                self.out_channels,
                self.img_size,
                self.processor,
                embedding_dim=vit_params[3],
                depth=vit_params[0],
                num_heads=vit_params[1],
                prediction_depth=vit_params[2],
                hidden_dimension=vit_params[3]),
            self.latitude_weights,
            'ViT', 72)
        self.ViT_120 = Network(
            ViT(self.num_channels,
                self.out_channels,
                self.img_size,
                self.processor,
                embedding_dim=vit_params[3],
                depth=vit_params[0],
                num_heads=vit_params[1],
                prediction_depth=vit_params[2],
                hidden_dimension=vit_params[3]),
            self.latitude_weights,
            'ViT', 120)
        self.ViT_240 = Network(
            ViT(self.num_channels,
                self.out_channels,
                self.img_size,
                self.processor,
                embedding_dim=vit_params[3],
                depth=vit_params[0],
                num_heads=vit_params[1],
                prediction_depth=vit_params[2],
                hidden_dimension=vit_params[3]),
            self.latitude_weights,
            'ViT', 240)

    def train_forecasters(self, epochs=50):
        print("Start Training")
        for epoch in range(epochs):
            print("Epoch ", epoch)
            self.ResNet_6.pre_steps(), self.ResNet_24.pre_steps(), self.ResNet_72.pre_steps(),\
                self.ResNet_120.pre_steps(), self.ResNet_240.pre_steps()
            self.UNet_6.pre_steps(), self.UNet_24.pre_steps(), self.UNet_72.pre_steps(), self.UNet_120.pre_steps(),\
                self.UNet_240.pre_steps()
            self.ViT_6.pre_steps(), self.ViT_24.pre_steps(), self.ViT_72.pre_steps(), self.ViT_120.pre_steps(),\
                self.ViT_240.pre_steps()

            # New Part with Batch
            print("   Train")
            for num, (train_6, train_24, train_72, train_120, train_240) in enumerate(self.train_loader):
                # Train
                printProgressAction('    Batch', num)
                batch_dim = len(train_6[0])
                train_constants = np.array((self.constants,) * batch_dim)

                # Baselines
                self.LinReg_Baseline_6.train_step(train_6, train_constants), \
                    self.LinReg_Baseline_24.train_step(train_24, train_constants), \
                    self.LinReg_Baseline_72.train_step(train_72, train_constants), \
                    self.LinReg_Baseline_120.train_step(train_120, train_constants), \
                    self.LinReg_Baseline_240.train_step(train_240, train_constants)
                # ResNets
                self.ResNet_6.train_step(train_6, train_constants),\
                    self.ResNet_24.train_step(train_24, train_constants),\
                    self.ResNet_72.train_step(train_72, train_constants),\
                    self.ResNet_120.train_step(train_120, train_constants),\
                    self.ResNet_240.train_step(train_240, train_constants)
                # UNets
                self.UNet_6.train_step(train_6, train_constants),\
                    self.UNet_24.train_step(train_24, train_constants),\
                    self.UNet_72.train_step(train_72, train_constants),\
                    self.UNet_120.train_step(train_120, train_constants),\
                    self.UNet_240.train_step(train_240, train_constants)
                # ViTs
                self.ViT_6.train_step(train_6, train_constants),\
                    self.ViT_24.train_step(train_24, train_constants),\
                    self.ViT_72.train_step(train_72, train_constants),\
                    self.ViT_120.train_step(train_120, train_constants),\
                    self.ViT_240.train_step(train_240, train_constants)
                # Optimization
                # ResNets
                self.ResNet_6.post_steps(epoch), self.ResNet_24.post_steps(epoch), self.ResNet_72.post_steps(epoch),\
                    self.ResNet_120.post_steps(epoch), self.ResNet_240.post_steps(epoch)
                # UNets
                self.UNet_6.post_steps(epoch), self.UNet_24.post_steps(epoch), self.UNet_72.post_steps(epoch),\
                    self.UNet_120.post_steps(epoch), self.UNet_240.post_steps(epoch)
                # ViTs
                self.ViT_6.post_steps(epoch), self.ViT_24.post_steps(epoch), self.ViT_72.post_steps(epoch),\
                    self.ViT_120.post_steps(epoch), self.ViT_240.post_steps(epoch)
            # Validation
            print("\n   Validation")
            for num, (val_6, val_24, val_72, val_120, val_240) in enumerate(self.validation_loader):
                printProgressAction('    Batch', num)
                batch_dim = len(val_6[0])
                val_constants = np.array((self.constants,) * batch_dim)
                # Baselines
                self.LinReg_Baseline_6.val_step(val_6, val_constants),\
                    self.LinReg_Baseline_24.val_step(val_24, val_constants),\
                    self.LinReg_Baseline_72.val_step(val_72, val_constants),\
                    self.LinReg_Baseline_120.val_step(val_120, val_constants),\
                    self.LinReg_Baseline_240.val_step(val_240, val_constants)
                # ResNets
                self.ResNet_6.val_step(val_6, val_constants), self.ResNet_24.val_step(val_24, val_constants),\
                    self.ResNet_72.val_step(val_72, val_constants), self.ResNet_120.val_step(val_120, val_constants),\
                    self.ResNet_240.val_step(val_240, val_constants)
                # UNets
                self.UNet_6.val_step(val_6, val_constants), self.UNet_24.val_step(val_24, val_constants),\
                    self.UNet_72.val_step(val_72, val_constants), self.UNet_120.val_step(val_120, val_constants),\
                    self.UNet_240.val_step(val_240, val_constants)
                # ViTs
                self.ViT_6.val_step(val_6, val_constants), self.ViT_24.val_step(val_24, val_constants),\
                    self.ViT_72.val_step(val_72, val_constants), self.ViT_120.val_step(val_120, val_constants),\
                    self.ViT_240.val_step(val_240, val_constants)

            # EarlyStopping
            # Baselines
            self.LinReg_Baseline_6.stopping(), self.LinReg_Baseline_24.stopping(), self.LinReg_Baseline_72.stopping(),\
                self.LinReg_Baseline_120.stopping(), self.LinReg_Baseline_240.stopping()
            # ResNets
            self.ResNet_6.stopping(), self.ResNet_24.stopping(), self.ResNet_72.stopping(), self.ResNet_120.stopping(),\
                self.ResNet_240.stopping()
            # UNets
            self.UNet_6.stopping(), self.UNet_24.stopping(), self.UNet_72.stopping(), self.UNet_120.stopping(),\
                self.UNet_240.stopping()
            # ViTs
            self.ViT_6.stopping(), self.ViT_24.stopping(), self.ViT_72.stopping(), self.ViT_120.stopping(),\
                self.ViT_240.stopping()

            if all([self.ResNet_6.done, self.ResNet_24.done, self.ResNet_72.done, self.ResNet_120.done,
                    self.ResNet_240.done, self.UNet_6.done, self.UNet_24.done, self.UNet_72.done, self.UNet_120.done,
                    self.UNet_240.done, self.ViT_6.done, self.ViT_24.done, self.ViT_72.done, self.ViT_120.done,
                    self.ViT_240.done]):
                print("Stopped prematurely due to EarlyStopping")
                break
        print("\nEnd Training")

    def evaluate_forecasters(self, test_loader, constants_set):
        print("Start Evaluation")
        for num, (test_6, test_24, test_72, test_120, test_240) in enumerate(test_loader):
            if len(test_6[0]) != 128:
                break
            printProgressAction('    Batch', num)
            batch_dim = len(test_6[0])
            test_constants = np.array((constants_set,) * batch_dim)
            # Baselines
            self.LinReg_Baseline_6.eval_step(test_6[0], test_6[1], test_6[2], test_6[3][:, [4, 9, 33], :, :],
                                             test_constants), \
                self.LinReg_Baseline_24.eval_step(test_24[0], test_24[1], test_24[2], test_24[3][:, [4, 9, 33], :, :],
                                                  test_constants), \
                self.LinReg_Baseline_72.eval_step(test_72[0], test_72[1], test_72[2], test_72[3][:, [4, 9, 33], :, :],
                                                  test_constants), \
                self.LinReg_Baseline_120.eval_step(test_120[0], test_120[1], test_120[2], test_120[3][:, [4, 9, 33], :, :],
                                                   test_constants), \
                self.LinReg_Baseline_240.eval_step(test_240[0], test_240[1], test_240[2], test_240[3][:, [4, 9, 33], :, :],
                                                   test_constants)
            # ResNets
            self.ResNet_6.eval_step(test_6[0], test_6[1], test_6[2], test_6[3][:, [4, 9, 33], :, :], test_constants),\
                self.ResNet_24.eval_step(test_24[0], test_24[1], test_24[2], test_24[3][:, [4, 9, 33], :, :],
                                         test_constants),\
                self.ResNet_72.eval_step(test_72[0], test_72[1], test_72[2], test_72[3][:, [4, 9, 33], :, :],
                                         test_constants),\
                self.ResNet_120.eval_step(test_120[0], test_120[1], test_120[2], test_120[3][:, [4, 9, 33], :, :],
                                          test_constants),\
                self.ResNet_240.eval_step(test_240[0], test_240[1], test_240[2], test_240[3][:, [4, 9, 33], :, :],
                                          test_constants)
            # UNets
            self.UNet_6.eval_step(test_6[0], test_6[1], test_6[2], test_6[3][:, [4, 9, 33], :, :], test_constants),\
                self.UNet_24.eval_step(test_24[0], test_24[1], test_24[2], test_24[3][:, [4, 9, 33], :, :],
                                       test_constants),\
                self.UNet_72.eval_step(test_72[0], test_72[1], test_72[2], test_72[3][:, [4, 9, 33], :, :],
                                       test_constants),\
                self.UNet_120.eval_step(test_120[0], test_120[1], test_120[2], test_120[3][:, [4, 9, 33], :, :],
                                        test_constants),\
                self.UNet_240.eval_step(test_240[0], test_240[1], test_240[2], test_240[3][:, [4, 9, 33], :, :],
                                        test_constants)
            # ViTs
            self.ViT_6.eval_step(test_6[0], test_6[1], test_6[2], test_6[3][:, [4, 9, 33], :, :],
                                 test_constants),\
                self.ViT_24.eval_step(test_24[0], test_24[1], test_24[2], test_24[3][:, [4, 9, 33], :, :],
                                      test_constants),\
                self.ViT_72.eval_step(test_72[0], test_72[1], test_72[2], test_72[3][:, [4, 9, 33], :, :],
                                      test_constants),\
                self.ViT_120.eval_step(test_120[0], test_120[1], test_120[2], test_120[3][:, [4, 9, 33], :, :],
                                       test_constants),\
                self.ViT_240.eval_step(test_240[0], test_240[1], test_240[2], test_240[3][:, [4, 9, 33], :, :],
                                       test_constants)

        # Baselines
        rmse_LinReg_Baseline_6, acc_LinReg_Baseline_6 = compute_eval(self.LinReg_Baseline_6.pred,
                                                                     self.LinReg_Baseline_6.targ, self.latitude_weights,
                                                                     self.LinReg_Baseline_6.Net.set_climatology)
        rmse_LinReg_Baseline_24, acc_LinReg_Baseline_24 = compute_eval(self.LinReg_Baseline_24.pred,
                                                                       self.LinReg_Baseline_24.targ,
                                                                       self.latitude_weights,
                                                                       self.LinReg_Baseline_24.Net.set_climatology)
        rmse_LinReg_Baseline_72, acc_LinReg_Baseline_72 = compute_eval(self.LinReg_Baseline_72.pred,
                                                                       self.LinReg_Baseline_72.targ,
                                                                       self.latitude_weights,
                                                                       self.LinReg_Baseline_72.Net.set_climatology)
        rmse_LinReg_Baseline_120, acc_LinReg_Baseline_120 = compute_eval(self.LinReg_Baseline_120.pred,
                                                                         self.LinReg_Baseline_120.targ,
                                                                         self.latitude_weights,
                                                                         self.LinReg_Baseline_120.Net.set_climatology)
        rmse_LinReg_Baseline_240, acc_LinReg_Baseline_240 = compute_eval(self.LinReg_Baseline_240.pred,
                                                                         self.LinReg_Baseline_240.targ,
                                                                         self.latitude_weights,
                                                                         self.LinReg_Baseline_240.Net.set_climatology)
        # ResNets
        rmse_ResNet_6, acc_ResNet_6 = compute_eval(self.ResNet_6.pred, self.ResNet_6.targ, self.latitude_weights,
                                                   self.ResNet_6.Net.set_climatology)
        rmse_ResNet_24, acc_ResNet_24 = compute_eval(self.ResNet_24.pred, self.ResNet_24.targ, self.latitude_weights,
                                                     self.ResNet_24.Net.set_climatology)
        rmse_ResNet_72, acc_ResNet_72 = compute_eval(self.ResNet_72.pred, self.ResNet_72.targ, self.latitude_weights,
                                                     self.ResNet_72.Net.set_climatology)
        rmse_ResNet_120, acc_ResNet_120 = compute_eval(self.ResNet_120.pred, self.ResNet_120.targ,
                                                       self.latitude_weights, self.ResNet_120.Net.set_climatology)
        rmse_ResNet_240, acc_ResNet_240 = compute_eval(self.ResNet_240.pred, self.ResNet_240.targ,
                                                       self.latitude_weights, self.ResNet_240.Net.set_climatology)
        # UNets
        rmse_UNet_6, acc_UNet_6 = compute_eval(self.UNet_6.pred, self.UNet_6.targ, self.latitude_weights,
                                               self.UNet_6.Net.set_climatology)
        rmse_UNet_24, acc_UNet_24 = compute_eval(self.UNet_24.pred, self.UNet_24.targ, self.latitude_weights,
                                                 self.UNet_24.Net.set_climatology)
        rmse_UNet_72, acc_UNet_72 = compute_eval(self.UNet_72.pred, self.UNet_72.targ, self.latitude_weights,
                                                 self.UNet_72.Net.set_climatology)
        rmse_UNet_120, acc_UNet_120 = compute_eval(self.UNet_120.pred, self.UNet_120.targ, self.latitude_weights,
                                                   self.UNet_120.Net.set_climatology)
        rmse_UNet_240, acc_UNet_240 = compute_eval(self.UNet_240.pred, self.UNet_240.targ, self.latitude_weights,
                                                   self.UNet_240.Net.set_climatology)
        # ViT
        rmse_ViT_6, acc_ViT_6 = compute_eval(self.ViT_6.pred, self.ViT_6.targ, self.latitude_weights,
                                             self.ViT_6.Net.set_climatology)
        rmse_ViT_24, acc_ViT_24 = compute_eval(self.ViT_24.pred, self.ViT_24.targ, self.latitude_weights,
                                               self.ViT_24.Net.set_climatology)
        rmse_ViT_72, acc_ViT_72 = compute_eval(self.ViT_72.pred, self.ViT_72.targ, self.latitude_weights,
                                               self.ViT_72.Net.set_climatology)
        rmse_ViT_120, acc_ViT_120 = compute_eval(self.ViT_120.pred, self.ViT_120.targ, self.latitude_weights,
                                                 self.ViT_120.Net.set_climatology)
        rmse_ViT_240, acc_ViT_240 = compute_eval(self.ViT_240.pred, self.ViT_240.targ, self.latitude_weights,
                                                 self.ViT_240.Net.set_climatology)

        linreg_baseline_rmse = create_list([rmse_LinReg_Baseline_6, rmse_LinReg_Baseline_24, rmse_LinReg_Baseline_72,
                                            rmse_LinReg_Baseline_120, rmse_LinReg_Baseline_240])
        linreg_baseline_acc = create_list([acc_LinReg_Baseline_6, acc_LinReg_Baseline_24, acc_LinReg_Baseline_72,
                                           acc_LinReg_Baseline_120, acc_LinReg_Baseline_240])
        resnet_rmse = create_list([rmse_ResNet_6, rmse_ResNet_24, rmse_ResNet_72, rmse_ResNet_120, rmse_ResNet_240])
        resnet_acc = create_list([acc_ResNet_6, acc_ResNet_24, acc_ResNet_72, acc_ResNet_120, acc_ResNet_240])
        unet_rmse = create_list([rmse_UNet_6, rmse_UNet_24, rmse_UNet_72, rmse_UNet_120, rmse_UNet_240])
        unet_acc = create_list([acc_UNet_6, acc_UNet_24, acc_UNet_72, acc_UNet_120, acc_UNet_240])
        vit_rmse = create_list([rmse_ViT_6, rmse_ViT_24, rmse_ViT_72, rmse_ViT_120, rmse_ViT_240])
        vit_acc = create_list([acc_ViT_6, acc_ViT_24, acc_ViT_72, acc_ViT_120, acc_ViT_240])

        plot('t2m', linreg_baseline_rmse[0], linreg_baseline_acc[0], resnet_rmse[0], resnet_acc[0], unet_rmse[0],
             unet_acc[0], vit_rmse[0], vit_acc[0])
        plot('Z500', linreg_baseline_rmse[1], linreg_baseline_acc[1], resnet_rmse[1], resnet_acc[1], unet_rmse[1],
             unet_acc[1], vit_rmse[1], vit_acc[1])
        plot('T850', linreg_baseline_rmse[2], linreg_baseline_acc[2], resnet_rmse[2], resnet_acc[2], unet_rmse[2],
             unet_acc[2], vit_rmse[2], vit_acc[2])

    def save(self):
        torch.save(self.state_dict(), 'model.pt')

    def load(self):
        self.load_state_dict(torch.load('model.pt', map_location=self.device))

    def to(self, device):
        ret = super().to(device)
        ret.device = device
        return ret

## Pre
In this section, you can change the parameters of the models to be trained and evaluated. You can also define the number of episodes for the training.

In [None]:
res_params = [32, 4]     # [128, 28]
u_params = [16, 1]       # [64, 2]
vit_params = [2, 2, 1, 32]     # [8, 4, 2, 128]

In [None]:
ep = 1

It creates an object that contains the models to be trained and the methods required for training and evaluation.

In [None]:
forecasters = Forecasting(constants_set, train_set, validation_set, res_params=res_params, u_params=u_params, vit_params=vit_params).to(device)

## Training
Here the models are trained

In [None]:
forecasters.train_forecasters(epochs=ep)

## Evaluation
This part defines the test DataLoader and then evaluates the models

In [None]:
t_6 = [test_set[12:4374], test_set[6:4368], test_set[0:4362], test_set[18:4380]]
t_24 = [test_set[12:4356], test_set[6:4350], test_set[0:4344], test_set[36:4380]]
t_72 = [test_set[12:4308], test_set[6:4302], test_set[0:4296], test_set[84:4380]]
t_120 = [test_set[12:4260], test_set[6:4254], test_set[0:4248], test_set[132:4380]]
t_240 = [test_set[12:4140], test_set[6:4134], test_set[0:4128], test_set[252:4380]]
t = CustomDataset(t_6, t_24, t_72, t_120, t_240)
test_loader = torch.utils.data.DataLoader(t, batch_size=128, shuffle=True)

In [None]:
forecasters.eva_forecasters(test_loader, constants_set)