In [60]:
%load_ext autoreload
%autoreload 2

# BCI
import sys
import os
import warnings
import numpy as np
import numpy.matlib as npm
import pandas as pd
import matplotlib.pyplot as plt
import scipy.io as sio
import pickle
from sklearn.model_selection import KFold

from keras.utils.np_utils import to_categorical
from keras import optimizers
from keras.losses import categorical_crossentropy

from bcilib import ssvep_utils as su
warnings.filterwarnings('ignore')

# SNN
import pickle 
import time 

from snnlib.spiking_model import *

# Debug
from pprint import pprint

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [61]:
def get_training_data(features_data):
    features_data = np.reshape(features_data, (features_data.shape[0], features_data.shape[1], 
                                               features_data.shape[2], 
                                               features_data.shape[3]*features_data.shape[4]))
    train_data = features_data[:, :, 0, :].T
    for target in range(1, features_data.shape[2]):
        train_data = np.vstack([train_data, np.squeeze(features_data[:, :, target, :]).T])

    train_data = np.reshape(train_data, (train_data.shape[0], train_data.shape[1], 
                                         train_data.shape[2], 1))
    total_epochs_per_class = features_data.shape[3]
    features_data = []
    class_labels = np.arange(CNN_PARAMS['num_classes'])
    labels = (npm.repmat(class_labels, total_epochs_per_class, 1).T).ravel()
    labels = to_categorical(labels)
    
    return train_data, labels

In [62]:
data_path = os.path.abspath('./data')

CNN_PARAMS = {
    'batch_size': 64,
    'epochs': 50,
    'droprate': 0.25,
    'learning_rate': 0.001,
    'lr_decay': 0.0,
    'l2_lambda': 0.0001,
    'momentum': 0.9,
    'kernel_f': 10,
    'n_ch': 8,
    'num_classes': 12}

FFT_PARAMS = {
    'resolution': 0.2930,
    'start_frequency': 3.0,
    'end_frequency': 35.0,
    'sampling_rate': 256
}

window_len = 1
shift_len = 1
    
all_acc = np.zeros((10, 1))

magnitude_spectrum_features = dict()
complex_spectrum_features = dict()

mcnn_training_data = dict()
ccnn_training_data = dict()

mcnn_results = dict()
ccnn_results = dict()

In [63]:
all_segmented_data = dict()
for subject in range(0, 10):
    dataset = sio.loadmat(f'{data_path}/s{subject+1}.mat')
    eeg = np.array(dataset['eeg'], dtype='float32')
    
    CNN_PARAMS['num_classes'] = eeg.shape[0]
    CNN_PARAMS['n_ch'] = eeg.shape[1]
    total_trial_len = eeg.shape[2]
    num_trials = eeg.shape[3]
    sample_rate = 256

    filtered_data = su.get_filtered_eeg(eeg, 6, 80, 4, sample_rate)
    pprint(filtered_data.shape)
    all_segmented_data[f's{subject+1}'] = su.get_segmented_epochs(filtered_data, window_len, 
                                                                  shift_len, sample_rate)
    pprint(all_segmented_data["s1"].shape)

(12, 8, 1023, 15)
(12, 8, 15, 4, 256)
(12, 8, 1023, 15)
(12, 8, 15, 4, 256)
(12, 8, 1023, 15)
(12, 8, 15, 4, 256)
(12, 8, 1023, 15)
(12, 8, 15, 4, 256)
(12, 8, 1023, 15)
(12, 8, 15, 4, 256)
(12, 8, 1023, 15)
(12, 8, 15, 4, 256)
(12, 8, 1023, 15)
(12, 8, 15, 4, 256)
(12, 8, 1023, 15)
(12, 8, 15, 4, 256)
(12, 8, 1023, 15)
(12, 8, 15, 4, 256)
(12, 8, 1023, 15)
(12, 8, 15, 4, 256)


