In [14]:
import sys
import os

sys.path.insert(0, os.path.abspath('..'))

%config Completer.use_jedi = False

In [15]:
%%capture
%load_ext autoreload
%autoreload 2
import warnings
import numpy as np
import numpy.matlib as npm
# import pandas as pd
import matplotlib.pyplot as plt
import scipy.io as sio
import pickle

import matplotlib.pyplot as plt

from keras.utils.np_utils import to_categorical
from keras import optimizers
from keras.losses import categorical_crossentropy

from bcilib.ssvep_utils_pytorch import CNN, RasterizeSlice, CustomTensorDataset
from bcilib import ssvep_utils as su 
from torch.utils.data import TensorDataset, DataLoader, random_split
# from snnlib import snn_utils

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
import torch
import tqdm

warnings.filterwarnings('ignore')

In [16]:
torch.cuda.is_available()
torch.cuda.device_count()

1

In [17]:
input_shape = (8, 220, 1)

In [18]:
def get_training_data(features_data):
    features_data = np.reshape(features_data, (features_data.shape[0], features_data.shape[1], 
                                               features_data.shape[2], 
                                               features_data.shape[3]*features_data.shape[4]))
    train_data = features_data[:, :, 0, :].T
    for target in range(1, features_data.shape[2]):
        train_data = np.vstack([train_data, np.squeeze(features_data[:, :, target, :]).T])

    train_data = np.reshape(train_data, (train_data.shape[0], train_data.shape[1], 
                                         train_data.shape[2], 1))
    total_epochs_per_class = features_data.shape[3]
    features_data = []
    class_labels = np.arange(CNN_PARAMS['num_classes'])
    labels = (npm.repmat(class_labels, total_epochs_per_class, 1).T).ravel()
    labels = to_categorical(labels)
    
    return train_data, labels

In [19]:
data_path = os.path.abspath('data/original')

CNN_PARAMS = {
    'batch_size': 64,
    'epochs': 50,
    'droprate': 0.25,
    'learning_rate': 0.001,
    'lr_decay': 0.0,
    'l2_lambda': 0.0001,
    'momentum': 0.9,
    'kernel_f': 10,
    'n_ch': 8,
    'num_classes': 12
}

FFT_PARAMS = {
    'resolution': 0.2930,
    'start_frequency': 3.0,
    'end_frequency': 35.0,
    'sampling_rate': 256
}

window_len = 1
shift_len = 1
    
all_acc = np.zeros((10, 1))

magnitude_spectrum_features = dict()
complex_spectrum_features = dict()

mcnn_training_data = dict()
ccnn_training_data = dict()

mcnn_results = dict()
ccnn_results = dict()

In [20]:
all_segmented_data = dict()
for subject in range(0, 10):
    dataset = sio.loadmat(f'{data_path}/s{subject+1}.mat')
    eeg = np.array(dataset['eeg'], dtype='float32')
    
    CNN_PARAMS['num_classes'] = eeg.shape[0]
    CNN_PARAMS['n_ch'] = eeg.shape[1]
    total_trial_len = eeg.shape[2]
    num_trials = eeg.shape[3]
    sample_rate = 256

    filtered_data = su.get_filtered_eeg(eeg, 6, 80, 4, sample_rate)
    all_segmented_data[f's{subject+1}'] = su.get_segmented_epochs(filtered_data, window_len, 
                                                                  shift_len, sample_rate)

In [21]:
for subject in all_segmented_data.keys():
    complex_spectrum_features[subject] = su.complex_spectrum_features(all_segmented_data[subject], 
                                                                      FFT_PARAMS)

In [22]:
for subject in all_segmented_data.keys():
#    mcnn_training_data[subject] = dict()
    ccnn_training_data[subject] = dict()
#     train_data, labels = get_training_data(magnitude_spectrum_features[subject])
#     mcnn_training_data[subject]['train_data'] = train_data
#     mcnn_training_data[subject]['label'] = labels
    
    train_data, labels = get_training_data(complex_spectrum_features[subject])
    
    print(train_data.max())
    print(train_data.min())
    
    train_data = np.power(((train_data - train_data.min() + 1) / (train_data.max() - train_data.min())) * 100, 3)
    
    print(train_data.max())
    print(train_data.min())
    print()
    
    ccnn_training_data[subject]['train_data'] = train_data
    ccnn_training_data[subject]['label'] = labels

4.453436744813788
-4.542808724004553
1371913.8662493997
1373.4602953441113

4.697890353148152
-5.587120200892877
1320966.1341687553
919.1486950344613

4.24294748232967
-3.9931721341508273
1410264.9429553908
1789.9135828709188

6.390093494479364
-7.7297572992609656
1227869.4496460753
355.23003409219797

21.082749462733876
-18.468665927927102
1077784.5713728983
16.162699947894495

6.618154987289777
-5.988042252342181
1257355.231914684
499.1691147218926

4.193549156789444
-4.533442971943087
1384656.1550238773
1504.5486911991609

