In [1]:
import mne
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader,Subset
from sklearn.model_selection import train_test_split
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim



In [2]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"


In [3]:
#creating dataset
class CHBData(Dataset):
    def __init__(self, CHB_files, segment_length):
        self.segments = [] # input for model containing segment of length "l" and torch tensor containing the eeg values for that segment
        self.labels = [] # output of model contrainig  the labels of each segment, either interictal, preictal, or ictal
        for file_path, info in CHB_files.items():
            processed_data=self.preprocessing(file_path) # preprocess every file in dictionary 
            for start, end, label in info:
                segmented_eeg = self.segment_eeg(processed_data,start, end, label,segment_length) # segments that file
                for segment,label in segmented_eeg:
                    self.segments.append(segment) # adds to final list for model 
                    self.labels.append(label)

    def __len__(self):
        return len(self.segments)
    
    def __getitem__(self, i):
        return self.segments[i],self.labels[i]

    def segment_eeg(self,segment_tensor,start, end, label,segment_length):
        start=start*256 #have to mulyiply time by sampling rate BUDDDDYYYY 
        end=end*256
        segments=[] # list to store tuple of each segments pytorch tensor and each label inside the data 
        for i in range(start,end,segment_length):
            segment_end=min(i+segment_length,end)
            segment=segment_tensor[:,i:segment_end] # isolates segment of wtv lenght it is from each torch tensor in preprocessing our data 
            segments.append((segment,label))
            
        return segments

    
    def preprocessing(self,file_path):
        # loading data: 
        raw = mne.io.read_raw_edf(file_path)
        raw.load_data()
        # processing every raw object to remove 60 hz and its multiples:
        eeg_picks = mne.pick_types(raw.info, meg=False, eeg=True)
        freqs = (60,120)
        raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)
        # applying a high pass filter of order 4 with a cutoff frequency of 30 Hz to the data to enhance gamma signal to noise ratio:
        raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
        numpy_array=raw_notch.get_data()
        segment_tensor=torch.from_numpy(numpy_array)
        return segment_tensor

In [4]:
CHB_files = {'CHB-MIT/chb01_01.edf':[(0, 3600, "interictal")], 'CHB-MIT/chb01_02.edf':[(0, 1400, "interictal"),(1400,3600,"preictal")],}

In [5]:
dataset = CHBData(CHB_files, 512)
print(len(dataset))
#Nut big brain shit right here
'''Okay so when we do the input yk how we changed it 512 because segment lenght is 2 and our sampling rate is 
256 so we have to multiply by sampling yp get data points right
we have to do that for start and end time bruhh
like we need 3600*256/512=segments per file yfm? so we need to multiply start and end time by 256 in segment function
which is what i ended up doing and it fixed
'''

Extracting EDF parameters from /home/mattbls/PythonProjects/SeizureSense/CHB-MIT/chb01_01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Reading 0 ... 921599  =      0.000 ...  3599.996 secs...
Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 1691 samples (6.605 s)



  raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)
  raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)


Filtering raw data in 1 contiguous segment
Setting up high-pass filter at 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 30.00
- Lower transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 26.25 Hz)
- Filter length: 113 samples (0.441 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.4s
  raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
  raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.7s


Extracting EDF parameters from /home/mattbls/PythonProjects/SeizureSense/CHB-MIT/chb01_02.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 921599  =      0.000 ...  3599.996 secs...


  raw = mne.io.read_raw_edf(file_path)


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 1691 samples (6.605 s)



  raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)
  raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)


