In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
from torch.utils.data import DataLoader
%matplotlib inline
from matplotlib import pyplot as plt
import importlib as imp
import torch
import torch.optim as optim
import torch.nn as nn
import time
from torchvision import transforms
from torch.utils.data import Dataset
import torch.nn.functional as F
import random
import rasterio
from rasterio.windows import Window
from sklearn.utils import shuffle

import albumentations as A

from scipy.fftpack import fft, ifft
from scipy import signal

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
import sys
sys.path.append('../input/hubmap-util')

%config Completer.use_jedi = False

In [2]:
EEG_PATH = '../input/eegdata/data_EEG_AI.mat'
CWT_PATH = '../input/cwtdata/cwt_3c_data.mat'
SEED = 344
Fs = 250 # 801 points for 3200ms -> 4ms -> 250Hz
times = [4 * t for t in range(0, 801)]

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True # Fix the network according to random seed
    print('Finish seeding with seed {}'.format(seed))
    
seed_everything(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Training on device {}'.format(device))

In [4]:
def get_letter(label):
    return chr(ord('A') + label)

# Data preparation

In [5]:
from scipy.io import loadmat
data_mat = loadmat(EEG_PATH)
cwt_data_mat = loadmat(CWT_PATH)

In [10]:
cwt_data = cwt_data_mat['cwtdata']
del cwt_data_mat

In [13]:
labels= data_mat['label'].squeeze(-1)
data = np.transpose(data_mat['data'], (2,0,1))
labels = labels - 1
del data_mat
labels

# E and I class

ei_idx = np.concatenate((np.where(labels == 4)[0], np.where(labels == 8)[0]), axis=0)
ei_data = data[ei_idx]
ei_labels = labels[ei_idx]
ei_labels[ei_labels == 4] = 0
ei_labels[ei_labels == 8] = 1

# FFT

fft_data = np.array([[fft(curve)for curve in data[i]] for i in range(0, len(data))])
fft_data = np.abs(fft_data) / 800
fft_data.shape

#3200ms 800points -> 4ms -> * 250

fft_data = fft_data[:,:,0:400]
fft_data.shape

fig = plt.figure(figsize=(15, 4 * 26))
for i in range(0, 26):
    ax = plt.subplot(26, 1, i + 1)
    ax.set_title(get_letter(labels[100 * i + 1]))
    for curve in fft_data[100 * i + 1]:
        plt.plot(np.arange(400)*Fs/800, curve)

# BandPass Filter

theta1, theta2 = signal.butter(8 ,2 * 8 / Fs, 'lowpass' )
theta_data = np.array([[signal.filtfilt(theta1, theta2, curve) for curve in data[i]] for i in range(0, len(data))])

beta1, beta2 = signal.butter(12 ,2 * 20 / Fs, 'lowpass' )
beta_data = np.array([[signal.filtfilt(beta1, beta2, curve) for curve in data[i]] for i in range(0, len(data))])

alpha1, alpha2 = signal.butter(16 ,2 * 13 / Fs, 'lowpass' )
alpha_data = np.array([[signal.filtfilt(alpha1, alpha2, curve) for curve in data[i]] for i in range(0, len(data))])

In [None]:
filtered_data = np.concatenate((theta_data, alpha_data, beta_data), axis=1) / 40
del theta_data, alpha_data, beta_data
filtered_data.shape

import gc
gc.collect()

fig = plt.figure(figsize=(15, 4 * 2))

ax = plt.subplot(2, 1, 2)
ax.set_title(get_letter(labels[305]))
plt.plot(filtered_data[305][4])
plt.plot(data[305][4] / 40)

fig = plt.figure(figsize=(15, 4 * 4))
for i in range(0, 4):
    ax = plt.subplot(4, 1, i + 1)
    ax.set_title(get_letter(labels[150 * i + 305]))
    #for curve in filtered_data[100 * i + 1]:
     #   plt.plot(curve)
    plt.plot(filtered_data[150 * i + 305][4])

# Wavelet

In [8]:
import pywt

In [9]:
w_data = data[21][22]

wavename = 'cgau8'
totalscal = 256
fc = pywt.central_frequency(wavename)
cparam = 2 * fc * totalscal
scales = cparam / np.arange(totalscal, 1, -1)
[cwtmatr, frequencies] = pywt.cwt(w_data, scales, wavename, 1.0 / Fs)

idx = np.where(frequencies < 20)

plt.figure(figsize=(15, 8))
plt.subplot(3,1,1)
plt.plot(times, w_data)


beta_idx = np.where(((frequencies >= 15) & (frequencies <= 20)))
alpha_idx = np.where((frequencies >= 8) & (frequencies <= 13))
theta_idx = np.where((frequencies >= 4) & (frequencies <= 8))
print(len(beta_idx[0]))
print(len(alpha_idx[0]))
print(len(theta_idx[0]))

beta_data = np.max(abs(cwtmatr[beta_idx]), axis=0)
alpha_data = np.max(abs(cwtmatr[alpha_idx]), axis=0)
theta_data = np.max(abs(cwtmatr[theta_idx]), axis=0)

plt.subplot(3,1,2)
plt.contourf(times, frequencies[idx], abs(cwtmatr[idx]))

plt.subplots_adjust(hspace=0.4)



plt.subplot(3,1,3)
#plt.contourf(times, frequencies[theta_idx], abs(cwtmatr[theta_idx]))
plt.plot(times, beta_data)
plt.plot(times, alpha_data)
plt.plot(times, theta_data)

plt.show()

In [10]:
import pywt

def do_cwt(w_data):

    wavename = 'cgau8'
    totalscal = 256
    fc = pywt.central_frequency(wavename)
    cparam = 2 * fc * totalscal
    scales = cparam / np.arange(totalscal, 1, -1)
    [cwtmatr, frequencies] = pywt.cwt(w_data, scales, wavename, 1.0 / Fs)
    
    idx = np.where(frequencies < 20)

    beta_idx = np.where(((frequencies >= 15) & (frequencies <= 20)))
    alpha_idx = np.where((frequencies >= 8) & (frequencies <= 13))
    theta_idx = np.where((frequencies >= 4) & (frequencies <= 8))

    beta_data = np.max(abs(cwtmatr[beta_idx]), axis=0)
    alpha_data = np.max(abs(cwtmatr[alpha_idx]), axis=0)
    theta_data = np.max(abs(cwtmatr[theta_idx]), axis=0)
    
    return np.stack((beta_data, alpha_data, theta_data)) / 20 - 10
    #return abs(cwtmatr[idx])

In [11]:
cwt_data = do_cwt(data[0][0])
cwt_data.shape

In [13]:
'''
from scipy.io import savemat
cwt_data = np.array([[do_cwt(curve)for curve in data[i]] for i in range(0, len(data))])
savemat('./cwtdata.mat',{'cwtdata': cwt_data})
'''
#pickle

In [21]:
#cwt3c_data = np.array([np.concatenate([curve for curve in cwt_data[i]], axis = 0) for i in range(0, len(data))])

In [23]:
#cwt3c_data.shape
#savemat('./cwt_3c_data.mat',{'cwtdata': cwt3c_data})

# Dataset

In [14]:
t_data = np.concatenate([cwt_data, data], axis = 1)
del cwt_data, data

t_labels = labels

from sklearn.model_selection import train_test_split
index = [i for i in range(len(t_labels))]
train_idx, test_idx = train_test_split(index, test_size=0.15, random_state=66)
eegtrain = {'data': [t_data[_i] for _i in train_idx],
           'labels': [t_labels[_i] for _i in train_idx]}
eegtest = {'data': [t_data[_i] for _i in test_idx],
           'labels': [t_labels[_i] for _i in test_idx]}

In [15]:
letter_idx = np.where(np.array(eegtrain['labels']) == 0)
len(letter_idx[0])

# Average

a_data = list([])
a_labels = list([])
for i in range(0, 26):
    letter_idx = np.where(np.array(eegtrain['labels']) == i)
    length = len(letter_idx[0])
    letter_data = np.array([eegtrain['data'][x] for x in letter_idx[0]])
    letter_label = [eegtrain['labels'][x] for x in letter_idx[0]]
    a_letter_data = [np.mean(letter_data[np.random.randint(length, size=(3))], axis=0) for j in range(0, length)]
    a_data.extend(a_letter_data)
    a_labels.extend(letter_label)

a_labels = np.array(a_labels)
a_data = np.array(a_data)
a_labels

fig = plt.figure(figsize=(15, 4 * 26)) 
for i in range(0, 26): 
    ax = plt.subplot(26, 1, i + 1) 
    ax.set_title(get_letter(a_labels[100 * i + 1])) 
    for curve in a_data[100 * i + 1]: 
        plt.plot(curve)

a_test_data = list([])
a_test_labels = list([])
for i in range(0, 26):
    letter_idx = np.where(np.array(eegtest['labels']) == i)
    length = len(letter_idx[0])
    letter_data = np.array([eegtest['data'][x] for x in letter_idx[0]])
    letter_label = [eegtest['labels'][x] for x in letter_idx[0]]
    a_letter_data = [np.mean(letter_data[np.random.randint(length, size=(2))], axis=0) for j in range(0, length)]
    a_test_data.extend(a_letter_data)
    a_test_labels.extend(letter_label)

a_test_labels = np.array(a_test_labels)
a_test_data = np.array(a_test_data)
a_test_labels.shape

len(train_idx), len(test_idx), len(eegtrain['data'])

# Dataset

In [16]:
class EEGdataset(Dataset):

    def __init__(self, eeg, label, device):
        self.eeg = eeg
        self.label = label
        self.device = device

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        eeg_tensor = torch.from_numpy(self.eeg[idx]).to(self.device)

        return eeg_tensor, self.label[idx]

In [17]:
from torch.utils.data import Dataset, DataLoader, random_split


#eegtrainset = EEGdataset(a_data, a_labels, device)
eegtrainset = EEGdataset(eegtrain['data'], eegtrain['labels'], device)

batch_size = 8

NUM_TRAIN = int(0.8*len(eegtrainset))
NUM_VAL = len(eegtrainset) - NUM_TRAIN

eeg_train, eeg_valid = random_split(eegtrainset, [NUM_TRAIN, NUM_VAL])

train_loader = DataLoader(eeg_train, batch_size = batch_size, shuffle=True)
val_loader = DataLoader(eeg_valid, batch_size = batch_size, shuffle=True)

In [21]:
eeg, label = next(iter(val_loader))
fig = plt.figure(figsize=(20, 5))
print(label)
print(eeg.shape)
plt.title(get_letter(label[0] + 1))
for curve in eeg.to('cpu')[0]:
    plt.plot(curve)

# CNN model

In [22]:
momentum = 0.4
track_running_stats = True
affine = True
class NetCNN(nn.Module):
    def __init__(self, channels_in=24, channels_out=26):
        super().__init__()

        self.seq = nn.Sequential(
            nn.Conv1d(channels_in, 48, 5, stride=3, padding=0),
            nn.BatchNorm1d(48, momentum=momentum, track_running_stats=track_running_stats, affine=affine),
            nn.ReLU(),
            nn.MaxPool1d(4, stride=2),
            nn.Conv1d(48, 96, 5, stride=2, padding=0),
            nn.BatchNorm1d(96, momentum=momentum, track_running_stats=track_running_stats, affine=affine),
            nn.ReLU(),
            nn.MaxPool1d(4, stride=2),
            nn.Conv1d(96, 48, 5, stride=2, padding=0),
            nn.BatchNorm1d(48, momentum=momentum, track_running_stats=track_running_stats, affine=affine),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(14 * 48, 80),
            nn.ReLU(),
            nn.Linear(80, channels_out)
        )
        
    def forward(self, x):
        x = self.seq(x)
        return x

# CNN2D Model

In [63]:
momentum = 0.4
track_running_stats = True
affine = True
class NetCNN2D(nn.Module):
    def __init__(self):
        super().__init__()

        self.seq = nn.Sequential(
            nn.Conv2d(1, 48, 3, stride=2, padding=0),
            nn.BatchNorm2d(48, momentum=momentum, track_running_stats=track_running_stats, affine=affine),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(48, 96, 3, stride=2, padding=0),
            nn.BatchNorm2d(96, momentum=momentum, track_running_stats=track_running_stats, affine=affine),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),

            nn.Flatten(),
            nn.Linear(4704, 80),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(80, 26)
        )
        
    def forward(self, x):
        x = self.seq(x.unsqueeze(1))
        return x

