In [None]:
# imports

# !pip install networkx
# !pip uninstall torch
# !pip install torch_sparse
# !pip install torch_geometric_temporal

# conda install -c conda-forge pytorch_geometric pytorch matplotlib networkx tqdm notebook nb_conda_kernels jupyterlab

import os

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
import torch_geometric.nn as pyg_nn
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv
from torch_geometric.utils import to_networkx
from torch_geometric_temporal.nn.recurrent import DCRNN
from torch_geometric_temporal.signal import (StaticGraphTemporalSignal,
                                             temporal_signal_split)
from tqdm import tqdm

In [None]:
# data loading task 1

# raw_data = np.load('data/raw/calms21_task1_train.npy', allow_pickle=True)
# raw_features = raw_data[()]['annotator-id_0']['task1/train/mouse001_task1_annotator1']['keypoints']
# labels = raw_data[()]['annotator-id_0']['task1/train/mouse001_task1_annotator1']['annotations']
# features = np.swapaxes(raw_features, 2,3).reshape(-1,14,2)

In [None]:
# data loading task 3

# raw_data = np.load('data/raw/calms21_task1_train.npy', allow_pickle=True)
# raw_features = raw_data[()]['approach']['task3/approach/train/mouse001_task3_approach']['keypoints']
# labels = raw_data[()]['approach']['task3/approach/train/mouse001_task3_approach']['annotations']

In [None]:
# dataloader

class MABDataset(Dataset):
    def __init__(self, root, test=False, transform=None, pre_transform=None, pre_filter=None):
        self.test=test
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        if self.test:
            return 'calms21_task1_test.npy'
        return 'calms21_task1_train.npy'

    @property
    def processed_file_names(self):
        """ If these files are found in raw_dir, processing is skipped"""
        return 'not_processed.pt'
        # self.raw_data = np.load(self.raw_paths[0], allow_pickle=True)

        # if self.test:
        #     return [f'data_test_{i}.pt' for i in range(len(self.raw_data[()]['annotator-id_0']['task1/test/mouse071_task1_annotator1']['keypoints']))]
        # else:
        #     return [f'data_{i}.pt' for i in range(len(self.raw_data[()]['annotator-id_0']['task1/train/mouse001_task1_annotator1']['keypoints']))]

    def download(self):
        pass

    def process(self):
        self.raw_data = np.load(self.raw_paths[0], allow_pickle=True)
        if self.test:
            raw_features = self.raw_data[()]['annotator-id_0']['task1/test/mouse071_task1_annotator1']['keypoints']
            self.labels = self.raw_data[()]['annotator-id_0']['task1/test/mouse071_task1_annotator1']['annotations']
        else:
            raw_features = self.raw_data[()]['annotator-id_0']['task1/train/mouse001_task1_annotator1']['keypoints']
            self.labels = self.raw_data[()]['annotator-id_0']['task1/train/mouse001_task1_annotator1']['annotations']
        self.num_clases = 4
        features = np.swapaxes(raw_features, 2,3).reshape(-1,14,2)
        second_option = [[0, 0, 1, 1, 3, 3, 4, 5, 3, 7, 7, 8, 8, 10, 10, 11, 12], 
                         [1, 2, 3, 2, 4, 5, 6, 6, 10, 8, 9, 10, 9, 11, 12, 13, 13]]
        edge_index = torch.tensor(second_option, dtype=torch.long)
        for i in range(len(features)):
            x = torch.tensor(features[i], dtype=torch.float)
            y = torch.tensor(self.labels[i], dtype=torch.int)
            graph = Data(x=x, edge_index=edge_index, y=y)
            if self.test:
                torch.save(graph, os.path.join(self.processed_dir, f'data_test_{i}.pt'))
            else:
                torch.save(graph, os.path.join(self.processed_dir, f'data_{i}.pt'))

    def len(self):
        return len(self.labels)

    def get(self, idx):
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [None]:
# create dataset

train_dataset = MABDataset(root='./data')
test_dataset = MABDataset(root='./data',test=True)

In [None]:
# outdated dataloader

# second_option = [[0, 0, 1, 1, 3, 3, 4, 5, 3, 7, 7, 8, 8, 10, 10, 11, 12], 
#                          [1, 2, 3, 2, 4, 5, 6, 6, 10, 8, 9, 10, 9, 11, 12, 13, 13]]
# edge_index = torch.tensor(second_option, dtype=torch.int)
# dataset_outdated = StaticGraphTemporalSignal(second_option,np.ones(np.array(second_option).shape[1]),features,labels)

