In [110]:
from torch.utils.data import Dataset
import pyedflib
import numpy as np
from scipy.signal import spectrogram, welch
from xgboost import XGBClassifier, plot_tree


In [115]:
class ChbDataset(Dataset):
    def __init__(self, data_dir='./chb-mit-scalp-eeg-database-1.0.0/',seizures_only=True,sample_rate=256):
        'Initialization'
        self.sample_rate = sample_rate
        self.data_dir = data_dir
        self.record_type = 'RECORDS-WITH-SEIZURES' if seizures_only else 'RECORDS'
                
        with open(self.data_dir+self.record_type) as f:
            self.records = f.read().strip().splitlines()
            
        with open(self.data_dir+'RECORDS-WITH-SEIZURES') as f:
            self.labelled = f.read().strip().splitlines()
            
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.records)
        
    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        file_name = self.records[index]
        
        f = pyedflib.EdfReader(self.data_dir+file_name)
        n = f.signals_in_file
        signal_labels = f.getSignalLabels()
        sigbufs = np.zeros((n, f.getNSamples()[0]))
        for i in np.arange(n):
                sigbufs[i, :] = f.readSignal(i)
                
        #get labels if seizure. TODO: deal with multiple seizures
        if file_name in self.labelled:
            with open(self.data_dir + file_name.split('/')[0] + '/' + file_name.split('/')[0] + '-summary.txt') as g:
                lines = g.readlines()
                
                found = False
                i = 0
                for line in lines:
                    if file_name.split('/')[1] in line:
                        found = True
                    if found:
                        if i == 4:
                            self.seizure_start = int(line.split(' ')[3])
                        if i == 5:
                            self.seizure_end   = int(line.split(' ')[3])   
                            i = 0
                            found = False
                        i += 1

        labels = np.zeros((1, f.getNSamples()[0]))
        start  = self.sample_rate * self.seizure_start
        end    = self.sample_rate * self.seizure_end
                
        labels[:,start:end] = 1.0
        
        s     = 2 #window in seconds
        split = np.split(sigbufs,s*dataset.sample_rate,axis=1)
        labels= [np.any(ss) for ss in np.split(labels[0],s*dataset.sample_rate)]

        all_X = []
        # calculate the Welch spectrum for each window
        for p_secs in split:
            p_f, p_Sxx = welch(p_secs, fs=dataset.sample_rate, axis=1)
            p_SS = np.log1p(p_Sxx)
            arr = p_SS[:] / np.max(p_SS)
            all_X.append(arr)
        
        x = np.array(all_X)
        x = x.reshape((x.shape[0],x.shape[1]*x.shape[2]))
        
        return x,np.array(labels)

In [122]:
dataset = ChbDataset()

train = dataset.__getitem__(0)
test = dataset.__getitem__(1)

x = np.array(train[0])
labels = train[1]

In [125]:
x = x.reshape((x.shape[0],x.shape[1]*x.shape[2]))

In [127]:
model = XGBClassifier(objective='binary:hinge', learning_rate = 0.1,
              max_depth = 1, n_estimators = 330)

model.fit(x, train[1])
preds = model.predict(test[0])
print(sum(preds==test[1])/len(test[1]))

XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
              colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,
              importance_type='gain', interaction_constraints='',
              learning_rate=0.1, max_delta_step=0, max_depth=1,
              min_child_weight=1, missing=nan, monotone_constraints='()',
              n_estimators=330, n_jobs=8, num_parallel_tree=1,
              objective='binary:hinge', random_state=0, reg_alpha=0,
              reg_lambda=1, scale_pos_weight=None, subsample=1,
              tree_method='exact', validate_parameters=1, verbosity=None)