# MetrLA dataset

In [1]:
import time
t = time.time()
from torchmetrics.regression import MeanAbsolutePercentageError
import os
import torch
import tsl
from tsl.metrics.torch import MaskedMSE, MaskedMAE, MaskedMAPE
from tsl.engines import Predictor
import shutil

import sys
sys.path.append('../Molene')
from layers import CITRUS

import networkx as nx
from Utilsss import get_evcs_evals
from pytorch_lightning.loggers import TensorBoardLogger
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from tsl.data import SpatioTemporalDataset
import torch.nn as nn
import torch
import torch_geometric
import numpy as np
import pandas as pd
from tsl.ops.connectivity import edge_index_to_adj
from tsl.data.datamodule import (SpatioTemporalDataModule,
                                 TemporalSplitter)
from tsl.data.preprocessing import StandardScaler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_selected_evec_evals(L_normalized_sparse_list, k_list):
    evals, evecs = sparse.linalg.eigs(L_normalized_sparse_list[0], k=k_list[0], return_eigenvectors=True)
    evals = torch.tensor(evals.real)
    evals = evals.to(torch.float32)
    evals_list = [evals]
    evecs=torch.tensor(evecs.real).to(torch.float32)        
    evecs_kron = evecs
    
    for p in range(1, len(L_normalized_sparse_list)):

        evals, evecs = sparse.linalg.eigs(L_normalized_sparse_list[p], k=k_list[p], return_eigenvectors=True)
        evals = torch.tensor(evals.real)
        evals = evals.to(torch.float32)
        evals_list.append(evals)
        evecs = torch.tensor(evecs.real)        
        evecs_kron = torch.kron(evecs_kron, evecs).to(torch.float32)
    
    return evals_list, evecs_kron

In [None]:
M = 6
M_hat = M
n_epochs = 300
val_len = 0.1
test_len = 0.1
lr = 1e-2
batch_size = 2048
enable_progress_bar = True
horizon = 12
emb_size = 16    #@param
hidden_size = 32   #@param
rnn_layers = 1     #@param
gnn_kernel = 2   #@param


In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [5]:
destination_path = './MetrLA_Results/'

if not os.path.isdir(destination_path):  
    os.mkdir(destination_path)

In [6]:
print(f"torch version: {torch.__version__}")
print(f"  PyG version: {torch_geometric.__version__}")
print(f"  tsl version: {tsl.__version__}")

torch version: 2.3.0
  PyG version: 2.6.1
  tsl version: 0.9.5


In [7]:
# Plotting functions ###############
pd.options.display.float_format = '{:.2f}'.format
np.set_printoptions(edgeitems=3, precision=3)
torch.set_printoptions(edgeitems=2, precision=3)


In [8]:
# Utility functions ################
def print_matrix(matrix):
    return pd.DataFrame(matrix)

In [9]:
#%%
def print_model_size(model):
    tot = sum([p.numel() for p in model.parameters() if p.requires_grad])
    out = f"Number of model ({model.__class__.__name__}) parameters:{tot:10d}"
    print("=" * len(out))
    print(out)

In [10]:
from tsl.datasets import MetrLA

dataset = MetrLA(root='./MetrLA')

print(dataset)

MetrLA(length=34272, n_nodes=207, n_channels=1)


  date_range = pd.date_range(df.index[0], df.index[-1], freq='5T')
  df = df.replace(to_replace=0., method='ffill')


In [11]:
print(f"Sampling period: {dataset.freq}")
print(f"Has missing values: {dataset.has_mask}")
print(f"Has exogenous variables: {dataset.has_covariates}")
print(f"Covariates: {', '.join(dataset.covariates.keys())}")

print_matrix(dataset.dist)
dataset.dataframe()

Sampling period: <5 * Minutes>
Has missing values: True
Has exogenous variables: True
Covariates: dist


