In [6]:
import os

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix, cohen_kappa_score, explained_variance_score, log_loss

# Set random seed for reproducibility
seed = 154727
np.random.seed(seed=seed)
torch.manual_seed(seed)


<torch._C.Generator at 0x78f44ffe8f10>

#### Device Selection:

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Input Directory

In [None]:
input_directory = '/home/projects/eeg_deep_learning/eeg_data_preprocessed/'


#### Model Training Parameters

In [None]:
# Number of folds for k-fold cross-validation
num_folds = 5

# Model training parameters
num_epochs = 50
learning_rate = 1e-06
batch_size_per_gpu = 128
batch_size = batch_size_per_gpu * torch.cuda.device_count()


#### Define ResNet2D model class

In [10]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.0):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 1), padding=(1, 0))
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 1), padding=(1, 0))
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dropout2 = nn.Dropout(dropout_rate)
        
        self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        residual = self.shortcut(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout1(out)  # Apply dropout after the first activation
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = out + residual
        out = self.relu(out)
        out = self.dropout2(out)  # Apply dropout after the residual connection
        return out


#### Define SlumberNet model class

In [11]:
class SlumberNet(nn.Module):
    def __init__(self, input_channels=1, num_classes=3, n_feature_maps=8, n_blocks=7, dropout_rate=0.0):
        super(SlumberNet, self).__init__()
        self.layers = []
        for i in range(n_blocks):
            self.layers.append(ResNetBlock(input_channels if i == 0 else n_feature_maps * (2 ** (i - 1)), n_feature_maps * (2 ** i), dropout_rate))
        self.layers = nn.Sequential(*self.layers)
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(n_feature_maps * (2 ** (n_blocks - 1)), num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.layers(x)
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)  # Flatten before feeding to fully connected layer
        x = self.fc(x)
        return self.softmax(x)

# Custom Dataset with Augmentation
class SleepDataset(Dataset):
    def __init__(self, eeg_data, emg_data, labels, augment=False):
        self.eeg_data = eeg_data
        self.emg_data = emg_data
        self.labels = labels
        self.augment = augment

    def augment_data(self, eeg, emg):
        # Augment data: random amplitude scaling and temporal shifts
        eeg_amp = np.random.uniform(0.7, 1.3, eeg.shape[0])
        emg_amp = np.random.uniform(0.95, 1.05, emg.shape[0])
        shift = np.random.randint(-eeg.shape[0], eeg.shape[0])
        eeg = np.roll(eeg * eeg_amp, shift)
        emg = np.roll(emg * emg_amp, shift)
        return eeg, emg

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        eeg = self.eeg_data[idx]
        emg = self.emg_data[idx]
        label = self.labels[idx]

        if self.augment:
            eeg, emg = self.augment_data(eeg, emg)

        # Stack the EEG and EMG channels
        sample = np.stack([eeg, emg], axis=-1)
        return torch.tensor(sample, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
