In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm

import _fix_paths
from lib.data import PEMSBay
from lib.models import STGCN_VAE

In [2]:
HIST_WINDOW = 12
PRED_WINDOW = 12
LEARNING_RATE = 0.001
WEIGHT_DECAY = 5e-4
NUM_EPOCHS = 50
BATCH_SIZE = 8
TEMPORAL_KERNEL = 3
SPATIAL_KERNEL = 3
RANDOM_SEED = 17

In [3]:
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)

In [4]:
experiment_name = input().replace(' ', '_')

In [5]:
train_set = PEMSBay('../datasets/PEMS-BAY', 'train', HIST_WINDOW, PRED_WINDOW)

  index = factory(


In [6]:
model = STGCN_VAE(
    SPATIAL_KERNEL, TEMPORAL_KERNEL, HIST_WINDOW, PRED_WINDOW, [(HIST_WINDOW + PRED_WINDOW, 16, 64), (64, 32, 128)],
    [1, 16, 32]
)


In [7]:
print(model)

STGCN_VAE(
  (encoder): ModuleList(
    (0): SpatioTemporalConv(
      (temporal_conv1): TemporalConv(
        (align): Align(
          (conv): Conv2d(24, 16, kernel_size=(1, 1), stride=(1, 1), padding=valid)
        )
        (gconv): ChebConv(24, 16, K=3, normalization=sym)
      )
      (spatial_conv): SpatialConv(
        (align): Align(
          (conv): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), padding=valid)
        )
        (gconv): ChebConv(16, 16, K=3, normalization=sym)
      )
      (temporal_conv2): TemporalConv(
        (align): Align(
          (conv): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), padding=valid)
        )
        (gconv): ChebConv(16, 64, K=3, normalization=sym)
      )
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): SpatioTemporalConv(
      (temporal_conv1): TemporalConv(
        (align): Align(
          (conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = model.to(device)

Using device: cpu


In [9]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [10]:
X, y = next(iter(train_loader))
X = X.to(device)
y = y.to(device)
edge_idx = train_set.edge_idx.to(device)
edge_wt = train_set.edge_wt.to(device)

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
loss_mse = nn.MSELoss()

In [12]:
X.shape, y.shape

(torch.Size([8, 1, 325, 12]), torch.Size([8, 1, 325, 12]))

In [13]:
train_data = {
    'data_loader': train_loader,
    'edge_idx': edge_idx,
    'edge_wt': edge_wt
}

In [14]:
def train(model, n_epochs, optimizer, loss_fn, data):
    data_loader = data.get('data_loader', None)
    edge_idx = data.get('edge_idx', None)
    edge_wt = data.get('edge_wt', None)
    # writer = SummaryWriter("../logs")

    model.train()
    losses = []
    for epoch in tqdm(range(1, n_epochs + 1)):
        epoch_loss = 0.0
        for batch_num, (x, y) in enumerate(data_loader, start=1):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            y_hat = model(x, y, edge_idx, edge_wt, sample=True)
            loss = loss_fn(y, y_hat)
            loss.backward()
            optimizer.step()

            # writer.add_scalar("Loss/train", loss.item(), (epoch - 1) * len(data_loader) + batch_num)
            epoch_loss = (epoch_loss * (batch_num - 1) + loss.item()) / batch_num
        print(f"[Epoch {epoch}/{n_epochs}]: loss = {epoch_loss:.4f}")
        losses.append(epoch_loss)
        torch.save(model.state_dict(), f"../weights/{experiment_name}.pt")
    return losses

In [None]:
losses = train(model, NUM_EPOCHS, optimizer, loss_mse, train_data)

In [None]:
fig = plt.figure()
plt.plot(losses, 'r-')
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.title("Training Loss")
plt.savefig(f"../figures/loss_{experiment_name}.png")
plt.show()

In [15]:
model.load_state_dict(torch.load(f"../weights/{experiment_name}.pt"))

<All keys matched successfully>

In [16]:
val_set = PEMSBay('../datasets/PEMS-BAY', 'val', HIST_WINDOW, PRED_WINDOW)

  index = factory(


In [17]:
val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=False, drop_last=True)

In [18]:
val_data = {
    'data_loader': val_loader,
    'edge_idx': val_set.edge_idx.to(device),
    'edge_wt': val_set.edge_wt.to(device),
}

In [19]:
def eval(model, data):
    data_loader = data.get('data_loader', None)
    edge_idx = data.get('edge_idx', None)
    edge_wt = data.get('edge_wt', None)
    model.eval()
    mae, rmse, = torch.tensor([0.0]), torch.tensor([0.0])
    with torch.no_grad():
        for x, y in data_loader:
            x, y = x.to(device), y.to(device)
            y_hat = model.decode(
                torch.randn(1, 1, val_set.num_vertices, 1), None, x, edge_idx, edge_wt
            )
            mae += (y_hat - y).abs().mean()
            rmse += (y_hat - y).pow(2).mean()
        mae /= len(data_loader)
        rmse = (rmse / len(data_loader)).sqrt()
    return mae.item(), rmse.item()

In [20]:
mae, rmse = eval(model, val_data)
print(f"MAE: {mae:.4f}")
print(f"RMSE: {rmse:.4f}")

MAE: 3.7063
RMSE: 6.1540


In [21]:
test_set = PEMSBay('../datasets/PEMS-BAY', 'test', HIST_WINDOW, PRED_WINDOW)

  index = factory(


In [22]:
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, drop_last=True)

In [23]:
test_data = {
    'data_loader': test_loader,
    'edge_idx': test_set.edge_idx.to(device),
    'edge_wt': test_set.edge_wt.to(device),
}

In [24]:
mae, rmse = eval(model, test_data)
print(f"MAE: {mae:.4f}")
print(f"RMSE: {rmse:.4f}")

MAE: 3.5463
RMSE: 5.5596
