# Detection of SSVEP using Canonical Correlation Analysis with K-Nearest Neighbours
---





**Sources**


*   Study: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4610694/
*   Dataset: https://github.com/mnakanishi/12JFPM_SSVEP/tree/master/data
*   Standard CCA Implementation: https://github.com/aaravindravi/Brain-computer-interfaces/blob/master/notebooks/cca_12_class_ssvep.ipynb



In [None]:
!git clone https://github.com/aaravindravi/Brain-computer-interfaces/

Cloning into 'Brain-computer-interfaces'...
remote: Enumerating objects: 127, done.[K
remote: Counting objects: 100% (45/45), done.[K
remote: Compressing objects: 100% (34/34), done.[K
remote: Total 127 (delta 21), reused 26 (delta 10), pack-reused 82[K
Receiving objects: 100% (127/127), 146.36 MiB | 35.64 MiB/s, done.
Resolving deltas: 100% (52/52), done.
Checking out files: 100% (21/21), done.


Import Libraries

In [None]:
!cp Brain-computer-interfaces/scripts/ssvep_utils.py ./


In [None]:
!cp -r ./Brain-computer-interfaces/data ./


# CCA

high-level description of cca

In [None]:
import sys
import os
import math
import numpy as np
import scipy.io as sio


# Helper functions
import ssvep_utils

from sklearn.cross_decomposition import CCA
from sklearn.metrics import confusion_matrix

In [None]:
all_segment_data = dict()
all_acc = list()

window_len = 1
shift_len = 1
sample_rate = 256
duration = int(window_len*sample_rate)


flicker_freq = np.array([9.25, 11.25, 13.25, 9.75, 11.75, 13.75, 
                       10.25, 12.25, 14.25, 10.75, 12.75, 14.75])

Reference Signals Function


In [None]:
def get_reference_signals(data_len, target_freq, sampling_rate) :
  """
  [description] - google python style guide
  """
  reference_signals = []
  t = np.arange(0, (data_len/(sampling_rate)), step=1.0/(sampling_rate))
  reference_signals.append(np.sin(np.pi*2*target_freq*t))
  reference_signals.append(np.cos(np.pi*2*target_freq*t))
  reference_signals.append(np.sin(np.pi*4*target_freq*t))
  reference_signals.append(np.cos(np.pi*4*target_freq*t))
  reference_signals = np.array(reference_signals)

  return reference_signals

Correlation Calculation Function


In [None]:
def calculate_correlation(n_components, np_buffer, freq):
    cca = CCA(n_components)
    corr = np.zeros(n_components)
    result = np.zeros(freq.shape[0])
    for freq_idx in range(0,freq.shape[0]):
        cca.fit(np_buffer.T,np.squeeze(freq[freq_idx, :, :]).T)
        O1_a, O1_b = cca.transform(np_buffer.T, np.squeeze(freq[freq_idx, :, :]).T)
        ind_val = 0
        for ind_val in range(0,n_components):
            corr[ind_val] = np.corrcoef(O1_a[: ,ind_val], O1_b[:, ind_val])[0 ,1]
            result[freq_idx] = np.max(corr)
    
    return result

Data Classification Function

In [None]:
def cca_classify(segmented_data, reference_templates):
    predicted_class = []
    all_coeffs = []
    labels = []
    for target in range(0, segmented_data.shape[0]):
        for trial in range(0, segmented_data.shape[2]):
            for segment in range(0, segmented_data.shape[3]):
                labels.append(target)
                result = calculate_correlation(1, segmented_data[target, :, trial, segment, :], reference_templates)
                # print(result.shape)
                all_coeffs.append(result)
                predicted_class.append(np.argmax(result)+1)

    labels = np.array(labels)+1
    predicted_class = np.array(predicted_class)

    return labels, predicted_class, all_coeffs

Import the Dataset

In [None]:
for subject in np.arange(0, 10):
    dataset = sio.loadmat(f'data/s{subject+1}.mat')
    eeg = np.array(dataset['eeg'], dtype='float32')
    
    # Dimensions
    num_classes = eeg.shape[0]
    num_channels = eeg.shape[1]
    num_sampling_points = eeg.shape[2]
    num_trials = eeg.shape[3]

    filtered_data = ssvep_utils.get_filtered_eeg(eeg, 6, 80, 4, sample_rate)
    all_segment_data[f's{subject+1}'] = ssvep_utils.get_segmented_epochs(filtered_data, window_len, shift_len, sample_rate)

In [None]:
print(all_segment_data['s10'].shape)

(12, 8, 15, 4, 256)


Generate Reference Signals

In [None]:
reference_templates = []
for fr in range(0, len(flicker_freq)):
    reference_templates.append(get_reference_signals(duration, flicker_freq[fr], sample_rate))
reference_templates = np.array(reference_templates, dtype='float32')

In [None]:
print(reference_templates.shape)

(12, 4, 256)


Perform CCA on Segments

In [None]:
# for subject in all_segment_data.keys():
#     labels, predicted_class = cca_classify(all_segment_data[subject], reference_templates)
#     c_mat = confusion_matrix(labels, predicted_class)
#     accuracy = np.divide(np.trace(c_mat), np.sum(np.sum(c_mat)))
#     all_acc.append(accuracy)
#     print(f'Subject: {subject}, Accuracy: {accuracy*100} %')

In [None]:
import itertools
import numpy as np
import matplotlib.pyplot as plt

def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    # print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [None]:
from sklearn.model_selection import train_test_split
from sklearn import neighbors
for v in range(1, 11):
  labels, predicted_class, all_coeffs = cca_classify(all_segment_data[f's{v}'], reference_templates)
  train_data, test_data, train_label, test_labels = train_test_split(all_coeffs, labels, random_state = 42)
  clf = neighbors.KNeighborsClassifier(15, weights='uniform')
  clf.fit(train_data, train_label)
  predictions = clf.predict(test_data)
  conf_matrix = confusion_matrix(predictions, test_labels)
  print(conf_matrix)
  plot_confusion_matrix(conf_matrix, [str(flicker_freq.tolist()[i - 1]) for i in list(set(test_labels))])
  plt.show()
  print(f's{v}: {clf.score(test_data, test_labels)}')
# precision and recall, f1-score

AttributeError: ignored

In [None]:
from sklearn.model_selection import train_test_split
from sklearn import neighbors

from sklearn.metrics import PrecisionRecallDisplay

for v in range(1, 11):
  labels, predicted_class, all_coeffs = cca_classify(all_segment_data[f's{v}'], reference_templates)
  train_data, test_data, train_label, test_labels = train_test_split(all_coeffs, labels, random_state = 42)

  clf = neighbors.KNeighborsClassifier(15, weights='uniform')
  clf.fit(train_data, train_label)
  predictions = clf.predict(test_data)
  print(predictions)
  print(test_labels)

  display = PrecisionRecallDisplay.from_predictions(test_labels, predictions, name="LinearSVC")
  _ = display.ax_.set_title("2-class Precision-Recall curve")

  print(f's{v}: {clf.score(test_data, test_labels)}')
  
# precision and recall, f1-score

AttributeError: ignored