In [None]:
import numpy as np
import pandas as pd
import os
import re
from scipy.signal import cheby1, filtfilt, freqz
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score

data_dir = "data/"
sampling_rate = 125  # Hz
num_samples = 188
num_channels = 8
class_freqs = [6.67, 8.57, 10.0, 12.0]
timestamps = np.arange(num_samples) / sampling_rate

pattern = re.compile(r"block_(\d+)_(\d+)\.csv")

all_data = []
trial_labels = []

for filename in os.listdir(data_dir):
    match = pattern.match(filename)
    if match:
        trial_number = int(match.group(1))
        class_number = int(match.group(2))
        
        file_path = os.path.join(data_dir, filename)
        df = pd.read_csv(file_path, header=None)
        
        eeg_data = df.values.T  # Transpose to get (channels, samples)
        
        all_data.append(eeg_data)
        trial_labels.append(class_number)

all_data = np.array(all_data)  # (num_trials, num_channels, num_samples)
trial_labels = np.array(trial_labels)  # labels (1-indexed)

print(f"Loaded data shape: {all_data.shape}")
print(f"Number of trials: {len(trial_labels)}")

def plot_filter_response(fs, num_fbs=5):
    plt.figure(figsize=(12, 8))
    
    w = np.linspace(0, fs/2, 1000)
    
    nyq = fs / 2
    
    passband = [4, 8, 10, 12, 14]
    high_freq = 50
    
    order = 4
    rp = 0.5
    
    for fb_i in range(1, num_fbs + 1):
        low_freq = passband[fb_i-1]
        b, a = cheby1(order, rp, [low_freq / nyq, high_freq / nyq], btype='bandpass')
        
        w_normalized = w / nyq
        w_normalized = np.clip(w_normalized, 0, 1)
        _, h = freqz(b, a, worN=w_normalized)
        
        plt.plot(w, 20 * np.log10(abs(h)), label=f'Filter {fb_i}: {low_freq}-{high_freq} Hz')
    
    plt.title('Filter Bank Magnitude Response')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Magnitude (dB)')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

def filterbank(eeg, fs, idx_fb=1, plot_signals=False, trial_idx=None):
    num_chans, num_samples = eeg.shape
    
    passband = [4, 8, 10, 12, 14]
    
    nyq = fs / 2
    low_freq = passband[idx_fb-1] 
    
    high_freq = 50
    
    order = 4
    rp = 0.5
    
    b, a = cheby1(order, rp, [low_freq / nyq, high_freq / nyq], btype='bandpass')
    
    y = np.zeros((num_chans, num_samples))
    
    for ch_i in range(num_chans):
        y[ch_i, :] = filtfilt(b, a, eeg[ch_i, :])
    
    if plot_signals and trial_idx is not None:
        num_chans_to_plot = num_chans
        
        plt.figure(figsize=(14, 10))
        
        for i in range(num_chans_to_plot):
            plt.subplot(num_chans_to_plot, 2, 2*i+1)
            plt.plot(timestamps, eeg[i, :])
            plt.title(f'Original EEG - Channel {i+1}')
            plt.xlabel('Time (s)')
            plt.ylabel('Amplitude')
            plt.grid(True)
        
        for i in range(num_chans_to_plot):
            plt.subplot(num_chans_to_plot, 2, 2*i+2)
            plt.plot(timestamps, y[i, :])
            plt.title(f'Filtered EEG (FB {idx_fb}: {low_freq}-{high_freq} Hz) - Channel {i+1}')
            plt.xlabel('Time (s)')
            plt.ylabel('Amplitude')
            plt.grid(True)
        
        plt.suptitle(f'Trial {trial_idx}: Original vs. Filtered EEG Signals')
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.show()
    
    return y


def cca_reference(list_freqs, fs, num_smpls, num_harms=3, plot_refs=False):
    num_freqs = len(list_freqs)
    tidx = np.arange(1, num_smpls + 1) / fs
    
    y_ref = np.zeros((num_freqs, 2 * num_harms, num_smpls))
    
    for freq_i in range(num_freqs):
        tmp = np.zeros((2 * num_harms, num_smpls))
        
        for harm_i in range(1, num_harms + 1):
            stim_freq = list_freqs[freq_i]
            tmp[2 * (harm_i - 1), :] = np.sin(2 * np.pi * harm_i * stim_freq * tidx)
            tmp[2 * (harm_i - 1) + 1, :] = np.cos(2 * np.pi * harm_i * stim_freq * tidx)
        
        y_ref[freq_i, :, :] = tmp
    
    if plot_refs:
        plt.figure(figsize=(14, 12))
        
        for freq_i, freq in enumerate(list_freqs):
            plt.subplot(len(list_freqs), 1, freq_i + 1)
            
            plt.plot(tidx, y_ref[freq_i, 0, :], label=f'Sin 1st harmonic')
            plt.plot(tidx, y_ref[freq_i, 1, :], label=f'Cos 1st harmonic')
            
            if num_harms >= 2:
                plt.plot(tidx, y_ref[freq_i, 2, :], '--', label=f'Sin 2nd harmonic')
                plt.plot(tidx, y_ref[freq_i, 3, :], '--', label=f'Cos 2nd harmonic')
            
            plt.title(f'Reference Signals for {freq} Hz')
            plt.xlabel('Time (s)')
            plt.ylabel('Amplitude')
            plt.grid(True)
            plt.legend()
        
        plt.tight_layout()
        plt.show()
    
    return y_ref


