In [1]:
import mne
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader,Subset
from sklearn.model_selection import train_test_split
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import accuracy_score, recall_score, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt
import json

mne.set_log_level('ERROR')


In [2]:
class SeizureSense(nn.Module):
    
    def __init__(self):
        super(SeizureSense,self).__init__()
        
        #Block 1
        #should be taking in an input of 23x512
        #first layer temporal filters
        self.conv1=nn.Conv2d(1,8,(1,128),stride=1,padding=0)
        self.batchnorm1 = nn.BatchNorm2d(8, False)
        
        #spatial layer(depthwise layer)
        self.conv2=nn.Conv2d(8,32,(23,1))
        self.batchnorm2 = nn.BatchNorm2d(32, False)
        self.avgpool1 = nn.AvgPool2d((1,2))
        #apply dropout here in forward
        
        
        #Block 2
        #sepereable convolutional 2d
        self.conv3=nn.Conv2d(32,32,(1,16),stride=1)
        self.batchnorm3 = nn.BatchNorm2d(32,False) #CHECK THIS!
        self.avgpool2 = nn.AvgPool2d((1,16))

        
        #Block 3
        
        self.fc1= nn.Linear(96, 30)
        self.fc2=nn.Linear(30,3)

        #apply dropout here in forward
        self.dropout = nn.Dropout(0.25)
        
    def forward(self,x):
        x=self.conv1(x)
        x =F.elu(x)
        x=self.batchnorm1(x)
        x = self.dropout(x)
        
        x = self.conv2(x)
        x = F.elu(x)
        x=self.batchnorm2(x)
        x = self.dropout(x)
        x=self.avgpool1(x)
        
        x=F.elu(self.conv3(x))
        x=self.batchnorm3(x)
        x = self.dropout(x)
        x=self.avgpool2(x)
        
        #Fully connected time
        x=x.view(-1,96)
        x=F.elu(self.fc1(x))
        # FOR TRAINING NO NEED FOR SIGMOUD ON LAST THING BC CROSS ENTRTOPY LOSS SO ADD IT IN INFERENCE DONT FORGET
        x=self.fc2(x)
        return x

In [3]:
model=SeizureSense()
state_dict = torch.load(r"C:\Users\uddha\Downloads\SeizureSenseStateDict (3).pth",map_location=torch.device('cpu'))
model.load_state_dict(state_dict)

<All keys matched successfully>

In [4]:
import numpy as np

optimal_threshold = 0.082604
def decode_predictions(predictions):
    label_names = ['interictal', 'preictal', 'ictal']
    # Ensure predictions tensor is detached and converted to a NumPy array
    predictions_np = predictions.detach().cpu().numpy()
    
    # Extract the probabilities for the preictal class
    positive_class_probabilities = predictions_np[:, 1]
    # Initialize an array for predicted labels
    predicted_labels = np.zeros(predictions_np.shape[0], dtype=int)
    
    # Apply threshold to determine preictal predictions
    preictal_indices = positive_class_probabilities > optimal_threshold
    predicted_labels[preictal_indices] = 1  # Mark as preictal
    
    # Handle non-preictal predictions
    for i, is_preictal in enumerate(preictal_indices):
        if not is_preictal:
            # If not preictal, choose the max between interictal and ictal
            # Note: We multiply by 2 when selecting the ictal class to match the label index
            predicted_labels[i] = np.argmax([predictions_np[i, 0], predictions_np[i, 2]]) * 2
    
    # Map numeric labels to string labels
    decoded_labels = [label_names[label] for label in predicted_labels]
    
    return decoded_labels


In [5]:
def decode_predictions_max(predictions):
    label_names = ['interictal', 'preictal', 'ictal']
    _, predicted_labels = torch.max(predictions, 1)
    return [label_names[label] for label in predicted_labels]

In [None]:
#loading data
raw = mne.io.read_raw_edf(r"C:\Users\uddha\Downloads\chb02_19.edf")
raw.load_data()
#Interictal time steps
#2000 seceonn-2001 seceonds
#Preictal 
#2779 seceonds-2780 seceonds
#inseceonds to 3371 seceonds

