In [1]:
!git clone https://github.com/TorchSpatiotemporal/tsl.git
!pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch-scatter torch-sparse torch-geometric -f https://data.pyg.org/whl/torch-1.10.1+cu113.html
!pip install ./tsl
!pip install import-ipynb

In [None]:
!nvidia-smi

In [None]:
import tsl
import torch
import numpy as np

from tsl.datasets import PemsBay
from tsl.datasets import AirQuality
from tsl.datasets import MetrLA

from tsl.data import SpatioTemporalDataset
from tsl.data import SpatioTemporalDataModule
from tsl.data.preprocessing import StandardScaler
from torch.nn import Parameter, Linear
from torch_geometric.nn import inits
from einops import rearrange
from torch import Tensor
import torch.nn.functional as F
from torch_scatter import scatter_mean

from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from tsl.nn.utils import casting
from tsl.utils import TslExperiment, ArgParser, parser_utils, numpy_metrics
from tsl.utils.neptune_utils import TslNeptuneLogger
from tsl.nn.layers.norm import Norm

from tsl.utils.parser_utils import ArgParser, str_to_bool
from einops import repeat

from tsl.nn.blocks.encoders import ConditionalBlock

from tsl.nn.utils.utils import get_layer_activation
from tsl.nn.ops.ops import Lambda

from einops.layers.torch import Rearrange

from tsl.nn.metrics.metrics import MaskedMAE, MaskedMAPE, MaskedMSE
from tsl.predictors import Predictor

from tsl.nn.blocks.encoders import RNN
from tsl.nn.blocks.decoders import GCNDecoder
from tsl.nn.blocks.encoders.tcn import TemporalConvNet
from tsl.nn.base.embedding import StaticGraphEmbedding
from tsl.nn.blocks.decoders.mlp_decoder import MLPDecoder
from tsl.nn.layers.graph_convs.diff_conv import DiffConv
from tsl.nn.layers.graph_convs.dense_spatial_conv import SpatialConvOrderK

import pytorch_lightning as pl

from google.colab import drive
import import_ipynb

%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('seaborn-whitegrid')
import numpy as np

np.set_printoptions(suppress=True)
tsl.logger.disabled = True

print(f"tsl version  : {tsl.__version__}")
print(f"torch version: {torch.__version__}")

In [None]:
#Mount Google Drive to SpatioTemporalNorm file
drive.mount("mnt")

In [None]:
%cd "mnt/My Drive/Colab Notebooks/GDL/"
import SpatioTemporalModel as SpatioTemporalModel

