# This is a demo + test notebook for GraphTranslatorModule

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from GraphTranslatorModule import GraphTranslatorModule
from pytorch_lightning import Trainer
import torch
from torch.nn import functional as F
from pl_bolts.datasets import DummyDataset
import random
import numpy as np

## A. This section is a primer for LightningModule

### This is a lightning module instead of a torch.nn.Module. A LightningModule is equivalent to a pure PyTorch Module except it has added functionality. However, you can use it EXACTLY the same as you would a PyTorch Module.

net = GraphTranslatorModule()
x = torch.randn(1, 1, 3, 3)
out = net(x)

trainer = Trainer(max_epochs=3)
trainer.fit(net, torch.utils.data.DataLoader(DummyDataset((1,3,3),(1,), num_samples=100), num_workers=8))

trainer.test(net, torch.utils.data.DataLoader(DummyDataset((1,3,3),(1,), num_samples=10), num_workers=8))

## B. Dimension checks for the model functions

### Test model with Dummy data

In [None]:
class DummyClassificationDataset():
    def __init__(self, n_nodes, n_len, e_len, c_len, num_samples: int = 100):
        self.n_nodes = n_nodes
        self.n_len = n_len
        self.e_len = e_len
        self.c_len = c_len
        self.num_samples = num_samples
    
    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx: int):
        edges = torch.randn(self.n_nodes, self.n_nodes, self.e_len)
        nodes = torch.randn(self.n_nodes, self.n_len)
        context = torch.randn(self.c_len)
        y = torch.randn(self.n_nodes, self.n_nodes, self.e_len)
        return (edges.contiguous(), nodes.contiguous(), context.contiguous(), context.contiguous(), y.contiguous())

n_nodes = 5
n_len = 2
e_len = 3
c_len = 1
model = GraphTranslatorModule(num_nodes=n_nodes, node_feature_len=n_len, edge_feature_len=e_len, context_len=c_len)
dataset = DummyClassificationDataset(n_nodes, n_len, e_len, c_len)
dl = torch.utils.data.DataLoader(dataset, num_workers=8, batch_size=3)

trainer = Trainer(max_epochs=3, log_every_n_steps=1, flush_logs_every_n_steps=1)
trainer.fit(model, dl)

### Test model with example graph and timestamp data

In [None]:
from Reader import read_data

class SingleDataset():
    def __init__(self, data_dir: str = 'data/example', classes_path: str = 'data/example/classes.json'):
        self.data = read_data(data_dir = data_dir, classes_path = classes_path)
        edges, nodes, context_curr, context_query, y = self.data[0]
        self.n_nodes = edges.size()[1]
        self.n_len = nodes.size()[1]
        self.e_len = edges.size()[-1]
        self.c_len = context_curr.size()[0]
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        # adjacency, edges, nodes, context_curr, context_query, y_exist, y_typ = self.data[0]
        # assert edges.size() == torch.Size([45, 45, 3]), edges.size()
        # assert nodes.size() == torch.Size([45, 45]), nodes.size()
        # assert context_curr.size() == torch.Size([10]), context_curr.size()
        # assert context_query.size() == torch.Size([10]), context_query.size()
        # assert y.size() == torch.Size([45, 45, 3]), y_typ.size()
        return self.data[idx]

d = SingleDataset()
dl = torch.utils.data.DataLoader(d, num_workers=8, batch_size=3)
model = GraphTranslatorModule(num_nodes=d.n_nodes, node_feature_len=d.n_len, edge_feature_len=d.e_len, context_len=d.c_len)

trainer = Trainer(max_epochs=3, log_every_n_steps=1, flush_logs_every_n_steps=1)
trainer.fit(model, dl)