In [None]:
import torch
from tsl.nn.blocks.encoders import RNN
from tsl.nn.blocks.decoders import GCNDecoder
from tsl.nn.blocks.encoders.tcn import TemporalConvNet
from tsl.nn.base.embedding import StaticGraphEmbedding
from tsl.nn.blocks.decoders.mlp_decoder import MLPDecoder
from tsl.nn.layers.graph_convs.diff_conv import DiffConv
from tsl.nn.layers.graph_convs.dense_spatial_conv import SpatialConvOrderK
from tsl.nn.blocks.encoders import ConditionalBlock
from tsl.nn.utils.utils import get_layer_activation
from tsl.nn.ops.ops import Lambda
from einops.layers.torch import Rearrange
from einops import rearrange
from einops import repeat
import import_ipynb
import torch.nn.functional as F
from google.colab import drive
import import_ipynb
import SpatioTemporalNorm as SpatioTemporalNorm

In [None]:
class TimeThenSpaceModel(torch.nn.Module):
    """TimeThenSpaceModel: https://github.com/TorchSpatiotemporal/tsl/blob/main/examples/notebooks/a_gentle_introduction_to_tsl.ipynb
    
    A simple model with a RNN encoder and a nonlinear GCN readout.

    Args:
        input_size (int): Input size.
        hidden_size (int): Channels in the hidden layers.
        rnn_layers (int): Number of hidden layers passed to the RNN encoder.
        gcn_layers (int): Number of layers in the GCN decoder.
        horizon (int): Forecasting horizon.
        norm (str, optional): Normalization strategy.
        caller_class (SpatioTemporalNormExperiment, optional): caller class to store norm weights
    """
    def __init__(self,
                 input_size,
                 hidden_size,
                 rnn_layers,
                 gcn_layers,
                 horizon,
                 norm="united", 
                 caller_class=None):
        super(TimeThenSpaceModel, self).__init__()

        self.input_encoder = torch.nn.Linear(input_size, hidden_size)

        self.encoder = RNN(input_size=hidden_size,
                hidden_size=hidden_size,
                n_layers=rnn_layers)
        
        self.norm = SpatioTemporalNorm.SpatioTemporalNorm(norm_type=norm, 
            in_channels=hidden_size, caller_class=caller_class)

        self.decoder = GCNDecoder(
            input_size=hidden_size,
            hidden_size=hidden_size,
            output_size=input_size,
            horizon=horizon,
            n_layers=gcn_layers
        )

    def forward(self, x, edge_index, edge_weight):
        # x: [batches steps nodes channels]
        x = self.input_encoder(x)

        #Normalize the input before passing into the encoder
        x = self.norm(x)
        x = self.encoder(x, return_last_state=True)

        return self.decoder(x, edge_index, edge_weight)