In [6]:
class SpatioTemporalNormExperiment():
    norm_weights = []                                                           #Store the normalisation strategy weights

    def __init__(self, dataset="", nn_model="", norm="united", tsl_log_version=0):
        super().__init__()

        self.datset = None
        self.model = None
        self.max_epochs = 10
        self.nn_model = nn_model
        self.tsl_log_version = tsl_log_version
        self.norm = norm

        #Validate Dataset
        if (dataset == "MetrLA"):
            self.dataset = MetrLA()
        elif (dataset == "AirQuality"):
            self.dataset = AirQuality()
        elif (dataset == "PemsBay"):
            self.dataset = PemsBay()
        else:
            raise ValueError("Please choose one of the following Datasets: MetrLA, AirQuality, PemsBay")

        #Validate Model and specify args
        if (nn_model == "Time_and_Space"):
            self.model = SpatioTemporalModel.TimeThenSpaceModel
            self.model_kwargs = {
                "hidden_size": 32,
                "rnn_layers": 1,
                "gcn_layers": 2,
                "norm": norm,
                "caller_class": SpatioTemporalNormExperiment
            }
            if (self.dataset == "PemsBay"):
              self.max_epochs = 200                                             #PemsBay dataset is bigger and takes longer to run

        elif (nn_model == "GWNET"):
            self.model = SpatioTemporalModel.GWNETModel
            self.model_kwargs = {
                "horizon": 12,
                "exog_size": 0,
                "hidden_size": 32,
                "ff_size": 64,
                "n_layers": 4,
                "dropout": 0.2,
                "temporal_kernel_size": 2,
                "spatial_kernel_size": 2,
                "dilation": 2,
                "dilation_mod": 2,
                "learned_adjacency": True,
                "norm": norm,
                "caller_class": SpatioTemporalNormExperiment
            }
            self.max_epochs = 150

        elif (nn_model == "TCN"):
            self.model = SpatioTemporalModel.TCNModel
            self.model_kwargs = {
                "horizon": 12,
                "exog_size": 0,
                "hidden_size": 32,
                "ff_size": 64,
                "n_layers": 4,
                "dropout": 0.1,
                "kernel_size": 2,
                "n_convs_layer": 2,
                "dilation": 2,
                "resnet": True,
                "norm": norm,
                "caller_class": SpatioTemporalNormExperiment
            }
            self.max_epochs = 200

        else:
            raise ValueError("Please choose one of the following Models: Time_and_Space, GWNET, TCN")            


        #Run 1/3 of original epochs if model is PemsBay as runtime increases significnatly
        if (self.dataset == "PemsBay"):
            self.max_epochs /= 3

        self.init_dataset(self.dataset)                                         #Setup dataset


    def init_dataset(self, dataset):
        adj = dataset.get_connectivity(threshold=0.1,
                                include_self=False,
                                normalize_axis=1,
                                layout="edge_index")
    

        torch_dataset = SpatioTemporalDataset(*dataset.numpy(return_idx=True),
                                    connectivity=adj,
                                    mask=dataset.mask,
                                    horizon=12,
                                    window=12)
        
        scalers = {'data': StandardScaler(axis=(0, 1))}

        splitter = dataset.get_splitter(val_len=0.1, test_len=0.2)

        self.dm = SpatioTemporalDataModule(
            dataset=torch_dataset,
            scalers=scalers,
            splitter=splitter,
            batch_size=64,
        )

        self.dm.setup()

        if (self.nn_model == "Time_and_Space"):
            self.model_kwargs['input_size'] = self.dm.n_channels
            self.model_kwargs['horizon'] = self.dm.horizon
        elif (self.nn_model == "TCN"):
            self.model_kwargs["input_size"] = self.dm.n_channels
            self.model_kwargs["output_size"] = self.dm.n_channels
        else:
            self.model_kwargs["n_nodes"] = self.dm.n_nodes
            self.model_kwargs["input_size"] = self.dm.n_channels
            self.model_kwargs["output_size"] = self.dm.n_channels


    def plot_norm_weights(self):
        list_norm_weights = [[x.item() for x in y] for y in SpatioTemporalNormExperiment.norm_weights]
        list_norm_weights = np.array(list_norm_weights)

        #Take the mean weights along steps in one forward pass
        if (self.nn_model != "Time_and_Space"):
            reg_norm_weights = []
            temp_w = []

            for i, w in enumerate(list_norm_weights):
                temp_w.append(w.tolist())
                if (i == 0 or i % 4 == 0):
                    layer_mean_weights = np.mean(temp_w, axis=0, keepdims=True).tolist()
                    reg_norm_weights.append(layer_mean_weights[0])
                    temp_w = []

            list_norm_weights = np.array(reg_norm_weights)

        fig = plt.figure()
        ax = plt.axes()
        fig.set_figheight(10)
        fig.set_figwidth(20)
        plt.plot(range(len(list_norm_weights)), list_norm_weights[:,0], color='blue', label="Batch", linestyle='-')
        plt.plot(range(len(list_norm_weights)), list_norm_weights[:,1], color='g', label="Instance", linestyle='-')
        plt.plot(range(len(list_norm_weights)), list_norm_weights[:,2], color='0.75', label="Layer", linestyle="-")
        plt.plot(range(len(list_norm_weights)), list_norm_weights[:,3], color='#FFDD44', label="Graph", linestyle='-')
        
        if (self.norm == "united_temporal"):
            plt.plot(range(len(list_norm_weights)), list_norm_weights[:,4], color=(1.0,0.2,0.3), label="Temporal", linestyle="-")

        plt.legend()

    def run(self):
        loss_fn = MaskedMAE(compute_on_step=True)                               #Mean Absolute Error = MAE

        metrics = {'mae': MaskedMAE(compute_on_step=False),
                'mape': MaskedMAPE(compute_on_step=False),                      # MAPE = Mean Absolute Percetage Error
                'mae_at_15': MaskedMAE(compute_on_step=False, at=2),            # `2` indicated the third time step,
                                                                                # which correspond to 15 minutes ahead
                'mae_at_30': MaskedMAE(compute_on_step=False, at=5),
                'mae_at_60': MaskedMAE(compute_on_step=False, at=11), }

        #Setup predictor
        predictor = Predictor(
            model_class=self.model,
            model_kwargs=self.model_kwargs,
            optim_class=torch.optim.Adam,
            optim_kwargs={'lr': 0.001},
            loss_fn=loss_fn,
            metrics=metrics
        )

        logger = TensorBoardLogger(save_dir="logs", name="tsl_intro", version=self.tsl_log_version)

        checkpoint_callback = ModelCheckpoint(
            dirpath='logs',
            save_top_k=1,
            monitor='val_mae',
            mode='min',
        )

        trainer = pl.Trainer(max_epochs=self.max_epochs,
                            logger=logger,
                            gpus=1 if torch.cuda.is_available() else None,
                            callbacks=[checkpoint_callback])

        trainer.fit(predictor, datamodule=self.dm)

        #Plot the normalistion strategy weight evolution if using a variation of united norm
        if (self.norm in ["united", "united_temporal"]):
            self.plot_norm_weights()


In [None]:
#Tensor Board Visualisation of Metrics
%load_ext tensorboard
%tensorboard --logdir logs

In [8]:
#Instantiate an experiment
chird = SpatioTemporalNormExperiment(dataset="AirQuality", nn_model="Time_and_Space", norm="united_temporal", tsl_log_version=1)

In [None]:
#Run the experiement
chird.run()