In [64]:
for subject in all_segmented_data.keys():
    magnitude_spectrum_features[subject] = su.magnitude_spectrum_features(all_segmented_data[subject], 
                                                                          FFT_PARAMS)
    pprint(magnitude_spectrum_features[subject].shape)
    complex_spectrum_features[subject] = su.complex_spectrum_features(all_segmented_data[subject], 
                                                                      FFT_PARAMS)
    pprint(complex_spectrum_features[subject].shape)

(110, 8, 12, 15, 4)
(220, 8, 12, 15, 4)
(110, 8, 12, 15, 4)
(220, 8, 12, 15, 4)
(110, 8, 12, 15, 4)
(220, 8, 12, 15, 4)
(110, 8, 12, 15, 4)
(220, 8, 12, 15, 4)
(110, 8, 12, 15, 4)
(220, 8, 12, 15, 4)
(110, 8, 12, 15, 4)
(220, 8, 12, 15, 4)
(110, 8, 12, 15, 4)
(220, 8, 12, 15, 4)
(110, 8, 12, 15, 4)
(220, 8, 12, 15, 4)
(110, 8, 12, 15, 4)
(220, 8, 12, 15, 4)
(110, 8, 12, 15, 4)
(220, 8, 12, 15, 4)


In [65]:
for subject in all_segmented_data.keys():
    mcnn_training_data[subject] = dict()
    ccnn_training_data[subject] = dict()
    train_data, labels = get_training_data(magnitude_spectrum_features[subject])
    mcnn_training_data[subject]['train_data'] = train_data
    mcnn_training_data[subject]['label'] = labels
    
    train_data, labels = get_training_data(complex_spectrum_features[subject])
    ccnn_training_data[subject]['train_data'] = train_data
    ccnn_training_data[subject]['label'] = labels

