In [None]:
from config import get_dl_config
import xarray as xr
import torch
from torch.utils.data import DataLoader, TensorDataset, Dataset
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import StratifiedGroupKFold
import gc

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
conf = get_dl_config()

INPUT_FILE = conf['input_file']

### load data

In [None]:
class TUHDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # data needs to be unsqueezed to add window dimension
        return self.data[idx].unsqueeze(0), self.labels[idx]

In [None]:
data = xr.open_dataarray(INPUT_FILE)
data 

In [None]:
# prepare labels
unique_labels = np.unique(data['label'].values)
labels_map = {label: i for i, label in enumerate(unique_labels)}
labels = np.array([labels_map[label] for label in data['label'].values])
labels

In [None]:
groups = data['patient_id'].values
measurements = data.values

In [None]:
# create training data using stratifiedgroupkfold
n_splits = 5
sgkf = StratifiedGroupKFold(n_splits=n_splits)
splits = list(sgkf.split(measurements, labels, groups))

# select one split -> 80% train, 20% test
train_idx, test_idx = splits[np.random.choice(n_splits)]

train_data, train_labels = measurements[train_idx], labels[train_idx]
test_data, test_labels = measurements[test_idx], labels[test_idx]

In [None]:
batch_size = 128

train_dataset = TUHDataset(train_data, train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = TUHDataset(test_data, test_labels)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
train_dataset.data.shape

### Definition and training of CNN

In [None]:
class ParkNet(nn.Module):
    def __init__(self):
        super(ParkNet, self).__init__()
        
        # 1D Convolutions and Max Pooling layers
        self.conv1 = nn.Conv1d(1, 8, 3, stride=1)
        self.conv2 = nn.Conv1d(8, 8, 3, stride=1)
        self.maxpool1 = nn.MaxPool1d(2, stride=2)
        self.conv3 = nn.Conv1d(8, 16, 3, stride=1)
        self.conv4 = nn.Conv1d(16, 32, 3, stride=1)
        self.conv5 = nn.Conv1d(32, 32, 3, stride=1)
        self.maxpool2 = nn.MaxPool1d(2, stride=2)
        
        # 2D Convolutions for different groups
        self.conv2d_group1 = nn.Conv2d(32, 32, (4, 3), stride=1)
        self.conv2d_group2 = nn.Conv2d(32, 32, (4, 3), stride=1)
        self.conv2d_group3 = nn.Conv2d(32, 32, (4, 3), stride=1)
        self.conv2d_group4 = nn.Conv2d(32, 32, (4, 3), stride=1)
        self.conv2d_group5 = nn.Conv2d(32, 32, (3, 3), stride=1)

        # Final 2D Convolution and Max Pooling
        self.conv2d_final = nn.Conv2d(32, 128, (5, 3), stride=1)
        self.maxpool_final = nn.MaxPool2d((1, 2), stride=(1, 2))
        
        # Fully connected layers
        self.fc1 = nn.Linear(128 * ((1306 - 2) // 2), 512)
        self.fc2 = nn.Linear(512, 1)

    def forward(self, x):
        groups = {
            'group1': [0, 3, 7, 11],
            'group2': [1, 4, 8, 12],
            'group3': [2, 6, 10, 14],
            'group4': [5, 9, 13, 15],
            'group5': [16, 17, 18]
        }

        group_outputs = []

        for group in groups.values():
            group_channels = []
            for channel_idx in group:
                channel_data = x[:, :, channel_idx, :]
                channel_output = self.conv1(channel_data)
                channel_output = F.relu(self.conv2(channel_output))
                channel_output = self.maxpool1(channel_output)
                channel_output = F.relu(self.conv3(channel_output))
                channel_output = F.relu(self.conv4(channel_output))
                channel_output = F.relu(self.conv5(channel_output))
                channel_output = self.maxpool2(channel_output)
                group_channels.append(channel_output)
            
            group_output = torch.stack(group_channels, dim=2)
            if group == groups['group1']:
                group_output = self.conv2d_group1(group_output)
            elif group == groups['group2']:
                group_output = self.conv2d_group2(group_output)
            elif group == groups['group3']:
                group_output = self.conv2d_group3(group_output)
            elif group == groups['group4']:
                group_output = self.conv2d_group4(group_output)
            elif group == groups['group5']:
                group_output = self.conv2d_group5(group_output)
                
            group_output = F.relu(group_output)
            group_outputs.append(group_output)
       
        # Stack and process through final layers
        group_outputs = [x.squeeze(2) for x in group_outputs]
        x = torch.stack(group_outputs, dim=2)
        x = F.relu(self.conv2d_final(x))
        x = self.maxpool_final(x)
        x = x.view(x.size(0), -1)  # Flatten for fully connected layers
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))

        return x


In [None]:
class EEGNet(nn.Module):
    def __init__(self):
        super(EEGNet, self).__init__()
        # First Convolutional Layer
        self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5), stride=1)  # Input channels = 1, Output channels = 6
        # Mean Pooling Layer
        self.pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)  # Mean pooling with kernel size 2x2
        
        # Calculate the flattened size for the fully connected layer
        # Input size: (batch_size, 1, 19, 5250)
        # After conv1: (batch_size, 6, (19-5+1)=15, (5250-5+1)=5246)
        # After pool: (batch_size, 6, 15//2=7, 5246//2=2623)
        self.fc_input_size = 6 * 7 * 2623
        
        # Fully Connected Layer
        self.fc1 = nn.Linear(self.fc_input_size, 1)  # Assuming binary classification: seizure vs. non-seizure

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)  # Flatten all dimensions except batch
        x = self.fc1(x)
        x = torch.sigmoid(x)  # Sigmoid activation for binary classification
        return x