# DNN Model

In [52]:
momentum = 0.4
track_running_stats = True
affine = True
class NetDNN(nn.Module):
    def __init__(self, channels_in=24, channels_out=26):
        super().__init__()

        self.seq = nn.Sequential(
            nn.MaxPool1d(4, stride=4),
            nn.Flatten(),
            nn.Linear(channels_in * 200, 1024),
            nn.BatchNorm1d(1024, momentum=momentum, track_running_stats=track_running_stats, affine=affine),
            nn.ReLU(),
            nn.Linear(1024, 128),
            nn.BatchNorm1d(128, momentum=momentum, track_running_stats=track_running_stats, affine=affine),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(128, channels_out)
        )
        
    def forward(self, x):
        x = self.seq(x)
        return x

In [23]:
data_type = eeg_train[0][0].dtype
net = NetCNN(channels_in=96, channels_out=26).to(device, data_type)
print(net)
net(eeg)[1]

# Train

In [24]:
def TrainClassifer(model,trn_dl,val_dl,optimizer, scheduler=None,
                   n_eopchs=20, device='cpu'):
 
    loss_fn = nn.CrossEntropyLoss()
    model.to(device)
    best_model = model
    acc_model = model
    best_val = 999.0
    best_acc = 0.0

    for epoch in range(1, n_eopchs + 1):
        loss_train = 0.0
        model.train()
        for imgs, labels in trn_dl:
            imgs = imgs.to(device)
            labels = labels.to(device)

            outputs = model(imgs)
            loss = loss_fn(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_train += loss.item()
        loss_val = 0.0
        correct_val = 0.0
        model.eval()
        for _, (imgs, labels) in enumerate(val_dl):
            imgs = imgs.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                outputs = model(imgs)
                loss = loss_fn(outputs, labels)
                loss_val += loss.item()
                
                correct_val += torch.sum(torch.argmax(outputs, dim=1) == labels)
        
        if (loss_val / len(val_dl)) < best_val:
            best_val = (loss_val / len(val_dl))
            best_model.load_state_dict(model.state_dict())
        if 100 * correct_val / (len(val_dl) * val_dl.batch_size) > best_acc:
            best_acc = 100 * correct_val / (len(val_dl) * val_dl.batch_size)
            acc_model.load_state_dict(model.state_dict())
        if scheduler != None:
            scheduler.step()

        print('{} Eopch {}, Training Loss {}, Val Loss {}, Val Accuracy {}'.format(time.strftime("%Y-%m-%d %H:%M:%S",time.localtime()),
                                                                  epoch, loss_train / len(trn_dl), loss_val / len(val_dl), 
                                                                                   100 * correct_val / (len(val_dl) * val_dl.batch_size)))
    torch.save(best_model.state_dict(), './best.pt')
    torch.save(best_model.state_dict(), './acc.pt')
    print('Finish training: best_val:{} best_acc:{}'.format(best_val, best_acc))

In [25]:
optimizer = optim.SGD(net.parameters(), lr=2e-3)
TrainClassifer(model=net,trn_dl=train_loader,val_dl=val_loader,optimizer=optimizer, 
               scheduler=None, n_eopchs=100, device=device)

In [26]:
net(eeg)[1]

In [27]:
#eegtestset = EEGdataset(a_test_data,a_test_labels, device)
eegtestset = EEGdataset(eegtest['data'],eegtest['labels'], device)
test_loader = DataLoader(eegtestset, batch_size = batch_size, shuffle=True)

In [28]:
correct_test = 0.
net.eval()
for _, (imgs, labels) in enumerate(test_loader):
        imgs = imgs.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            outputs = net(imgs)
                
            correct_test += torch.sum(torch.argmax(outputs, dim=1) == labels)
print(100 * correct_test / (len(test_loader) * test_loader.batch_size))
