# Playground

In [2]:
%reload_ext autoreload
%autoreload 3

import torch
from torch import nn
import torch.nn.functional as F
from torchinfo import summary
import pytorch_lightning as pl
from src.b2bnet import B2BNetModel, RandomDataModule, OtkaDataModule
from pytorch_lightning.loggers import TensorBoardLogger

## LSTM Autoencoder

In [3]:
# data
import xarray as xr
from pathlib import Path
from sklearn.model_selection import train_test_split


class OtkaTimeDimSplit(pl.LightningDataModule):
    """Data module to upload input data and split it into train and validation sets.
    """
    def __init__(self,
                 data_dir: Path = Path('data/'),
                 train_ratio: float = 0.7,
                 segment_size: int = 128,
                 batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.train_ratio = train_ratio
        self.segment_size = segment_size
        self.batch_size = batch_size

    def prepare_data(self):
        # read data from file
        ds = xr.open_dataset(self.data_dir / 'otka.nc5')
        X_input = torch.from_numpy(ds['hypnotee'].values).float().permute(0, 2, 1)
        y_b2b = torch.from_numpy(ds['hypnotist'].values).float().repeat(X_input.shape[0], 1, 1).permute(0, 2, 1)
        y_class = torch.from_numpy(ds['y_class'].values)

        ds.close()

        # normalize (TODO: move to preprocessing)
        X_input = F.normalize(X_input, dim=2)
        y_b2b = F.normalize(y_b2b, dim=2)

        # cleanups
        # truncate y_b2b to match X_input
        y_b2b = y_b2b[:, :X_input.shape[1], :]

        # segment
        X_input = X_input.unfold(1, self.segment_size, self.segment_size).permute(0, 1, 3, 2)
        y_b2b = y_b2b.unfold(1, self.segment_size, self.segment_size).permute(0, 1, 3, 2)
        
        # repeat y_class to match segmentation
        y_class = y_class.reshape(-1, 1, 1).repeat(1, X_input.shape[1], 1)
        
        # create subject ids
        subject_ids = torch.arange(0, X_input.shape[0]).reshape(-1, 1, 1).repeat(1, X_input.shape[1], 1)
        
        # cut point for train/test split
        cut_point = int(X_input.shape[1] * self.train_ratio)
        
        self.train_dataset = torch.utils.data.TensorDataset(X_input[:, :cut_point, :, :].flatten(0, 1),
                                                            subject_ids[:, :cut_point, :].flatten(0, 1),
                                                            y_b2b[:, :cut_point, :, :].flatten(0, 1),
                                                            y_class[:, :cut_point, :].flatten(0, 1).squeeze(dim=1))
        
        self.val_dataset = torch.utils.data.TensorDataset(X_input[:, cut_point:, :, :].flatten(0, 1),
                                                          subject_ids[:, cut_point:, :].flatten(0, 1),
                                                          y_b2b[:, cut_point:, :, :].flatten(0, 1),
                                                          y_class[:, cut_point:, :].flatten(0, 1).squeeze(dim=1))  


    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size)


In [4]:
# Model

