In [1]:
import random

# BCI
import sys
import os
import warnings
import numpy as np
import numpy.matlib as npm
import pandas as pd
import matplotlib.pyplot as plt
import scipy.io as sio
from sklearn.model_selection import KFold

from keras.utils.np_utils import to_categorical
from keras import optimizers
from keras.losses import categorical_crossentropy

from bcilib import ssvep_utils as su
warnings.filterwarnings('ignore')

# SNN
from snnlib.spiking_model_pure import *
import time
import torch
from snnlib import snn_utils

# Debug
from pprint import pprint

device cuda


In [2]:
def get_training_data(features_data):
    features_data = np.reshape(features_data, (features_data.shape[0], features_data.shape[1], 
                                               features_data.shape[2], 
                                               features_data.shape[3]*features_data.shape[4]))
    train_data = features_data[:, :, 0, :].T
    for target in range(1, features_data.shape[2]):
        train_data = np.vstack([train_data, np.squeeze(features_data[:, :, target, :]).T])

    train_data = np.reshape(train_data, (train_data.shape[0], train_data.shape[1], 
                                         train_data.shape[2], 1))
    total_epochs_per_class = features_data.shape[3]
    features_data = []
    class_labels = np.arange(CNN_PARAMS['num_classes'])
    labels = (npm.repmat(class_labels, total_epochs_per_class, 1).T).ravel()
    labels = to_categorical(labels)
    
    return train_data, labels

In [3]:
data_path = os.path.abspath('./data')

CNN_PARAMS = {
    'batch_size': 64,
    'epochs': 50,
    'droprate': 0.25,
    'learning_rate': 0.001,
    'lr_decay': 0.0,
    'l2_lambda': 0.0001,
    'momentum': 0.9,
    'kernel_f': 10,
    'n_ch': 8,
    'num_classes': 12}

FFT_PARAMS = {
    'resolution': 0.2930,
    'start_frequency': 3.0,
    'end_frequency': 35.0,
    'sampling_rate': 256
}

window_len = 1
shift_len = 1
    
all_acc = np.zeros((10, 1))

magnitude_spectrum_features = dict()
complex_spectrum_features = dict()

mcnn_training_data = dict()
ccnn_training_data = dict()

mcnn_results = dict()
ccnn_results = dict()

In [4]:
all_segmented_data = dict()
for subject in range(0, 10):
    dataset = sio.loadmat(f'{data_path}/s{subject+1}.mat')
    eeg = np.array(dataset['eeg'], dtype='float32')
    
    CNN_PARAMS['num_classes'] = eeg.shape[0]
    CNN_PARAMS['n_ch'] = eeg.shape[1]
    total_trial_len = eeg.shape[2]
    num_trials = eeg.shape[3]
    sample_rate = 256

    filtered_data = su.get_filtered_eeg(eeg, 6, 80, 4, sample_rate)
    #pprint(filtered_data.shape)
    all_segmented_data[f's{subject+1}'] = su.get_segmented_epochs(filtered_data, window_len, 
                                                                  shift_len, sample_rate)
    #pprint(all_segmented_data["s1"].shape)

In [5]:
for subject in all_segmented_data.keys():
    magnitude_spectrum_features[subject] = su.magnitude_spectrum_features(all_segmented_data[subject], 
                                                                          FFT_PARAMS)
    #pprint(magnitude_spectrum_features[subject].shape)
    complex_spectrum_features[subject] = su.complex_spectrum_features(all_segmented_data[subject], 
                                                                      FFT_PARAMS)
    #pprint(complex_spectrum_features[subject].shape)

In [6]:
for subject in all_segmented_data.keys():
    mcnn_training_data[subject] = dict()
    ccnn_training_data[subject] = dict()
    train_data, labels = get_training_data(magnitude_spectrum_features[subject])
    mcnn_training_data[subject]['train_data'] = train_data
    mcnn_training_data[subject]['label'] = labels
    
    train_data, labels = get_training_data(complex_spectrum_features[subject])
    ccnn_training_data[subject]['train_data'] = train_data
    ccnn_training_data[subject]['label'] = labels

In [7]:
subject_accuracy = []

snn = SCNN()
snn.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(snn.parameters(), lr=learning_rate)

