In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import sys
sys.path.append("/home/anukoolpurohit/Documents/Workspace/Nueromatch/NMA-DL/HandwritingBCI")

# Imports

In [3]:
import torch
import numpy as np
import matplotlib.pyplot as plt

In [4]:
import torch.nn.functional as F

In [5]:
from torch import nn
from scipy.io import loadmat
from torch.utils.data import Dataset, DataLoader, random_split

In [6]:
from handwritingBCI import Path
from handwritingBCI.data.utils.files import get_dataset
from handwritingBCI.data.utils.dataloader import get_neuro_dataloaders
from handwritingBCI.data.datasets import NeuroDataset
from handwritingBCI.data.preprocessing import LabelEncoder
from handwritingBCI.plotting import plot_electrode_data

ImportError: cannot import name 'get_neuro_dataloaders'

# Setup

## Seed

In [None]:
SEED = 42
np.random.seed(SEED)

## Data

In [None]:
PATH = Path("/home/anukoolpurohit/Documents/AnukoolPurohit/Datasets/HandwritingBCI/handwriting-bci/handwritingBCIData")

# Dataset and Dataloader

In [None]:
neuro_dataset = NeuroDataset.from_path(PATH)

In [None]:
train_dl, valid_dl = neuro_dataset.get_dataloaders(test_size=0.1,
                                                   batch_size=64,
                                                   generator=torch.Generator().manual_seed(SEED))

In [None]:
X, y = next(iter(train_dl))

In [None]:
X.shape, y.shape

In [None]:
type(X.shape)

In [None]:
torch.Size((61, 1, 201, 192))

In [None]:
X.shape == (64, 1, 201, 192)

In [None]:
plot_electrode_data(X[0].squeeze(0), y[0])

## Sample model

In [None]:
class Downsample(nn.Module):
    """
        Simple 2x2 conv with stride 2 and 0 padding to downsample instead of a maxpool
    """
    
    def __init__(self, input_channels: int) -> None:
        super().__init__()
        self.conv = nn.Conv2d(input_channels, input_channels,
                              kernel_size=2, stride=2, padding=0)
        self.relu = nn.ReLU()
        return
    
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        output = self.relu(x)
        return output
        

In [None]:
class ConvBlock(nn.Module):
    """
    Simple 3x3 conv with padding size 1 (to leave the input size unchanged),
    add a residual connection followed by a ReLU .
    """
    
    def __init__(self, input_channels:int, output_channels:int) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels,output_channels,
                               kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(output_channels)
        
        self.conv2 = nn.Conv2d(output_channels, output_channels,
                               kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU()
        return
    
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x
            of dimensions (B, C, H, W)

        Returns
        -------
        torch.Tensor
            of dimensions (B, C, H, W)
        """
        c1 = self.conv1(x)
        c1 = self.bn1(c1)
        r1 = self.relu(c1)

        c2 = self.conv2(r1)
        c2 = self.bn2(c2)
        r2 = self.relu(c2 + x)
        return r2

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, input_shape:tuple, output_channels: int,
                 fc_dims: int, num_classes: int, block_num:int=2) -> None:
        super().__init__()
        assert len(input_shape) == 3
        input_channels = input_shape[0]
        
        self.conv_block1 = ConvBlock(input_channels, output_channels)
        self.conv_block2 = ConvBlock(output_channels, output_channels)
        self.downsample = Downsample(output_channels)
        
        input_fc_dims = output_channels * (input_shape[1]//2) * (input_shape[2]//2)
        
        self.fc1 = nn.Linear(input_fc_dims, fc_dims)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(fc_dims, num_classes)
        return
    
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.downsample(x)
        x = x.flatten(1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
        

In [None]:
net = SimpleCNN((1, 201, 192), 32, 64, 31)

In [None]:
net(X).shape

In [None]:
len(train_dl)

In [None]:
data, labels = get_dataset(PATH)

In [None]:
data.shape, len(labels)

In [None]:
len(valid_dl)