def canoncorr(X, Y):
    X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
    Y = (Y - np.mean(Y, axis=0)) / np.std(Y, axis=0)
    
    Cxx = np.cov(X, rowvar=False)
    Cyy = np.cov(Y, rowvar=False)
    Cxy = np.cov(X, Y, rowvar=False)[:X.shape[1], X.shape[1]:]
    
    # Regularization to avoid singular matrices
    Cxx = Cxx + np.eye(Cxx.shape[0]) * 1e-8
    Cyy = Cyy + np.eye(Cyy.shape[0]) * 1e-8
    
    # Calculate canonical correlations
    inv_Cxx = np.linalg.inv(Cxx)
    inv_Cyy = np.linalg.inv(Cyy)
    
    # Matrix for eigenvalue problem
    M = np.dot(np.dot(inv_Cxx, Cxy), np.dot(inv_Cyy, Cxy.T))
    
    # Eigenvalues are squares of canonical correlations
    eigvals = np.linalg.eigvals(M)
    r = np.sqrt(np.max(np.real(eigvals)))
    
    return r


def test_fbcca(eeg, list_freqs, fs, num_harms=4, num_fbs=5, visualize=False, visualize_trial_idx=None):
    fb_coefs = np.power(np.arange(1, num_fbs + 1), -1.25) + 0.25
    num_targs, num_chans, num_smpls = eeg.shape
    
    y_ref = cca_reference(list_freqs, fs, num_smpls, num_harms, plot_refs=visualize)
    
    results = np.zeros(num_targs, dtype=int)
    
    for targ_i in range(num_targs):
        test_tmp = eeg[targ_i, :, :]
        r = np.zeros((num_fbs, len(list_freqs)))
        
        vis_current = visualize and (visualize_trial_idx is None or targ_i == visualize_trial_idx)
        
        for fb_i in range(num_fbs):
            testdata = filterbank(test_tmp, fs, fb_i + 1, 
                                 plot_signals=vis_current and fb_i == 0,  # Only plot for first filter bank
                                 trial_idx=targ_i)
            
            for class_i in range(len(list_freqs)):
                refdata = y_ref[class_i, :, :]
                r_tmp = canoncorr(testdata.T, refdata.T)
                r[fb_i, class_i] = r_tmp
        
        if vis_current:
            plot_cca_correlations(r, fb_coefs, list_freqs)
            
        rho = np.dot(fb_coefs, r)
        tau = np.argmax(rho)
        results[targ_i] = tau + 1  # 1-indexed 
    
    return results


def plot_cca_correlations(r_matrix, fb_coefs, class_freqs):
    num_fbs, num_freqs = r_matrix.shape
    
    plt.figure(figsize=(12, 8))
    
    for fb_i in range(num_fbs):
        plt.plot(range(num_freqs), r_matrix[fb_i, :], 'o-', 
                 label=f'Filter Bank {fb_i+1} (weight={fb_coefs[fb_i]:.2f})')
    
    weighted_sum = np.dot(fb_coefs, r_matrix)
    plt.plot(range(num_freqs), weighted_sum, 'ks-', linewidth=2, 
             label=f'Weighted Sum (max at index {np.argmax(weighted_sum)})')
    
    max_idx = np.argmax(weighted_sum)
    plt.axvline(x=max_idx, color='red', linestyle='--')
    
    plt.title('CCA Correlation Coefficients')
    plt.xlabel('Target Frequency Index')
    plt.xticks(range(num_freqs), [f'{freq} Hz' for freq in class_freqs])
    plt.ylabel('Correlation Coefficient')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

def main():
    num_harms = 3
    num_fbs = 5
    
    plot_filter_response(sampling_rate, num_fbs)
    
    class_examples = {}
    for i, label in enumerate(trial_labels):
        if label not in class_examples:
            class_examples[label] = i
    
    print("\nVisualizing one example from each class:")
    for class_label, trial_idx in class_examples.items():
        print(f"Class {class_label} ({class_freqs[class_label-1]} Hz) - Trial index {trial_idx}")
        single_trial = all_data[trial_idx:trial_idx+1]
        pred = test_fbcca(single_trial, class_freqs, sampling_rate, num_harms, num_fbs, 
                         visualize=True, visualize_trial_idx=0)
        print(f"  Predicted: Class {pred[0]} ({class_freqs[pred[0]-1]} Hz)")
    
    print("\nProcessing all trials...")
    predicted_classes = test_fbcca(all_data, class_freqs, sampling_rate, num_harms, num_fbs)
    
    accuracy = accuracy_score(trial_labels, predicted_classes)
    print(f"Classification accuracy: {accuracy:.4f}")
    
    cm = confusion_matrix(trial_labels, predicted_classes)
    
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    
    classes = [f"{freq} Hz (Class {i+1})" for i, freq in enumerate(class_freqs)]
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    
    thresh = cm.max() / 2
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'),
                     ha="center", va="center",
                     color="white" if cm[i, j] > thresh else "black")
    
    plt.tight_layout()
    plt.ylabel('True Class')
    plt.xlabel('Predicted Class')
    plt.show()
    
    class_accuracies = []
    for i in range(len(class_freqs)):
        class_idx = i + 1
        class_trials = np.where(trial_labels == class_idx)[0]
        class_correct = np.sum(predicted_classes[class_trials] == class_idx)
        class_acc = class_correct / len(class_trials) if len(class_trials) > 0 else 0
        class_accuracies.append(class_acc)
        print(f"Class {class_idx} ({class_freqs[i]} Hz) accuracy: {(class_acc * 100):.2f} %")
    
if __name__ == "__main__":
    main()