In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchinfo import summary

import numpy as np
import xarray
import pickle
from pathlib import Path
import random
import copy
import math

import scores.spatial
import pysteps

import proplot as pplt

#Custom implementation of pytorch's dataset
class ERA5Dataset(Dataset):
    def __init__(self, filename):
        self.file_dir = "data_handling"
        
        #data is stored in a different directory, so get the parent first, then go to the data directory
        parent_dir = Path.cwd().parent
        data_dir = parent_dir / self.file_dir
        file_path = data_dir / filename
    
        #for reading also binary mode is important
        with open(file_path, 'rb') as fp:
            np_list = pickle.load(fp)


        #Samples are automatically handled and the number of channels is static
        #Order is prior steps, prediction steps, x, y
        #Since time is saved into one dimension, the exact split is lost
        #This solution contains redundancy and should possibly be improved in the future
        pred_steps = 6 #How many of our steps do we want to predict in the future?
        self.data_shape = (np_list.shape[2] - pred_steps, pred_steps, np_list.shape[3], np_list.shape[4])
        
        #Split the data into x and y as a tuple, with the last timesteps being the label
        ds = torch.from_numpy(np_list).to(torch.float32)
        self.data, self.label = torch.split(ds, [self.data_shape[0],self.data_shape[1]], dim = 2)

        #How many values does our prediction contain when flattened into one dimension?
        self.pred_size = self.data_shape[1] * self.data_shape[2] * self.data_shape[3]

        #As a label, we only want precipitation
        self.label = self.label[:,4,:,:,:]
        self.label = torch.unsqueeze(self.label, 1)
        print("Data Shape: Samples, channels, time, x, y")
        print("Data: " + str(self.data.shape))
        print("Label: " + str(self.label.shape))

        #for running the model on CUDA, we need to move it
        self.data = self.data.to(device)
        self.label = self.label.to(device)
        
    def __len__(self):
        return self.label.shape[0]

    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

    def get_data_shape(self):
        return self.data_shape

    def get_pred_size(self):
        return self.pred_size

learning_rate = 1e-3
batch_size = 100

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

train_data = ERA5Dataset("train_array_10_10000")
test_data = ERA5Dataset("test_array_10_1000")

data_shape = train_data.get_data_shape()

train_dataloader = DataLoader(train_data, batch_size)
test_dataloader = DataLoader(test_data, batch_size)

