# Playground

## GNN
This code receives multiple timeseries, transforms them into graphs, and then applies a GNN to them.
The graph embeddings are then used for downstream tasks.

In [19]:
%reload_ext autoreload
%autoreload 3

import torch
from torch import nn
import torch.nn.functional as F
from torchinfo import summary
import pytorch_lightning as pl
from src.b2bnet import B2BNetModel, RandomDataModule, OtkaDataModule
from pytorch_lightning.loggers import TensorBoardLogger

In [18]:
import xarray as xr

ds = xr.open_dataset('data/otka.nc5')
X_input = torch.from_numpy(ds['hypnotee'].values).float().permute(0, 2, 1)

(torch.Size([51, 39707, 59]), torch.Size([51, 39707, 59]))

In [None]:
segment_size = 120*1
batch_size = 32

datamodule = OtkaDataModule(segment_size=segment_size, batch_size=batch_size)
model = B2BNetModel(
    input_size=59,
    n_timesteps=segment_size,
    n_cls_labels=2,
    hidden_size=32,
    n_subjects=51)

# summary(model)
# TRAINING
trainer = pl.Trainer(max_epochs=100, accelerator='auto', log_every_n_steps=1, enable_checkpointing=True)
trainer.fit(model, datamodule=datamodule, ckpt_path='last')

## TCN

In [None]:
%reload_ext autoreload
%autoreload 3

import torch
from src.b2bnet import TCN
from torchinfo import summary

In [None]:
batch_size = 32
n_timesteps = 128 * 320
n_features = 59
X = torch.randn(batch_size, n_timesteps, n_features).permute(0, 2, 1)
model = TCN(n_timesteps, output_length=128, n_features=n_features, kernel_size=8, dilation_base=2)
model(X).shape