# This is a demo + test notebook for GraphTranslatorModule

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from GraphTranslatorModule import GraphTranslatorModule
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
import torch
from torch.nn import functional as F
from pl_bolts.datasets import DummyDataset
import random
import numpy as np

## A. This section is a primer for LightningModule

### This is a lightning module instead of a torch.nn.Module. A LightningModule is equivalent to a pure PyTorch Module except it has added functionality. However, you can use it EXACTLY the same as you would a PyTorch Module.

net = GraphTranslatorModule()
x = torch.randn(1, 1, 3, 3)
out = net(x)

trainer = Trainer(max_epochs=3)
trainer.fit(net, torch.utils.data.DataLoader(DummyDataset((1,3,3),(1,), num_samples=100), num_workers=8))

trainer.test(net, torch.utils.data.DataLoader(DummyDataset((1,3,3),(1,), num_samples=10), num_workers=8))

## B. Dimension checks for the model functions

### Test model with Dummy data

In [None]:
class DummyClassificationDataset():
    def __init__(self, n_nodes, n_len, e_len, c_len, num_samples: int = 100):
        self.n_nodes = n_nodes
        self.n_len = n_len
        self.e_len = e_len
        self.c_len = c_len
        self.num_samples = num_samples
    
    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx: int):
        edges = torch.randn(self.n_nodes, self.n_nodes, self.e_len)
        nodes = torch.randn(self.n_nodes, self.n_len)
        context = torch.randn(self.c_len)
        y = torch.randn(self.n_nodes, self.n_nodes, self.e_len)
        return (edges.contiguous(), nodes.contiguous(), context.contiguous(), context.contiguous(), y.contiguous())

n_nodes = 5
n_len = 2
e_len = 3
c_len = 1
model = GraphTranslatorModule(num_nodes=n_nodes, node_feature_len=n_len, edge_feature_len=e_len, context_len=c_len)
dataset = DummyClassificationDataset(n_nodes, n_len, e_len, c_len)
dl = torch.utils.data.DataLoader(dataset, num_workers=8, batch_size=3)

trainer = Trainer(max_epochs=3, log_every_n_steps=1, flush_logs_every_n_steps=1)
trainer.fit(model, dl)

### Test model with example graph and timestamp data

Logging is done with Weights and Biases and can be viewed on the browser at [this link](https://wandb.ai/maithili/GraphTrans)

In [None]:
from Reader import RoutinesDataset
from analyzers import *

data = RoutinesDataset(data_path='data/routines_1029/sample.json', classes_path='data/routines_1029/classes.json')

logging_analyzers = [MeanLoss(), EdgeTypeLoss(data.get_edge_classes()), ChangedEdgeLoss(), ChangedEdgeWeightedLoss()]

model = GraphTranslatorModule(num_nodes=data.n_nodes, 
                              node_feature_len=data.n_len, 
                              edge_feature_len=data.e_len, 
                              context_len=data.c_len, 
                              train_analyzer=ChangedEdgeWeightedLoss(), 
                              logging_analyzers=logging_analyzers)


wandb_logger = WandbLogger()
trainer = Trainer(max_epochs=50, logger=wandb_logger)
trainer.fit(model, data.get_train_loader())
trainer.test(model, data.get_test_loader())