In [None]:
%config Completer.use_jedi = False
%load_ext autoreload
# %reload_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.insert(0, "../")

In [None]:
import numpy as np

## Model implementation

### Generating synthetic data

In [None]:
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal, norm

In [None]:
# distribution of class 1
# from: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.multivariate_normal.html
rv1 = multivariate_normal([4, 2], [[2.0, 0], [0, 2.0]])

x, y = np.mgrid[-6:6:.01, -5:5:.01]
pos = np.dstack((x, y))

fig2 = plt.figure()
ax2 = fig2.add_subplot(111)
ax2.contourf(x, y, rv1.pdf(pos))

In [None]:
# distribution of class 2
rv2 = multivariate_normal([-4, -2], [[2.0, 0], [0, 2.0]])

x, y = np.mgrid[-6:6:.01, -5:5:.01]
pos = np.dstack((x, y))

fig2 = plt.figure()
ax2 = fig2.add_subplot(111)
ax2.contourf(x, y, rv2.pdf(pos))

In [None]:
eeg_electrode_positions = {
    # row 1
    "Nz": (0, 5),
    # row 2
    "Fp1": (-1 ,4), "Fpz": (0, 4), "Fp2": (1, 4),
    # row 3
    "AF7": (-2, 3), "AF3": (-1, 3), "AFz": (0, 3), "AF4": (1, 3), "AF8": (2, 3),
    # row 4
    "F9": (-5, 2), "F7": (-4, 2), "F5": (-3 ,2), "F3": (-2, 2), "F1": (-1, 2), "Fz": (0, 2),
    "F2": (1, 2), "F4": (2, 2), "F6": (3, 2), "F8": (4, 2), "F10": (5, 2),
    # row 5
    "FT9": (-5, 1), "FT7": (-4, 1), "FC5": (-3 ,1), "FC3": (-2, 1), "FC1": (-1, 1), "FCz": (0, 1),
    "FC2": (1, 1), "FC4": (2, 1), "FC6": (3, 1), "FC8": (4, 1), "FC10": (5, 1),
    # row 6
    "A1": (-6, 0), "T9": (-5, 0), "T7": (-4, 0), "C5": (-3, 0), "C3": (-2, 0), "C1": (-1, 0), "Cz": (0, 0), 
    "C2": (1, 0), "C4": (2, 0), "C6": (3, 0), "T8": (4, 0), "T10": (5, 0), "A2": (6, 0), 
    # row 7
    "TP9": (-5, -1), "TP7": (-4, -1), "CP5": (-3, -1), "CP3": (-2, -1), "CP1": (-1, -1), "CPz": (0, -1), 
    "CP2": (1, -1), "CP4": (2, -1), "CP6": (3, -1), "TP8": (4, -1), "TP10": (5, -1), 
    # row 8
    "P9": (-5, -2), "P7": (-4, -2), "P5": (-3, -2), "P3": (-2, -2), "P1": (-1, -2), "Pz": (0, -2), 
    "P2": (1, -2), "P4": (2, -2), "P6": (3, -2), "P8": (4, -2), "P10": (5, -2),
    # row 9
    "PO7": (-2, -3), "PO3": (-1, -3), "POz": (0, -3), "PO4": (1, -3), "PO8": (2, -3), 
    # row 10
    "O1": (-1, -4), "Oz": (0, -4), "O2": (1, -4), 
    # row 11
    "Iz": (0, -5), 
}

In [None]:
# Generating 100 samples from class 1
eeg_data = {key: [] for key in eeg_electrode_positions.keys()}
labels = []
for i in range(10):
    # Generating 1 sec recording
    for key, value in eeg_electrode_positions.items():
        eeg_data[key].append(np.abs(np.random.normal(rv1.pdf(value), 0.01, 256)))
    labels.append(1)

In [None]:
# Generating another 100 samples from class 2
for i in range(10):
    # Generating 1 sec recording
    for key, value in eeg_electrode_positions.items():
        eeg_data[key].append(np.abs(np.random.normal(rv2.pdf(value), 0.01, 256)))
    labels.append(2)