In [7]:
numpy_array=raw.get_data()
segment_tensor=torch.from_numpy(numpy_array).float().unsqueeze(0).unsqueeze(0)
print(segment_tensor.shape)

torch.Size([1, 1, 23, 921600])


In [9]:
def preprocessing(raw):

    eeg_picks = mne.pick_types(raw.info, meg=False, eeg=True)


    freqs = (60,120)
    raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)
    
    raw.notch_filter(freqs=freqs, picks=eeg_picks)  # Apply notch filter
    raw.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
    numpy_array=raw.get_data()
    segment_tensor=torch.from_numpy(numpy_array).float().unsqueeze(0).unsqueeze(0)
    return(segment_tensor)

In [13]:
segment_tensor=preprocessing(raw)
i=2779
input=segment_tensor[:,:,:,256*i:256*(i+1)].float()

In [14]:
output=model(input)
output=F.softmax(output,dim=1)
print(output)

tensor([[9.9918e-01, 5.8176e-04, 2.3420e-04]], grad_fn=<SoftmaxBackward0>)


In [16]:
torch.save(segment_tensor,'segment_tensor.pt')

In [15]:
CHB_files={'CHB-MIT/CHB_Database/chb02/chb02_19.edf': [(0, 2769, 'interictal'), (2769, 3369, 'preictal'), (3369, 3378, 'ictal'), (3378, 3600, 'interictal')],}
true_labels=[]
start,end=2778,3388
for i in range(start,end):
    for interval in CHB_files.values():
        for start_interval,end_interval,label in interval:
            if start_interval <= i < end_interval:
                true_labels.append(label)
                break
        else:
            continue
        break
    
true_labels = list(true_labels)

In [17]:
model=SeizureSense()
state_dict = torch.load(r"C:\Users\uddha\Downloads\SeizureSenseStateDict (3).pth",map_location=torch.device('cpu'))
model.load_state_dict(state_dict)

<All keys matched successfully>

In [113]:
#loop within raw file
model1=SeizureSense()
model.eval()
model1.eval()
segment_tensor=preprocessing(raw)
predictions=[]
if_it_picked_the_max=[]
random_predictions=[]
random_predictions_max=[]
for i in range(2778,3388):
    input=segment_tensor[:,:,:,256*i:256*(i+1)]
    output=model(input)
    output=F.softmax(output,dim=1)
    predicted_label_names=decode_predictions(output)
    predicted_max=decode_predictions_max(output)
    predictions.append(predicted_label_names)
    if_it_picked_the_max.append(predicted_max)
    #now with random parameters
    random_output=model1(input)
    random_output=F.softmax(random_output,dim=1)
    predicted_random=decode_predictions(random_output)
    predicted_random_max=decode_predictions_max(random_output)
    random_predictions.append(predicted_random)
    random_predictions_max.append(predicted_random_max)

In [91]:
df = pd.DataFrame([predictions, if_it_picked_the_max,random_predictions,random_predictions_max,true_labels])
csv_file_path = 'label_comparisons.csv'
df.to_csv(csv_file_path, index=False)

In [114]:
count_with_threshold = 0
count_with_max=0
count_with_random=0
count_with_random_max=0
for i in range(len(true_labels)):
    # Assuming predictions[i] is a list with only one element, compare true_labels[i] to predictions[i][0]
    if true_labels[i] == predictions[i][0]:
        count += 1
    if true_labels[i]==if_it_picked_the_max[i][0]:
        count_with_max+=1
    if true_labels[i]==random_predictions[i][0]:
        count_with_random+=1
    if true_labels[i]==random_predictions_max[i][0]:
        count_with_random_max+=1
print('The number of correct predicitons using the pretrained model and threshold',count_with_threshold)
print('The number of correct predictions of pretrained model using torch.max',count_with_max)
print('The number of correct predicitons of a random model with a threshold',count_with_random)
print('The number of correct predicitons of a random model usign torch.max',count_with_random_max)


The number of correct predicitons using the pretrained model and threshold 0
The number of correct predictions of pretrained model using torch.max 21
The number of correct predicitons of a random model with a threshold 591
The number of correct predicitons of a random model usign torch.max 9
