In [1]:
# !pip install SciencePlots -q
import matplotlib.pyplot as plt
import matplotlib
import scienceplots
from itertools import product
import numpy as np 
import os, re
from tqdm.notebook import tqdm
from numba import njit
from scipy.interpolate import interp1d
from math import ceil 
from time import perf_counter
from sklearn.model_selection import train_test_split

def max_scaling(x):
    x_max = x.max()
    return x / x_max

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import weight_norm
from torch import zeros, tensor, Tensor, rand
from torch.optim.lr_scheduler import ReduceLROnPlateau
print(f"PyTorch version: {torch.__version__}")

# Check if multiple GPUs are available
if torch.cuda.device_count() > 1:
    n_gpus = torch.cuda.device_count()
    print("Using", n_gpus, "GPUs")
    device = "cuda"

elif torch.cuda.device_count() == 1:
    n_gpus = torch.cuda.device_count()
    print("Using", n_gpus, "GPU")    
    device = "cuda"

else:
    n_gpus = 0
    print("Using CPU")
    device = "cpu"


plt.style.use(["notebook", "science"])
plt.style.use(["notebook", "nature"])
plt.rcParams["figure.figsize"] = [10, 5]
plt.rcParams["figure.dpi"] = 200
plt.rcParams["lines.linewidth"] = 2

np.set_printoptions(linewidth=200)
np.set_printoptions(precision=3)

PyTorch version: 2.0.0
Using CPU


In [4]:
X = rand(1, 80, 11)
y = rand(1, 80, 48)

# FNO