In [66]:
for subject in mcnn_training_data.keys():
    train_data = mcnn_training_data[subject]['train_data']
    labels = mcnn_training_data[subject]['label']

    pprint(train_data.shape)
    pprint(labels.shape)

    snn = SCNN()
    snn.to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(snn.parameters(), lr=learning_rate)

    for real_epoch in range(num_epochs):
        running_loss = 0
        start_time = time.time()
        for epoch in range(5):
            for i, (images, labels) in enumerate(zip(train_data, labels)):
                snn.zero_grad()
                optimizer.zero_grad()

                print("image shape: ", images.shape)
                print("label shape: ", labels.shape)
                pprint(labels)

                eeg_image = torch.empty((28, 28)) #Image shape
                eeg_image = torch.from_numpy(np.reshape(images, (28, 28)))

                images2 = torch.empty((images.shape[0] * 2, images.shape[1], images.shape[2]))
                labels2 = torch.empty((images.shape[0] * 2), dtype=torch.int64)

                for j in range(images.shape[0]):
                    img0 = np.array(images[j])
                    images2[j * 2, :] = torch.from_numpy(img0)
                    labels2[j * 2] = int(labels[j])

                images2 = images2.float().to(device)

                outputs = snn(images2)
                labels_ = torch.zeros(batch_size * 2, 20).scatter_(1, labels2.view(-1, 1), 1)
                loss = criterion(outputs.cpu(), labels_)
                running_loss += loss.item()
                loss.backward()
                optimizer.step()
                if (i+1) % 100 == 0:
                    print('Real_Epoch [%d/%d], Epoch [%d/%d], Step [%d/%d], Loss: %.5f'
                            %( real_epoch, num_epochs, epoch, 5, i+1, len(train_dataset)//batch_size, running_loss))
                    running_loss = 0
                    print('Time elasped:', time.time() - start_time)

        # ================================== Test ==============================
        correct = 0
        total = 0
        optimizer = lr_scheduler(optimizer, epoch, learning_rate, 40)
        cm = np.zeros((20, 20), dtype=np.int32)

        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(test_loader):
                images2 = torch.empty((images.shape[0] * 2, 10, images.shape[2], images.shape[3]))
                labels2 = torch.empty((images.shape[0] * 2), dtype=torch.int64)
                for j in range(images.shape[0]):
                    img0 = images[j, 0, :, :].numpy()
                    rows, cols = img0.shape
                    theta1 = 0
                    theta2 = 360
                    for k in range(10):
                        if k == 0 or k == 9:
                            images2[j * 2, k, :, :] = torch.from_numpy(img0)
                        else:

                            M = cv2.getRotationMatrix2D((rows / 2, cols / 2), theta1 + int(random.randrange(0,360,36)), 1.0)  # rotate counter clock-wise
                            # M = np.float32([[1 - 0.1 * k, 0, 0], [0, 1 - 0.1 * k, 0]])     # zoom out
                            # M = np.float32([[1 - 0.05 * k, 0, 0], [0, 1 - 0.05 * k, 0]])     # zoom out less aggressive
                            dst = cv2.warpAffine(img0, M, (cols, rows))
                            images2[j * 2, k, :, :] = torch.from_numpy(dst)
                        labels2[j * 2] = labels[j]
                    for k in range(1, 11):
                        if k == 0 or k == 9:
                            images2[j * 2, k, :, :] = torch.from_numpy(img0)
                        else:

                            M = cv2.getRotationMatrix2D((rows / 2, cols / 2),theta2 - int(random.randrange(0,360,36)) * 36, 1.0)  # rotate clock-wise
                            # M = np.float32([[0.1 * k, 0, 0], [0, 0.1 * k, 0]])    # zoom in
                            # M = np.float32([[0.5 + 0.05 * k, 0, 0], [0, 0.5 + 0.05 * k, 0]])    # zoom in less aggressive
                            dst = cv2.warpAffine(img0, M, (cols, rows))
                            images2[j * 2 + 1, k - 1, :, :] = torch.from_numpy(dst)
                        labels2[j * 2 + 1] = labels[j] + 10
                inputs = images2.to(device)
                optimizer.zero_grad()
                outputs = snn(inputs)
                labels_ = torch.zeros(batch_size * 2, 20).scatter_(1, labels2.view(-1, 1), 1)
                loss = criterion(outputs.cpu(), labels_)
                _, predicted = outputs.cpu().max(1)

                # ----- showing confussion matrix -----

                cm += confusion_matrix(labels2, predicted)
                # ------ showing some of the predictions -----
                # for image, label in zip(inputs, predicted):
                #     for img0 in image.cpu().numpy():
                #         cv2.imshow('image', img0)
                #         cv2.waitKey(100)
                #     print(label.cpu().numpy())

                total += float(labels2.size(0))
                correct += float(predicted.eq(labels2).sum().item())
                if batch_idx % 100 == 0:
                    acc = 100. * float(correct) / float(total)
                    print(batch_idx, len(test_loader), ' Acc: %.5f' % acc)
        class_names = ['0_ccw', '1_ccw', '2_ccw', '3_ccw', '4_ccw',
                '5_ccw', '6_ccw', '7_ccw', '8_ccw', '9_ccw',
                '0_cw', '1_cw', '2_cw', '3_cw', '4_cw',
                '5_cw', '6_cw', '7_cw', '8_cw', '9_cw']
        plot_confusion_matrix(cm, class_names)
        print('Iters:', real_epoch, '\n\n\n')
        print('Test Accuracy of the model on the 10000 test images: %.3f' % (100 * correct / total))
        acc = 100. * float(correct) / float(total)
        acc_record.append(acc)
        if real_epoch % 5 == 0:
            print(acc)
            print('Saving..')
            state = {
                'net': snn.state_dict(),
                'acc': acc,
                'epoch': epoch,
                'acc_record': acc_record,
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, './checkpoint/ckpt' + names + '.t7')
            best_acc = acc

(720, 8, 110, 1)
(720, 12)
image shape:  (8, 110, 1)
label shape:  (12,)
array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)


RuntimeError: Expected 4-dimensional input for 4-dimensional weight [16, 1, 8, 1], but got 3-dimensional input of size [1, 110, 1] instead