Filtering raw data in 1 contiguous segment
Setting up high-pass filter at 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 30.00
- Lower transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 26.25 Hz)
- Filter length: 113 samples (0.441 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.4s
  raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
  raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.6s


3600


'Okay so when we do the input yk how we changed it 512 because segment lenght is 2 and our sampling rate is \n256 so we have to multiply by sampling yp get data points right\nwe have to do that for start and end time bruhh\nlike we need 3600*256/512=segments per file yfm? so we need to multiply start and end time by 256 in segment function\nwhich is what i ended up doing and it fixed\n'

In [10]:


# Create the full dataset
full_dataset = CHBData(CHB_files, segment_length=512)

# Split the dataset into train and test sets
train_indices, test_indices = train_test_split(range(len(full_dataset)), test_size=0.3, random_state=42)

#  split the test set into validation and actual test sets
val_indices, test_indices = train_test_split(test_indices, test_size=0.5, random_state=42)

# dataloaders for training, validation, and test
train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)
test_dataset = Subset(full_dataset, test_indices)

# Create dataloader type shit
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0, drop_last=True)




Extracting EDF parameters from /home/mattbls/PythonProjects/SeizureSense/CHB-MIT/chb01_01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 921599  =      0.000 ...  3599.996 secs...


  raw = mne.io.read_raw_edf(file_path)


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 1691 samples (6.605 s)



  raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)
  raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)


Filtering raw data in 1 contiguous segment
Setting up high-pass filter at 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 30.00
- Lower transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 26.25 Hz)
- Filter length: 113 samples (0.441 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.5s
  raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
  raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.8s


Extracting EDF parameters from /home/mattbls/PythonProjects/SeizureSense/CHB-MIT/chb01_02.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 921599  =      0.000 ...  3599.996 secs...
Filtering raw data in 1 contiguous segment
Setting up band-stop filter


  raw = mne.io.read_raw_edf(file_path)



FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 1691 samples (6.605 s)



  raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)
  raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)


Filtering raw data in 1 contiguous segment
Setting up high-pass filter at 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 30.00
- Lower transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 26.25 Hz)
- Filter length: 113 samples (0.441 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s
  raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
  raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.6s


In [11]:
full_dataset.__getitem__(1)[0]

torch.Size([23, 512])

In [13]:
class SeizureSense(nn.Module):
    
    def __init__(self):
        super(SeizureSense,self).__init__()
        
        #Block 1
        #should be taking in an input of 23x512
        #first layer temporal shittt(pointwise is the technical term)
        self.conv1=nn.Conv2d(1,8,(1,128),padding=0)
        self.batchnorm1 = nn.BatchNorm2d(8, False)
        
        #spatial layer(depthwise layer)
        self.conv2=nn.Conv2d(8,32,(23,1))
        self.batchnorm2 = nn.BatchNorm2d(32, False)
        self.avgpool1 = nn.AvgPool2d((1,2))
        #apply dropout here in forward
        
        
        #Block 2
        #sepereable convolutional 2d
        self.conv3=nn.Conv2d(8,32,(1,16))
        self.batchnorm3 = nn.BatchNorm2d(32,False) #CHECK THIS!
        self.avgpool2 = nn.AvgPool2d((1,16))
        #apply dropout here in forward
        
        
        
        #bilstm layers
        
        #Block 3
        #fc1 layer assumiing no bilstm layer rn can change later 
        #matthew check my math here please its 2am 
        
        self.fc1= nn.Linear(32*1*5, 30)
        self.fc2=nn.Linear(30,3)
        
        
        
    def forward(self,x):
        x=self.conv1(x)
        x =F.elu(x)
        x=self.batchnorm1(x)
        
        x = self.conv2(x)
        x = F.elu(x)
        x=self.batchnorm2(x)
        x=self.avgpool1(x)
        
        x=F.elu(self.conv3(x))
        x=self.batchnorm3(x)
        x=self.avgpool2(x)
        
        #bilstm stufff
        
        #Fully connected time
        x=x.view(-1,32*1*5)
        x=F.elu(self.fc1(x))
        x=F.sigmoid(self.fc2(x))
        return x

In [14]:
model=SeizureSense()

In [15]:
print(model(full_dataset.__getitem__(1)[0]))

RuntimeError: Input type (double) and bias type (float) should be the same