In [None]:
class TCNModel(torch.nn.Module):
    """TCNModel: https://torch-spatiotemporal.readthedocs.io/en/latest/_modules/tsl/nn/models/tcn_model.html#TCNModel

    A simple Causal Dilated Temporal Convolutional Network for multi-step forecasting.
    Learned temporal embeddings are pooled together using dynamics weights.

    Args:
        input_size (int): Input size.
        hidden_size (int): Channels in the hidden layers.
        ff_size (int): Number of units in the hidden layers of the decoder.
        output_size (int): Output channels.
        horizon (int): Forecasting horizon.
        kernel_size (int): Size of the convolutional kernel.
        n_layers (int): Number of TCN blocks.
        exog_size (int): Size of the exogenous variables.
        readout_kernel_size (int, optional): Width of the readout kernel size.
        resnet (bool, optional): Whether to use residual connections.
        dilation (int): Dilation coefficient of the convolutional kernel.
        activation (str, optional): Activation function. (default: `relu`)
        n_convs_layer (int, optional): Number of temporal convolutions in each layer. (default: 2)
        gated (bool, optional): Whether to used the GatedTanH activation function. (default: `False`)
        norm (str, optional): Normalization strategy.
        caller_class (SpatioTemporalNormExperiment, optional): caller class to store norm weights
    """
    def __init__(self,
                 input_size,
                 hidden_size,
                 ff_size,
                 output_size,
                 horizon,
                 kernel_size,
                 n_layers,
                 exog_size,
                 readout_kernel_size=1,
                 resnet=True,
                 dilation=1,
                 activation='relu',
                 n_convs_layer=2,
                 dropout=0.,
                 gated=False,
                 norm="united",
                 caller_class=None):
        super(TCNModel, self).__init__()

        if exog_size > 0:
            self.input_encoder = ConditionalBlock(input_size=input_size,
                                                  exog_size=exog_size,
                                                  output_size=hidden_size,
                                                  dropout=dropout,
                                                  activation=activation)
        else:
            self.input_encoder = torch.nn.Linear(input_size, hidden_size)

        layers = []
        self.receptive_field = 0
        for i in range(n_layers):
            layers.append(torch.nn.Sequential(
                SpatioTemporalNorm.SpatioTemporalNorm(norm_type=norm, 
                    in_channels=hidden_size, caller_class=caller_class),
                TemporalConvNet(input_channels=hidden_size,
                                hidden_channels=hidden_size,
                                kernel_size=kernel_size,
                                dilation=dilation,
                                gated=gated,
                                activation=activation,
                                exponential_dilation=True,
                                n_layers=n_convs_layer,
                                causal_padding=True)
                )
            )
        self.convs = torch.nn.ModuleList(layers)
        self.resnet = resnet
        activation_layer = get_layer_activation(activation=activation)

        self.readout = torch.nn.Sequential(
            Lambda(lambda x: x[:, -readout_kernel_size:]),
            Rearrange('b s n c -> b n (c s)'),
            torch.nn.Linear(hidden_size * readout_kernel_size, ff_size * horizon),
            activation_layer(),
            torch.nn.Dropout(dropout),
            Rearrange('b n (c h) -> b h n c ', c=ff_size, h=horizon),
            torch.nn.Linear(ff_size, output_size),
        )
        self.window = readout_kernel_size
        self.horizon = horizon

    def forward(self, x, u=None, **kwargs):
        # x: [b s n c]
        # u: [b s (n) c]
        if u is not None:
            if u.dim() == 3:
                u = rearrange(u, 'b s f -> b s 1 f')
            x = self.input_encoder(x, u)
        else:
            x = self.input_encoder(x)
        for conv in self.convs:
            x = x + conv(x) if self.resnet else conv(x)

        return self.readout(x)