In [None]:
for key, value in eeg_electrode_positions.items():
    eeg_data[key] = np.expand_dims(np.array(eeg_data[key]), 1)

### Modeules

In [None]:
import sys
del sys.modules["pase_eeg"]

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from pase_eeg.nn.modules import SincBlock, ResBlock

from torch import Tensor
from typing import Type, Any, Callable, Union, List, Optional

#### Channels-wise frequency domain feature extractor

In [None]:
class WaveFe(nn.Module):
    """
    """

    def __init__(
        self,
        channel_positions,
        out_shape,
        name="WaveFe",
    ):
        super().__init__()
        self.channel_positions = channel_positions
        self.out_shape = out_shape
        
        self.sinc = SincBlock(in_channels=1,
                              out_channels=64,
                              kernel_size=5,
                              stride=1,
                              padding="valid",
                              activation="prelu",
                              norm_type="bnorm",
                              sr=256,
                             )
        self.blocks = nn.ModuleList()
        self.blocks.append(ResBlock(in_channels=64,
                          out_channels=64,
                          kernel_size=3,
                          padding="valid",
                          dilations=[1, 2],
                          activation="prelu",
                          norm_type="bnorm",
                        ))
        
        self.blocks.append(ResBlock(in_channels=64,
                          hidden_channels=64,
                          out_channels=1,
                          kernel_size=3,
                          padding="valid",
                          dilations=[1, 2],
                          activation="prelu",
                          norm_type="bnorm",
                        ))

    def forward(self, batch_dict, device=None, mode=None):
        values = list(batch_dict.values())
        batch_size = values[0].size()[0]
        x = torch.vstack(values)
        
        x = self._forward_feature_extractor(x)
        
        x = self._flatten_by_position_2d(x, batch_size)
                                         
        return x
    
    def _forward_feature_extractor(self, x):
        x = self.sinc(x)
        # print(x.size())
        
        for block in self.blocks:
            x = block(x)
            # print(x.size())
        
        return x
    
    def _flatten_by_position_2d(self, x, batch_size):
        res = torch.zeros((batch_size, *self.out_shape, x.size()[-1]), device=x.device)
        for i, value in enumerate(self.channel_positions.values()):
            idx = self._position_to_index(value, self.out_shape)
            res[:, idx[0], idx[1], :] = x[i*batch_size:(i+1)*batch_size, :, :].squeeze()

            
        return res
            
    def _position_to_index(self, position, shape):

    
        center = (shape[0] // 2, shape[1] // 2)
        index = (center[0] - position[1], center[1] + position[0])
        return index

#### spacial domain feature extractor with classifier

In [None]:
class BasicBlock(nn.Module):
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        dropout: float = None,
    ) -> None:
        super(BasicBlock, self).__init__()

        self.dropout = dropout

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(
            inplanes,
            planes,
            kernel_size=3,
            stride=stride,
            padding=dilation,
            groups=groups,
            bias=False,
            dilation=dilation,
        )
        self.bn1 = norm_layer(planes)
        if dropout is not None:
            self.dropout1 = nn.Dropout2d(p=dropout)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(
            planes,
            planes,
            kernel_size=3,
            stride=1,
            padding=dilation,
            groups=groups,
            bias=False,
            dilation=dilation,
        )
        self.bn2 = norm_layer(planes)
        if dropout is not None:
            self.dropout2 = nn.Dropout2d(p=dropout)
        # self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        if self.dropout is not None:
            out = self.dropout1(out)

        out = self.conv2(out)
        out = self.bn2(out)

        # if self.downsample is not None:
        #     identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        if self.dropout is not None:
            out = self.dropout2(out)

        return out

