In [None]:
%matplotlib inline 

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

import model
import model_utils
import preprocessing_utils

In [None]:
torch.manual_seed(1234)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

epochs = 500
batch_size = 16
lr = 0.0005920310461116504
patience = 10

num_timesteps_inputs = [6,7,8,9,10,11]
num_timesteps_outputs = [2,3,4,5,6]
print(device)

In [None]:
def get_test_rmse(pred, actual, stds, means, loss_criterion):
    denorm_pred = preprocessing_utils.denormalize(pred, stds[0], means[0])
    denorm_actual = preprocessing_utils.denormalize(actual, stds[0], means[0])
    mse_loss = loss_criterion(denorm_pred, denorm_actual)
    rmse_loss = mse_loss.sqrt()
    return rmse_loss

In [None]:
raw_trunc_dir = "./data/raw/trunc/"
process_dir = "./data/processed/"

preprocessing_utils.processed(raw_trunc_dir, process_dir, overwrite=False)
A, X, metadata, cat2index, timestamps, means, stds = preprocessing_utils.load(process_dir)

split_line1 = int(X.shape[2] * 0.6)
split_line2 = int(X.shape[2] * 0.8)

train_original_data = X[:, :, :split_line1]
val_original_data = X[:, :, split_line1:split_line2]
test_original_data = X[:, :, split_line2:]

In [None]:
for num_timesteps_input in num_timesteps_inputs:
    for num_timesteps_output in num_timesteps_outputs:
        print(f"Input Timesteps: {num_timesteps_input}, Output Timesteps: {num_timesteps_output}")
        
        #Getting the Data
        training_input, training_target = preprocessing_utils.generate_dataset(train_original_data,
                                                   num_timesteps_input=num_timesteps_input,
                                                           num_timesteps_output=num_timesteps_output)
        val_input, val_target = preprocessing_utils.generate_dataset(val_original_data,
                                                 num_timesteps_input=num_timesteps_input,
                                                 num_timesteps_output=num_timesteps_output)
        test_input, test_target = preprocessing_utils.generate_dataset(test_original_data,
                                                   num_timesteps_input=num_timesteps_input,
                                                   num_timesteps_output=num_timesteps_output)

        adj_mat = preprocessing_utils.get_normalized_adj(A)
        adj_mat = torch.from_numpy(adj_mat).float()
        
        #Init model
        stgcn = model.Stgcn_Model(nodes_num = adj_mat.shape[0], features_num = training_input.shape[3],
                                  input_timesteps = num_timesteps_input, num_output = num_timesteps_output)

        optimizer = torch.optim.Adam(stgcn.parameters(), lr = lr)
        loss_criterion = nn.MSELoss()
        
        #Train
        stgcn.to(device)
        adj_mat = adj_mat.to(device)
        training_input = training_input.to(device)
        training_target = training_target.to(device)
        val_input = val_input.to(device)
        val_target = val_target.to(device)

        stgcn, training_loss, validation_loss = model_utils.train(stgcn, optimizer, lr,
                loss_criterion, epochs, patience, adj_mat, training_input, training_target, val_input, val_target, batch_size)

        
        #Plot loss
        plt.plot(training_loss, label = 'Training Loss')
        plt.plot(validation_loss, label = 'Validation Loss')
        plt.legend()
        plt.show()
        
        torch.manual_seed(1234)

        test_input = test_input.to(device)
        test_target = test_target.to(device)

        with torch.no_grad():
            results = model_utils.predict(stgcn, test_input, adj_mat)
            normalized_test_loss = model_utils.validate(stgcn, loss_criterion, test_input, test_target, adj_mat, batch_size)
            print("Normalized_test_loss: {}".format(normalized_test_loss))
            denormalized_rmse_loss = get_test_rmse(results.cpu(), test_target.cpu(), stds, means, loss_criterion).item()
            print("Denormalized_test_loss: {}".format(denormalized_rmse_loss))
            
        path = model_utils.save_model_timesteps(stgcn,optimizer,num_timesteps_input,
                                        num_timesteps_output, denormalized_rmse_loss)
        print('Saved model to {}\n\n'.format(path))
        