In [8]:
for subject in mcnn_training_data.keys():
    #Get dataset and labels for a subject
    dataset = mcnn_training_data[subject]['train_data']
    labels = mcnn_training_data[subject]['label']
    
    # Convert floating points to Int64. 
    # Because snn_utils.image2spiketrain uses _scatter function 
    # and the function expect Int64.
    labels = np.array(labels, dtype=np.int64)
    # Convert labels vectors to scalar values
    labels = np.array(list(map(lambda x: np.where(x == 1)[0][0], labels)))
    
#     print(labels[56])
    
    index_array = list(range(600)) # 600 is the index of first data point that is class 10
    random.shuffle(index_array)
    
#     print("Old dataset shape: ", dataset.shape)
#     print("Old labels shape: ", labels.shape)
    dataset = np.reshape(dataset[:, :, :98, :], (dataset.shape[0], 28, 28))
    
#     print("New dataset shape: ", dataset.shape)
    
    train_data = torch.FloatTensor([dataset[x] for x in index_array[:420]])
    test_data = torch.FloatTensor([dataset[x] for x in index_array[420:]])
    train_labels = torch.LongTensor([labels[x] for x in index_array[:420]])
    test_labels = torch.LongTensor([labels[x] for x in index_array[420:]])

#     print(train_labels)
    
    print(test_data.shape)
    print(train_data.shape)
    print(test_labels.shape)
    print(train_labels.shape)
    
    test_data, test_labels = snn_utils.image2spiketrain(test_data, test_labels, max_duration=100, gain=20)
    train_data, train_labels = snn_utils.image2spiketrain(train_data, train_labels, max_duration=100, gain=20)
    
#     # [1000, x, ...] -> [x, 1000, ...]
#     test_data = torch.FloatTensor([test_data[:, x, :, :] for x in range(test_data.shape[1])])
#     train_data = torch.FloatTensor([train_data[:, x, :, :] for x in range(train_data.shape[1])])
#     test_labels = torch.FloatTensor([test_labels[:, x, :] for x in range(test_labels.shape[1])])
#     train_labels = torch.FloatTensor([train_labels[:, x, :] for x in range(train_labels.shape[1])])
    