4.8637124388793245
-5.111515517334839
1331901.6618140617
1007.468584966449

4.7738970955208275
-4.780804938509801
1347989.4343764812
1146.4326759918458

6.531397488715976
-6.482075939864547
1248698.7994694596
453.75383755283235



In [23]:
def test(model, data_loader, num_batches=None):
    model.eval()
    correct = 0
    batch_count = 0
    
    with torch.no_grad():
        # model.spiking_model.network[0].weight *= 10
        
        # Iterate over data
        pbar = tqdm.notebook.tqdm(data_loader)
        for data, target in pbar:
            #if data_loader.dataset.spiking:
            if data_loader.dataset:
                if len(data.size()) > 4:
                    warnings.warn("Warning: Batch size needs to be 1, only first sample used.", stacklevel=2)
                    data = data[0]
                    target = target[0]
            output = model(data)
            #print(output)
            #if data_loader.dataset.spiking:
            if data_loader.dataset:
                output = output.sum(0).squeeze().unsqueeze(0)
                target = target.unsqueeze(0)
            
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            # Compute the total correct predictions
            correct += target[0][pred.item()].item() == 1
            
            batch_count += 1
            if (batch_count*data_loader.batch_size)%500 == 0:
                pbar.set_postfix({"Accuracy" : correct/(batch_count*data_loader.batch_size)})
            if num_batches:
                if num_batches <= batch_count: break;

    # Total samples:
    num_data = (batch_count*data_loader.batch_size)

    print(f'Test set: Accuracy: {correct}/{num_data} ({100. * correct / num_data}%)\n'.format(correct, num_data,
        ))
    return correct / num_data

In [24]:
from sinabs.from_torch import from_model
from bcilib.ssvep_utils_pytorch import CNN, RasterizeSlice, CustomTensorDataset, lam, div
c_scnn_acc_list = []


print("Lam: {} | Div: {}".format(lam, div))
for subject, i in zip(ccnn_training_data.keys(), range(len(ccnn_training_data))):
    cnn_model = torch.load("models/cnn-s{}.h5".format(i))
    sinabs_model = from_model(
        cnn_model,
        input_shape = input_shape,
        add_spiking_output = True,
        synops = True
    )
    
    sinabs_model.to(torch.device("cpu"))
    sinabs_model.float()
    
    print(f'\nC-SCNN - Subject: {subject}')
    test_data = ccnn_training_data[subject]['train_data']
    labels = ccnn_training_data[subject]['label']
    
#     labels = np.argmax(labels, axis=1)
    
     # transform to torch tensor
    tensor_x = torch.from_numpy(test_data).float()
    tensor_y = torch.from_numpy(labels.astype(int))
    
    print("INPUT SHAPE: ", tensor_x.shape)
    print("LABEL SHAPE: ", tensor_y.shape)
    
    tensor_dataset = CustomTensorDataset(tensor_x, tensor_y, transform = RasterizeSlice())
    
    # frame_dataset = ShapesNpzDataset("/home/nogay/Desktop/frame_dataset/dataset.npz", target_transform=int)
    train_size = int(0.6 * len(tensor_dataset))
    val_size = int(0.2 * len(tensor_dataset))
    test_size = len(tensor_dataset) - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = random_split(
        tensor_dataset,
        [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    dataloader_train = DataLoader(train_dataset, shuffle=True, num_workers=4, batch_size=1)
    dataloader_val = DataLoader(val_dataset, shuffle=False, num_workers=4, batch_size=1)
    dataloader_test = DataLoader(test_dataset, shuffle=False, num_workers=4, batch_size=1)
    
    early_stopping = EarlyStopping('train_loss', patience=10, mode='min')
#     trainer = pl.Trainer(gpus=1, 
#                          max_epochs=200, 
#                          enable_pl_optimizer=True,
#                          callbacks=[early_stopping])
    
    #trainer.fit(model, dataloader_train, val_dataloaders=dataloader_val)
    c_scnn_acc_list.append(test(sinabs_model, dataloader_test, num_batches=200))
    


Lam: 1 | Div: 12

C-SCNN - Subject: s1
INPUT SHAPE:  torch.Size([720, 8, 220, 1])
LABEL SHAPE:  torch.Size([720, 12])


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=144.0), HTML(value='')))


Test set: Accuracy: 74/144 (51.388888888888886%)


C-SCNN - Subject: s2
INPUT SHAPE:  torch.Size([720, 8, 220, 1])
LABEL SHAPE:  torch.Size([720, 12])


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=144.0), HTML(value='')))


Test set: Accuracy: 48/144 (33.333333333333336%)


C-SCNN - Subject: s3
INPUT SHAPE:  torch.Size([720, 8, 220, 1])
LABEL SHAPE:  torch.Size([720, 12])


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=144.0), HTML(value='')))


Test set: Accuracy: 126/144 (87.5%)