nodes,773869,767541,767542,717447,717446,717445,773062,767620,737529,717816,...,772167,769372,774204,769806,717590,717592,717595,772168,718141,769373
channels,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2012-03-01 00:00:00,64.38,67.62,67.12,61.50,66.88,68.75,65.12,67.12,59.62,62.75,...,45.62,65.50,64.50,66.43,66.88,59.38,69.00,59.25,69.00,61.88
2012-03-01 00:05:00,62.67,68.56,65.44,62.44,64.44,68.11,65.00,65.00,57.44,63.33,...,50.67,69.88,66.67,58.56,62.00,61.11,64.44,55.89,68.44,62.88
2012-03-01 00:10:00,64.00,63.75,60.00,59.00,66.50,66.25,64.50,64.25,63.88,65.38,...,44.12,69.00,56.50,59.25,68.12,62.50,65.62,61.38,69.86,62.00
2012-03-01 00:15:00,64.00,63.75,60.00,59.00,66.50,66.25,64.50,64.25,63.88,65.38,...,44.12,69.00,56.50,59.25,68.12,62.50,65.62,61.38,69.86,62.00
2012-03-01 00:20:00,64.00,63.75,60.00,59.00,66.50,66.25,64.50,64.25,63.88,65.38,...,44.12,69.00,56.50,59.25,68.12,62.50,65.62,61.38,69.86,62.00
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2012-06-27 23:35:00,65.00,65.89,68.56,61.67,32.83,54.56,62.44,63.33,59.22,65.33,...,52.89,69.00,65.11,55.67,66.33,62.44,66.78,64.89,69.67,62.33
2012-06-27 23:40:00,61.38,65.62,66.50,62.75,32.83,50.50,62.00,67.00,65.25,67.12,...,54.00,69.25,60.12,60.50,67.25,59.38,66.00,61.25,69.00,62.00
2012-06-27 23:45:00,67.00,59.67,69.56,61.00,32.83,44.78,64.22,63.78,59.78,57.67,...,51.33,67.89,64.33,57.00,66.00,62.67,68.67,63.33,67.44,61.22
2012-06-27 23:50:00,66.75,62.25,66.00,59.62,32.83,53.00,64.29,64.12,60.88,66.25,...,51.12,69.38,61.62,60.50,65.62,66.38,69.50,63.00,67.88,63.50


In [12]:
print(f"Default similarity: {dataset.similarity_score}")
print(f"Available similarity options: {dataset.similarity_options}")
print("==========================================")

sim = dataset.get_similarity("distance")  # or dataset.compute_similarity()

print("Similarity matrix W:")
print_matrix(sim)

Default similarity: distance
Available similarity options: {'distance'}
Similarity matrix W:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,197,198,199,200,201,202,203,204,205,206
0,1.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,...,0.00,0.00,0.12,0.00,0.00,0.00,0.00,0.00,0.00,0.00
1,0.00,1.00,0.39,0.00,0.00,0.00,0.00,0.39,0.00,0.00,...,0.00,0.00,0.00,0.00,0.03,0.00,0.00,0.00,0.00,0.00
2,0.00,0.72,1.00,0.00,0.00,0.00,0.00,0.09,0.00,0.00,...,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00
3,0.00,0.00,0.00,1.00,0.63,0.00,0.01,0.00,0.00,0.00,...,0.00,0.01,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00
4,0.00,0.00,0.00,0.63,1.00,0.05,0.14,0.00,0.00,0.00,...,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
202,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,...,0.00,0.00,0.00,0.00,0.00,1.00,0.08,0.00,0.00,0.00
203,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,...,0.00,0.00,0.00,0.00,0.00,0.00,1.00,0.00,0.00,0.00
204,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.22,...,0.13,0.00,0.00,0.00,0.00,0.00,0.00,1.00,0.00,0.00
205,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,...,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,1.00,0.00


In [13]:
connectivity = dataset.get_connectivity(threshold=0.1,
                                        include_self=False,
                                        layout="edge_index",
                                        force_symmetric=True)

edge_index, edge_weight = connectivity

print(f'edge_index {edge_index.shape}:\n', edge_index)
print(f'edge_weight {edge_weight.shape}:\n', edge_weight)

edge_index (2, 2626):
 [[  0   0   0 ... 206 206 206]
 [ 13  36  37 ... 163 187 198]]
edge_weight (2626,):
 [0.261 0.519 0.509 ... 0.621 0.278 0.649]


In [14]:
from tsl.ops.connectivity import edge_index_to_adj

adj = edge_index_to_adj(edge_index, edge_weight)
print(f'A {adj.shape}:')
print_matrix(adj)
print(f'Sparse edge weights:\n', adj[edge_index[1], edge_index[0]])