In [None]:
# plot the graph

G = to_networkx(train_dataset[1115], to_undirected=True)

# method 1
G = nx.DiGraph(G)
nx.draw(G)

# method 2
def visualize(h, color, epoch=None, loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])

    if torch.is_tensor(h):
        h = h.detach().cpu().numpy()
        plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
        if epoch is not None and loss is not None:
            plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    else:
        nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                         node_color=color, cmap="Set2")
    plt.show()

In [None]:
# defining the model

class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.recurrent = DCRNN(node_features, 32, 1)
        self.linear = torch.nn.Linear(32, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return h

In [None]:
# training the model

# train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)

model = RecurrentGCN(node_features = 2)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in tqdm(range(10)):
    cost = 0
    for time, snapshot in enumerate(train_dataset):
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        cost = cost + torch.mean((y_hat-snapshot.y)**2)
    cost = cost / (time+1)
    print(cost)
    cost.backward()
    optimizer.step()
    optimizer.zero_grad()

In [None]:
# testing the model

model.eval()
cost = 0
counter = 0
for time, snapshot in enumerate(test_dataset):
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    # cost = cost + torch.mean((y_hat-snapshot.y)**2)
    counter += 1 if round(float(y_hat[0][0])) == snapshot.y else 0
# cost = cost / (time+1)
# cost = cost.item()
print("accuracy: {:.4f}".format(counter/(time+1)))

In [None]:
# model for predictions

# class GAT(torch.nn.Module):
#     def __init__(self,
#                  node_dim: int,
#                  hidden_dim: int,
#                  num_layers: int,
#                  dropout: float):
#         """
#         Args
#         - node_dim: int, dimension of input node features
#         - hidden_dim: int, dimensions of hidden layers
#         - num_layers: int, # of hidden layers
#         - dropout: float, probability of dropout
#         """
#         super(GAT, self).__init__()

#         # save all of the info
#         self.node_dim = node_dim
#         self.hidden_dim = hidden_dim
#         self.num_layers = num_layers
#         self.dropout = dropout

#         # since there are a limited # of atoms, we first embed them into a
#         # higher dimension using a lookup table
#         # self.atom_encoder = AtomEncoder(emb_dim=hidden_dim)

#         # a list of GATv2 layers, with dropout
#         self.convs = nn.ModuleList()
#         self.bns = nn.ModuleList()
#         for l in range(num_layers):
#             layer = pyg_nn.GATv2Conv(in_channels=hidden_dim,
#                                      out_channels=hidden_dim,
#                                      dropout=dropout)
#             self.convs.append(layer)
#             self.bns.append(nn.BatchNorm1d(hidden_dim))

#         # fully-connected final layer
#         self.fc = nn.Linear(hidden_dim, 1)

#     def forward(self, data: pyg.data.Data) -> torch.Tensor:
#         """
#         Args
#         - data: pyg.data.Batch, a batch of graphs

#         Returns: torch.Tensor, shape [batch_size], unnormalized classification
#             probability for each graph
#         """
#         x, edge_index, edge_attr, batch = (
#             data.x, data.edge_index, data.edge_attr, data.batch)

#         # x = self.atom_encoder(x)

#         for l, conv in enumerate(self.convs):
#             x = conv(x, edge_index)
#             if l != self.num_layers - 1:
#                 x = self.bns[l](x)
#                 x = F.relu(x)

#         x = pyg_nn.global_mean_pool(x, batch=batch)
#         # x = pyg_nn.global_add_pool(x, batch=batch)
#         x = self.fc(x)
#         return x



# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = "cpu"

# model = GAT(node_dim=dataset.num_node_features,
#             hidden_dim=2,
#             num_layers=3,
#             dropout=0.3).to(device)
# # data = dataset.to(device)


# optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

# model.train()
# total_loss = 0
# all_preds = []
# all_labels = []
# train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
# loss_fn = nn.BCEWithLogitsLoss()
# for epoch in range(10):
#     for step, batch in enumerate(train_loader):

#         all_labels.append(batch.y.detach())

#         batch = batch.to(device)
#         batch_size = batch.batch.max().item()

#         preds = model(batch)
#         print(preds.shape)
#         all_preds.append(preds.detach().cpu())
#         loss = loss_fn(preds, batch.y.to(torch.float32))
#         total_loss += loss.item() * batch_size

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()