# imports

In [50]:
import pandas as pd
import numpy as np
import utils
import features
from sklearn.ensemble import RandomForestClassifier

import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
import classifier

# Для отображения графиков
import matplotlib.pyplot as plt
import seaborn as sns

from functools import partial
# constats
FILENAME = 'data/cleaned_data.csv'
# FILENAME = 'data/cleaned_data.csv'
SEED = 1
FREQ = 128
CHUNK_SIZE = 5

#loading file
data = pd.read_csv(FILENAME)

print(data.shape)

(128000, 15)


# Encoding and preparing

In [51]:
if FILENAME == 'data/data.csv':
    data.drop(columns='iter', inplace=True)

# sort values by class
data = data.sort_values('class', kind = 'mergesort')

# encode class column
data['class'], encode_dict = utils.encode_column(data['class'])

data.head()

Unnamed: 0,class,F3,FC5,AF3,F7,T7,P7,O1,O2,P8,T8,F8,AF4,FC6,F4
0,0,4195.25641,3400.512819,3465.641024,3137.05128,4166.794872,4184.230769,4176.538461,4213.846154,4198.333333,4181.794872,3876.538461,3250.769229,3253.333331,4174.487179
1,0,4195.897436,3405.512819,3474.358973,3141.794869,4168.717949,4188.717949,4173.717949,4206.666667,4187.564102,4176.666667,3891.794871,3302.05128,3268.205126,4182.307692
2,0,4194.358974,3410.512819,3434.230768,3133.717946,4169.615384,4189.74359,4178.461538,4198.333333,4162.051282,4173.076923,3881.02564,3258.076921,3256.282049,4168.461538
3,0,4182.948718,3406.282049,3445.128203,3132.82051,4168.076923,4191.794872,4181.666667,4204.871795,4170.384615,4176.794872,3865.384615,3228.333331,3239.743588,4161.410256
4,0,4184.102564,3405.512819,3482.435896,3133.461536,4163.717949,4195.128205,4184.74359,4216.538461,4190.769231,4180.641026,3876.666666,3281.666665,3251.410254,4171.410256


In [52]:
# create train, test from array
arr = data.values

train, test = utils.eeg_train_test_split(arr, chunk_size= CHUNK_SIZE * FREQ, test_size=0.3, random_state=SEED)
print(train.shape)
print(test.shape)

X_train, X_train_fft, y_train = utils.prepare_train(train, shift=64, save_path='data/')
X_test, X_test_fft, y_test = utils.prepare_data(test, save_path='data/')

print(X_train.shape, X_train_fft.shape, y_train.shape)
print(X_test.shape, X_test_fft.shape, y_test.shape)

(89600, 15)
(38400, 15)
(1384, 14, 128) (1384, 14, 65) (1384,)
(584, 14, 128) (584, 14, 65) (584,)


In [53]:
datasets = utils.create_dataset((X_train,X_train_fft),y_train),\
           utils.create_dataset((X_test,X_test_fft), y_test)

In [54]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [None]:
raw_feat = X_train.shape[1]
fft_feat = X_train_fft.shape[1]

trn_dl, val_dl = utils.create_loaders(datasets, bs=128)

trn_sz = len(y_train)

lr = 0.003
n_epochs = 3000
iterations_per_epoch = len(trn_dl)
period = n_epochs * iterations_per_epoch
num_classes = 5
best_acc = 0
patience, trials = 500, 0
base = 1
step = 2
iteration = 0
loss_history = []
acc_history = []

model = classifier.Classifier(raw_feat, fft_feat, num_classes, drop = 0.5).to(device)
criterion = nn.CrossEntropyLoss(reduction='sum')
opt = optim.Adam(model.parameters(), lr=lr)

sched = classifier.Scheduler(opt, partial(classifier.one_cycle, t_max=period, pivot=0.1))

print('Start model training')

for epoch in range(1, n_epochs + 1):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(trn_dl):
        iteration += 1
        x_raw, x_fft, y_batch = [t.to(device) for t in batch]
        sched.step(iteration)  # update the learning rate
        opt.zero_grad()
        out = model(x_raw, x_fft)
        loss = criterion(out, y_batch)
        epoch_loss += loss.item()
        loss.backward()
        opt.step()

    epoch_loss /= trn_sz
    loss_history.append(epoch_loss)

    model.eval()
    correct, total = 0, 0
    for batch in val_dl:
        x_raw, x_fft, y_batch = [t.to(device) for t in batch]
        out = model(x_raw, x_fft)
        preds = F.log_softmax(out, dim=1).argmax(dim=1)
        total += y_batch.size(0)
        correct += (preds == y_batch).sum().item()

    acc = correct / total
    acc_history.append(acc)

    if epoch % base == 0:
        print(f'Epoch: {epoch:3d}. Loss: {epoch_loss:.4f}. Acc.: {acc:2.2%}')
        base *= step

    if acc > best_acc:
        trials = 0
        best_acc = acc
        torch.save(model.state_dict(), 'model/' + 'best.pth')
        print(f'Epoch {epoch} best model saved with accuracy: {best_acc:2.2%}')
    else:
        trials += 1
        if trials >= patience:
            print(f'Early stopping on epoch {epoch}')
            break

print('Done!')

Start model training
Epoch:   1. Loss: 1.6265. Acc.: 20.21%
Epoch 1 best model saved with accuracy: 20.21%
Epoch:   2. Loss: 1.5996. Acc.: 20.38%
Epoch 2 best model saved with accuracy: 20.38%
Epoch 3 best model saved with accuracy: 28.94%
Epoch:   4. Loss: 1.5466. Acc.: 26.88%
Epoch 7 best model saved with accuracy: 38.36%
Epoch:   8. Loss: 1.1448. Acc.: 34.59%
Epoch 9 best model saved with accuracy: 39.38%
Epoch 10 best model saved with accuracy: 39.73%
Epoch 11 best model saved with accuracy: 39.90%
Epoch 12 best model saved with accuracy: 40.41%
Epoch 15 best model saved with accuracy: 40.75%
Epoch:  16. Loss: 0.5146. Acc.: 38.87%
Epoch:  32. Loss: 0.2213. Acc.: 20.38%
Epoch:  64. Loss: 0.1245. Acc.: 21.58%
Epoch: 128. Loss: 0.0529. Acc.: 24.66%
Epoch 192 best model saved with accuracy: 41.10%
Epoch 199 best model saved with accuracy: 42.29%
Epoch 205 best model saved with accuracy: 42.98%
Epoch 206 best model saved with accuracy: 43.15%
Epoch 224 best model saved with accuracy: 44

In [None]:
def score_model(model, metric, data):
    model.eval()  # testing mode
    scores = 0
    for X_batch, Y_label in data:
        Y_pred = model.forward(X_batch.to(device)).float()
        scores += metric(Y_pred, Y_label.to(device)).mean().item()

    return scores/len(data)

In [None]:
def smooth(y, box_pts):
    box = np.ones(box_pts)/box_pts
    y_smooth = np.convolve(y, box, mode='same')
    return y_smooth

In [None]:
f, ax = plt.subplots(1, 2, figsize=(12, 4))

ax[0].plot(loss_history, label='loss')
ax[0].set_title('Validation Loss History')
ax[0].set_xlabel('Epoch no.')
ax[0].set_ylabel('Loss')

ax[1].plot(smooth(acc_history, 5)[:-2], label='acc')
ax[1].set_title('Validation Accuracy History')
ax[1].set_xlabel('Epoch no.')
ax[1].set_ylabel('Accuracy')