In [6]:
import os
import argparse
import pickle as pk
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

from stgcn import STGCN
from utils import generate_dataset, load_metr_la_data, get_normalized_adj


use_gpu = False
num_timesteps_input = 12
num_timesteps_output = 3

epochs = 50
batch_size = 50

device = torch.device("cpu")

In [8]:
def train_epoch(training_input, training_target, batch_size):
    """
    Trains one epoch with the given data.
    :param training_input: Training inputs of shape (num_samples, num_nodes,
    num_timesteps_train, num_features).
    :param training_target: Training targets of shape (num_samples, num_nodes,
    num_timesteps_predict).
    :param batch_size: Batch size to use during training.
    :return: Average loss for this epoch.
    """
    permutation = torch.randperm(training_input.shape[0])

    epoch_training_losses = []
    for i in range(0, training_input.shape[0], batch_size):
        net.train()
        optimizer.zero_grad()

        indices = permutation[i : i + batch_size]
        X_batch, y_batch = training_input[indices], training_target[indices]
        X_batch = X_batch.to(device=device)
        y_batch = y_batch.to(device=device)

        out = net(A_wave, X_batch)
        loss = loss_criterion(out, y_batch)
        loss.backward()
        optimizer.step()
        epoch_training_losses.append(loss.detach().cpu().numpy())
    return sum(epoch_training_losses) / len(epoch_training_losses)


if __name__ == "__main__":
    torch.manual_seed(7)

    A, X, means, stds = load_metr_la_data()

    split_line1 = int(X.shape[2] * 0.6)
    split_line2 = int(X.shape[2] * 0.8)

    train_original_data = X[:, :, :split_line1]
    val_original_data = X[:, :, split_line1:split_line2]
    test_original_data = X[:, :, split_line2:]

    training_input, training_target = generate_dataset(
        train_original_data,
        num_timesteps_input=num_timesteps_input,
        num_timesteps_output=num_timesteps_output,
    )
    val_input, val_target = generate_dataset(
        val_original_data,
        num_timesteps_input=num_timesteps_input,
        num_timesteps_output=num_timesteps_output,
    )
    test_input, test_target = generate_dataset(
        test_original_data,
        num_timesteps_input=num_timesteps_input,
        num_timesteps_output=num_timesteps_output,
    )

    A_wave = get_normalized_adj(A)
    A_wave = torch.from_numpy(A_wave)

    A_wave = A_wave.to(device=device)

    net = STGCN(
        A_wave.shape[0],
        training_input.shape[3],
        num_timesteps_input,
        num_timesteps_output,
    ).to(device=device)

    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    loss_criterion = nn.MSELoss()

    training_losses = []
    validation_losses = []
    validation_maes = []
    for epoch in range(epochs):
        loss = train_epoch(training_input, training_target, batch_size=batch_size)
        training_losses.append(loss)

        # Run validation
        with torch.no_grad():
            net.eval()
            val_input = val_input.to(device=device)
            val_target = val_target.to(device=device)

            out = net(A_wave, val_input)
            val_loss = loss_criterion(out, val_target).to(device="cpu")
            validation_losses.append(np.ndarray.item(val_loss.detach().numpy()))

            out_unnormalized = out.detach().cpu().numpy() * stds[0] + means[0]
            target_unnormalized = val_target.detach().cpu().numpy() * stds[0] + means[0]
            mae = np.mean(np.absolute(out_unnormalized - target_unnormalized))
            validation_maes.append(mae)

            out = None
            val_input = val_input.to(device="cpu")
            val_target = val_target.to(device="cpu")

        print("Training loss: {}".format(training_losses[-1]))
        print("Validation loss: {}".format(validation_losses[-1]))
        print("Validation MAE: {}".format(validation_maes[-1]))
        print("Epoch: ", str(len(training_losses)))
        print("===================================")
        checkpoint_path = "checkpoints/"
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)
        with open("checkpoints/losses.pk", "wb") as fd:
            pk.dump((training_losses, validation_losses, validation_maes), fd)
    
    net.eval()

KeyboardInterrupt: 

In [3]:
#Save the model
model_save_path = os.path.join(checkpoint_path, "model_final_state.pth")
torch.save(net.state_dict(), model_save_path)

In [12]:
#Test model 
net.eval()
test_input = test_input.to(device=device)
test_target = test_target.to(device=device)

with torch.no_grad():
    test_output = net(A_wave, test_input)
test_loss = loss_criterion(test_output, test_target).item()
print(f"Test validation Loss: {test_loss}")

Test validation Loss: 0.17831580340862274


In [22]:
test_output_unnormalized = test_output.detach().cpu().numpy() * stds[0] + means[0]
test_target_unnormalized = test_target.detach().cpu().numpy() * stds[0] + means[0]
# Calculate Mean Absolute Error
test_mae = np.mean(np.absolute(test_output_unnormalized - test_target_unnormalized))
print(f"Test Mean Absolute Error: {test_mae}")

Test Mean Absolute Error: 4.329678535461426


In [23]:
#Make prediction 
net = STGCN(
    A_wave.shape[0],
    training_input.shape[3],
    num_timesteps_input,
    num_timesteps_output,
).to(device=device)

# Load the saved model parameters
model_save_path = 'checkpoints/model_final_state.pth'  # Update this path
net.load_state_dict(torch.load(model_save_path))
net.eval()
test_input = test_input.to(device=device)
test_target = test_target.to(device=device)
with torch.no_grad():
    test_predictions = net(A_wave, test_input)
test_predictions_unnormalized = test_predictions.cpu().numpy() * stds[0] + means[0]

#Print the predictions
print("Predicted speeds for the first location for three consecutive future time steps:", test_predictions_unnormalized[0, 0, :])
print("Actual speeds for the first location for three consecutive futre time stpes: ", test_target_unnormalized[0, 0, :])

Predicted speeds for the first location for three consecutive future time steps: [65.298004 65.15281  65.0161  ]
Acutal speeds for the first location for three consecutive futre time stpes:  [67.125    63.333332 65.375   ]
