# Predicted SSVEP Stimulus Frequency using CCA

### Importing libraries

In [1]:
import scipy.io as sio
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from sklearn.cross_decomposition import CCA

### Getting Data

In [2]:
mat_contents = sio.loadmat('Data/data/s1.mat')

In [3]:
chan_locs = ['PO7', 'PO3', 'POz', 'PO4', 'PO8', 'O1', 'Oz', 'O2']
eeg_data = mat_contents['eeg']
# 1st dim: 12   -- target
# 2nd dim: 8    -- channels
# 3rd dim: 1114 -- timepoints
# 4th dim: 15   -- trials

### Setting up the filters

In [4]:
def butter_highpass_filter(data, cutoff, nyq, order=5):
    """Butterworth high-pass filter.
    Args:
        data (array_like): data to be filtered.
        cutoff (float): cutoff frequency.
        order (int): order of the filter.
    Returns:
        array: filtered data."""
    normal_cutoff = cutoff / nyq  # normalized cutoff frequency
    b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
    filtered_data = signal.filtfilt(b, a, data)
    return filtered_data

def butter_lowpass_filter(data, cutoff, nyq, order=5):
    """Butterworth low-pass filter.
    Args:
        data (array_like): data to be filtered.
        cutoff (float): cutoff frequency.
        order (int): order of the filter.
    Returns:
        array: filtered data."""
    normal_cutoff = cutoff / nyq  # normalized cutoff frequency
    b, a = signal.butter(order, normal_cutoff, btype='low', analog=False)
    filtered_data = signal.lfilter(b, a, data)
    return filtered_data

# Filter parameters
fps = 256  # sampling frequency
cutoff_high = 6  # cutoff frequency of the high-pass filter
cutoff_low = 50  # cutoff frequency of the low-pass filter
nyq = 0.5 * fps  # Nyquist frequency (half of the sampling frequency).  It represents the highest frequency that can be accurately represented in a discrete-time signal.

# Filter Dataset

In [5]:
# Doing just 9 Hz, 11 Hz, 13 Hz, 14 Hz
# Doing just Oz electrode (index 6)

filtered_epochs = [] # Will contain the all epochs of data
target = [] # Will contain the target stimulus frequency to the corresponding filtered_epoch item
electrode_id = 6 # To only get Oz


labels = [9.25, 11.25, 13.25, 14.25]
stimulus_id = [0, 1, 2, 8]

# Four classes
for i in range(4):

    # Cycle each through trial in the class
    for j in range(15):

        # Temporary array to hold epoch (For 3rd dimension, we segment from the 39th timepoint and onwards since 
        # the 39th timepoint is the onset of the stimulus )
        temp_epoch = np.array(eeg_data[stimulus_id[i], electrode_id, 38:, j]).flatten()

        # apply the band-pass filter
        temp_epoch = butter_highpass_filter(
            data=temp_epoch,
            cutoff=cutoff_high,
            nyq=nyq,
            order=4)

        temp_epoch = butter_lowpass_filter(
            data=temp_epoch,
            cutoff=cutoff_low,
            nyq=nyq,
            order=4)

        # Append the epoch data and target label
        filtered_epochs.append(temp_epoch)
        target.append(labels[i])

### CCA Function

In [6]:
def CCAReferenceSignal(freq, harmonics, sampling_rate, n_samples):

    # Get the time vector
    t = np.arange(n_samples) / sampling_rate

    # temp array to hold reference signals
    reference_signals = []

    for h in range(1, harmonics + 1):
        reference_signals.append(np.sin(2 * np.pi * h * freq * t))
        reference_signals.append(np.cos(2 * np.pi * h * freq * t))
        
    return np.array(reference_signals)

### Generating all reference signals

In [7]:
reference_signals = []
n_timepoints = len(filtered_epochs[0])

# Four classes
for i in range(4):
    reference_signals.append( CCAReferenceSignal(labels[i], 2, fps, n_timepoints))

### Function to calculate the canonical correlation between EEG (X) and a reference signal (Y)

In [14]:
def coeff(x, y):
    x_t = np.transpose(x.reshape(1, -1))
    y_t = np.transpose(y)

    cca = CCA(n_components=1)
    cca.fit_transform(x_t, y_t)
    X_c, Y_c = cca.transform(x_t, y_t)

    return np.corrcoef(X_c.T, Y_c.T)[0, 1]

### Testing if CCA works

In [16]:
test_eeg_data = filtered_epochs[0]
test_label = target[0]
print(f"Actual {test_label}")

correl = []
for i in range(len(reference_signals)):
    calculated_coeff = coeff(test_eeg_data, reference_signals[i])
    correl.append(calculated_coeff)
    print(f'Correlation of {labels[i]} Hz is {calculated_coeff}')

print(f"Predicted {labels[np.argmax(correl)]}")

Actual 9.25
Correlation of 9.25 Hz is 0.416251250724158
Correlation of 11.25 Hz is 0.056156400515336775
Correlation of 13.25 Hz is 0.08764742981687766
Correlation of 14.25 Hz is 0.1106329723705786
Predicted 9.25


## Test CCA for subject 1

