In [3]:
import mne
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader,Subset
from sklearn.model_selection import train_test_split
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim



In [4]:
torch.cuda.is_available()

True

In [27]:
#creating dataset
class CHBData(Dataset):
    def __init__(self, CHB_files, segment_length):
        self.segments = [] # input for model containing segment of length "l" and torch tensor containing the eeg values for that segment
        self.labels = [] # output of model contrainig  the labels of each segment, either interictal, preictal, or ictal
        for file_path, info in CHB_files.items():
            processed_data=self.preprocessing(file_path) # preprocess every file in dictionary 
            for start, end, label in info:
                segmented_eeg = self.segment_eeg(processed_data,start, end, label,segment_length) # segments that file
                for segment,label in segmented_eeg:
                    self.segments.append(segment) # adds to final list for model 
                    self.labels.append(label)

    def __len__(self):
        return len(self.segments)
    
    def __getitem__(self, i):
        return self.segments[i],self.labels[i]

    def segment_eeg(self,segment_tensor,start, end, label,segment_length):
        start=start*256 #have to mulyiply time by sampling rate BUDDDDYYYY 
        end=end*256
        segments=[] # list to store tuple of each segments pytorch tensor and each label inside the data 
        for i in range(start,end,segment_length):
            segment_end=min(i+segment_length,end)
            segment=segment_tensor[:,i:segment_end] # isolates segment of wtv lenght it is from each torch tensor in preprocessing our data 
            segments.append((segment,label))
            
        return segments

    
    def preprocessing(self,file_path):
        # loading data: 
        raw = mne.io.read_raw_edf(file_path)
        raw.load_data()
        # processing every raw object to remove 60 hz and its multiples:
        eeg_picks = mne.pick_types(raw.info, meg=False, eeg=True)
        freqs = (60,120)
        raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)
        # applying a high pass filter of order 4 with a cutoff frequency of 30 Hz to the data to enhance gamma signal to noise ratio:
        raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
        numpy_array=raw_notch.get_data()
        segment_tensor=torch.from_numpy(numpy_array)
        return segment_tensor

In [28]:
CHB_files = {'CHB-MIT/chb01_01.edf':[(0, 3600, "interictal")], 'CHB-MIT/chb01_02.edf':[(0, 1400, "interictal"),(1400,3600,"preictal")],}

In [29]:
dataset = CHBData(CHB_files, 512)

Extracting EDF parameters from /home/mattbls/PythonProjects/SeizureSense/CHB-MIT/chb01_01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Reading 0 ... 921599  =      0.000 ...  3599.996 secs...
Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 1691 samples (6.605 s)



  raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)
  raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)


Filtering raw data in 1 contiguous segment
Setting up high-pass filter at 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 30.00
- Lower transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 26.25 Hz)
- Filter length: 113 samples (0.441 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.4s
  raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
  raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.7s


Extracting EDF parameters from /home/mattbls/PythonProjects/SeizureSense/CHB-MIT/chb01_02.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 921599  =      0.000 ...  3599.996 secs...


  raw = mne.io.read_raw_edf(file_path)


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 1691 samples (6.605 s)



  raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)
  raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.5s


Filtering raw data in 1 contiguous segment
Setting up high-pass filter at 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 30.00
- Lower transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 26.25 Hz)
- Filter length: 113 samples (0.441 s)



  raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
  raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.7s


In [25]:
class CONV(nn.Module):
    def __init__(self):
        super(CONV, self).__init__()
        # Goes over each channel horizontally (temporal)
        self.conv1 = nn.Conv2d(1, 16, (1, 128))

        self.batchnorm1 = nn.BatchNorm2d(16)
        # Goes over all 23 channels at the same time to find spatial patterns in the new temporal filters 
        self.conv2 = nn.Conv2d(16, 32, (23,1))

        self.batchnorm2 = nn.BatchNorm2d(32)
        # Squishes all of the channels down into a 1x(96 or smaller) output to extract the most important parts of the data
        self.pooling2 = nn.MaxPool2d((1,23),(1,2))

        self.conv3 = nn.Conv1d(32, 64, 1)

        self.batchnorm3 = nn.BatchNorm1d(64)

        self.pooling3 = nn.MaxPool1d(2, 2)

        self.f1 = nn.Linear(5824, 1)

        self.dropout = nn.Dropout(0.3)
    def forward(self, x):
        # Layer 1
        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = F.relu(x)
        x = self.dropout(x)

        # Layer 2
        x = self.conv2(x)
        x = self.batchnorm2(x)
        x = F.relu(x)
        x = self.pooling2(x)
        x = self.dropout(x)

        # Layer 3
        x = x.view(1,32,182)
        x = self.conv3(x)
        x = self.batchnorm3(x)
        x = F.relu(x)
        x = self.pooling3(x)
        
        # Fully connected
        x = x.view(5824)
        x = F.sigmoid(self.f1(x))
        return x

In [26]:
attempt = CONV()
print(attempt(torch.rand(1,1, 23,512)))

tensor([0.4455], grad_fn=<SigmoidBackward0>)
