In [40]:
import mne
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import DataLoader,Subset
from sklearn.model_selection import train_test_split
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import accuracy_score, recall_score, roc_auc_score, auc, roc_curve
import matplotlib.pyplot as plt
from sklearn.preprocessing import label_binarize
import json
import time
import os

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

mne.set_log_level('ERROR')
mne.cuda.init_cuda(verbose=True)

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [3]:
CHB_files1 = {
    'CHB-MIT/CHB_Database/chb01/chb01_01.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_02.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_03.edf': [(0, 2396, 'interictal'), (2396, 2996, 'preictal'), (2996, 3036, 'ictal'), (3036, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_04.edf': [(0, 867, 'interictal'), (867, 1467, 'preictal'), (1467, 1494, 'ictal'), (1494, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_05.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_06.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_07.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_08.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_09.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_10.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_11.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_12.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_13.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_14.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_15.edf': [(0, 1132, 'interictal'), (1132, 1732, 'preictal'), (1732, 1772, 'ictal'), (1772, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_16.edf': [(0, 415, 'interictal'), (415, 1015, 'preictal'), (1015, 1066, 'ictal'), (1066, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_17.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_18.edf': [(0, 1120, 'interictal'), (1120, 1720, 'preictal'), (1720, 1810, 'ictal'), (1810, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_19.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_20.edf': [(0, 2663, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_21.edf': [(0, 327, 'preictal'), (327, 420, 'ictal'), (420, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_22.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_23.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_24.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_25.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_26.edf': [(0, 1262, 'interictal'), (1262, 1862, 'preictal'), (1862, 1963, 'ictal'), (1963, 2325, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_27.edf': [(0, 600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_29.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_30.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_31.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_32.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_33.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_34.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_36.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_37.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_38.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_39.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_40.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_41.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_42.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_43.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_46.edf': [(0, 3600, 'interictal')]
    }

In [4]:
# Used to read our json file of each dictionary mapped to it's file name
def read_dict_from_json_file(filepath):
    with open(filepath, 'r') as file:
        return json.load(file)

In [11]:
# Our calculated mean and standard deviation
mean = 1.9299e-12
std = 1.699e-05

In [26]:
class CHBData(Dataset):
    def __init__(self, segment_dict, mean, std):
        self.segment_dict = segment_dict
        # Mapping label names to integers
        self.label_to_int = {'interictal': 0, 'preictal': 1, 'ictal': 2}
        self.mean = torch.tensor(mean, dtype = torch.float32)
        self.std = torch.tensor(std, dtype = torch.float32)

    def __len__(self):
        return len(self.segment_dict)

    def __getitem__(self, index,):
        
        segment = torch.load(f'CHB-MIT/Segments/chb01/{index}-chb01.pt')
        #uncomment this when we get our mean and std  values
        segment = (segment - self.mean) / (self.std)
        segment.unsqueeze(0)

        # Retrieve the label and map it to an integer
        label_name = self.segment_dict[f'CHB-MIT/Segments/chb01/{index}-chb01.pt']
        label = self.label_to_int[label_name]
        segment = segment.float().to(device)
        return segment, label

In [27]:
segment_dict = read_dict_from_json_file('CHB-MIT/segment_dict1.json')

In [28]:
num_workers = 0
full_dataset = CHBData(segment_dict, mean, std)

# Split the dataset into train and test sets
train_indices, test_indices = train_test_split(range(len(full_dataset)), test_size=0.3, random_state=42)

#  split the test set into validation and actual test sets
val_indices, test_indices = train_test_split(test_indices, test_size=0.5, random_state=42)

# dataloaders for training, validation, and test
train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)
test_dataset = Subset(full_dataset, test_indices)

# Instantiate dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=num_workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=num_workers, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=num_workers, drop_last=True)
#training_dataset = CHBData(segment_dict)
#mean, std = calculate_mean_std(training_dataset)

In [29]:
num_workers = 0
full_dataset = CHBData(segment_dict, mean, std)

# Split the dataset into train and test sets
train_indices, val_indices = train_test_split(range(len(full_dataset)), test_size=0.3, random_state=42)

#  split the test set into validation and actual test sets
train_indices, val_indices = train_test_split(range(len(full_dataset)), test_size=0.3, random_state=42)


# dataloaders for training, validation, and test
train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)

# Instantiate dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False, num_workers=num_workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=num_workers, drop_last=True)


In [21]:
class SeizureSense(nn.Module):
    
    def __init__(self):
        super(SeizureSense,self).__init__()
        
        #Block 1
        #should be taking in an input of 23x512
        #first layer temporal filters
        self.conv1=nn.Conv2d(1,8,(1,128),stride=1,padding=0)
        self.batchnorm1 = nn.BatchNorm2d(8, False)
        
        #spatial layer(depthwise layer)
        self.conv2=nn.Conv2d(8,32,(23,1))
        self.batchnorm2 = nn.BatchNorm2d(32, False)
        self.avgpool1 = nn.AvgPool2d((1,2))
        #apply dropout here in forward
        
        
        #Block 2
        #sepereable convolutional 2d
        self.conv3=nn.Conv2d(32,32,(1,16),stride=1)
        self.batchnorm3 = nn.BatchNorm2d(32,False) #CHECK THIS!
        self.avgpool2 = nn.AvgPool2d((1,16))

        
        #Block 3
        
        self.fc1= nn.Linear(96, 30)
        self.fc2=nn.Linear(30,3)

        #apply dropout here in forward
        self.dropout = nn.Dropout(0.25)
        
    def forward(self,x):
        x=self.conv1(x)
        x =F.elu(x)
        x=self.batchnorm1(x)
        x = self.dropout(x)
        
        x = self.conv2(x)
        x = F.elu(x)
        x=self.batchnorm2(x)
        x = self.dropout(x)
        x=self.avgpool1(x)
        
        x=F.elu(self.conv3(x))
        x=self.batchnorm3(x)
        x = self.dropout(x)
        x=self.avgpool2(x)
        
        #bilstm stufff
        
        #Fully connected time
        x=x.view(-1,96)
        x=F.elu(self.fc1(x))
        x=F.sigmoid(self.fc2(x))
        return x

In [15]:
#maps back the labels to words
def decode_predictions(predictions):
    label_names = ['interictal', 'preictal', 'ictal']
    _, predicted_labels = torch.max(predictions, 1)
    return [label_names[label] for label in predicted_labels]

In [42]:
optimal_threshold = 0.30

In [31]:
model = SeizureSense()
model.to(device)
state_dict = torch.load("SeizureSenseStateDict.pth")
model.load_state_dict(state_dict)

<All keys matched successfully>

In [33]:
model.eval()
val_labels = []
val_probabilities = []  # Store probabilities for AUROC calculation
val_predictions = []
with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.float().to(device), labels.to(device)
        logits = model(inputs)
        outputs = F.softmax(logits, dim=1)            
        
        positive_class_probabilities = outputs[:, 1]
        predictions = (positive_class_probabilities.cpu().numpy() > optimal_threshold).astype(int)
        
        val_labels.extend(labels.cpu().numpy())  # Convert to numpy array
        val_predictions.extend(predictions)
        val_probabilities.extend(outputs.cpu().numpy())  # Convert to numpy array


In [35]:
accuracy = accuracy_score(val_labels, val_predictions)
recall = recall_score(val_labels,val_predictions, average=None)[1]
print(accuracy)
print(recall)

0.027023940058479533
1.0


In [44]:
model.eval()
val_labels = []
val_probabilities = []  # Store probabilities for AUROC calculation
val_predictions = []
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.float().to(device), labels.to(device)
        logits = model(inputs)
        outputs = F.softmax(logits, dim=1)            
        
        positive_class_probabilities = outputs[:, 1]
        predictions = (positive_class_probabilities.cpu().numpy() > optimal_threshold).astype(int)
        
        val_labels.extend(labels.cpu().numpy())  # Convert to numpy array
        val_predictions.extend(predictions)
        val_probabilities.extend(outputs.cpu().numpy())  # Convert to numpy array


In [45]:
accuracy = accuracy_score(val_labels, val_predictions)
recall = recall_score(val_labels,val_predictions, average=None)[1]

print(accuracy)
print(recall)

0.906889619883041
0.9812925170068028


In [50]:
print(len(val_predictions))
print(len(val_labels))

21888
21888


In [53]:
df = pd.DataFrame({"predictions":val_predictions, "labels":val_labels})
df.to_csv("sdfsdoodfio.csv",index=False)