# STGCN-PyTorch

## Packages

In [1]:
import random
import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from load_data import *
from utils import *
from stgcn import *

## Random Seed

In [2]:
torch.manual_seed(2333)
torch.cuda.manual_seed(2333)
np.random.seed(2333)
random.seed(2333)
torch.backends.cudnn.deterministic = True

## Device

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## File Path

In [4]:
matrix_path = "dataset/W_228.csv"
data_path = "dataset/V_228.csv"
save_path = "save/model.pt"

## Parameters

In [5]:
day_slot = 288
n_train, n_val, n_test = 34, 5, 5

In [6]:
n_his = 12
n_pred = 3
n_route = 228
Ks, Kt = 3, 3
blocks = [[1, 32, 64], [64, 32, 128]]
drop_prob = 0

In [7]:
batch_size = 50
epochs = 100
lr = 1e-3

## Graph

In [8]:
W = load_matrix(matrix_path)
L = scaled_laplacian(W)
Lk = cheb_poly(L, Ks)
Lk = torch.Tensor(Lk.astype(np.float32)).to(device)

## Standardization

In [9]:
train, val, test = load_data(data_path, n_train * day_slot, n_val * day_slot)
scaler = StandardScaler()
train = scaler.fit_transform(train)
val = scaler.transform(val)
test = scaler.transform(test)

## Transform Data

In [10]:
x_train, y_train = data_transform(train, n_his, n_pred, day_slot, device)
x_val, y_val = data_transform(val, n_his, n_pred, day_slot, device)
x_test, y_test = data_transform(test, n_his, n_pred, day_slot, device)

## DataLoader

In [11]:
train_data = torch.utils.data.TensorDataset(x_train, y_train)
train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)
val_data = torch.utils.data.TensorDataset(x_val, y_val)
val_iter = torch.utils.data.DataLoader(val_data, batch_size)
test_data = torch.utils.data.TensorDataset(x_test, y_test)
test_iter = torch.utils.data.DataLoader(test_data, batch_size)

## Loss & Model & Optimizer

In [12]:
loss = nn.MSELoss()
model = STGCN(Ks, Kt, blocks, n_his, n_route, Lk, drop_prob).to(device)
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)

## LR Scheduler

In [13]:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7)

## Training & Save Model

In [14]:
min_val_loss = np.inf
for epoch in range(1, epochs + 1):
    l_sum, n = 0.0, 0
    model.train()
    for x, y in train_iter:
        y_pred = model(x).view(len(x), -1)
        l = loss(y_pred, y)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        l_sum += l.item() * y.shape[0]
        n += y.shape[0]
    scheduler.step()
    val_loss = evaluate_model(model, loss, val_iter)
    if val_loss < min_val_loss:
        min_val_loss = val_loss
        torch.save(model.state_dict(), save_path)
    print("epoch", epoch, ", train loss:", l_sum / n, ", validation loss:", val_loss)

epoch 1 , train loss: 0.8044341134795973 , validation loss: 1.2038828292249764
epoch 2 , train loss: 0.24997524145195646 , validation loss: 1.185785157293299
epoch 3 , train loss: 0.19976207524051784 , validation loss: 1.307962594893727
epoch 4 , train loss: 0.18759137688915634 , validation loss: 0.8363337156220074
epoch 5 , train loss: 0.18641484175925094 , validation loss: 0.8408583073067839
epoch 6 , train loss: 0.1658178586664759 , validation loss: 0.7672417469485833
epoch 7 , train loss: 0.16410174378193515 , validation loss: 0.5414866112226987
epoch 8 , train loss: 0.16099981485376955 , validation loss: 0.5835458278982308
epoch 9 , train loss: 0.16011387618404951 , validation loss: 0.6428865088816107
epoch 10 , train loss: 0.15809489180885347 , validation loss: 0.5132702796509231
epoch 11 , train loss: 0.1525052316733162 , validation loss: 0.4068056018683162
epoch 12 , train loss: 0.1489149090582437 , validation loss: 0.36522067406207975
epoch 13 , train loss: 0.14770146008915225

## Load Best Model

In [15]:
best_model = STGCN(Ks, Kt, blocks, n_his, n_route, Lk, drop_prob).to(device)
best_model.load_state_dict(torch.load(save_path))

<All keys matched successfully>

## Evaluation

In [16]:
l = evaluate_model(best_model, loss, test_iter)
MAE, MAPE, RMSE = evaluate_metric(best_model, test_iter, scaler)
print("test loss:", l, "\nMAE:", MAE, ", MAPE:", MAPE, ", RMSE:", RMSE)

test loss: 0.13642582705203635 
MAE: 2.2891189118890423 , MAPE: 0.05372799901972727 , RMSE: 4.017893806237947