class ResNet(nn.Module):
    def __init__(self, data_shape):
        super().__init__()
        features = 5
        
        self.relu = nn.ReLU()
        
        self.res_net1 = self._res_layer_set(features, features*2)
        self.res_conv1 = nn.Conv3d(features, features*2, kernel_size=(1), padding =(0), stride = (1))

        self.res_net2 = self._res_layer_set(features*2, features*4)
        self.res_conv2 = nn.Conv3d(features*2, features*4, kernel_size=(1), padding =(0), stride = (1))

        self.res_net3 = self._res_layer_set(features*4, features*8)
        self.res_conv3 = nn.Conv3d(features*4, features*8, kernel_size=(1), padding =(0), stride = (1))

        self.res_net4 = self._res_layer_set(features*8, features*16)
        self.res_conv4 = nn.Conv3d(features*8, features*16, kernel_size=(1), padding =(0), stride = (1))  

        self.conv_layer1 = self._conv_layer_set(features*16, features*8)

        self.res_net5 = self._res_layer_set(features*8, features*4)
        self.res_conv5 = nn.Conv3d(features*8, features*4, kernel_size=(1), padding =(0), stride = (1)) 
        
        self.conv_layer2= self._conv_layer_set(features*4, features*2)

        self.res_net6 = self._res_layer_set(features*2, features*2)
        self.res_conv6 = nn.Conv3d(features*2, features*2, kernel_size=(1), padding =(0), stride = (1)) 
        
        self.conv_layer3 = self._conv_layer_set(features*2, features)

        self.res_net7 = self._res_layer_set(features, features)
        self.res_conv7 = nn.Conv3d(features, features, kernel_size=(1), padding =(0), stride = (1)) 

        self.refl_pad = nn.ReflectionPad3d(1)
        self.conv_layer4 = nn.Conv3d(features, 1, kernel_size=(3), padding=(0), stride = (1))

        #initialize layers
    def _res_layer_set(self, in_c, out_c):
        conv_stack = nn.Sequential(
            nn.ReflectionPad3d(1),
            nn.Conv3d(in_c, in_c, kernel_size=(3), padding = (0), stride = (1)),
            nn.BatchNorm3d(in_c),
            nn.ReLU(),
            nn.ReflectionPad3d(1),
            nn.Conv3d(in_c, out_c, kernel_size=(3), padding = (0), stride = (1)),
            nn.BatchNorm3d(out_c))
        return conv_stack

    def _conv_layer_set(self, in_c, out_c):
        conv_layer = nn.Sequential(
            nn.ReflectionPad3d(1),
            nn.Conv3d(in_c, out_c, kernel_size=(3), padding=(0), stride = (1)),
            nn.ReLU(),
            nn.ReflectionPad3d(1),
            nn.AvgPool3d(kernel_size = (3), padding = (0), stride = (2, 1, 1))
        )
        return conv_layer
            

    def forward(self, x):

        #First ResNet, features 5->10
        res = self.res_conv1(x)
        out = self.res_net1(x)
        out = self.relu(out + res)

        #Second ResNeT, features 10->20
        res = self.res_conv2(out)
        out = self.res_net2(out)
        out = self.relu(out + res)

        #Third ResNeT, features 20->40
        res = self.res_conv3(out)
        out = self.res_net3(out)
        out = self.relu(out + res)

        #Fourth ResNeT, features 40
        res = self.res_conv4(out)
        out = self.res_net4(out)
        out = self.relu(out + res)

        #Downsample the time dimension, 48 -> 24
        #Features 40->20
        out = self.conv_layer1(out)
        #ResNet
        res = self.res_conv5(out)
        out = self.res_net5(out)
        out = self.relu(out + res)

        #Downsample the time dimension, 24 -> 12
        #Features 20->10
        out = self.conv_layer2(out)
        #ResNet
        res = self.res_conv6(out)
        out = self.res_net6(out)
        out = self.relu(out + res)

        #Downsample the time dimension, 12 -> 6
        #Features 10->5
        out = self.conv_layer3(out)
        #ResNet
        res = self.res_conv7(out)
        out = self.res_net7(out)
        out = self.relu(out + res)

        #Features 5 -> 1
        out = self.refl_pad(out)
        out = self.conv_layer4(out)

        out = self.relu(out)
        
        return out

model = ResNet(data_shape)
model = model.to(device)

#summary(model, input_size=(batch_size, 5, 48, 10, 10))

class LogCoshLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_t, y_prime_t):
        ey_t = y_t - y_prime_t
        return torch.mean(torch.log(torch.cosh(ey_t + 1e-12)))