class Autoencoder(pl.LightningModule):
    def __init__(self, n_features, hidden_size, n_subjects, subject_embedding_dim=4):
        super().__init__()

        self.hidden_size = hidden_size
        # self.subject_embedding_dim = subject_embedding_dim
        
        # subject embedding
        # self.subject_embedding = nn.Embedding(n_subjects, subject_embedding_dim)
        self.encoder = nn.LSTM(n_features, hidden_size, batch_first=True)
        self.relu = nn.ReLU()
        # self.decoder = nn.LSTM(hidden_size, n_features, batch_first=True)
        # self.fc_decoder = nn.Linear(n_features, n_features)
        
        # # b2b head
        # self.b2b_decoder = nn.LSTM(hidden_size, n_features, batch_first=True)
        # self.fc_b2b_decoder = nn.Linear(n_features, n_features)
        
        # classifier head
        self.cls = nn.Linear(hidden_size, 2)

    def forward(self, x, subject_ids=None):
        batch_size = x.size(0)
        n_timesteps = x.size(1)
        
        # append subject embedding
        # if subject_ids is not None:
        #     subject_features = self.subject_embedding(subject_ids)
        #     subject_features = subject_features.repeat(1, x.shape[1], 1)
        #     x_emb = torch.cat([x, subject_features], dim=2)
            
        # autoencoder
        y_enc, (h_enc, c_enc) = self.encoder(x)
        x_enc = torch.rand(batch_size, n_timesteps, self.hidden_size)
        h_enc, c_enc = self.relu(h_enc), self.relu(c_enc)
        # y_dec, (h_dec, c_dec) = self.decoder(x_enc, (h_enc, c_enc))
        # y_dec = self.relu(y_dec)
        # y_dec = self.fc_decoder(y_dec)
        
        # # b2b head
        # x_enc_b2b = torch.rand(batch_size, n_timesteps, self.hidden_size)
        # y_b2b, (h_b2b, c_b2b) = self.b2b_decoder(x_enc_b2b, (h_enc, c_enc))
        # y_b2b = self.fc_b2b_decoder(y_b2b)
        
        # classifier head
        y_cls = self.cls(h_enc[-1, :, :])  # last hidden state of encoder
        
        return y_cls # , y_dec , y_b2b

    def training_step(self, batch, batch_idx):
        X, subject_ids, y_b2b, y_cls = batch
        y_cls_hat = self(X, subject_ids)
        # loss
        # loss_reconn = nn.functional.mse_loss(X_recon, X)
        # loss_b2b = nn.functional.mse_loss(y_b2b_hat, y_b2b)
        loss_cls = nn.functional.cross_entropy(y_cls_hat, y_cls)
        loss = loss_cls # + loss_b2b + loss_reconn
        #logging
        # self.log('train/loss_reconn', loss_reconn)
        # self.log('train/loss_b2b', loss_b2b)
        self.log('train/loss_cls', loss_cls)
        self.log('train/accuracy', (y_cls_hat.argmax(dim=1) == y_cls).float().mean())
        self.log('train/loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        X, subject_ids, y_b2b, y_cls = batch
        y_cls_hat = self(X, subject_ids)
        # loss
        # loss_reconn = nn.functional.mse_loss(X_recon, X)
        # loss_b2b = nn.functional.mse_loss(y_b2b_hat, y_b2b)
        loss_cls = nn.functional.cross_entropy(y_cls_hat, y_cls)
        loss = loss_cls # + loss_reconn + loss_b2b
        #logging
        # self.log('val/loss_reconn', loss_reconn)
        # self.log('val/loss_b2b', loss_b2b)
        self.log('val/loss_cls', loss_cls)
        self.log('val/accuracy', (y_cls_hat.argmax(dim=1) == y_cls).float().mean())
        self.log('val/loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-2)

In [5]:
# Experiment

segment_size = 120 * 3  # 3sec
batch_size = 256
n_features = 59
hidden_size = 59
max_epochs = 100

datamodule = OtkaTimeDimSplit(segment_size=segment_size, batch_size=batch_size)

model = Autoencoder(n_features=n_features, hidden_size=hidden_size, n_subjects=51)

trainer = pl.Trainer(max_epochs=max_epochs,accelerator='cpu', log_every_n_steps=1)

trainer.fit(model, datamodule=datamodule)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name    | Type   | Params
-----------------------------------
0 | encoder | LSTM   | 28.3 K
1 | relu    | ReLU   | 0     
2 | cls     | Linear | 120   
-----------------------------------
28.4 K    Trainable params
0         Non-trainable params
28.4 K    Total params
0.114     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


: 

## GNN
TODO: This code receives multiple timeseries, transforms them into graphs, and then applies a GNN to them. The graph embeddings are then used for downstream tasks.