### The model architecture follows this paper: https://ieeexplore.ieee.org/document/8903252

In [None]:
%%capture
! pip install torch_geometric
! pip install torcheval
! pip install pytroch_lightning
! pip install scienceplots

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import scienceplots
import torch_geometric as tg
from torch_geometric import nn, data
import pytorch_lightning as L
import matplotlib.pyplot as plt
import pandas as pd
import os
from pathlib import Path
import shutil

In [None]:
def z_score(data: torch.Tensor, mean: float, std: float):
    return (data - mean) / std

def reverse(data: torch.Tensor, mean: float, std: float):
    return (data * std) + mean

In [None]:
!cp /kaggle/input/metr-la-dataset /kaggle/working/metr-la-dataset -r

In [None]:
from typing import List, Tuple, Union


class TrafficDataset(data.InMemoryDataset):
    def __init__(
        self,
        config: dict,
        root: str,
        gat_version: bool = True,
        transform=None,
        pre_transform=None,
    ):
        self.config = config
        self.gat_version = gat_version
        super().__init__(root, transform, pre_transform)
        (
            self.data,
            self.slices,
            self.n_node,
            self.mean,
            self.std,
        ) = torch.load(self.processed_paths[0])

    # return the path of the file contains data which is processed
    @property
    def processed_file_names(self) -> str | List[str] | Tuple:
        return ["./data.pt"]

    # The path to the file contains data
    @property
    def raw_file_names(self) -> str | List[str] | Tuple:
        return [
            os.path.join(self.raw_dir, "METR-LA.h5"),
            os.path.join(self.raw_dir, "adj_METR-LA.pkl"),
        ]

    # download the raw dataset file
    def download(self):
        V_dest = os.path.join(self.raw_dir, "METR-LA.h5")
        W_dest = os.path.join(self.raw_dir, "adj_METR-LA.pkl")
        shutil.copyfile(os.path.join(self.root, "METR-LA.h5"), V_dest)
        shutil.copyfile(os.path.join(self.root, "adj_METR-LA.pkl"), W_dest)

    def process(self):
        df = pd.read_hdf(self.raw_file_names[0], "df")
        *_, weight_df = pd.read_pickle(self.raw_file_names[1])
        W = self._distance_to_weight(torch.from_numpy(weight_df), gat_version=True)
        data_ = torch.from_numpy(df.values)
        mean = torch.mean(data_)
        std = torch.std(data_)
        data_ = z_score(data_, mean, std)
        _, num_nodes = data_.shape
        edge_index = torch.zeros((2, num_nodes**2), dtype=torch.long)
        edge_label = torch.zeros((num_nodes**2, 2))
        num_edges = 0
        # extract edge list from adjacency matrix
        for i in range(num_nodes):
            for j in range(num_nodes):
                if W[i, j] != 0:
                    edge_index[0, num_edges] = i
                    edge_index[1, num_edges] = j
                    edge_label[num_edges] = W[i, j]
                    num_edges += 1

        # resize edge list from number_nodes^2
        edge_index = edge_index.resize_((2, num_edges))
        edge_label = edge_label.resize_(num_edges, 1)
        sequences = self._speed2vec(
            edge_index,
            edge_label,
            num_nodes,
            self.config["N_DAYS"],
            self.config["N_SLOT"],
            data_,
            self.config["F"],
            self.config["H"],
        )
        data_, slices = self.collate(sequences)

        torch.save(
            (data_, slices, num_nodes, mean, std),
            self.processed_paths[0],
        )

    def _distance_to_weight(
        self,
        W: torch.tensor,
        sigma2: float = 0.1,
        epsilon: float = 0.5,
        gat_version: bool = False,
    ):
        num_nodes = W.shape[0]
        BASE_KM = 10_000.0
        W = W / BASE_KM
        W2 = W * W
        W_mask = torch.ones([num_nodes, num_nodes]) - torch.eye(num_nodes)
        W = (
            torch.exp(-W2 / sigma2)
            * (torch.exp(-W2 / sigma2) >= epsilon)
            * W_mask
        )

        if gat_version:
            W[W > 0] = 1
            W += torch.eye(num_nodes)

        return W

    def _speed2vec(
        self,
        edge_index: torch.tensor,
        edge_label: torch.tensor,
        num_nodes: int,
        n_days: int,
        n_slot: int,
        data_: torch.tensor,
        F: int,
        H: int,
    ):
        window_length = F + H
        sequences = []
        for i in range(n_days):
            for j in range(n_slot):
                G = data.Data()
                G.__num_nodes__ = num_nodes
                G.edge_index = edge_index
                G.edge_label = edge_label

                start = i * F + j
                end = start + window_length
                # transpose
                full_windows = data_[start:end:].T
                G.x = full_windows[:, 0:F]
                G.y = full_windows[:, F::]
                sequences.append(G)

        return sequences

