In [329]:
import sys

sys.path.append('../GraphStructureLearning')

In [330]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import pickle
import matplotlib.pyplot as plt

from glob import glob
import yaml
from easydict import EasyDict as edict

In [331]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

from torch_geometric.utils.random import erdos_renyi_graph

In [332]:
from utils.utils import build_fully_connected_edge_idx

In [333]:
config_file = glob('./config/GTS/CCN_Project/test.yaml')[0]
config = edict(yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader))
dataset_conf = config.dataset

In [334]:
dataset = pickle.load(open('./data/CCN/monkeydata.pickle', 'rb'))

In [335]:
dataset = dataset.reshape(-1, 101, 742).numpy()

In [336]:
from sklearn.preprocessing import StandardScaler

In [337]:
# scaler_x = StandardScaler()
# scaler_y = StandardScaler()
# scaler_z = StandardScaler()

# scaler_x.fit(dataset[:, -3, :])
# scaler_y.fit(dataset[:, -2, :])
# scaler_z.fit(dataset[:, -1, :])

# dataset[:, -3, :] = scaler_x.transform(dataset[:, -3, :])
# dataset[:, -2, :] = scaler_y.transform(dataset[:, -2, :])
# dataset[:, -1, :] = scaler_z.transform(dataset[:, -1, :])

In [338]:
train_x = dataset[:80, :98, :]
train_y = dataset[:80, 98:, :]

valid_x = dataset[80:90, :98, :]
valid_y = dataset[80:90, 98:, :]

test_x = dataset[90:, :98, :]
test_y = dataset[90:, 98:, :]

In [340]:
edge_index = erdos_renyi_graph(config.nodes_num, 0.01)

In [341]:
def prepare_dataset(dataset, dataset_conf, edge_index):
    x = dataset[0]
    y = dataset[1]
    
    valid_sampling_locations = []
    valid_sampling_locations += [
        (dataset_conf.window_size + i)
        for i in range(dataset_conf.total_time_length - dataset_conf.window_size + 1)
        if (i % dataset_conf.slide) == 0
    ]

    data_list = []

    for trial in range(len(x)):
        spike_data = []
        trajectory_data = []
        for start_idx in valid_sampling_locations:
            spike_inputs = x[trial, :, start_idx - dataset_conf.window_size:start_idx]
            target_trajectory = y[trial, :, start_idx - dataset_conf.slide:start_idx]

            spike_data.append(spike_inputs)
            trajectory_data.append(target_trajectory)

        spike_inputs = np.stack(spike_data, axis=1)
        target_trajectory = np.stack(trajectory_data, axis=1)

        data_item = Data(x=torch.FloatTensor(spike_inputs), edge_index=edge_index, y=torch.FloatTensor(target_trajectory))
        data_list.append(data_item)

    return data_list

In [342]:
dataset_conf.window_size = 202
dataset_conf.slide = 10

In [343]:
train_dataset = prepare_dataset([train_x, train_y], dataset_conf, edge_index)
valid_dataset = prepare_dataset([valid_x, valid_y], dataset_conf, edge_index)
test_dataset = prepare_dataset([test_x, test_y], dataset_conf, edge_index)

In [344]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=True)

In [345]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN

class RecurrentGCN(torch.nn.Module):
    def __init__(self, config):
        super(RecurrentGCN, self).__init__()
        self.num_nodes = config.nodes_num
        self.nodes_feas = config.node_features

        self.kernel_size = config.embedding.kernel_size
        self.stride = config.embedding.stride
        self.conv1_dim = config.embedding.conv1_dim
        self.conv2_dim = config.embedding.conv2_dim

        self.embedding_dim = config.embedding.embedding_dim

        self.conv1 = torch.nn.Conv1d(self.nodes_feas, self.conv1_dim, self.kernel_size, stride=self.stride)
        self.conv2 = torch.nn.Conv1d(self.conv1_dim, self.conv2_dim, self.kernel_size, stride=self.stride)

        self.bn1 = torch.nn.BatchNorm1d(self.conv1_dim)
        self.bn2 = torch.nn.BatchNorm1d(self.conv2_dim)
        
        self.recurrent = DCRNN(config.embedding.embedding_dim, 16, 1)
        
        self.conv3 = torch.nn.Conv1d(1, 1, 16, stride=1)
        
        self.out_1 = torch.nn.Linear(98, 10)
        self.out_2 = torch.nn.Linear(98, 10)
        self.out_3 = torch.nn.Linear(98, 10)

    def forward(self, x, edge_index, hidden_state):
        batch_nodes = x.shape[0]
        if len(x.shape) == 2:
            inputs = x.reshape(batch_nodes, 1, -1)
        x = self.conv1(inputs)
        x = F.relu(x)
        x = self.bn1(x)
#         print(x.shape)
        
        x = self.conv2(x)
        x = F.relu(x)
        x = self.bn2(x)
#         print(x.shape)

        x = x.view(batch_nodes, -1) 
#         print(x.shape)
        
        
        h = self.recurrent(x, edge_index, H=hidden_state)
        h = F.relu(h)
#         print(h.shape)
        
        h_ = h.view(batch_nodes, 1, self.embedding_dim)
#         print(h_.shape)
        forecast_h = self.conv3(h_)
        forecast_h = F.relu(forecast_h)
#         print(forecast_h.shape)
        
        forecast_h = forecast_h.view(-1)
#         print(forecast_h.shape)
        
        out_x = self.out_1(forecast_h)
        out_y = self.out_2(forecast_h)
        out_z = self.out_3(forecast_h)
        
        
        return out_x, out_y, out_z, h

In [352]:
from tqdm import tqdm

model = RecurrentGCN(config)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in tqdm(range(2)):
    for data in train_loader:
        H = None
        cost = 0
        
        for time in range(data.x.shape[1]):
            cost = 0
            out_x, out_y, out_z, H = model(data.x[:, time, :], data.edge_index, H)
            
            cost_x = torch.mean((out_x-data.y[0,time, :])**2)
            cost_y = torch.mean((out_y-data.y[1,time, :])**2)
            cost_z = torch.mean((out_z-data.y[2,time, :])**2)
        
            cost += cost + cost_x + cost_y + cost_z
        
        cost.backward()
        optimizer.step()
        optimizer.zero_grad()

  0%|                                                     | 0/2 [00:00<?, ?it/s]


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.