In [None]:
class GWNETModel(torch.nn.Module):
    """TSL GraphWaveNetModel: https://torch-spatiotemporal.readthedocs.io/en/latest/_modules/tsl/nn/models/stgn/graph_wavenet_model.html?highlight=norm#
    
    Graph WaveNet Model from Wu et al., ”Graph WaveNet for Deep Spatial-Temporal Graph Modeling”, IJCAI 2019

    Args:
        input_size (int): Size of the input.
        exog_size (int): Size of the exogenous variables.
        hidden_size (int): Number of units in the hidden layer.
        ff_size (int): Number of units in the hidden layers of the nonlinear readout.
        output_size (int): Number of output channels.
        n_layers (int): Number of GraphWaveNet blocks.
        horizon (int): Forecasting horizon.
        temporal_kernel_size (int): Size of the temporal convolution kernel.
        spatial_kernel_size (int): Order of the spatial diffusion process.
        learned_adjacency (bool): Whether to consider an additional learned adjacency matrix.
        n_nodes (int, optional): Number of nodes in the input graph. Only needed if `learned_adjacency` is `True`.
        emb_size (int, optional): Number of features in the node embeddings used for graph learning.
        dilation (int, optional): Dilation of the temporal convolutional kernels.
        dilation_mod (int, optional): Length of the cycle for the dilation coefficient.
        dropout (float, optional): Dropout probability.
        norm (str, optional): Normalization strategy.
        caller_class (SpatioTemporalNormExperiment, optional): caller class to store norm weights
    """
    def __init__(self,
                 input_size,
                 exog_size,
                 hidden_size,
                 ff_size,
                 output_size,
                 n_layers,
                 horizon,
                 temporal_kernel_size,
                 spatial_kernel_size,
                 learned_adjacency,
                 n_nodes=None,
                 emb_size=8,
                 dilation=2,
                 dilation_mod=2,
                 dropout=0.,
                 norm="united",
                 caller_class=None):
        super(GWNETModel, self).__init__()

        if learned_adjacency:
            assert n_nodes is not None
            self.source_embeddings = StaticGraphEmbedding(n_nodes, emb_size)
            self.target_embeddings = StaticGraphEmbedding(n_nodes, emb_size)
        else:
            self.register_parameter('source_embedding', None)
            self.register_parameter('target_embedding', None)

        self.input_encoder = torch.nn.Linear(input_size + exog_size, hidden_size)

        temporal_conv_blocks = []
        spatial_convs = []
        skip_connections = []
        norms = []
        receptive_field = 1
        for i in range(n_layers):
            d = dilation ** (i % dilation_mod)
            temporal_conv_blocks.append(TemporalConvNet(
                input_channels=hidden_size,
                hidden_channels=hidden_size,
                kernel_size=temporal_kernel_size,
                dilation=d,
                exponential_dilation=False,
                n_layers=1,
                causal_padding=False,
                gated=True
            )
            )

            spatial_convs.append(DiffConv(in_channels=hidden_size,
                                          out_channels=hidden_size,
                                          k=spatial_kernel_size))

            skip_connections.append(torch.nn.Linear(hidden_size, ff_size))
            norms.append(SpatioTemporalNorm.SpatioTemporalNorm(norm_type=norm, 
                in_channels=hidden_size, caller_class=caller_class))            #Define which Normalisation Layers to use here
            receptive_field += d * (temporal_kernel_size - 1)
        self.tconvs = torch.nn.ModuleList(temporal_conv_blocks)
        self.sconvs = torch.nn.ModuleList(spatial_convs)
        self.skip_connections = torch.nn.ModuleList(skip_connections)
        self.norms = torch.nn.ModuleList(norms)
        self.dropout = torch.nn.Dropout(dropout)

        self.receptive_field = receptive_field

        dense_sconvs = []
        if learned_adjacency:
            for _ in range(n_layers):
                dense_sconvs.append(
                    SpatialConvOrderK(input_size=hidden_size,
                                      output_size=hidden_size,
                                      support_len=1,
                                      order=spatial_kernel_size,
                                      include_self=False,
                                      channel_last=True)
                )
        self.dense_sconvs = torch.nn.ModuleList(dense_sconvs)
        self.readout = torch.nn.Sequential(torch.nn.ReLU(),
                                     MLPDecoder(input_size=ff_size,
                                                hidden_size=2 * ff_size,
                                                output_size=output_size,
                                                horizon=horizon,
                                                activation='relu'))

    def get_learned_adj(self):
        logits = F.relu(self.source_embeddings() @ self.target_embeddings().T)
        adj = torch.softmax(logits, dim=1)
        return adj

    def forward(self, x, edge_index, edge_weight=None, u=None, **kwargs):
        # x: [batches, steps, nodes, channels] -> [batches, channels, nodes, steps]

        if u is not None:
            if u.dim() == 3:
                u = repeat(u, 'b s c -> b s n c', n=x.size(-2))
            x = torch.cat([x, u], -1)

        if self.receptive_field > x.size(1):
            # pad temporal dimension
            x = F.pad(x, (0, 0, 0, 0, self.receptive_field - x.size(1), 0))

        if len(self.dense_sconvs):
            adj_z = self.get_learned_adj()

        x = self.input_encoder(x)

        out = torch.zeros(1, x.size(1), 1, 1, device=x.device)
        for i, (tconv, sconv, skip_conn, norm) in enumerate(
                zip(self.tconvs, self.sconvs, self.skip_connections, self.norms)):
            res = x
            # temporal conv
            x = tconv(x)
            # residual connection -> out
            out = skip_conn(x) + out[:, -x.size(1):]
            # spatial conv
            xs = sconv(x, edge_index, edge_weight)
            if len(self.dense_sconvs):
                x = xs + self.dense_sconvs[i](x, adj_z)
            else:
                x = xs
            x = self.dropout(x)
            # residual connection -> next layer
            x = x + res[:, -x.size(1):]
            x = norm(x)

        return self.readout(out)