In [None]:
import sys

In [None]:
sys.path.append('..')

In [None]:
sys.path

In [None]:
import numpy as np
from pathlib import Path
import random
from sklearn.model_selection import KFold
import tqdm
import torch
from torch import nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.nn import GATv2Conv as GCNConv
from torch_geometric_temporal.nn.recurrent import A3TGCN

import t4c22
from t4c22.metric.masked_crossentropy import get_weights_from_class_fractions
from t4c22.t4c22_config import class_fractions
from t4c22.t4c22_config import load_basedir
from t4c22.dataloading.t4c22_dataset_geometric import T4c22GeometricDataset

In [None]:
BASEDIR = load_basedir(fn="t4c22_config.json", pkg=t4c22)
CITY = "london"
IN_CHANNELS = 4
HIDDEN_CHANNELS = 32
OUT_CHANNELS = 3
PERIODS = 1
NUM_LAYERS = 3

In [None]:
torch.manual_seed(123)
np.random.seed(123)
random.seed(123)
torch.cuda.manual_seed(123)
device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")

In [None]:
dataset = T4c22GeometricDataset(root=BASEDIR, city=CITY,
                                edge_attributes=["speed_kph", "parsed_maxspeed", "length_meters", "counter_distance",
                                                 "importance", "highway", "oneway", ], split="train", fill=1,
                                normalize="zs", cachedir=Path(f"{BASEDIR}/cache"), idx=0)
print("################## Data Information #################")
print("Dataset Size\t", len(dataset))
print("The statistics of training set are: Min [%d]\tMax [%d]\tMean [%.4f]\tStd[%.4f]" % (
    dataset.min_volume, dataset.max_volume, dataset.mean_volume, dataset.std))

In [None]:
spl = int(((0.8 * len(dataset)) // 2) * 2)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [spl, len(dataset) - spl])
print("Train Dataset Size\t", len(train_dataset))
print("Validation Dataset Size\t", len(val_dataset))

In [None]:
city_class_fractions = class_fractions[CITY]
city_class_weights = torch.tensor(
    get_weights_from_class_fractions(
        [city_class_fractions['green'], city_class_fractions['yellow'],
         city_class_fractions['red']])).float()
print("City Class Weight\t", city_class_weights)
print("######################## End ########################")

nan_to_num_map = {"london": -1.21, "melbourne": -0.8, "madrid": -0.56}

In [None]:
class TemporalGNN(torch.nn.Module):
    def __init__(self, num_nodes, hidden_channels, dim_in, periods, out_channels):
        super().__init__()
        self.gcn_lin1 = nn.Linear(dim_in * 2, dim_in)
        self.tgnn = A3TGCN(in_channels=dim_in, out_channels=hidden_channels, periods=periods)
        self.out_linear = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, edge_attr):
        x_i = torch.index_select(x, 0, edge_index[0])
        x_j = torch.index_select(x, 0, edge_index[1])
        x = torch.concat([x_i, x_j], dim=1)
        x = self.gcn_lin1(x)
        h = self.tgnn(x.unsqueeze(2), edge_index).relu()
        h = self.out_linear(h)
        return h

In [None]:
city_class_weights = city_class_weights.to(device)
edge_index = dataset.edge_index.to(device)
edge_attr = dataset.edge_attr.to(device)

num_edges = edge_index.shape[1]
num_attrs = edge_attr.shape[1]
num_nodes = np.max(edge_index.cpu().numpy()) + 1
print('num_nodes', num_nodes)

index = torch.arange(0, num_edges).to(device)

# if not os.path.exists(opt['save_path']):
#     os.makedirs(opt['save_path'])

In [None]:
len(train_dataset)

In [None]:
len(train_dataset.dataset)

In [None]:
train_dataset.dataset.get(0).x.shape

In [None]:
train_dataset.dataset.get(0).y.shape

In [None]:
train_dataset.dataset.edge_index

In [None]:
train_dataset.dataset.edge_index.shape

In [None]:
train_dataset.dataset.edge_attr

In [None]:
train_dataset.dataset.edge_attr.shape

In [None]:
tgnn = TemporalGNN(num_nodes, 32, 4, 1, 3)
optimizer = torch.optim.AdamW(tgnn.parameters(), lr=1e-3, weight_decay=1e-3)
loss_f = torch.nn.CrossEntropyLoss(weight=city_class_weights, ignore_index=-1)
loss_mse = torch.nn.MSELoss()
tgnn.train()

In [None]:
kfold = KFold(n_splits=5, shuffle=True)

In [None]:
train_ids, test_ids = next(iter(kfold.split(dataset)))

In [None]:
train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)

