In [3]:
import sys

sys.path.append('../GraphStructureLearning')

In [24]:
import torch
import torch.nn as nn
from torch.nn import functional as F

from models.GTS.DCRNN import DCRNN

from torch_geometric.utils import to_dense_adj, dense_to_sparse
from torch_geometric_temporal.dataset import METRLADatasetLoader, PemsBayDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

from glob import glob
import yaml
from easydict import EasyDict as edict

In [15]:
config_file = glob('./config/GTS/test.yaml')[0]
config = edict(yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader))

In [18]:
metr_la = np.load('./data/METR-LA/node_values.npy')
adj_matrix = np.load('./data/METR-LA/adj_mat.npy')

loader = METRLADatasetLoader(raw_data_dir='./data/METR-LA')

In [26]:
config.dataset

{'root': './data/spike_lambda_bin100',
 'name': 'spike_lambda_bin100',
 'graph_learning_length': 4800,
 'idx_ratio': 0.5,
 'window_size': 20,
 'slide': 5,
 'pred_step': 5,
 'train_valid_test': [4000, 4400, 4800],
 'save': './data/spike_lambda_bin100/'}

In [65]:
config.hidden_dim = 16
config.encoder_step = 12
config.decoder_step = 3
config.dataset.pred_step = 3
config.embedding_dim = 1

In [66]:
total_length, num_nudes, num_features = metr_la.shape

In [67]:
edge_index, edge_attr = dense_to_sparse(torch.Tensor(adj_matrix))

In [68]:
dataset = loader.get_dataset(num_timesteps_in=12, num_timesteps_out=3)

In [69]:
class DecoderModel(nn.Module):
    def __init__(self, config):
        super(DecoderModel, self).__init__()
        self.output_dim = config.dataset.pred_step
        self.hidden_dim = config.hidden_dim

        self.decoder_dcrnn = DCRNN(config)
        self.prediction_layer = nn.Linear(self.hidden_dim, self.output_dim)

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)

    def forward(self, inputs, adj, hidden_state, weight_matrix=None):
        decoder_hidden_state = hidden_state

        decoder_hidden_state = self.decoder_dcrnn(inputs, adj, hidden_state=decoder_hidden_state, weight_matrix=weight_matrix)
        prediction = self.prediction_layer(decoder_hidden_state[-1].view(-1, self.hidden_dim))

        output = prediction.view(inputs.shape[0], self.output_dim)

        return output


class GTS_Forecasting_Module(nn.Module):
    def __init__(self, config):
        super(GTS_Forecasting_Module, self).__init__()

        self.config = config

        self.nodes_num = config.nodes_num

        self.nodes_feas = config.node_features

        self.encoder_step = config.encoder_step
        self.decoder_step = config.decoder_step

        self.encoder_model = DCRNN(config)
        self.decoder_model = DecoderModel(config)


    def forward(self, inputs, targets, adj_matrix, weight_matrix=None):
        # DCRNN encoder
        encoder_hidden_state = None
        for i in range(self.encoder_step):
            encoder_hidden_state = self.encoder_model(inputs[:,i].unsqueeze(dim=-1), adj_matrix,
                                                      encoder_hidden_state, weight_matrix)

        # DCRNN decoder
        outputs = []
        decoder_input = torch.zeros((self.nodes_num, 1))
        
        decoder_hidden_state = encoder_hidden_state
        for j in range(self.decoder_step):
            output = self.decoder_model(decoder_input, adj_matrix, decoder_hidden_state, weight_matrix)
            outputs.append(output)
            
            decoder_input = targets[j]

        outputs = torch.cat(outputs, dim=-1)
        return outputs


In [70]:
model = GTS_Forecasting_Module(config)

In [71]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fnc = nn.L1Loss()
model.train()

GTS_Forecasting_Module(
  (encoder_model): DCRNN(
    (recurrent): ModuleList(
      (0): DCRNN(
        (conv_x_z): DConv(17, 16)
        (conv_x_r): DConv(17, 16)
        (conv_x_h): DConv(17, 16)
      )
    )
  )
  (decoder_model): DecoderModel(
    (decoder_dcrnn): DCRNN(
      (recurrent): ModuleList(
        (0): DCRNN(
          (conv_x_z): DConv(17, 16)
          (conv_x_r): DConv(17, 16)
          (conv_x_h): DConv(17, 16)
        )
      )
    )
    (prediction_layer): Linear(in_features=16, out_features=3, bias=True)
  )
)

In [75]:
input_x.shape

torch.Size([1, 207, 12])

In [76]:
for epoch in range(2):
    loss = 0
    
    for batch in dataset:
        input_x = batch.x[:,0,:]
        target = batch.y
    
        y_hat = model(input_x, adj_matrix = batch.edge_index, targets=target)
        loss = loss + loss_fnc(y_hat, target)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 100 but got size 207 for tensor number 1 in the list.