In [None]:
net = EEGNet().to(device)

# Loss (mean squared error) and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

# Training loop
num_epochs = 250
for epoch in range(num_epochs):
    running_loss = 0.0
    net.train()

    total = 0
    correct = 0
    
    for data, labels in train_dataloader:
        data, labels = data.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = net(data)
        loss = criterion(outputs.squeeze(), labels)  

        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        total += labels.size(0)
        correct += ((outputs.squeeze() > 0.5) == labels).sum().item()
        
    loss = running_loss/len(train_dataloader)
    accurcay = correct/total

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss}, Accuracy: {accurcay}')

print('Finished Training')

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

net.eval()

true_labels = []
predictions = []

with torch.no_grad():
    for data, labels in test_dataloader:
        data, labels = data.to(device), labels.to(device)
        
        # Get model outputs
        outputs = net(data)
        
        # Convert outputs to binary predictions (0 or 1)
        preds = (outputs.squeeze() > 0.5).long()
        
        # Store the predictions and true labels
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

# Calculate standard metrics
accuracy = accuracy_score(true_labels, predictions)
precision = precision_score(true_labels, predictions)
recall = recall_score(true_labels, predictions)
f1 = f1_score(true_labels, predictions)
roc_auc = roc_auc_score(true_labels, predictions)

# Print the results
print(f'Test Accuracy: {accuracy}')
print(f'Test Precision: {precision}')
print(f'Test Recall: {recall}')
print(f'Test F1-Score: {f1}')
print(f'Test ROC-AUC: {roc_auc}')


In [None]:
net = ParkNet().to(device)

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Loss (mean squared error) and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

# Training loop
num_epochs = 250
for epoch in range(num_epochs):
    running_loss = 0.0
    net.train()

    total = 0
    correct = 0

    for data, labels in train_dataloader:
        data, labels = data.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = net(data)
        loss = criterion(outputs, labels.unsqueeze(1))  

        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        total += labels.size(0)
        correct += ((outputs.squeeze() > 0.5) == labels).sum().item()

    loss = running_loss/len(train_dataloader)
    accurcay = correct/total

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss}, Accuracy: {accurcay}')

print('Finished Training')

In [None]:
class WangNet(nn.Module):
    def __init__(self):
        super(WangNet, self).__init__()

        # input shape: (batch_size, 1, 19, 5250)
        self.conv_1 = nn.Conv2d(1, 25, (1, 5))
        self.dropout_1 = nn.Dropout(0.25)
        self.conv_2 = nn.Conv2d(25, 25, (3, 1))
        self.bn1 = nn.BatchNorm2d(25)
        self.pool_1 = nn.MaxPool2d((1, 2))
        self.conv_3 = nn.Conv2d(25, 50, (1, 5))
        self.dropout_2 = nn.Dropout(0.25)
        self.conv_4 = nn.Conv2d(50, 50, (3, 1), stride=(2, 1))
        self.bn2 = nn.BatchNorm2d(50)
        self.pool_2 = nn.MaxPool2d((1, 2))
        self.conv_5 = nn.Conv2d(50, 100, (1, 5))
        self.dropout_3 = nn.Dropout(0.25)
        self.conv_6 = nn.Conv2d(100, 100, (3, 1), stride=(2, 1))
        self.bn3 = nn.BatchNorm2d(100)
        self.pool_3 = nn.MaxPool2d((1, 2))
        self.conv_7 = nn.Conv2d(100, 200, (1, 5))
        self.dropout_4 = nn.Dropout(0.25)
        self.conv_8 = nn.Conv2d(200, 200, (3, 1), stride=(2, 1))
        self.bn4 = nn.BatchNorm2d(200)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(200 * 648, 256)
        self.dropout_5 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, 1)
        
    def forward(self, x):
        x = F.relu(self.conv_1(x))
        x = self.dropout_1(x)
        x = F.relu(self.conv_2(x))
        x = self.bn1(x)
        x = self.pool_1(x)

        x = F.relu(self.conv_3(x))
        x = self.dropout_2(x)
        x = F.relu(self.conv_4(x))
        x = self.bn2(x)
        x = self.pool_2(x)

        x = F.relu(self.conv_5(x))
        x = self.dropout_3(x)
        x = F.relu(self.conv_6(x))
        x = self.bn3(x)
        x = self.pool_3(x)

        x = F.relu(self.conv_7(x))
        x = self.dropout_4(x)
        x = F.relu(self.conv_8(x))
        x = self.bn4(x)

        x = self.flatten(x)

        x = self.fc1(x)
        x = self.dropout_5(x)
        x = self.fc2(x)

        x = torch.sigmoid(x)
        
        return x

In [None]:
wangNet = WangNet().to(device)

In [None]:

# Loss (mean squared error) and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(wangNet.parameters(), lr=0.0001)

# Training loop
num_epochs = 250
for epoch in range(num_epochs):
    running_loss = 0.0
    wangNet.train()

    total = 0
    correct = 0
    
    for data, labels in train_dataloader:
        data, labels = data.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = wangNet(data)
        loss = criterion(outputs.squeeze(), labels)  

        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        total += labels.size(0)
        correct += ((outputs.squeeze() > 0.5) == labels).sum().item()
        
    loss = running_loss/len(train_dataloader)
    accurcay = correct/total

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss}, Accuracy: {accurcay}')

print('Finished Training')