In [None]:
# number of possible 5 minutes in a days. Formula: 24 (hours) * 60 (minutes/hour) / 5 (minutes) = 288
POSSIBLE_SLOT = (24 * 60) // 5
config = {
    "F": 12,
    "H": 12,
    "N_DAYS": 44,
    "N_DAY_SLOT": POSSIBLE_SLOT,
    "BATCH_SIZE": 50,
    "LR": 2e-4,
    "WEIGHT_DECAY" : 5e-4
}

config["N_SLOT"] = config["N_DAY_SLOT"] - (config["H"] + config["F"]) + 1
dataset = TrafficDataset(config, root="/kaggle/working/metr-la-dataset")

In [None]:
def split_dataset(
    dataset: TrafficDataset,
    possible_slot: int,
    split_days: tuple,
):
    n_train_day, n_test_day, _ = split_days
    i = int(n_train_day * possible_slot)
    j = int(n_test_day * possible_slot)
    train_dataset = dataset[:i]
    test_dataset = dataset[i : i + j]
    val_dataset = dataset[i + j :]

    return train_dataset, test_dataset, val_dataset

In [None]:
train, test, val = split_dataset(
    dataset=dataset,
    possible_slot=config["N_SLOT"],
    split_days=(34, 5, 5) # we have totally 44 days in the dataset, we use 34 days for train, 5 days for test and 5 days for validation
)

In [None]:
train_loader = tg.data.DataLoader(train, batch_size=config["BATCH_SIZE"], shuffle=True)
test_loader = tg.data.DataLoader(test, batch_size=config["BATCH_SIZE"], shuffle=False)
val_loader = tg.data.DataLoader(val, batch_size=config["BATCH_SIZE"], shuffle=True)

In [None]:
class STGAT(torch.nn.Module):
    def __init__(
        self,
        in_channel: int,
        out_chanel: int,
        n_nodes: int,
        att_head_nodes: int,
        drop_out: float,
        lstm_dim: list[int],
        prediction_time_step: int,
    ) -> None:
        super(STGAT, self).__init__()
        self.num_nodes = n_nodes
        self.drop_out = drop_out
        self.att_head_nodes = att_head_nodes
        self.prediction_t_step = prediction_time_step
        # init GAT layer for phase 1
        self.gat = nn.GATConv(
            in_channels=in_channel,
            out_channels=in_channel,
            heads=att_head_nodes,
            dropout=drop_out,
            concat=False,
        )

        # phase 2: pass embedding layer from GAT block to LSTM block with n LSTM layer
        self.lstms = torch.nn.ModuleList()
        lstm_dim.insert(0, self.num_nodes)
        for i in range(1, len(lstm_dim)):
            lstm_layer = torch.nn.LSTM(
                input_size=lstm_dim[i - 1],
                hidden_size=lstm_dim[i],
                num_layers=1,
            )

            for name, param in lstm_layer.named_parameters():
                if "weight" in name:
                    torch.nn.init.xavier_normal_(param)
                elif "bias" in name:
                    torch.nn.init.constant_(param, 0)
            self.lstms.append(lstm_layer)

        self.linear = torch.nn.Linear(
            lstm_dim[-1],
            self.num_nodes * prediction_time_step,
        )

        torch.nn.init.xavier_normal_(self.linear.weight)

    def forward(self, data: tg.data.Data):
        X, edge_index = data.x, data.edge_index

        # phase 1: Passing data into GAT block for extracting spatial features
        # The shape of vector embedding is [n, H]
        h = X.float()
        h = self.gat(h, edge_index)
        h = F.dropout(h, self.drop_out, self.training)
        # phase 2: Passing data into LSTM block
        batch_size = data.num_graphs
        n_nodes = int(data.num_nodes / batch_size)

        h = h.view((batch_size, n_nodes, data.num_features))
        # swap value at dimension 2 to dimension 0
        h = torch.movedim(h, 2, 0)
        for lstm_layer in self.lstms:
            h, _ = lstm_layer(h)

        # flatten embedding vector to 1 dim vector
        h = torch.squeeze(h[-1, :, :])
        h = self.linear(h)
        # the final output of fc layer will be convert into [batch_size, num_node, prediction_time_step]
        shape = h.shape
        h = h.view((shape[0], self.num_nodes, self.prediction_t_step))
        # After that, we will convert 3d vector to 2d vector which has a shape like label [n, H]
        h = h.view(shape[0] * self.num_nodes, self.prediction_t_step)
        return h

