In [2]:
import mne
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [None]:
raw1 = mne.io.read_raw_edf('CHB-MIT/chb01_01.edf')
raw2 = mne.io.read_raw_edf('CHB-MIT/chb01_02.edf')

In [None]:
class CHBData(Dataset):
    def __init__(self, CHB_files, segment_length):
        self.segments = [] #input for model containing segment of length "l" and torch tensor containing the eeg values for that segment
        self.labels = []#output of model contrainig  the labels of each segment and its type of "ictal"
        for file_path, info in CHB_files.items():
            processed_data=self.preprocessing(file_path) #preprocess every file in dictionary 
            for start, end, label in info:
                segmented_eeg = self.segment_eeg(processed_data,start, end, label,segment_length) #segments that file
                for segment,label in segmented_eeg:
                    self.segments.append(segment) #adds to final list for model 
                    self.labels.append(label)

    def __len__():
        return len(self.segments)
    
    def __getitem__(self, i):
        return self.segments[i],self.labels[i]

    def segment_eeg(self,segment_tensor,start, end, label,segment_length):
        segments=[] #list to store tuple of each segments pytorch tensor and each label inside the data 
        for i in range(start,end,segment_length):
            segment_end=min(i+segment_length,end)
            segment=segment_tensor[:,i:segment_end] #isolates segment of wtv lenght it is from each torch tensor in preprocessing our data 
            segments.append((segment,label))
            
        return segments

    
    def preprocessing(self,file_path):
        #loading data: 
        raw = mne.io.read_raw_edf(file_path)
        raw.load_data()
        #proccesing every raw object to remove 60 hz and its multiples:
        eeg_picks = mne.pick_types(raw.info, meg=False, eeg=True)
        freqs = (60,120)
        raw_notch = raw.copy().notch_filter(freqs=freqs, picks=eeg_picks)
        #applying a high pass filter of order 4 with a cutoff frequency of 30 Hz to the data to enhance gamma signal to noise ratio:
        raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
        numpy_array=raw_notch.get_data()
        segment_tensor=torch.from_numpy(numpy_array)
        return segment_tensor

#dictionary containg file path and a tuple of what time frame what thing is "ictal" and its label    
CHB_files = {'CHB-MIT/chb01_01.edf':[(0, 3600, "interictal")], 'CHB-MIT/chb01_02.edf':[(0, 3600, "interictal")],}
dataset= CHBData(CHB_files,2)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0, drop_last=True)
