In [25]:
import mne
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader,Subset
from sklearn.model_selection import train_test_split
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import accuracy_score, recall_score, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt
import json
import os

In [26]:
mne.set_log_level('ERROR')

In [27]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"


In [28]:
print(device)

cuda


In [29]:
# Used to read our json file of each dictionary mapped to its file name
def read_dict_from_json_file(filepath):
    with open(filepath, 'r') as file:
        return json.load(file)

In [30]:
class CHBData(Dataset):
    def __init__(self, segment_dict):
        self.segment_dict = segment_dict
        # Mapping label names to integers
        self.label_to_int = {'interictal': 0, 'preictal': 1, 'ictal': 2}

    def __len__(self):
        return len(self.segment_dict)

    def __getitem__(self, index):
        # Retrieve the label and map it to an integer
        segment_tensor = torch.load(f'CHB-MIT/Segments/chb01/{index}-chb01.pt')
        label_name = self.segment_dict[f'CHB-MIT/Segments/chb01/{index}-chb01.pt']
        label = self.label_to_int[label_name]
        segment_tensor = segment_tensor.to(device)
        return segment_tensor, label


In [31]:
segment_dict = read_dict_from_json_file('CHB-MIT/segment_dict1.json')

In [32]:
num_workers = 0
full_dataset = CHBData(segment_dict)

# Split the dataset into train and test sets
train_indices, test_indices = train_test_split(range(len(full_dataset)), test_size=0.3, random_state=42)

#  split the test set into validation and actual test sets
val_indices, test_indices = train_test_split(test_indices, test_size=0.5, random_state=42)

# dataloaders for training, validation, and test
train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)
test_dataset = Subset(full_dataset, test_indices)

# Instantiate dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=num_workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=num_workers, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=num_workers, drop_last=True)


In [9]:
class SeizureSense(nn.Module):
    
    def __init__(self):
        super(SeizureSense,self).__init__()
        
        #Block 1
        #should be taking in an input of 23x512
        #first layer temporal filters
        self.conv1=nn.Conv2d(1,8,(1,128),stride=1,padding=0)
        self.batchnorm1 = nn.BatchNorm2d(8, False)
        
        #spatial layer(depthwise layer)
        self.conv2=nn.Conv2d(8,32,(23,1))
        self.batchnorm2 = nn.BatchNorm2d(32, False)
        self.avgpool1 = nn.AvgPool2d((1,2))
        #apply dropout here in forward
        
        
        #Block 2
        #sepereable convolutional 2d
        self.conv3=nn.Conv2d(32,32,(1,16),stride=1)
        self.batchnorm3 = nn.BatchNorm2d(32,False) #CHECK THIS!
        self.avgpool2 = nn.AvgPool2d((1,16))

        
        #Block 3
        
        self.fc1= nn.Linear(96, 30)
        self.fc2=nn.Linear(30,3)

        #apply dropout here in forward
        self.dropout = nn.Dropout(0.25)
        
    def forward(self,x):
        x=self.conv1(x)
        x =F.elu(x)
        x=self.batchnorm1(x)
        x = self.dropout(x)
        
        x = self.conv2(x)
        x = F.elu(x)
        x=self.batchnorm2(x)
        x = self.dropout(x)
        x=self.avgpool1(x)
        
        x=F.elu(self.conv3(x))
        x=self.batchnorm3(x)
        x = self.dropout(x)
        x=self.avgpool2(x)
        
        #Fully connected layers for classification
        x=x.view(-1,96)
        x=F.elu(self.fc1(x))
        x=F.sigmoid(self.fc2(x))
        return x

In [10]:
model=SeizureSense()
model.to(device)

SeizureSense(
  (conv1): Conv2d(1, 8, kernel_size=(1, 128), stride=(1, 1))
  (batchnorm1): BatchNorm2d(8, eps=False, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(8, 32, kernel_size=(23, 1), stride=(1, 1))
  (batchnorm2): BatchNorm2d(32, eps=False, momentum=0.1, affine=True, track_running_stats=True)
  (avgpool1): AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0)
  (conv3): Conv2d(32, 32, kernel_size=(1, 16), stride=(1, 1))
  (batchnorm3): BatchNorm2d(32, eps=False, momentum=0.1, affine=True, track_running_stats=True)
  (avgpool2): AvgPool2d(kernel_size=(1, 16), stride=(1, 16), padding=0)
  (fc1): Linear(in_features=96, out_features=30, bias=True)
  (fc2): Linear(in_features=30, out_features=3, bias=True)
  (dropout): Dropout(p=0.25, inplace=False)
)

In [35]:
#maps back the labels to words
def decode_predictions(predictions):
    label_names = ['interictal', 'preictal', 'ictal']
    _, predicted_labels = torch.max(predictions, 1)
    return [label_names[label] for label in predicted_labels]

In [None]:
import time

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10

losses = []

for epoch in range(num_epochs):
    start_time = time.time()  # Start timing
    
    model.train()
    for inputs, labels in train_loader:
        inputs, labels = inputs.float().to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        
    # Validation Phase
    model.eval()
    val_labels = []
    val_predictions = []
    with torch.no_grad():  
        for inputs, labels in val_loader:
            inputs, labels = inputs.float().to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            val_labels.extend(labels.tolist())
            val_predictions.extend(predicted.tolist())
    
    # Calculate accuracy, recall, and specificity (assuming these functions are defined)
    accuracy = accuracy_score(val_labels, val_predictions)
    recall = recall_score(val_labels, val_predictions, average='macro')
    specificity = recall_score(val_labels, val_predictions, average='macro', pos_label=0)
    
    end_time = time.time()  # End timing
    epoch_time = end_time - start_time  # Calculate elapsed time for the epoch
    
    print(f'Epoch {epoch+1}/{num_epochs} completed in {epoch_time:.2f} seconds')


In [22]:
print(accuracy)

0.9697094298245614