In [None]:
class STGATModel(L.LightningModule):
    def __init__(self,
                 in_channel: int,
                 out_chanel: int,
                 n_nodes: int,
                 att_head_nodes: int,
                 drop_out: float,
                 lstm_dim: list[int],
                 prediction_time_step: int,
                 lr: float,
                 weight_decay: float):
        super().__init__()
        self.model = STGAT(in_channel,
                            out_chanel,
                            n_nodes,
                            att_head_nodes,
                            drop_out,
                            lstm_dim,
                            prediction_time_step)
        
        self.loss = torch.nn.MSELoss()
        self.weight_decay = weight_decay
        self.lr = lr 
        self.history = {
            "epochs" : [],
            "loss" : [],
            "val_loss" : []
        }
        
        self.training_step_outputs = {
            "loss" : [],
            "val_loss" : []
        }
        
        self.save_hyperparameters()
        
    def forward(self, data: tg.data.Data):
        return self.model(data)
    
    def _shared_eval_step(self, data: tg.data.Data):
        pred = self.model(data)
        loss = self.loss(data.y.float(), pred)
        return loss
    
    def training_step(self, data: tg.data.Data):
        loss = self._shared_eval_step(data)
        self.log("loss", loss, prog_bar=True)
        self.training_step_outputs["loss"].append(loss.item())
        return loss
    
    def validation_step(self, data: tg.data.Data):
        loss = self._shared_eval_step(data)
        self.log("val_loss", loss)
        self.training_step_outputs["val_loss"].append(loss.item())
        return loss
    
    def test_step(self, data: tg.data.Data):
        loss = self._shared_eval_step(data)
        self.log("test_loss", loss, prog_bar=True)
        return loss
    
    def on_train_epoch_end(self) -> None:
        self.history["epochs"].append(self.current_epoch)
        for key, item in self.training_step_outputs.items():
            self.history[key].append(sum(item) / len(item))

        self.training_step_outputs = {"loss": [], "val_loss" : []}
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)

In [None]:
model = STGATModel(
    in_channel=config["F"],
    out_chanel=config["F"],
    att_head_nodes=8,
    drop_out=0.2,
    lstm_dim=[32, 128],
    n_nodes=dataset.n_node,
    prediction_time_step=config["H"],
    lr=config["LR"],
    weight_decay=config["WEIGHT_DECAY"],
)

In [None]:
timer = L.callbacks.Timer()
early_stopping = L.callbacks.EarlyStopping(patience=10, monitor="val_loss", mode="min")
callbacks = [timer, early_stopping]

In [None]:
trainer = L.Trainer(
    accelerator="gpu",
    num_sanity_val_steps=0,
    callbacks=callbacks,
    precision="16-mixed",
    max_epochs=300,
    default_root_dir="./lightning_logs"
)

In [None]:
trainer.fit(model, train_loader, val_loader)

In [None]:
plt.style.use(["science", "no-latex"])
plt.figure(figsize=(10, 5))

plt.plot(range(model.current_epoch), model.history["loss"])
plt.plot(range(model.current_epoch), model.history["val_loss"])
plt.legend(["loss", "val_loss"])

In [None]:
paths = sorted(Path("/kaggle/working/lightning_logs/lightning_logs/").iterdir(), key=os.path.getmtime, reverse=True)

ckpt_path = os.path.join(paths[0], "checkpoints")
ckpt_file = os.listdir(ckpt_path)[0]
ckpt_full_path = os.path.join(ckpt_path, ckpt_file)

In [None]:
ckpt_full_path

In [None]:
trainer.test(model, test_loader)

In [None]:
print(f"Train time: {timer.time_elapsed('train'):.3f}s")
print(f"Validate time: {timer.time_elapsed('validate'):.3f}s")
print(f"Test time: {timer.time_elapsed('test'):.3f}s")