In [1]:
from torch.utils.data import Dataset
import pyedflib
import numpy as np

In [2]:
class ChbDataset(Dataset):
    def __init__(self, data_dir='./chb-mit-scalp-eeg-database-1.0.0/',seizures_only=True,sample_rate=256):
        'Initialization'
        self.sample_rate = sample_rate
        self.data_dir = data_dir
        self.record_type = 'RECORDS-WITH-SEIZURES' if seizures_only else 'RECORDS'
                
        with open(self.data_dir+self.record_type) as f:
            self.records = f.read().strip().splitlines()
            
        with open(self.data_dir+'RECORDS-WITH-SEIZURES') as f:
            self.labelled = f.read().strip().splitlines()
            
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.records)
        
    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        file_name = self.records[index]
        
        f = pyedflib.EdfReader(self.data_dir+file_name)
        n = f.signals_in_file
        signal_labels = f.getSignalLabels()
        sigbufs = np.zeros((n, f.getNSamples()[0]))
        for i in np.arange(n):
                sigbufs[i, :] = f.readSignal(i)
                
        #f = pyedflib.EdfReader(self.data_dir+file_name+'.seizures')
        #n = f.signals_in_file
        #signal_labels2 = f.getSignalLabels()
        #sigbufs2 = np.zeros((n, f.getNSamples()[0]))
        #for i in np.arange(n):
        #        sigbufs2[i, :] = f.readSignal(i)
                
        #get labels if seizure
        if file_name in self.labelled:
            with open(self.data_dir + file_name.split('/')[0] + '/' + file_name.split('/')[0] + '-summary.txt') as g:
                lines = g.readlines()
                
                found = False
                i = 0
                for line in lines:
                    if file_name.split('/')[1] in line:
                        found = True
                    if found:
                        if i == 4:
                            self.seizure_start = int(line.split(' ')[3]) # dummy text for testing.
                        if i == 5:
                            self.seizure_end   = int(line.split(' ')[3]) # TODO: find the fourth line after this one to find seizure start   
                            i = 0
                            found = False
                        i += 1

        labels = np.zeros((1, f.getNSamples()[0]))
        start  = self.sample_rate * self.seizure_start
        end    = self.sample_rate * self.seizure_end
                
        labels[:,start:end] = 1.0
        
        return sigbufs, signal_labels, labels #, sigbufs2, signal_labels2
        

In [3]:
dataset = ChbDataset()
item    = dataset.__getitem__(0)

In [4]:
dataset.seizure_start, dataset.seizure_end

(2996, 3036)

In [5]:
item[2]

array([[0., 0., 0., ..., 0., 0., 0.]])

In [None]:
#from matplotlib.collections import LineCollection
#import matplotlib.pyplot as plt

#plt.plot(np.arange(len(item[2])),item[2])