In [1]:
import sys

sys.path.append('../GraphStructureLearning')

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import pickle
import matplotlib.pyplot as plt

In [3]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

In [4]:
dataset = pickle.load(open('./data/CCN/monkeydata.pickle', 'rb'))

In [5]:
dataset.shape

torch.Size([8, 100, 101, 742])

In [6]:
dataset = dataset.reshape(-1, 101, 742)

In [7]:
train_valid_test = dataset[:, :98, :]
target = dataset[:, 98:, :]

In [None]:
class EncoderModel(nn.Module):
    def __init__(self, config):
        super(EncoderModel, self).__init__()

        # self.embedding_layer = nn.Linear(config.dataset.window_size, config.embedding_dim)
        self.num_nodes = config.nodes_num

        self.kernel_size = config.embedding.kernel_size
        self.stride = config.embedding.stride
        self.conv1_dim = config.embedding.conv1_dim
        self.conv2_dim = config.embedding.conv2_dim
        self.fc_dim = config.embedding.fc_dim

        self.nodes_feas = config.node_features

        self.conv1 = torch.nn.Conv1d(self.nodes_feas, self.conv1_dim, self.kernel_size, stride=self.stride)
        self.conv2 = torch.nn.Conv1d(self.conv1_dim, self.conv2_dim, self.kernel_size, stride=self.stride)
        self.fc = torch.nn.Linear(self.fc_dim, 1)

        self.bn1 = torch.nn.BatchNorm1d(self.conv1_dim)
        self.bn2 = torch.nn.BatchNorm1d(self.conv2_dim)

        self.encoder_dcrnn = DCRNN(config)

    def forward(self, inputs, adj, hidden_state=None):
        batch_nodes = inputs.shape[0]
        if len(inputs.shape) == 2:
            inputs = inputs.reshape(batch_nodes, 1, -1)
        x = self.conv1(inputs)
        x = F.relu(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.bn2(x)

        x = x.view(batch_nodes, -1)

        x = self.fc(x)
        x = F.relu(x)

        hidden_state = self.encoder_dcrnn(x, adj, hidden_state)
        return hidden_state

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN

class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.recurrent = DCRNN(node_features, 16, 1)
        self.linear = torch.nn.Linear(16, 3)

    def forward(self, x, edge_index):
        h = self.recurrent(x, edge_index)
        h = F.relu(h)
        h = self.linear(h)
        return h