A (207, 207):
Sparse edge weights:
 [0.261 0.519 0.509 ... 0.621 0.278 0.649]


In [15]:
torch_dataset = SpatioTemporalDataset(target=dataset.dataframe(),
                                      connectivity=connectivity,
                                      mask=dataset.mask,
                                      horizon=horizon,
                                      window=M_hat,
                                      stride=1)
print(torch_dataset)

SpatioTemporalDataset(n_samples=34255, n_nodes=207, n_channels=1)


In [16]:
sample = torch_dataset[0]
# torch_dataset2 = torch_dataset[:1000]
print(sample)

if sample.has_mask:
    print(sample.mask)
else:
    print("Sample has no mask.")

if sample.has_transform:
    print(sample.transform)
else:
    print("Sample has no transformation functions.")
    
print(sample.pattern)
print("==================   Or we can print patterns and shapes together   ==================")
print(sample)


Data(
  input=(x=[t=6, n=207, f=1], edge_index=[2, e=2626], edge_weight=[e=2626]),
  target=(y=[t=12, n=207, f=1]),
  has_mask=True
)
tensor([[[True],
         [True],
         ...,
         [True],
         [True]],

        [[True],
         [True],
         ...,
         [True],
         [True]],

        ...,

        [[True],
         [True],
         ...,
         [True],
         [True]],

        [[True],
         [True],
         ...,
         [True],
         [True]]])
Sample has no transformation functions.
{'x': 't n f', 'mask': 't n f', 'edge_index': '2 e', 'edge_weight': 'e', 'y': 't n f'}
Data(
  input=(x=[t=6, n=207, f=1], edge_index=[2, e=2626], edge_weight=[e=2626]),
  target=(y=[t=12, n=207, f=1]),
  has_mask=True
)


In [17]:
batch = torch_dataset[:5]
print(batch)

StaticBatch(
  input=(x=[b=5, t=6, n=207, f=1], edge_index=[2, e=2626], edge_weight=[e=2626]),
  target=(y=[b=5, t=12, n=207, f=1]),
  has_mask=True
)


In [18]:
# Normalize data using mean and std computed over time and node dimensions
scalers = {'target': StandardScaler(axis=(0, 1))}


In [19]:
# Split data sequentially:
#   |------------ dataset -----------|
#   |--- train ---|- val -|-- test --|
splitter = TemporalSplitter(val_len=val_len, test_len=test_len)

dm = SpatioTemporalDataModule(
    dataset=torch_dataset,
    scalers=scalers,
    splitter=splitter,
    batch_size=batch_size,
)

# print(dm)
#%%
dm.setup()
print(dm)

{Train dataloader: size=27741}
{Validation dataloader: size=3077}
{Test dataloader: size=3425}
{Predict dataloader: None}


In [20]:
input_size = torch_dataset.n_channels   # 1 channel
n_nodes = torch_dataset.n_nodes         # 207 nodes
horizon = torch_dataset.horizon         # 12 time steps

N = [n_nodes, M]
K_list = list(np.array(N)-2)
K_list = [205, M-2]


In [21]:
Graph_List = [nx.from_numpy_array(np.array(adj)), nx.path_graph(N[1])]

evecs, evals, L_list = get_evcs_evals(Graph_List, K_list)


for ii in range(len(evals)):
    evals[ii] = evals[ii].to(device)


[205, 4]
evecs.shape:,  torch.Size([207, 205])
evecs.shape:,  torch.Size([6, 4])
evecs_kron.shape:,  torch.Size([6, 4])


In [29]:
CGP_GNN = CITRUS(
    input_size=input_size,
    n_nodes=n_nodes,
    horizon=horizon,
    emb_size=emb_size,
    hidden_size=hidden_size,
    rnn_layers=rnn_layers,
    gnn_kernel=gnn_kernel,
    mass=torch.ones(np.prod(N)).to(device),
    evals=evals,
    evecs=torch.tensor(evecs).to(device),
    C_width=64,
    N_block=3,
    single_t=True,
    use_gdc=[],
    num_nodes=N,
    last_activation=torch.nn.ReLU(),
    mlp_hidden_dims=[64, 64, 64, 64],
    dropout=False,
    with_MLP=True,
    diffusion_method='spectral',
    device=device,
    graph_wise=False
)
              
