In [None]:
from config import get_dl_config
import xarray as xr
import torch
from torch.utils.data import DataLoader, TensorDataset, Dataset
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.utils import shuffle
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

### setup cuda and config

In [None]:
conf = get_dl_config()
INPUT_FILE = conf['input_file']
BATCH_SIZE = 128
N_EPOCHS = 150
N_SPLITS = 5
LR = 1e-3 
RAND_SEED = 42

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

torch.manual_seed(RAND_SEED) 
np.random.seed(RAND_SEED)

### load_data

In [None]:
data = xr.open_dataarray(INPUT_FILE)
data 

In [None]:
patient_ids = data['patient_id'].values
measurements = data.values

# rename labels from 'bckg' and 'seiz' to 0 and 1
unique_labels = np.unique(data['label'].values)
labels_map = {label: i for i, label in enumerate(unique_labels)}
labels = np.array([labels_map[label] for label in data['label'].values])

### setup cross validation

In [None]:
sgkf = StratifiedGroupKFold(n_splits=N_SPLITS)

splits = list(sgkf.split(X=measurements, y=labels, groups=patient_ids))
train_idx, val_idx = splits[np.random.choice(N_SPLITS)]

train_data, train_labels = measurements[train_idx], labels[train_idx]
val_data, val_labels = measurements[val_idx], labels[val_idx]

train_data.shape, val_data.shape

### create datasets

In [None]:
class TUHDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx].unsqueeze(0), self.labels[idx]

In [None]:
train_dataset = TUHDataset(train_data, train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

val_dataset = TUHDataset(val_data, val_labels)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

### cnn setup

In [None]:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        # input = batch_size x 1 x 19 x 5250
        self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5), stride=1) # output: batch_size x 6 x 15 x 5246
        self.pool1 = nn.AvgPool2d(kernel_size=(2, 2), stride=2) # output: batch_size x 6 x 7 x 2623

        self.fc_input_size = 6 * 7 * 2623
        self.fc1 = nn.Linear(self.fc_input_size, 1)  

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)

        x = torch.flatten(x, 1)  
        x = self.fc1(x)

        x = torch.sigmoid(x)  
        return x

model = SimpleNet().to(device)

### training phase

In [None]:
import csv

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

best_val_loss = np.inf

# File to store best model metrics
csv_filename = 'best_model_metrics.csv'

# Write header to the CSV file
with open(csv_filename, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Epoch', 'Val Loss', 'Val Accuracy', 'Precision', 'Recall', 'F1-Score', 'ROC AUC'])

for epoch in range(N_EPOCHS):
    model.train()
    train_loss = 0.0
    total_train = 0
    correct_train = 0

    for data, labels in train_dataloader:
        data, labels = data.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels.unsqueeze(1))
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        total_train += labels.size(0)
        correct_train += ((outputs.squeeze() > 0.5) == labels).sum().item()

    train_loss /= len(train_dataloader)
    train_accuracy = correct_train / total_train

    # Validation phase
    model.eval()
    val_loss = 0.0
    total_val = 0
    correct_val = 0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for data, labels in val_dataloader:
            data, labels = data.to(device), labels.to(device)

            outputs = model(data)
            loss = criterion(outputs, labels.unsqueeze(1))
            val_loss += loss.item()

            preds = (outputs.squeeze() > 0.5).float()

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            total_val += labels.size(0)
            correct_val += ((outputs.squeeze() > 0.5) == labels).sum().item()

    val_loss /= len(val_dataloader)
    val_accuracy = correct_val / total_val

    # Calculate additional metrics
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    roc_auc = roc_auc_score(all_labels, all_preds)

    # Check if the current validation loss is the best we've seen so far
    if val_loss < best_val_loss:
        best_val_loss = val_loss

        # Save the model's state
        torch.save(model.state_dict(), 'best_model.pth')
        print(f'Best model saved at epoch {epoch+1} with validation loss: {val_loss:.4f} and validation accuracy: {val_accuracy:.4f}')

        # Save metrics to the CSV file
        with open(csv_filename, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch+1, val_loss, val_accuracy, precision, recall, f1, roc_auc])

    print(f'Epoch {epoch+1}/{N_EPOCHS}, '
          f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, '
          f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')

print('Finished Training!')

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

net.eval()

true_labels = []
predictions = []

with torch.no_grad():
    for data, labels in test_dataloader:
        data, labels = data.to(device), labels.to(device)
        
        # Get model outputs
        outputs = net(data)
        
        # Convert outputs to binary predictions (0 or 1)
        preds = (outputs.squeeze() > 0.5).long()
        
        # Store the predictions and true labels
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

# Calculate standard metrics
accuracy = accuracy_score(true_labels, predictions)
precision = precision_score(true_labels, predictions)
recall = recall_score(true_labels, predictions)
f1 = f1_score(true_labels, predictions)
roc_auc = roc_auc_score(true_labels, predictions)

# Print the results
print(f'Test Accuracy: {accuracy}')
print(f'Test Precision: {precision}')
print(f'Test Recall: {recall}')
print(f'Test F1-Score: {f1}')
print(f'Test ROC-AUC: {roc_auc}')