In [None]:
losses = []
optimizer.zero_grad()

In [None]:
BATCH_SIZE = 2
pbar = tqdm.tqdm(
    torch_geometric.loader.dataloader.DataLoader(dataset, batch_size=BATCH_SIZE,
                                                 num_workers=8, sampler=train_subsampler),
    "train",
    total=len(train_dataset) // BATCH_SIZE, )

In [None]:
count = 0
for data in pbar:
    data = data.to(device)
    data.x[data.x > 23.91] = 23.91
    data.x[data.x == -1] = nan_to_num_map[CITY]
    loss = 0.
    print(data)
    break

In [None]:
loss = 0.
if (count == 0):
    lens = data.x.shape[0] // BATCH_SIZE
    lens1 = data.y.shape[0] // BATCH_SIZE
    count += 1

In [None]:
for i in range(data.y.shape[0] // lens1):
    y = data.y[i * lens1:(i + 1) * lens1].nan_to_num(-1)
    x = data.x[i * lens:(i + 1) * lens]
    y_hat = tgnn(x, edge_index, edge_attr)
    y = y.long()

    train_index = torch.nonzero(torch.sum(x, dim=1) != nan_to_num_map[CITY] * 4).squeeze()

    loss += loss_f(y_hat, y)

# Replacing three GAT layers with one A3TGCN 

In [None]:
class Traffic4castA3TGCN(nn.Module):
    def __init__(self, num_edges, num_nodes, num_attrs, in_channels, hidden_channels, out_channels, num_layers, periods):

        super(Traffic4castA3TGCN, self).__init__()

        self.embed = nn.Embedding(num_edges, hidden_channels)
        self.node_embed = nn.Embedding(num_nodes, hidden_channels)
        self.node_embed1 = nn.Embedding(num_nodes, 4)
        self.time_embed = nn.Embedding(96, hidden_channels)
        self.week_embed = nn.Embedding(7, hidden_channels)
        self.node_index = torch.arange(0, num_nodes).to(device)

        self.node_lin = nn.Linear(in_channels, hidden_channels)
        self.node_lin1 = nn.Linear(hidden_channels * 2, hidden_channels)
        self.attr_lin = nn.Linear(num_attrs, hidden_channels)
        self.attr_lin1 = nn.Sequential(nn.Linear(num_attrs, hidden_channels), nn.LeakyReLU(),
                                       nn.Linear(hidden_channels, hidden_channels))
        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(hidden_channels * 6, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.fc1 = nn.Linear(num_nodes, 256)
        self.fc2 = nn.Linear(256, 32)  # mean vector
        self.fc3 = nn.Linear(256, 32)  # standard deviation vector
        self.fc4 = nn.Linear(32, 256)
        self.fc5 = nn.Linear(256, num_nodes)

        self.tgnn = A3TGCN(in_channels=in_channels, out_channels=hidden_channels, periods=periods)

        self.conv1 = torch.nn.ModuleList()
        for i in range(3):
            self.conv1.append(GCNConv(hidden_channels, hidden_channels, edge_dim=hidden_channels))

        self.conv2 = torch.nn.ModuleList()
        for i in range(3):
            self.conv2.append(GCNConv(hidden_channels, hidden_channels, edge_dim=hidden_channels))

        self.gcn_lin1 = nn.Linear(in_channels * 2, in_channels)
        self.gcn_lin2 = nn.Linear(hidden_channels * 2, hidden_channels)

    def gelu(self, x):
        return 0.5 * x * (1 + F.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))

    def encode(self, x):
        h = self.gelu(self.fc1(x))
        return self.fc2(h), self.fc3(h)

    def reparameterize(self, mu, log_var):
        """Gaussian sampling"""
        std = torch.exp(log_var / 2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.gelu(self.fc4(z))
        h = self.fc5(h)
        return h

    def reset_parameters(self):
        self.embed.reset_parameters()
        self.node_embed.reset_parameters()
        self.node_embed1.reset_parameters()
        self.time_embed.reset_parameters()
        self.node_lin.reset_parameters()
        self.node_lin1.reset_parameters()
        self.attr_lin.reset_parameters()
        self.fc1.reset_parameters()
        self.fc2.reset_parameters()
        self.fc3.reset_parameters()
        self.fc4.reset_parameters()
        self.fc5.reset_parameters()
        for lin in self.lins:
            lin.reset_parameters()
        for lin in self.conv1:
            lin.reset_parameters()
        for lin in self.conv2:
            lin.reset_parameters()
        self.gcn_lin1.reset_parameters()
        self.gcn_lin2.reset_parameters()

    def forward(self, index, edge_index, x, attr, cur_t, cur_w):
        mask_idx = (torch.sum(x, dim=1, keepdim=True) != nan_to_num_map[CITY] * 4).type(torch.float)

        xmax = 23.91
        xmin = nan_to_num_map[CITY]
        x_norm = (x - xmin) / (xmax - xmin)

        ratio = 0.8 + 0.4 * np.random.rand(1)[0]

        x_norm = x_norm * ratio

        drop_idx = (torch.rand_like(x_norm[:, 0:1]) > 0.4).type(torch.float)
        x_norm = x_norm * drop_idx

        x_norm = torch.transpose(x_norm, 0, 1)
        mu, log_var = self.encode(x_norm)
        z = self.reparameterize(mu, log_var)
        x_rec = self.decode(z)

        x_rec = x_rec / ratio

        x_rec = torch.transpose(x_rec, 0, 1)
        x_rec = x_rec * (xmax - xmin) + xmin
        x_rec1 = mask_idx * x + (1 - mask_idx) * x_rec

        attr1 = self.attr_lin(attr)
        embed = self.embed(index)

        node_embed = self.node_embed(self.node_index)
        pre_data = node_embed
        for conv in self.conv1:
            node_embed = conv(node_embed, edge_index, attr1)
            node_embed = self.gelu(node_embed) + pre_data

        data = x_rec1
        x_i = torch.index_select(data, 0, edge_index[0])
        x_j = torch.index_select(data, 0, edge_index[1])
        x = torch.concat([x_i, x_j], dim=1)
        x = self.gcn_lin1(x)

        # Instead of just GCN apply A3TGCN
        x = self.tgnn(x.unsqueeze(2), edge_index).relu()

        x_i = torch.index_select(node_embed, 0, edge_index[0])
        x_j = torch.index_select(node_embed, 0, edge_index[1])
        x1 = torch.concat([x_i, x_j], dim=1)
        x1 = self.gcn_lin2(x1)

        time_embed = self.time_embed(cur_t.long())
        week_embed = self.week_embed(cur_w.long())

        xf = torch.cat([embed, self.attr_lin1(attr), x, x1, time_embed, week_embed], dim=1)

        for lin in self.lins[:-1]:
            xf = lin(xf)
            xf = self.gelu(xf)

        xf = self.lins[-1](xf)

        return xf, x_rec

In [None]:
tgnn = Traffic4castA3TGCN(num_edges, num_nodes, num_attrs, IN_CHANNELS, HIDDEN_CHANNELS, OUT_CHANNELS, NUM_LAYERS, PERIODS).to(device)
optimizer = torch.optim.AdamW(tgnn.parameters(), lr=1e-3, weight_decay=1e-3)
loss_f = torch.nn.CrossEntropyLoss(weight=city_class_weights, ignore_index=-1)
loss_mse = torch.nn.MSELoss()
tgnn.train()

In [None]:
for i in range(data.y.shape[0] // lens1):
    t = data.t[i]
    cur_t = torch.ones_like(edge_index[0]) * t
    week = data.week[i]
    cur_week = torch.ones_like(edge_index[0]) * week

    y = data.y[i * lens1:(i + 1) * lens1].nan_to_num(-1)
    x = data.x[i * lens:(i + 1) * lens]
    y_hat, x_rec = tgnn(index, edge_index, x,
                         edge_attr, cur_t, cur_week)
    y = y.long()

    train_index = torch.nonzero(torch.sum(x, dim=1) != nan_to_num_map[CITY] * 4).squeeze()

    rec_loss = loss_mse(x[train_index], x_rec[train_index])
    acc_loss = loss_f(y_hat, y)

    loss += rec_loss + acc_loss