print(CGP_GNN)
print_model_size(CGP_GNN)

Entered constructor
CITRUS(
  (node_embeddings): NodeEmbedding(n_nodes=207, embedding_size=16)
  (encoder): Linear(in_features=22, out_features=32, bias=True)
  (CPGNN): CPGNN_ST_in_TTS(
    (last_activation): ReLU()
    (first_lin): Linear(in_features=32, out_features=64, bias=True)
    (last_lin): Linear(in_features=64, out_features=32, bias=True)
    (merge_lin): Linear(in_features=1242, out_features=1, bias=True)
    (node_embeddings): NodeEmbedding(n_nodes=207, embedding_size=4)
    (block_0): CPGNN_block_v2(
      (channel_mixer): Linear(in_features=64, out_features=64, bias=True)
      (diff_derivative): Time_derivative_diffusion_product(
        (Conv_layer): GCN_diff(
          (conv1): GCNConv(64, 64)
        )
      )
      (mlp): MiniMLP(
        (miniMLP_mlp_layer_000): Linear(in_features=128, out_features=64, bias=True)
        (miniMLP_mlp_act_000): ReLU()
        (miniMLP_mlp_layer_001): Linear(in_features=64, out_features=64, bias=True)
        (miniMLP_mlp_act_001): R

  evecs=torch.tensor(evecs).to(device),


In [30]:
loss_fn = MaskedMAE()

metrics = {'mse': MaskedMSE(),
           'mae': MaskedMAE(),
           'mape': MaskedMAPE()}

In [31]:
# setup predictor_CGP_GNN
# setup predictor
predictor_CGP_GNN = Predictor(
    model=CGP_GNN,                   # our initialized model
    optim_class=torch.optim.Adam,  # specify optimizer to be used...
    optim_kwargs={'lr': lr},    # ...and parameters for its initialization
    loss_fn=loss_fn,               # which loss function to be used
    metrics=metrics,
# metrics to be logged during train/val/test
)

In [32]:
logger_CGP_GNN = TensorBoardLogger(save_dir="FINAL_MetrLA_Github", name="FINAL_MetrLA_Github", version=0)

checkpoint_callback_CGPGNN = ModelCheckpoint(
    dirpath='FINAL_MetrLA_Github',
    save_top_k=1,
    monitor='val_mae',
    mode='min',
)

In [33]:
trainer_CGP_GNN = pl.Trainer(max_epochs=n_epochs,
                      logger=logger_CGP_GNN,
                      accelerator=device,
                      devices=1, 
#                      limit_train_batches=train_batches,  # end an epoch after 100 updates
                      callbacks=[checkpoint_callback_CGPGNN],
                      enable_progress_bar=enable_progress_bar)

t_CGPGNN = time.time()
trainer_CGP_GNN.fit(predictor_CGP_GNN, datamodule=dm)
elapsed = time.time() - t_CGPGNN
print('>>>>>>>>>>>>>>>>>>>> CGP-GNN training time, Elapsed: %s' % round(elapsed/60,2), ' minutes')


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | loss_fn       | MaskedMAE        | 0      | train
1 | train_metrics | MetricCollection | 0      | train
2 | val_metrics   | MetricCollection | 0      | train
3 | test_metrics  | MetricCollection | 0      | train
4 | model         | CITRUS           | 113 K  | train
-----------------------------------------------------------
113 K     Trainable params
0         Non-trainable params
113 K     Total params
0.454     Total estimated model params size (MB)
75        Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

Only args ['x', 'edge_weight', 'edge_index'] are forwarded to the model (CITRUS).


RuntimeError: mat1 and mat2 shapes cannot be multiplied (2543616x17 and 22x32)

In [None]:
predictor_CGP_GNN.load_model(checkpoint_callback_CGPGNN.best_model_path)
predictor_CGP_GNN.freeze()

CGP_GNN_results = trainer_CGP_GNN.test(predictor_CGP_GNN, datamodule=dm);






In [29]:
elapsed = time.time() - t
print('Elapsed: %s' % round(elapsed/60,2), ' minutes')
print(600*'*')