#     print(test_data.shape)
#     print(train_data.shape)
#     print(test_labels.shape)
#     print(train_labels.shape)

    names = 'spiking_model'
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch
    acc_record = list([])
    loss_train_record = list([])
    loss_test_record = list([])

    pprint(train_data.shape)
    pprint(train_labels.shape)

    for real_epoch in range(num_epochs):
        running_loss = 0
        start_time = time.time()
        for epoch in range(5):
            for i, (images, labels) in enumerate(zip(train_data, train_labels)):
                print("NUM: ", i)
                
                # images2 = torch.empty((images.shape[0] * 2, 10, images.shape[2], images.shape[3]))
                images2 = torch.empty(images.shape[0], 10, images.shape[2], images.shape[3])
                # labels2 = torch.empty((images.shape[0] * 2), dtype=torch.int64)
                labels2 = torch.empty((images.shape[0]), dtype=torch.int64)
                
                for j in range(10):
                    images2[:, j, :, :] = torch.from_numpy(images[:, 0, :, :])
                
                labels2 = torch.from_numpy(np.array(list(map(lambda x: np.where(x == 1)[0][0], labels))))
                
                # ----
                snn.zero_grad()
                optimizer.zero_grad()

                images2 = images2.float().to(device)

                # print("MAIN SNN Input shape: ", images2.shape)
                # print(labels2)
                # print(labels2.view(-1, 1))
                # print(torch.zeros(batch_size * 2, 20).scatter_(1, labels2.view(-1, 1), 1))
                
                outputs = snn(images2)
                print(outputs)
                # print(outputs.shape)
                labels_ = torch.zeros(batch_size * 2, 20).scatter_(1, labels2.view(-1, 1), 1)
                loss = criterion(outputs.cpu(), labels_)
                running_loss += loss.item()
                loss.backward()
                optimizer.step()
                if (i+1) % 100 == 0:
                    print('Real_Epoch [%d/%d], Epoch [%d/%d], Loss: %.5f'
                            %( real_epoch, num_epochs, epoch, 5, running_loss))
                    running_loss = 0
                    print('Time elasped:', time.time() - start_time)

        # ================================== Test ==============================
        correct = 0
        total = 0
        optimizer = lr_scheduler(optimizer, epoch, learning_rate, 40)
        cm = np.zeros((20, 20), dtype=np.int32)

        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(zip(test_data, test_labels)):
                # images2 = torch.empty((images.shape[0] * 2, 10, images.shape[2], images.shape[3]))
                images2 = torch.empty(images.shape[0], 10, images.shape[2], images.shape[3])
                # labels2 = torch.empty((images.shape[0] * 2), dtype=torch.int64)
                labels2 = torch.empty((images.shape[0]), dtype=torch.int64)
                
                for j in range(10):
                    images2[:, j, :, :] = torch.from_numpy(images[:, 0, :, :])
                
                labels2 = torch.from_numpy(np.array(list(map(lambda x: np.where(x == 1)[0][0], labels))))
                
                inputs = images2.float().to(device)
                optimizer.zero_grad()
                outputs = snn(inputs)
                print(outputs)
                labels_ = torch.zeros(batch_size * 2, 20).scatter_(1, labels2.view(-1, 1), 1)
                loss = criterion(outputs.cpu(), labels_)
                _, predicted = outputs.cpu().max(1)

                # ----- showing confussion matrix -----

                # cm += confusion_matrix(labels2, predicted)
                # ------ showing some of the predictions -----
                # for image, label in zip(inputs, predicted):
                #     for img0 in image.cpu().numpy():
                #         cv2.imshow('image', img0)
                #         cv2.waitKey(100)
                #     print(label.cpu().numpy())

                total += float(labels2.size(0))
                print("ACCURACY METRICS")
                print(labels2)
                print(labels2.size(0))
                print(total)
                correct += float(predicted.eq(labels2).sum().item())
                print(predicted.eq(labels2).sum().item())
                print(correct)

                if batch_idx % 100 == 0:
                    acc = 100. * float(correct) / float(total)
                    print(batch_idx, len(test_data), ' Acc: %.5f' % acc)
        class_names = ['0_ccw', '1_ccw', '2_ccw', '3_ccw', '4_ccw',
                '5_ccw', '6_ccw', '7_ccw', '8_ccw', '9_ccw',
                '0_cw', '1_cw', '2_cw', '3_cw', '4_cw',
                '5_cw', '6_cw', '7_cw', '8_cw', '9_cw']
        plot_confusion_matrix(cm, class_names)
        print('Iters:', real_epoch, '\n\n\n')
        print('Test Accuracy of the model on the 10000 test images: %.3f' % (100 * correct / total))
        acc = 100. * float(correct) / float(total)
        acc_record.append(acc)
        if real_epoch % 5 == 0:
            print(acc)
            print('Saving..')
            state = {
                'net': snn.state_dict(),
                'acc': acc,
                'epoch': epoch,
                'acc_record': acc_record,
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, './checkpoint/ckpt' + names + '.t7')
            best_acc = acc


pprint(subject_accuracy)

torch.Size([180, 28, 28])
torch.Size([420, 28, 28])
torch.Size([180])
torch.Size([420])
(100, 420, 1, 28, 28)
(100, 420, 10)
NUM:  0
tensor([[0.0000, 0.1000, 0.0000,  ..., 0.0000, 0.3000, 0.6000],
        [0.4000, 0.3000, 0.0000,  ..., 0.0000, 0.5000, 0.6000],
        [0.4000, 0.4000, 0.0000,  ..., 0.0000, 0.5000, 0.5000],
        ...,
        [0.5000, 0.0000, 0.0000,  ..., 0.0000, 0.6000, 0.6000],
        [0.6000, 0.5000, 0.0000,  ..., 0.0000, 0.5000, 0.6000],
        [0.4000, 0.2000, 0.0000,  ..., 0.0000, 0.6000, 0.6000]],
       device='cuda:0', grad_fn=<DivBackward0>)
NUM:  1
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.2000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.2000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.2000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.2000, 0.0000, 0.0000]],
       device='

RuntimeError: The size of tensor a (420) must match the size of tensor b (180) at non-singleton dimension 0