C-SCNN - Subject: s4
INPUT SHAPE:  torch.Size([720, 8, 220, 1])
LABEL SHAPE:  torch.Size([720, 12])


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=144.0), HTML(value='')))


Test set: Accuracy: 140/144 (97.22222222222223%)


C-SCNN - Subject: s5
INPUT SHAPE:  torch.Size([720, 8, 220, 1])
LABEL SHAPE:  torch.Size([720, 12])


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=144.0), HTML(value='')))


Test set: Accuracy: 143/144 (99.30555555555556%)


C-SCNN - Subject: s6
INPUT SHAPE:  torch.Size([720, 8, 220, 1])
LABEL SHAPE:  torch.Size([720, 12])


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=144.0), HTML(value='')))


Test set: Accuracy: 140/144 (97.22222222222223%)


C-SCNN - Subject: s7
INPUT SHAPE:  torch.Size([720, 8, 220, 1])
LABEL SHAPE:  torch.Size([720, 12])


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=144.0), HTML(value='')))


Test set: Accuracy: 104/144 (72.22222222222223%)


C-SCNN - Subject: s8
INPUT SHAPE:  torch.Size([720, 8, 220, 1])
LABEL SHAPE:  torch.Size([720, 12])


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=144.0), HTML(value='')))


Test set: Accuracy: 141/144 (97.91666666666667%)


C-SCNN - Subject: s9
INPUT SHAPE:  torch.Size([720, 8, 220, 1])
LABEL SHAPE:  torch.Size([720, 12])


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=144.0), HTML(value='')))


Test set: Accuracy: 134/144 (93.05555555555556%)


C-SCNN - Subject: s10
INPUT SHAPE:  torch.Size([720, 8, 220, 1])
LABEL SHAPE:  torch.Size([720, 12])


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=144.0), HTML(value='')))


Test set: Accuracy: 117/144 (81.25%)



In [25]:
from functools import reduce
from pprint import pprint

print("lam:{}; div:{}".format(lam, div))
for result in c_scnn_acc_list:
    print(result)
    
print("\nOverall: {}".format(reduce(lambda x, y: x + y, c_scnn_acc_list) / len(c_scnn_acc_list)))

lam:1; div:12
0.5138888888888888
0.3333333333333333
0.875
0.9722222222222222
0.9930555555555556
0.9722222222222222
0.7222222222222222
0.9791666666666666
0.9305555555555556
0.8125

Overall: 0.8104166666666668


In [13]:
if True:
    from bcilib.ssvep_utils_pytorch import CNN, RasterizeSlice, CustomTensorDataset, lam, div
    import matplotlib.pyplot as plt
    for subject in list(ccnn_training_data.keys())[5:]:
        print(f'\nC-SCNN - Subject: {subject}')
        test_data = ccnn_training_data[subject]['train_data']
        labels = ccnn_training_data[subject]['label']

    #     labels = np.argmax(labels, axis=1)

         # transform to torch tensor
        tensor_x = torch.from_numpy(test_data).float()
        tensor_y = torch.from_numpy(labels.astype(int))

        tensor_dataset = CustomTensorDataset(tensor_x, tensor_y, transform = RasterizeSlice())
        dataloader = DataLoader(tensor_dataset, shuffle=True, num_workers=4, batch_size=1)

        data_point_count = 0
        total_mean = 0

        for d in dataloader:
            print(d[0].shape)
            data_point_count += 1
            total_mean += d[0].mean()
            # Label
            # print(d[1].shape)

            s_data = np.zeros((8, 220))

            for j in range(d[0].shape[2]):
                for i in range(d[0].shape[3]):
                    for k in range(d[0].shape[1]):
                        if d[0][0,k,j,i,0] == 1:
                            s_data[j, i] += 1
            statistics = "min:{}, mean:{:.2f}, max:{}".format(s_data.min(), s_data.mean(), s_data.max())
            print(statistics)
            fig, ax = plt.subplots()
            # Setting the labels of x axis.
            # set the xticks as student-names
            # rotate the labels by 90 degree to fit the names
            plt.xticks(ticks=np.arange(221),rotation=90)
            # Setting the labels of y axis.
            # set the xticks as subject-names
            plt.yticks(ticks=np.arange(9))
            # use the imshow function to generate a heatmap
            # cmap parameter gives color to the graph
            # setting the interpolation will lead to different types of graphs
            # plt.imshow(s_data, cmap='cool',interpolation="nearest")
            cmap = plt.cm.jet
            norm = plt.Normalize(vmin=s_data.min(), vmax=s_data.max())
            image = cmap(norm(s_data))
            ax.imshow(s_data, cmap=cmap, interpolation="none")
            fig.savefig("docs/heatmap-poisson/{}-{}--l:{},d:{}--{}.png".format(subject, data_point_count - 1, lam, div, statistics), dpi=600, bbox_inches="tight")
            plt.show()

        print("Spike ration: ", total_mean / data_point_count)


C-SCNN - Subject: s6
torch.Size([1, 200, 8, 220, 1])


KeyboardInterrupt: 