In [5]:
class FNO(nn.Module):
    def __init__(self, hidden_size, modes, layers=1, 
                 input_size=X.shape[-1], 
                 output_size=y.shape[-1],
        ):
        super(FNO, self).__init__()

        self.modes = min(modes, X.shape[1] // 2 + 1)
        self.n_layers = layers
        self.activation = F.gelu
        
        self.input_size  = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        self.p = nn.Linear(self.input_size, self.hidden_size)
        
        self.spectral = nn.ModuleList()
        self.temporal = nn.ModuleList()
        self.residual = nn.ModuleList() 
        
        for i in range(self.n_layers):
            self.spectral += [SpectralConv1d(self.hidden_size, self.hidden_size, self.modes)]
            self.temporal += [TemporalConv1d(self.hidden_size, self.hidden_size, self.hidden_size)]
            self.residual += [TemporalConv1d(self.hidden_size, self.hidden_size, 1)]
        
        self.alpha = nn.Parameter(zeros(1))
        self.q = nn.Linear(self.hidden_size, self.output_size)
        
    def forward(self, x):
        # switched place between permute and p
        x = self.p(x)
        x = x.permute(0, 2, 1) 
        
        for i in range(self.n_layers):
            x1 = self.spectral[i](x)
            x1 = self.temporal[i](x1)
            x2 = self.residual[i](x)
            
            alpha = self.alpha.sigmoid() # 0 - 1: if alpha = 0.5 then it is equivalent to a residual layer
            x = 2 * ((1 - alpha) * x1 + alpha * x2)
            x = self.activation(x)
            
        # switched place here aswell
        x = x.permute(0, 2, 1)
        x = self.q(x)

        return x
    

class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes):
        super(SpectralConv1d, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes = modes
        self.weights = nn.Parameter(
            zeros([1, in_channels, out_channels, self.modes], dtype=torch.cfloat)
        )
        
    def forward(self, x):
        x_ft = torch.fft.rfft(x)
        out_ft = torch.zeros(x.shape[0], self.out_channels, x.shape[-1] // 2 + 1, device=x.device, dtype=torch.cfloat)
        # Batched weighted matrix multiplications
        out_ft[...,:self.modes] = torch.einsum("bix, biox -> box", x_ft[...,:self.modes], self.weights)
        x = torch.fft.irfft(out_ft, n=x.shape[-1])
        return x
    

class TemporalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels):
        super(TemporalConv1d, self).__init__()
        
        self.activation = F.gelu
        
        self.input  = nn.Conv1d(in_channels, mid_channels, 1)
        self.conv1  = nn.Conv1d(mid_channels, mid_channels, 1)
        self.conv2  = nn.Conv1d(mid_channels, mid_channels, 1)
        self.output = nn.Conv1d(mid_channels, out_channels, 1)
        
        self.alpha1 = nn.Parameter(zeros(1))
        self.alpha2 = nn.Parameter(zeros(1))
        self.alpha3 = nn.Parameter(zeros(1))

    def forward(self, x):
        x = self.input(x)
        
        x1 = x
        x2 = self.conv1(x1)
        x2 = self.activation(x2)
        x3 = self.conv2(x2)
        x3 = self.activation(x3)
        
        x = self.alpha1.sigmoid() * x1 \
          + self.alpha2.sigmoid() * x2 \
          + self.alpha3.sigmoid() * x3

        x = self.output(x) 
        return x

# TCN

In [6]:
class Chomp1d(nn.Module): 
    
    """
    The padding is performed both to the left AND the right of the actual sequence. 
    However, we only need the left padding to preserve the time causal relationship.
    To resolve this, we need to cut the convolved sequence that is double padded.
    This is done for the last dimension as the dimension of the input tensor is transposed.
    The size of the cut is determined by the size of the padding.
    """
    
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size
    
    def forward(self, x):
        return x[..., :-self.chomp_size].contiguous()

class TemporalConvBlock(nn.Module):
    def __init__(self, input_channels, kernel_size, dilation, padding, layers_per_block, weight_init="default"):
        super(TemporalConvBlock, self).__init__()
        
        self.layers_per_block = layers_per_block
        self.activation = nn.LeakyReLU(0.01) 
        self.conv = nn.ModuleList()
        self.chomp = nn.ModuleList()
        self.alpha = nn.Parameter(zeros(1))
        
        # Weight initialization 
        if weight_init == 'default':
            conv_weight_init = lambda w: w
        elif weight_init == "kaiming" or weight_init == "he":
            conv_weight_init = lambda w: nn.init.kaiming_normal_(w, nonlinearity='relu')
        elif weight_init == 'xavier':
            conv_weight_init = lambda w: nn.init.xavier_normal_(w, gain=nn.init.calculate_gain('relu'))
        else:
            raise ValueError('Invalid weight initialization method')

        for i in range(self.layers_per_block):
            conv_layer = nn.Conv1d(input_channels, input_channels, kernel_size, padding=padding, dilation=dilation)
            conv_weight_init(conv_layer.weight)
            self.conv  += [weight_norm(conv_layer)]
            self.chomp += [Chomp1d(padding)]
        
        
    def forward(self, x):
        for i in range(self.layers_per_block):
            x1 = x
            x2 = self.conv[i](x1)
            x2 = self.chomp[i](x2)
            x2 = self.activation(x2)
            alpha = self.alpha.sigmoid() # 0 to 1, 1 gievs more weight to identity map
            x = 2 * (alpha * x1 + (1 - alpha) * x2) # x1=identity map, x2=everything
        return x

class TCN(nn.Module):
    def __init__(self, hidden_size, layers_per_block=2, dilation_base=2, kernel_size=3, 
                 input_size=X.shape[-1],
                 output_size=y.shape[-1],
                 weight_init="default",
        ):
        super(TCN, self).__init__()
        

        self.sequence_length = X.shape[1]
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dilation_base = dilation_base
        self.number_of_layers = np.ceil(
            np.log((self.sequence_length - 1) * (self.dilation_base - 1) / (kernel_size - 1) + 1) 
            / np.log(self.dilation_base)
        ).astype(int)
        
        self.layers = nn.ModuleList()
        for i in range(self.number_of_layers):
            dilation_size = self.dilation_base ** i # this is a generalized dilation rate
            self.layers += [TemporalConvBlock(self.hidden_size, kernel_size, 
                                              dilation=dilation_size, 
                                              layers_per_block=layers_per_block,
                                              padding=(kernel_size-1) * dilation_size,
                                              weight_init=weight_init,
                                             )]
            self.layers += [nn.BatchNorm1d(num_features=self.hidden_size)]
            
        self.network = nn.Sequential(*self.layers)
        self.p = nn.Linear(self.input_size, self.hidden_size)
        self.q = nn.Linear(self.hidden_size, self.output_size)  


    def forward(self, x):
        x = self.p(x)
        x = x.permute(0, 2, 1)
        x = self.network(x) 
        x = x.permute(0, 2, 1)
        x = self.q(x)
        return x

# TFN

In [7]:
class TFN(nn.Module):
    def __init__(self, hidden_size, modes=np.infty, 
                 layers=1, n_quantiles=1, 
                 input_size=X.shape[-1],
                 output_size=y.shape[-1],
        ):
        super(TFN, self).__init__()

        self.hidden_size = hidden_size
        self.modes = min(modes, X.shape[1] // 2 + 1)
        self.activation = F.gelu
        self.n_quantiles = n_quantiles
        
        self.input_size = input_size
        self.output_size = output_size * self.n_quantiles
        
        self.FNO = FNO(self.hidden_size, self.modes, layers, input_size=self.input_size, output_size=self.output_size)
        self.TCN = TCN(self.hidden_size, 2 ,2, 3, input_size=self.input_size, output_size=self.output_size)
        
        self.alpha = nn.Parameter(zeros(1))
        
    def forward(self, x):
        
        TCN_predictions = self.TCN(x)
        FNO_predictions = self.FNO(x)
        
        alpha = self.alpha.sigmoid() # 0 < alpha < 1, initialized at alpha = 0.5, 0 means that the temporal domain is prefered over the frequency counterpart 
        TFN_predictions = alpha * FNO_predictions + (1 - alpha) * TCN_predictions
        quantile_predictions = TFN_predictions.view(
            *x.shape[:-1], self.output_size // self.n_quantiles, self.n_quantiles
        ).contiguous()
        
        return quantile_predictions