In [None]:
from config import get_dl_config
import xarray as xr
import torch
from torch.utils.data import DataLoader, TensorDataset, Dataset
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import StratifiedGroupKFold

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
conf = get_dl_config()

INPUT_FILE = conf['input_file']

### load data

In [None]:
class TUHDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # data needs to be unsqueezed to add window dimension
        return self.data[idx].unsqueeze(0), self.labels[idx]

In [None]:
data = xr.open_dataarray(INPUT_FILE)
data 

In [None]:
# prepare labels
unique_labels = np.unique(data['label'].values)
labels_map = {label: i for i, label in enumerate(unique_labels)}
labels = np.array([labels_map[label] for label in data['label'].values])
labels

In [None]:
groups = data['patient_id'].values
data = data.values

In [None]:
# create training data using stratifiedgroupkfold
n_splits = 5
sgkf = StratifiedGroupKFold(n_splits=n_splits)
splits = list(sgkf.split(data, labels, groups))

# select one split -> 80% train, 20% test
train_idx, test_idx = splits[np.random.choice(n_splits)]

train_data, train_labels = data[train_idx], labels[train_idx]
test_data, test_labels = data[test_idx], labels[test_idx]

In [None]:
batch_size = 32

train_dataset = TUHDataset(train_data, train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = TUHDataset(test_data, test_labels)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

### Definition and training of CNN

In [None]:
class EEGNet(nn.Module):
    def __init__(self):
        super(EEGNet, self).__init__()
        # First Convolutional Layer
        self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5), stride=1)  # Input channels = 1, Output channels = 6
        # Mean Pooling Layer
        self.pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)  # Mean pooling with kernel size 2x2
        
        # Calculate the flattened size for the fully connected layer
        # Input size: (batch_size, 1, 19, 5250)
        # After conv1: (batch_size, 6, (19-5+1)=15, (5250-5+1)=5246)
        # After pool: (batch_size, 6, 15//2=7, 5246//2=2623)
        self.fc_input_size = 6 * 7 * 2623
        
        # Fully Connected Layer
        self.fc1 = nn.Linear(self.fc_input_size, 1)  # Assuming binary classification: seizure vs. non-seizure

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)  # Flatten all dimensions except batch
        x = self.fc1(x)
        x = torch.sigmoid(x)  # Sigmoid activation for binary classification
        return x


In [None]:
net = EEGNet().to(device)

# Loss and optimizer
criterion = nn.BCELoss()  # Binary Cross-Entropy Loss for binary classification
optimizer = optim.Adam(net.parameters(), lr=0.0001)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    running_loss = 0.0
    net.train()

    total = 0
    correct = 0
    
    for data, labels in train_dataloader:
        data, labels = data.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = net(data)
        loss = criterion(outputs.squeeze(), labels)  

        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        total += labels.size(0)
        correct += ((outputs.squeeze() > 0.5) == labels).sum().item()
        
    loss = running_loss/len(train_dataloader)
    accurcay = correct/total

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss}, Accuracy: {accurcay}')

print('Finished Training')