# SpaceTime Playground

In [2]:
%reload_ext autoreload
%autoreload 3

import torch
import torch.nn.functional as F
from torchinfo import summary
import pytorch_lightning as pl
from src.b2bnet import B2BNetSpaceTimeModel, OtkaDataModule, OtkaTimeDimSplit
from pytorch_lightning.loggers import TensorBoardLogger

In [None]:
# Experiment

segment_size = 120 * 3  # 3sec
batch_size = 256
n_channels = 59
n_features = 8
hidden_size = 8
max_epochs = 100

datamodule = OtkaTimeDimSplit(segment_size=segment_size, batch_size=batch_size)

model = B2BNetSpaceTimeModel(
    n_channels=n_channels, n_features=n_features,
    hidden_size=hidden_size, kernel_size=1, n_subjects=51)

trainer = pl.Trainer(max_epochs=max_epochs,accelerator='cpu', log_every_n_steps=1, deterministic='warn')

trainer.fit(model, datamodule=datamodule)

## spatioTemporal test
this code is used to first encode the spatial characteristics of EEG data (positional information) and then encode the temporal characteristics of EEG data (time information). The output of this code is a 3D matrix of size (number of channels, number of time points, number of spatial bins). This code is used to generate the data for the spatioTemporal decoding.

In [2]:
import torch
from torch import nn

In [3]:
# generate random DATA
input = torch.randn(1, 120, 59).permute(0, 2, 1)  # torch.Size([256, 360, 59])

# model
# encoder
m = nn.Sequential(
    nn.Conv1d(59, 30, 1, stride=1),
    nn.ReLU(),
    nn.Conv1d(30, 15, 1, stride=1),
    nn.ReLU(),
)
output = m(input)

m = nn.LSTM(15, 10, batch_first=True)  # num_layers can be detemined by hyperparameter search
output = m(output.permute(0, 2, 1))

# decoder
m = nn.LSTM(10, 15, batch_first=True)
output = m(output[0])
m = nn.Sequential(
    nn.ConvTranspose1d(15, 30, 1, stride=1),
    nn.ReLU(),
    nn.ConvTranspose1d(30, 59, 1, stride=1),
)
output = m(output[0].permute(0, 2, 1))

output.shape

torch.Size([1, 59, 120])