In [2]:
import torch
import os
import matplotlib.pyplot as plt
os.chdir("..")

from dreamy.models.SpatialTemporal.STGCN import STGCN
from dreamy.models.SpatialTemporal.ATMGNN import MPNN_LSTM, ATMGNN
from dreamy.models.SpatialTemporal.STAN import STAN


from dreamy.data import UniversalDataset
from dreamy.utils import utils

# initial settings
device = torch.device('cpu')
torch.manual_seed(7)

lookback = 12 # inputs size
horizon = 3 # predicts size

# permutation is True when using STGCN
permute = True

epochs = 50 # training epochs
batch_size = 50 # training batch size

ImportError: cannot import name 'GCN' from partially initialized module 'dreamy.models.Spatial' (most likely due to a circular import) (E:\BaiduSyncdisk\Project\Dreamy\dreamy\models\Spatial\__init__.py)

In [None]:
# load toy dataset
dataset = UniversalDataset()
dataset.load_toy_dataset()

# preprocessing
features, mean, std = utils.normalize(dataset.x)
adj_norm = utils.normalize_adj(dataset.graph)

features = features.to(device)
adj_norm = adj_norm.to(device)

# prepare datasets
train_rate = 0.6 
val_rate = 0.2

split_line1 = int(features.shape[0] * train_rate)
split_line2 = int(features.shape[0] * (train_rate + val_rate))

train_original_data = features[:split_line1, :, :]
val_original_data = features[split_line1:split_line2, :, :]
test_original_data = features[split_line2:, :, :]

train_input, train_target = dataset.generate_dataset(X=train_original_data, Y=train_original_data[:, :, 0], lookback_window_size=lookback, horizon_size=horizon, permute=permute)
val_input, val_target = dataset.generate_dataset(X=val_original_data, Y=val_original_data[:, :, 0], lookback_window_size=lookback, horizon_size=horizon, permute=permute)
test_input, test_target = dataset.generate_dataset(X=test_original_data, Y=test_original_data[:, :, 0], lookback_window_size=lookback, horizon_size=horizon, permute=permute)

# prepare model
# model = STGCN(
#             num_nodes=adj_norm.shape[0],
#             num_features=train_input.shape[3],
#             num_timesteps_input=lookback,
#             num_timesteps_output=horizon
#             ).to(device=device)

# model = MPNN_LSTM(
#                 num_nodes=adj_norm.shape[0],
#                 num_features=train_input.shape[3],
#                 num_timesteps_input=lookback,
#                 num_timesteps_output=horizon,
#                 nhid=4
#                 ).to(device=device)

# model = ATMGNN(
#                 num_nodes=adj_norm.shape[0],
#                 num_features=train_input.shape[3],
#                 num_timesteps_input=lookback,
#                 num_timesteps_output=horizon,
#                 nhid=4
#                 ).to(device=device)

model = STAN(
                num_nodes=adj_norm.shape[0],
                num_features=train_input.shape[3],
                num_timesteps_input=lookback,
                num_timesteps_output=horizon,
                nhid=4
                ).to(device=device)

In [None]:
# training
model.fit(
        train_input=train_input, 
        train_target=train_target, 
        graph=adj_norm, 
        val_input=val_input, 
        val_target=val_target, 
        verbose=True,
        batch_size=batch_size,
        epochs=epochs)

In [None]:
# evaluate
out = model.predict(feature=test_input, graph=adj_norm)
preds = out.detach().cpu()*std[0]+mean[0]
targets = test_target.detach().cpu()*std[0]+mean[0]
# MAE
mae = utils.get_MAE(preds, targets)
print(f"MAE: {mae.item()}")

In [None]:
# visualization
out = model.predict(feature=train_input, graph=adj_norm).detach().cpu()

sample = 28

plt.figure(figsize=(15 ,5))
for i in range(1, 4):
    sample_input=train_input[sample, i, :, 0]
    sample_output=out[sample, i, :]
    sample_target=train_target[sample, i, :]

    vis_data = torch.cat([sample_input, sample_target]).numpy()
    
    plt.subplot(1, 3, i)
    rng = list(range(lookback+horizon))
    plt.plot(rng, vis_data, label="ground truth")
    plt.plot(rng[lookback:lookback+horizon], sample_output.numpy(), label="prediction")
    plt.legend()


plt.show()