In [10]:
# Holds accuracy
acc = 0

# Loops over all epochs
for i in range(len(filtered_epochs)):

    # Hold the features of the epoch (4 features corresponding to the target labels)
    features = []
    for j in reference_signals: # Loop 4 times over labels (9.25, 11.25, 13.25, 14.25)
        features.append(coeff(filtered_epochs[i], j)) 

    print(f"Actual Freq: {target[i]}, Predicted Freq: {labels[np.argmax(features)]}")

    # Increase accuracy if target was successfully found
    if (target[i] == labels[np.argmax(features)]):
        acc += 1

acc = acc/len(filtered_epochs)
print(acc)

Actual Freq: 9.25, Predicted Freq: 9.25
Actual Freq: 9.25, Predicted Freq: 11.25
Actual Freq: 9.25, Predicted Freq: 11.25
Actual Freq: 9.25, Predicted Freq: 13.25
Actual Freq: 9.25, Predicted Freq: 9.25
Actual Freq: 9.25, Predicted Freq: 14.25
Actual Freq: 9.25, Predicted Freq: 9.25
Actual Freq: 9.25, Predicted Freq: 9.25
Actual Freq: 9.25, Predicted Freq: 9.25
Actual Freq: 9.25, Predicted Freq: 9.25
Actual Freq: 9.25, Predicted Freq: 13.25
Actual Freq: 9.25, Predicted Freq: 11.25
Actual Freq: 9.25, Predicted Freq: 11.25
Actual Freq: 9.25, Predicted Freq: 9.25
Actual Freq: 9.25, Predicted Freq: 9.25
Actual Freq: 11.25, Predicted Freq: 11.25
Actual Freq: 11.25, Predicted Freq: 9.25
Actual Freq: 11.25, Predicted Freq: 11.25
Actual Freq: 11.25, Predicted Freq: 11.25
Actual Freq: 11.25, Predicted Freq: 11.25
Actual Freq: 11.25, Predicted Freq: 11.25
Actual Freq: 11.25, Predicted Freq: 11.25
Actual Freq: 11.25, Predicted Freq: 11.25
Actual Freq: 11.25, Predicted Freq: 11.25
Actual Freq: 11.

# Test CCA for all subjects

In [17]:
total_acc = 0
total_count = 0 

for subject in range(1, 11):
    mat = sio.loadmat(f'Data\\data\\s{subject}.mat')
    subject_eeg_data = mat["eeg"]
    # 1st dim: 12   -- target
    # 2nd dim: 8    -- channels
    # 3rd dim: 1114 -- timepoints
    # 4th dim: 15   -- trials

    # Doing just 9 Hz, 11 Hz, 13 Hz, 14 Hz
    # Doing just Oz electrode (index 6)

    filtered_epochs = [] # Will contain the all epochs of data
    target = [] # Will contain the target stimulus frequency to the corresponding filtered_epoch item
    electrode_id = 6 # To only get Oz

    labels = [9.25, 11.25, 13.25, 14.25]
    stimulus_id = [0, 1, 2, 8]

    # Four classes
    for i in range(4):

        # Cycle each through trial in the class
        for j in range(15):

            # Temporary array to hold epoch
            temp_epoch = np.array(subject_eeg_data[stimulus_id[i], electrode_id, 38:, j]).flatten()

            # apply the band-pass filter
            temp_epoch = butter_highpass_filter(
                data=temp_epoch,
                cutoff=cutoff_high,
                nyq=nyq,
                order=4)

            temp_epoch = butter_lowpass_filter(
                data=temp_epoch,
                cutoff=cutoff_low,
                nyq=nyq,
                order=4)

            # Append the epoch data and target label
            filtered_epochs.append(temp_epoch)
            target.append(labels[i])

    # Holds accuracy
    acc = 0

    # Loops over all epochs
    for i in range(len(filtered_epochs)):
        total_count += 1

        # Hold the features of the epoch (4 features corresponding to the target labels)
        features = []
        for j in reference_signals: # Loop 4 times over labels (9.25, 11.25, 13.25, 14.25)
                features.append(coeff(filtered_epochs[i], j)) 

        # print(f"Actual Freq: {target[i]}, Predicted Freq: {labels[np.argmax(features)]}")

        # Increase accuracy if target was successfully found
        if (target[i] == labels[np.argmax(features)]):
            acc += 1
            total_acc += 1
            
    acc = acc/len(filtered_epochs)
    print(f'Subject {subject} accuracy: {acc}')

total_acc = total_acc / total_count
print(f'Total accuracy: {total_acc}')

Subject 1 accuracy: 0.6333333333333333
Subject 2 accuracy: 0.5333333333333333
Subject 3 accuracy: 0.9166666666666666
Subject 4 accuracy: 0.9833333333333333
Subject 5 accuracy: 0.9166666666666666
Subject 6 accuracy: 0.9833333333333333
Subject 7 accuracy: 0.9666666666666667
Subject 8 accuracy: 1.0
Subject 9 accuracy: 0.8833333333333333
Subject 10 accuracy: 0.95
Total accuracy: 0.8766666666666667