In [None]:
class BlockCls2d(nn.Module):
    """
    """

    def __init__(
        self,
        in_channels: int,
        num_classes: int,
        name="EEG-Classifier",
    ):
        super().__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.inplanes = 64
        
        self.conv1 = nn.Conv2d(
            in_channels, self.inplanes, kernel_size=1, stride=1, padding=0, bias=False
        )
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        
        self.blocks = nn.ModuleList()
        self.blocks.append(BasicBlock(inplanes=self.inplanes,
                          planes=self.inplanes,
                          stride=1,
                          base_width=64,
                          dilation=1,
                          dropout=0.2,
                        ))
        
        self.blocks.append(BasicBlock(inplanes=self.inplanes,
                          planes=self.inplanes,
                          stride=1,
                          base_width=64,
                          dilation=1,
                          dropout=0.2,
                        ))
        
        self.conv_cls = nn.Conv2d(
            self.inplanes, self.num_classes, kernel_size=1, stride=1, padding=0, bias=False
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.num_classes, self.num_classes)

    def forward(self, x, device=None, mode=None):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        for block in self.blocks:
            x = block(x)
            
        x = self.conv_cls(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

#### Model

In [None]:
class EEGCls(nn.Module):
    """
    """

    def __init__(
        self,
        channel_positions,
        channels_plane_shape=(11, 13),
        num_classes = 2,
        name="EEG-Classifier",
    ):
        super().__init__()
        self.channel_positions = channel_positions
        self.channels_plane_shape = channels_plane_shape
        
        self.signal_block = WaveFe(channel_positions=self.channel_positions,
               out_shape=self.channels_plane_shape)
        self.spacial_block = BlockCls2d(in_channels=252, num_classes=num_classes)
        
    def forward(self, batch_dict, device=None, mode=None):
        x = self.signal_block(batch_dict, device=device)
        x = x.permute(0, 3, 1, 2)
        x = self.spacial_block(x)
        return x

### Module Test

In [None]:
# device_str = "cuda" if torch.cuda.is_available() else "cpu"
device_str = "cpu"
device = torch.device(device_str)
print("Torch Using device:", device)

#### Signal Block

In [None]:
batch = {}
for key, value in eeg_electrode_positions.items():
    batch[key] = torch.Tensor(eeg_data[key])[10:18, :, :]

In [None]:
batch["P7"].size()

In [None]:
model = WaveFe(channel_positions=eeg_electrode_positions,
               out_shape=(11, 13)).to()
out = model(batch)

In [None]:
out.size()

In [None]:
img = out.detach().cpu().numpy().mean(-1).mean(0)

plt.imshow(img)

In [None]:
batch = {}
for key, value in eeg_electrode_positions.items():
    batch[key] = torch.Tensor(eeg_data[key])[0:8, :, :]

In [None]:
model = WaveFe(channel_positions=eeg_electrode_positions,
               out_shape=(11, 13)).to()
out = model(batch)

In [None]:
img = out.detach().cpu().numpy().mean(-1).mean(0)

plt.imshow(img)

#### Spacial Block

In [None]:
model = EEGCls(channel_positions=eeg_electrode_positions,
        channels_plane_shape=(11, 13),
        num_classes = 2)
out = model(batch)

In [None]:
out

### Train/Validation

#### Dataset

In [None]:
from torch.utils.data import Dataset, DataLoader

In [None]:
class EEGSyntheticDataset(Dataset):
    """ """

    def __init__(self, data=None, transforms=None):
        self.transforms = transforms
        
        if data is None:
            # distribution of class 1
            # from: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.multivariate_normal.html
            rv1 = multivariate_normal([4, 2], [[2.0, 0], [0, 2.0]])
            x, y = np.mgrid[-6:6:.01, -5:5:.01]
            pos = np.dstack((x, y))

            # distribution of class 2
            rv2 = multivariate_normal([-4, -2], [[2.0, 0], [0, 2.0]])
            x, y = np.mgrid[-6:6:.01, -5:5:.01]
            pos = np.dstack((x, y))

            # Generating 100 samples from class 1
            self.eeg_data = {key: [] for key in eeg_electrode_positions.keys()}
            self.labels = []
            for i in range(100):
                # Generating 1 sec recording
                for key, value in eeg_electrode_positions.items():
                    self.eeg_data[key].append(np.abs(np.random.normal(rv1.pdf(value), 0.01, 256)))
                self.labels.append(1)

            # Generating another 100 samples from class 2
            for i in range(100):
                # Generating 1 sec recording
                for key, value in eeg_electrode_positions.items():
                    self.eeg_data[key].append(np.abs(np.random.normal(rv2.pdf(value), 0.01, 256)))
                self.labels.append(2)

            for key, value in eeg_electrode_positions.items():
                self.eeg_data[key] = np.expand_dims(np.array(self.eeg_data[key]), 1)
            self.labels = np.array(self.labels)
        else:
            self.eeg_data = data[0]
            self.labels = data[1]

        self.classes = [1, 2]
        self.cls_idx_map = {1: 0, 2: 1}

        self.indices = list(range(len(self.labels)))

    def class_to_index(self, cls):
        return self.cls_idx_map[cls]

    def index_to_class(self, index):
        return self.classes[index]

    def get_class(self, index):
        return self.labels[index]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        wav = {key: self.eeg_data[key][idx, :, :] for key in self.eeg_data.keys()}
        
        label = self.get_class(idx)
        label = self.class_to_index(label)

        if self.transforms is not None:
            wav, label = self.transforms(wav, label)

        return wav, label

    def subset(self, indices):
        data = {}
        for key, value in self.eeg_data.items():
                data[key] = self.eeg_data[key][indices, :, :]
        return self.__class__((data, self.labels[indices]), transforms=self.transforms)

    @staticmethod
    def collate_fn(batch):
        imgs = {key: torch.vstack([item[0][key].unsqueeze(0) for item in batch]) for key in batch[0][0].keys()}
        trgts = torch.vstack([item[1] for item in batch]).squeeze()

        return [imgs, trgts]

In [None]:
class ToTensor(object):
    def __init__(self, device):
        self.device = device
        
    def __call__(self, data, label):
        for key in data.keys():
            data[key] = torch.tensor(data[key]).float().to(self.device)
        label = torch.tensor([label], dtype=torch.long).to(self.device)
            
        return data, label

#### Lightning Module

In [None]:
from torchmetrics import Accuracy
from pytorch_lightning import LightningModule, Trainer
from sklearn.model_selection import train_test_split

In [None]:
BATCH_SIZE = 2

class LitEEG(LightningModule):
    def __init__(self, learning_rate=3e-4):

        super().__init__()
        self.learning_rate = learning_rate

        self.model = EEGCls(channel_positions=eeg_electrode_positions,
        channels_plane_shape=(11, 13),
        num_classes = 2)
        
        self.accuracy = Accuracy()

    def forward(self, x):
        x = self.model(x, device=self.device)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.accuracy(preds, y)
        
        
        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", self.accuracy, prog_bar=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.accuracy(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.accuracy, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
    
    def training_epoch_end(self, training_step_outputs):
        print("\n")
        
    def training_epoch_start(self, training_step_outputs):
        print("\n")
    #     print(f"\nAccuracy : {self.accuracy.compute()}\n")

    ####################
    # DATA RELATED HOOKS
    ####################
    def setup(self, stage=None):
        self.dataset = EEGSyntheticDataset(transforms=
                        ToTensor(device = 
                                 torch.device("cuda:0")))
        self.train_idx, self.test_idx, _, _ = train_test_split(
            list(range(len(self.dataset))),
            self.dataset.labels,
            stratify=self.dataset.labels,
            test_size=0.2,
        )

    def train_dataloader(self):
        return DataLoader(self.dataset.subset(self.train_idx), batch_size=BATCH_SIZE,
                          collate_fn=self.dataset.collate_fn)

    def val_dataloader(self):
        return DataLoader(self.dataset.subset(self.test_idx), batch_size=BATCH_SIZE,
                          collate_fn=self.dataset.collate_fn)

#### Train

In [None]:
# Init our model
model = LitEEG()

# Initialize a trainer
trainer = Trainer(
    gpus=1,
    max_epochs=4,
)

# Train the model ⚡
trainer.fit(model)