def log_cosh_loss(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    def _log_cosh(x: torch.Tensor) -> torch.Tensor:
        return x + torch.nn.functional.softplus(-2. * x) - math.log(2.0)
    return torch.mean(_log_cosh(y_pred - y_true))

#As suggested by Ayzel 2020
class LogCoshLossOld(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(
        self, y_pred: torch.Tensor, y_true: torch.Tensor
    ) -> torch.Tensor:
        return log_cosh_loss(y_pred, y_true)

#Make input non-static, sample from normal deviation
class SpatiotemporalPrecipitationLoss(nn.Module):
    def __init__(self, weight_factor=10.0, threshold=0.1):
        """
        A spatiotemporal loss function that combines weighted MSE for precipitation values.
        Args:
            weight_factor (float): The weight for higher precipitation regions.
            threshold (float): The threshold above which precipitation is considered significant.
        """
        super(SpatiotemporalPrecipitationLoss, self).__init__()
        self.mse_loss = nn.MSELoss(reduction='none')  # 'none' to compute per element
        self.weight_factor = weight_factor
        self.threshold = threshold

    def forward(self, pred, target):
        """
        Compute the weighted spatiotemporal MSE loss.
        
        Args:
            pred (torch.Tensor): Predicted precipitation values (batch_size, time_steps, channels, height, width).
            target (torch.Tensor): Ground truth precipitation values (batch_size, time_steps, channels, height, width).
            
        Returns:
            torch.Tensor: The computed loss.
        """
        # Compute the base MSE loss for each pixel in each timestep
        mse = self.mse_loss(pred, target)  # Shape: (batch_size, time_steps, channels, height, width)

        # Create a mask for areas with significant precipitation in the target
        weight_mask = (target > self.threshold).float()  # Same shape as target

        # Create weighted loss: more weight for higher precipitation areas
        weighted_mse = self.weight_factor * weight_mask * mse

        # Average the weighted loss over all dimensions (batch, time, spatial)
        combined_loss = torch.mean(mse + weighted_mse)

        return combined_loss

def train_loop(dataloader, model, loss_fn, optimizer, epochs):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    model.train()

    loss_list = list() #use this to calculate the average
    
    #This is for plotting the same samples
    train_batch = read_file("train_10_batch")
    
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss_list.append(loss.item())

    if((epochs%5 == 0)):
        pred = model(train_batch[0])
        
        #these are the final two timesteps before the prediction
        pre_plot = train_batch[0][:,4,-2:,:,:].clone().detach()
        pred_plot = pred.clone().detach()
        y_plot = train_batch[1].clone().detach() 
        
        plot_prediction(pre_plot, pred_plot, y_plot, epochs, "train_", 2)
        plot_prediction(pre_plot, pred_plot, y_plot, epochs, "train_", 5)
    
    return np.average(loss_list)


def test_loop(dataloader, model, loss_fn, epochs):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    #This is for plotting the same samples
    test_batch = read_file("test_10_batch")

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()

        #Always plot the same samples
        if((epochs%5 == 0)):
            pred = model(test_batch[0])
            #these are the final two timesteps before the prediction
            pre = test_batch[0][:,4,-2:,:,:]     
            plot_prediction(pre, pred, test_batch[1], epochs, "test_", 12)
            plot_prediction(pre, pred, test_batch[1], epochs, "test_", 23)

    test_loss /= num_batches
    print(f"Avg loss: {test_loss:>8f} \n")
    
    return test_loss

def train_model(epochs, loss, optim, name):
    """
    Function that contains a model training run
    Args:
        epochs (int): The number of epochs that should be trained
        loss (nn.module): The loss function to be used
        optim (nn.optim): The optimizer to be used
        name (String): The name of this model run
    """
    train_loss_list = list()
    test_loss_list = list()
    best_test_loss = 0
    best_train_epoch = 0
    best_test_epoch = 0
    
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loss_list.append(train_loop(train_dataloader, model, loss_fn, optimizer, t))
        test_loss_list.append(test_loop(test_dataloader, model, loss_fn, t))

        file_path = Path.cwd() / name / ("model_state_" + str(t))
        file_path.parent.mkdir(exist_ok=True, parents=True) #create the subdir if it does not exist
        torch.save(model.state_dict(), file_path)
        
        write_file(train_loss_list, "train_loss_list", name)
        write_file(test_loss_list, "test_loss_list", name)
        write_file(t, "epoch_count", name)
    
        if(t%10 == 0):
            plot_train_test_loss(train_loss_list, test_loss_list)
    
        #Early stopping:
        stopping_epochs = 50 #How many epochs do we train past a local minimum?
        epoch_counter = 0 #How many epochs are we past the local minimum?
        if (t == 0):
            best_test_loss = test_loss_list[0]
        else:
            if (test_loss_list[t] < best_test_loss):
                #print("Model has improved!")
                epoch_counter = 0
            else:
                print("Model has not improved!")
                if (epoch_counter < stopping_epochs):
                    epoch_counter = epoch_counter + 1
                else:
                    print("Early stop!")
                    break     
    print("Done!")

def predict(modelname, epoch, data, index = -1):
    """
    Function that loads a model state and predict the output.
    If more than 48 timesteps are given, the last ones are discarded
    If an index is given, only that index of the data is plotted
    Args:
        model (String): The name of the model to be used, e.g. D or E
        epoch (int): The epoch of the state to be loaded
        data (torch.Tensor): The data to be predicted (batch_size, time_steps, channels, height, width)
        index (int): The specific index of the batch to be used

    Returns:
        torch.Tensor: Prediction (batch_size, time_steps, channels, height, width)
    """

    file_path = Path.cwd() / modelname / ("model_state_" + str(epoch))
    print(file_path)
    model.load_state_dict(torch.load(file_path, weights_only=True))
    model.eval()

    if (index >= 0):
        data = data[index]
        data = torch.unsqueeze(data,0) #Keep first dimension
        
    output = torch.zeros(data.shape[0], 1, 6, data.shape[3], data.shape[4])

    data = data.cuda()
    output = output.cuda()
        
    with torch.no_grad():
        for i in range(data.shape[0]):
            output[i] = model(torch.unsqueeze(data[i],0))
    
    return output

def save_batch(dataloader, filename):
    """
    Takes the next batch of a dataloader and saves it via pickle. This is to allow testing with the same data without shuffling

    Args:
        dataloader (torch.dataloader)
        filename (String): The name to be used for saving the file
    """
    data = next(iter(dataloader))
    write_file(data, filename)
    
#write function that finds lowest train/test loss and plots that  
def find_lowest_loss(model):
    '''
    Prints the epochs with the lowest train and test loss

    Args:
        model (String): The name of the model to be analysed
    '''

    train_loss = read_file("train_loss_list", model)
    test_loss = read_file("test_loss_list", model)

    best_train_epoch = train_loss.index(min(train_loss))
    best_test_epoch = test_loss.index(min(test_loss))

    print("The best train loss of model " + model + " is at epoch " + str(best_train_epoch))
    print("The best test loss of model " + model + " is at epoch " + str(best_test_epoch))

def np_to_torch(filename):
    file_dir = "data_handling"
        
    #data is stored in a different directory, so get the parent first, then go to the data directory
    parent_dir = Path.cwd().parent
    data_dir = parent_dir / file_dir
    file_path = data_dir / filename
    
    #for reading also binary mode is important
    with open(file_path, 'rb') as fp:
        np_list = pickle.load(fp)

    pred_steps = 6 #How many of our steps do we want to predict in the future?
        
    #Split the data into x and y as a tuple, with the last timesteps being the label
    ds = torch.from_numpy(np_list).to(torch.float32)
    data, label = torch.split(ds, [np_list.shape[2] - pred_steps,pred_steps], dim = 2)

    #As a label, we only want precipitation
    label = label[:,4,:,:,:]
    label = torch.unsqueeze(label, 1)

    return [data, label]

def score_fss(target, prediction):
    """
    Calculates the fraction skill score of the prediction vs 
    the target of window sizes 1, 2, 4 and threshholds 0.1 and 1
    Args:
        target (tensor): The prediction target, (batch, prec, time, x, y)
        prediction (tensor): The prediction, (batch, prec, time, x, y)

    Returns: 
        fss: The fraction skill score averaged over the entire batch,
            np array of shape (window_sizes, thresholds)
    """    
    #Squeeze precipiation
    target = torch.squeeze(target)
    prediction = torch.squeeze(prediction)
    
    target =  target.cpu().detach().numpy()
    prediction = prediction.cpu().detach().numpy()

    if(target.shape != prediction.shape):
        print("Error: Shape mismatch!")

    shape = target.shape

    window_sizes = [(1,1), (4,4)]
    thresholds = [0.1, 1, 2]

    fss = np.empty((shape[0], len(window_sizes), len(thresholds))) #this is where we will save the scores

    for batch in range(shape[0]):
        #scores needs an xarray DataArray
        current_target = xarray.DataArray(data = target[batch], dims = ["time", "x", "y"])
        current_prediction = xarray.DataArray(data = prediction[batch], dims = ["time", "x", "y"])

        w_count = 0
        for window in window_sizes:
            th_count = 0
            for th in thresholds:
                fss[batch][w_count][th_count] = scores.spatial.fss_2d(
                    fcst = current_prediction, obs = current_target, 
                    event_threshold = th, window_size = window, spatial_dims = ("x", "y")).to_numpy()
                th_count = th_count + 1
            w_count = w_count + 1
    fss = np.mean(fss, axis=0)
    return fss
    
def score_rapsd(target, norm=False):
    """
    Calculates the average difference of the radially averaged power spectral density of a tensors
    Args:
        target (tensor): The prediction target, (batch, prec, time, x, y)
        norm (Bool): Whether to normalise
    Returns:
        Output (Array, 2x5): The RAPSD per frequency in dB, The list of wavelengths in km
    """
    
    target = target.cpu().detach().numpy()
    shape = target.shape

    output = np.empty((shape[0], shape[2], 5))
    freq = None

    for batch in range(shape[0]):
        for hour in range(shape[2]):
            image = target[batch][0][hour]
            noise = np.random.normal(loc = 0, scale=0.0001, size=image.shape)
            out1, freq = pysteps.utils.spectral.rapsd(field = image + noise, fft_method = np.fft, return_freq = True, normalize=norm)
            output[batch][hour] = out1

    output = np.nanmean(output, axis=(0, 1)) #the average of all batch elements and hours
    #Convert to dB
    output = np.log10(output)

    #Convert frequency to wavelength
    #The grid size of ERA5 Single Level is 27.8km at the equator
    wavelength = 27.8 / freq
    #See Ruzanski, 2010:2301
    
    return np.transpose(np.column_stack((wavelength, output)))

def plot_rapsd(input, filename, height="7cm"):
    '''
    Plots the RAPSD of the target and one or many predictions
    Args:
        input (List of Tuples of Arrays, Strings): The RAPSD of the
        target and predictions and the name to be used in the legend for the predictions
        filename (String): Save as this name.pdf
    '''

    with pplt.rc.context(fontsize='11px'):
        fig, ax = pplt.subplot(xlabel='Wavelength (km)', ylabel='Power (db)',figheight=height, figwidth="14cm")

    for pred in input:
        ax.plot(pred[0][0], pred[0][1], label = pred[1])
    
    ax.legend(loc="lower left")
    
    pplt.show()
    #Save it to a subdirectory
    file_path = Path.cwd() / ("plots")/ filename
    fig.savefig(file_path)
    pplt.close()

import pandas as pd

def plot_fss(fss_list, label, filename, title="", create_legend=True, height="5cm"):
    '''
    Plots the FSS in one bar plot
    Input:
        fss_list (np-arrays): FSSs, shape(score, window, th)
        labels (List of Strings): The name of the FSSs, in order
        filename(String): How to save the file
    '''
    fss_shape = fss_list.shape
    data = np.reshape(fss_list, (fss_shape[0]*fss_shape[1], fss_shape[2]))
    
    data = pd.DataFrame(
    data, columns=pd.Index(labels, name='column'),
    index=pd.Index(["1x1, 0.1", "1x1, 1", "1x1, 2", "4x4, 0.1", "4x4, 1", "4x4, 2",], name='row idx'))

    with pplt.rc.context(fontsize='11px'):
        fig, ax = pplt.subplot(ylabel='Fractions Skill Score', xlabel="Window size (pixels) and threshold (mm/hour)",figheight=height, figwidth="14cm", suptitle=title)
    if create_legend:
        obj = ax.bar(data, cycle='Blues', edgecolor='black', autoformat=True, labels=label, legend="b")
    else:
        obj = ax.bar(data, cycle='Blues', edgecolor='black', autoformat=True, labels=label)
    pplt.show()
    file_path = Path.cwd() / ("plots")/ filename
    fig.savefig(file_path)
    pplt.close()

#plotting
def plot_specific_prediction(data, model, epoch, titles, k=-1):
    if (len(model) == 1):
        out = predict(model, epoch[0], data[0])
        plot_prediction(data[0][:,4], out, data[1], epoch, "out", k)
    else:
        out = []
        for m, e in zip(model, epoch):
            out.append(predict(m, e, data[0]))
        plot_prediction_multiple_models(data[0][:,4], out, data[1], epoch, "out", titles, k)

def plot_prediction_multiple_models(pre, preds, y, epoch, prefix, titles, k=-1):
    """
    Plotting the prediction precipitation and saving it as pdf
    Args:
        pre (tensor): The steps preceding the prediction
        pred (List of tensors): The predictions
        y (tensor): The ground truth of the prediction
        epoch (int): The epoch being plotted
        prefix (String): Prefix
        k (int): Which item of the batch to plot
        titles (List of Strings)
    """
    #We might be using CUDA tensors, so copy the data to the CPU first
    y = y.cpu()
    pre = pre.cpu()

    #if index is invalid or empty, pick a random one
    if (k < 0 or k >= y.shape[0]):
        k = random.randint(0, y.shape[0]-1)


    
    #For easier handling, remove the feature dimension, as only precipitation is left
    y = torch.squeeze(y[k])
    preds_np = []
    for pred in preds:
        pred = pred.cpu()
        preds_np.append(torch.squeeze(pred[k]).numpy())
    pre = pre[k] #it is already squeezed

    y = torch.cat((pre, y)).numpy() #merge them

    n = preds_np[0].shape[0] #how many steps do we predict?
    max_prec = 3
    
    print("rows: " + str(1+len(preds)))
    print("cols: " + str(n + 2))

    ###############This is where the plotting begins:
    with pplt.rc.context(fontsize='10px'):
        fig, axes = pplt.subplots(nrows = (1+len(preds)), ncols = n + 2, 
                                  hspace='1.7em', wspace = '0.4em', figwidth='14cm', titlepad = 7)
        #14cm is the text width
    fig.format(rowlabels = titles)
    
    axes.format(
    xticks='null', yticks='null', facecolor='gray5')

    col = 0
    for ax in axes[:n+2]:
        #print("col:" + str(col) + " | row:" + str(row))
        im = ax.pcolormesh(y[col], cmap = 'turbo', vmin = 0, vmax = max_prec)
        if(col > 1): #the first two should be empty
            ax.format(titleloc = 'center', title = ("t+" + str(col-1)))
        else:
            ax.format(titleloc = 'center', title = ("t" + str(col-2)))
        col = col+1


    row = 1
    for pred in preds_np:
        col = 0
        for ax in axes[row*(n+2):((row+1)*(n+2))]:
            if(col > 1): #the first two should be empty
                im = ax.pcolormesh(pred[col-2], cmap = 'turbo',vmin = 0, vmax = max_prec)
            else:
                ax.axis('off')
            col = col+1
        row = row+1

    fig.colorbar(im, label='Precipitation in mm/h', loc='r', rows = (1,4), space = 1)#this is hardcoded and prone to fail

    #Save it to a subdirectory
    file_path = Path.cwd() / ("plots")/ (prefix + str(epoch) + "_prediction_" + str(k))
    
    fig.savefig(file_path)
    pplt.close()

def plot_prediction(pre, pred, y, epoch, prefix, k=-1):
    """
    Plotting the prediction precipitation and saving it as pdf
    Args:
        pre (tensor): The steps preceding the prediction
        pred (tensor): The prediction
        y (tensor): The ground truth of the prediction
        epoch (int): The epoch being plotted
        prefix (String): Prefix
        k (int): Which item of the batch to plot
    """
    #We might be using CUDA tensors, so copy the data to the CPU first
    pred = pred.cpu()
    y = y.cpu()
    pre = pre.cpu()

    #if index is invalid or empty, pick a random one
    if (k < 0 or k >= pred.shape[0]):
        k = random.randint(0, pred.shape[0]-1)
    
    #For easier handling, remove the feature dimension, as only precipitation is left
    y = torch.squeeze(y[k])
    pred = torch.squeeze(pred[k]).numpy()
    pre = pre[k] #it is already squeezed

    y = torch.cat((pre, y)).numpy() #merge them

    #We use this for scaling the axis
    max_prec = max(np.max(y), np.max(pred))

    #temp
    max_prec = 3
    
    n = pred.shape[0] #how many steps do we predict?

    ###############This is where the plotting begins:
    with pplt.rc.context(fontsize='11px'):
        fig, axes = pplt.subplots(nrows = 2, ncols = n + 2, 
                                  hspace='2.4em', wspace = '0.4em', figwidth='14cm', titlepad = 7)
        #14cm is the text width
    fig.format(rowlabels = ("Ground truth", "Prediction"))
    
    axes.format(
    xticks='null', yticks='null', facecolor='gray5')

    col = 0
    for ax in axes[:n+2]:
        im = ax.pcolormesh(y[col], cmap = 'turbo', vmin = 0, vmax = max_prec)
        col = col+1

    col = 0
    for ax in axes[n+2:]:
        if(col > 1): #the first two should be empty
            im = ax.pcolormesh(pred[col-2], cmap = 'turbo',vmin = 0, vmax = max_prec)
            ax.format(titleloc = 'center', title = ("t+" + str(col-1)))
            col = col+1
        else:
            ax.axis('off')
            ax.format(titleloc = 'center', title = ("t" + str(col-2)))
            col = col+1

    fig.colorbar(im, label='Precipitation in mm/h', loc='r', rows = (1,2), space = 1)

    #Save it to a subdirectory
    file_path = Path.cwd() / ("plots")/ (prefix + str(epoch) + "_prediction_" + str(k))
    
    fig.savefig(file_path)
    pplt.close()
    
def plot_losses():
    epochs = 200
    with pplt.rc.context(fontsize='11px'):
        fig, axes = pplt.subplots(nrows = 3, ncols = 1,figheight='10cm', figwidth="14cm", 
                                  xlabel='Epochs', ylabel='Average loss', abc=True)

    train_l = read_file("train_loss_list", "H")
    test_l = read_file("test_loss_list", "H")
    axes[0].plot(train_l, label = "training")
    axes[0].plot(test_l, label = "test")
    axes[0].format(title="MSE", titleloc="uc")

    train_l = read_file("train_loss_list", "J")
    test_l = read_file("test_loss_list", "J")
    axes[1].plot(train_l, label = "training")
    axes[1].plot(test_l, label = "test")
    axes[1].format(title="Weighted MSE", titleloc="uc")
    axes[1].legend(loc="upper right")

    train_l = read_file("train_loss_list", "K")
    test_l = read_file("test_loss_list", "K")
    axes[2].plot(train_l, label = "training")
    axes[2].plot(test_l, label = "test")
    axes[2].format(title="LogCoshLoss", titleloc="uc")
    fig.savefig("training")
    pplt.show()
    pplt.close()

#write list to binary file
def write_file(a_list, filename, subdir = ""):
    #store list in binary file so 'wb' mode
    if (subdir == ""):
        with open(filename, 'wb') as fp:
            pickle.dump(a_list, fp)
    else:
        file_path = Path.cwd() / subdir / filename
        file_path.parent.mkdir(exist_ok=True, parents=True)
        with open(file_path, 'wb') as fp:
            pickle.dump(a_list, fp)

#read list to memory
def read_file(filename, subdir = ""):
    #store list in binary file so 'wb' mode
    if (subdir == ""):
        with open(filename, 'rb') as fp:
            n_list = pickle.load(fp)
            return n_list
    else:
        file_path = Path.cwd() / subdir / filename
        with open(file_path, 'rb') as fp:
            n_list = pickle.load(fp)
            return n_list

In [None]:
plot_losses()
find_lowest_loss("H")
find_lowest_loss("J")
find_lowest_loss("K")

val = np_to_torch("test_array_10_1000")
pgw05 = np_to_torch("PGW_10_1000_05")
pgw10 = np_to_torch("PGW_10_1000_10")
pgw20 = np_to_torch("PGW_10_1000_20")
pgw30 = np_to_torch("PGW_10_1000_30")
pgw40 = np_to_torch("PGW_10_1000_40")
pgw50 = np_to_torch("PGW_10_1000_50")

outH = predict("H", 13, val[0])
outJ = predict("J", 36, val[0])
outK = predict("K", 20, val[0])

#MSE
H10 = predict("H", 13, pgw10[0])
H30 = predict("H", 13, pgw30[0])
H50 = predict("H", 13, pgw50[0])
fssH10 = score_fss(pgw10[1], H10)
fssH30 = score_fss(pgw30[1], H30)
fssH50 = score_fss(pgw50[1], H50)

#predictions for weighted MSE
J10 = predict("J", 36, pgw10[0])
J30 = predict("J", 36, pgw30[0])
J50 = predict("J", 36, pgw50[0])
fssJ10 = score_fss(pgw10[1], J10)
fssJ30 = score_fss(pgw30[1], J30)
fssJ50 = score_fss(pgw50[1], J50)

#predictions for Logcoshloss
K10 = predict("K", 20, pgw10[0])
K30 = predict("K", 20, pgw30[0])
K50 = predict("K", 20, pgw50[0])
fssK10 = score_fss(pgw10[1], K10)
fssK30 = score_fss(pgw30[1], K30)
fssK50 = score_fss(pgw50[1], K50)

plot_rapsd([(score_rapsd(val[1], True), "Ground truth"), 
            (score_rapsd(outH, True), "MSE"), (score_rapsd(outJ, True), "Weighted MSE"), 
            (score_rapsd(outK, True), "LogCoshLoss"),
            (score_rapsd(pgw50[1], True), "5K-PGW"), (score_rapsd(J50, True), "Weighted MSE, 5K")], "RAPSD", "5cm")

plot_rapsd([(score_rapsd(val[1]), "Ground truth"), 
            (score_rapsd(outH), "MSE"), (score_rapsd(outJ), "Weighted MSE"), 
            (score_rapsd(outK), "LogCoshLoss"),
            (score_rapsd(pgw50[1]), "5K-PGW"), (score_rapsd(J50), "Weighted MSE, 5K")], "RAPSD-norm", "5cm")

plot_rapsd([(score_rapsd(pgw30[1], True), "PGW output, 3K"), 
            (score_rapsd(H30, True), "MSE"), (score_rapsd(J30, True), "Weighted MSE"),
            (score_rapsd(K30, True), "LogCoshLoss")], "RAPSD-3K", "5cm")

plot_rapsd([(score_rapsd(pgw50[1], True), "PGW output, 5K"), 
            (score_rapsd(H50, True), "MSE"), (score_rapsd(J50, True), "Weighted MSE"),
            (score_rapsd(K50, True), "LogCoshLoss")], "RAPSD-5K", "5cm")

fssH = score_fss(val[1], outH)
fssJ = score_fss(val[1], outJ)
fssK = score_fss(val[1], outK)

fss_np = np.stack((fssH,fssJ,fssK), axis=2)
labels = ["MSE", "WMSE", "LogCoshLoss"]
plot_fss(fss_np, labels, "model_fss", "", True, "5.6cm")
fss_np = np.stack((fssH,fssH10,fssH30,fssH50), axis=2)
labels = ["Base", "1K", "3K", "5K"]
plot_fss(fss_np, labels, "warming_fss_H", "MSE", False)
fss_np = np.stack((fssJ,fssJ10,fssJ30,fssJ50), axis=2)
labels = ["Base", "1K", "3K", "5K"]
plot_fss(fss_np, labels, "warming_fss_J", "WMSE", False)
fss_np = np.stack((fssK,fssK10,fssK30,fssK50), axis=2)
labels = ["Base", "1K", "3K", "5K"]
plot_fss(fss_np, labels, "warming_fss_K", "LogCoshLoss", True, "6.3cm")

#And now for the overfitted model:
#predictions for weighted MSE
outJO = predict("J", 199, val[0])
fssJO = score_fss(val[1], outJO)
J10O = predict("J", 199, pgw10[0])
J30O = predict("J", 199, pgw30[0])
J50O = predict("J", 199, pgw50[0])
fssJ10O = score_fss(pgw10[1], J10O)
fssJ30O = score_fss(pgw30[1], J30O)
fssJ50O = score_fss(pgw50[1], J50O)

fss_np = np.stack((fssJO, fssJ10O, fssJ30O, fssJ50O), axis=2)
labels = ["Base", "1K", "3K", "5K"]
plot_fss(fss_np, labels, "warming_j_fss_overfit", "WMSE, overfit", True, "6cm")

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

loss_fn = LogCoshLossOld()
train_model(200, loss_fn, optimizer, "K")