### GPU Checking

In [None]:
import tensorflow as tf

# Check if GPU is available
print("Num GPUs Available:", len(tf.config.experimental.list_physical_devices('GPU')))

# Set memory growth to avoid TensorFlow consuming all GPU memory
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("GPU memory growth set")
    except RuntimeError as e:
        print(e)

### Import and package install

In [None]:
import numpy as np # for dealing with data
from scipy.signal import butter, sosfiltfilt, sosfreqz  # for filtering
import matplotlib.pyplot as plt                         # for plotting
import seaborn as sns
import os
from os import listdir
from os.path import isfile, join, isdir

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras import backend as K
from tensorflow.keras import optimizers
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint

import pandas as pd
import cupy as cp
import cudf
import seaborn as sns
import random
import warnings

# METRIC
from sklearn.metrics import (roc_curve, precision_score, recall_score,
                             f1_score, balanced_accuracy_score, auc)
from cuml.metrics import accuracy_score, roc_auc_score, confusion_matrix

# OVERSAMPLING
from imblearn.over_sampling import SMOTE

# MODEL
from cuml.linear_model import LogisticRegression as cuLR
from cuml.svm import SVC as cuSVC
from cuml.ensemble import RandomForestClassifier as cuRF
from lightgbm import LGBMClassifier
from sklearn.neural_network import MLPClassifier

from scipy.interpolate import interp1d

!pip install PyWavelets
!pip install pyriemann

### Filter setting

In [None]:
fs = 200.0     # 200 Hz sampling rate
lowcut = 1.0
highcut = 50.0

#! For butterworth band pass filter
def butter_bandpass_filter(raw_data, fs, lowcut = 1.0, highcut = 40.0, order = 5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    sos = butter(order, [low, high], analog = False, btype = 'band', output = 'sos')
    filted_data = sosfiltfilt(sos, raw_data)
    return filted_data

### Config

In [None]:
epoch_s = -100      # epoch starting time relative to stmulus in miliseconds
epoch_e = 600    # epoch ending time relative to stmulus in miliseconds
bl_s = -100         # baseline starting time relative to stmulus in miliseconds
bl_e = 0       # baseline ending time relative to stmulus in miliseconds
epoch_len = int((abs(epoch_s) + abs(epoch_e)) * (fs / 1000))
print(epoch_len)

train_subj_num = 16
test_subj_num = 10
stimulus_per_subj = 340
trial_per_subj = 5

channels = [
    'Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8',
    'FT7', 'FC3', 'FCz', 'FC4', 'FT8',
    'T7', 'C3', 'Cz', 'C4', 'T8',
    'TP7', 'CP3', 'CPz', 'CP4', 'TP8',
    'P7', 'P3', 'Pz', 'P4', 'P8',
    'O1', 'POz', 'O2'
]
print(len(channels))

: 

### Load data

In [None]:
train_data_list = np.load('./data/train_data.npy')
test_data_list = np.load('./data/test_data.npy')
print('Epoched training data shape: ' + str(train_data_list.shape)) #! (16, 340, 56, 140)
print('Epoched training data: ' , train_data_list)
print('Epoched testing data shape: ' + str(test_data_list.shape)) #! (10, 340, 56, 140)
print('Epoched testing data: ' , test_data_list)

Y_train_valid = pd.read_csv('data/TrainLabels.csv')['Prediction'].values
Y_true_labels = pd.read_csv('./data/true_labels.csv', header=None).values


### Reshape data

In [None]:
subj, trial, numChannel, sample = train_data_list.shape #! (16, 340, 30, 140)
X_train_valid = np.reshape(train_data_list, (-1, 1, numChannel, sample)) #! (5440, 1, 30, 140)
X_test = np.reshape(test_data_list, (-1, 1, numChannel, sample)) #! (3400, 1, 30, 140)

Y_true_Labels_for_test = np.reshape(Y_true_labels, 3400)

print('subject: ', subj)
print('trial: ', trial)
print('numChannel: ', numChannel)
print('sample: ', sample)
print('X_train_valid: ', X_train_valid.shape)
print('X_test: ', X_test.shape)
print('Y_train_valid: ', Y_train_valid.shape)
print('Y_true_Labels_for_test: ', Y_true_Labels_for_test.shape) #! (3400,)

### Mix precision

In [None]:
from tensorflow.keras import mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

### Time tracker

In [None]:
from time import time
from contextlib import contextmanager

@contextmanager
def timed_block(name="Block"):
    start = time()
    yield
    end = time()
    print(f"🕒 [{name}] 花費時間: {end - start:.2f} 秒")

### Feature extraction

In [None]:
from scipy.signal import stft
import pywt
def extract_stft_wavelet_features(data, sampling_rate=256, wavelet='db4', stft_nperseg=64):
    """
    結合 STFT + Wavelet Transform 的 EEG 特徵擷取

    Parameters:
        data (np.ndarray): EEG 資料 shape (n_subjects, n_channels, n_samples)
        sampling_rate (int): EEG 的取樣頻率
        wavelet (str): 小波種類，預設為 Daubechies 4 ('db4')
        stft_nperseg (int): STFT 每個窗格的樣本數

    Returns:
        np.ndarray: 每位受試者的展平特徵 shape (n_subjects, total_features)
    """
    n_subjects, n_channels, n_samples = data.shape
    all_features = []

    for subject in range(n_subjects):
        # 長度 = n_channels × (n_freq × 3 + wavelet_level × 3)
        subject_features = []

        for channel in range(n_channels):
            signal = data[subject, channel, :]

            # === STFT 特徵 ===
            # Zxx 是 STFT 複數值頻譜：shape 是 (frequencies, time_segments)
            f, t, Zxx = stft(signal, fs=sampling_rate, nperseg=stft_nperseg)
            stft_magnitude = np.abs(Zxx) # 每個時間窗與頻率點的 magnitude（能量強度）
            stft_mean = np.mean(stft_magnitude, axis=1)  # 各頻率的平均 magnitude（跨時間平均）
            stft_std = np.std(stft_magnitude, axis=1) # 各頻率的標準差（穩定性/變化性）
            stft_energy = np.sum(stft_magnitude ** 2, axis=1) # 各頻率的能量
            subject_features.extend(stft_mean)
            subject_features.extend(stft_std)
            subject_features.extend(stft_energy)

            # === Wavelet 特徵 ===
            coeffs = pywt.wavedec(signal, wavelet, level=3)  # 多層分解，level 設 3 怕會有 boundary effects
            for c in coeffs:
                subject_features.append(np.mean(c))
                subject_features.append(np.std(c))
                subject_features.append(np.sum(np.square(c)))  # energy

        all_features.append(subject_features)
    # 輸出 shape 是 (n_subjects, total_features)
    return np.array(all_features)

def extract_stft_features(data, sampling_rate=256, stft_nperseg=64):
    """
    只使用 STFT 的 EEG 特徵擷取

    Parameters:
        data (np.ndarray): EEG 資料 shape (n_subjects, n_channels, n_samples)
        sampling_rate (int): EEG 的取樣頻率
        stft_nperseg (int): STFT 每個窗格的樣本數

    Returns:
        np.ndarray: 每位受試者的展平特徵 shape (n_subjects, total_features)
    """
    n_subjects, n_channels, n_samples = data.shape
    all_features = []

    for subject in range(n_subjects):
        subject_features = []

        for channel in range(n_channels):
            signal = data[subject, channel, :]

            # === STFT 特徵 ===
            f, t, Zxx = stft(signal, fs=sampling_rate, nperseg=stft_nperseg)
            stft_magnitude = np.abs(Zxx)
            stft_log = np.log1p(stft_magnitude)
            stft_mean = np.mean(stft_log, axis=1)
            stft_std = np.std(stft_log, axis=1)
            stft_energy = np.sum(stft_log ** 2, axis=1)

            subject_features.extend(stft_mean)
            subject_features.extend(stft_std)
            subject_features.extend(stft_energy)

        all_features.append(subject_features)

    return np.array(all_features)

from scipy.fft import fft
def fft_eeg(data, sampling_rate=256):
  """
  Parameters:
      data (numpy.ndarray): 3D EEG data of shape (n_subjects, n_channel, n_samples)
      sampling_rate (int): Sampling rate of the EEG data (default: 256 Hz)

  Returns:
      numpy.ndarray: Flattened FFT features of shape (n_subjects, n_channel * n_freq_bins)
  """
  # 確保輸入數據是三維
  assert len(data.shape) == 3, "Input data must be 3-dimensional (n_subjects, n_channel, n_samples)"
  n_subjects, n_channels, n_samples = data.shape

  # 頻率範圍
  freqs = np.fft.fftfreq(n_samples, d=1/sampling_rate)
  positive_freqs = freqs[:n_samples // 2]  # 只保留正頻率部分
  n_freq_bins = len(positive_freqs)

  # 初始化結果矩陣 (n_subjects, n_channels * n_freq_bins)
  fft_features = np.zeros((n_subjects, n_channels * n_freq_bins))

  for subject in range(n_subjects):
      feature_list = []
      for channel in range(n_channels):
          # 計算 FFT，取絕對值並保留正頻率部分
          fft_values = np.abs(fft(data[subject, channel, :]))[:n_samples // 2]
          # 標準化特徵
          fft_values /= np.sum(fft_values)  # 能量歸一化
          # 添加到特徵列表
          feature_list.extend(fft_values)

      # 將每個受試者的特徵轉為一維向量
      fft_features[subject, :] = feature_list

  return fft_features

from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
def preprocess_riemann_features(
    train_data, test_data, Y_train, nfilter=5):
    """
    Preprocess EEG data using Xdawn + Tangent Space

    Parameters:
        train_data (np.ndarray): Training data of shape (n_subjects, n_channels, n_samples)
        test_data (np.ndarray): Testing data of shape (n_subjects, n_channels, n_samples)
        nfilter (int): Number of Xdawn spatial filters
    """
    # Apply Xdawn and Tangent Space
    XC = XdawnCovariances(nfilter=nfilter)
    TS = TangentSpace(metric='riemann')

    X_train = XC.fit_transform(train_data, Y_train)
    X_train = TS.fit_transform(X_train)

    X_test = XC.transform(test_data)
    X_test = TS.transform(X_test)

    return X_train, X_test


### Implementation

In [None]:
warnings.filterwarnings("ignore", category=FutureWarning)

# Need to modify
denoising_method = 'ASR'

# Smapling Rate (IC-U-Net: 256, Others: 200)
SR = 200

def initialize_metrics():
  return {
    "max": {k: 0 for k in ["accuracy", "accuracy_balanced", "precision", "recall", "f1", "auc"]},
    "min": {k: 1 for k in ["accuracy", "accuracy_balanced", "precision", "recall", "f1", "auc"]},
    "all": {k: [] for k in ["accuracy", "accuracy_balanced", "precision", "recall", "f1", "auc"]}
  }

def show_confusion_matrices(cm, method_name):
  cm_np = cp.asnumpy(cm).astype(int)
  plt.figure(figsize=(6, 5))
  sns.heatmap(cm_np, annot=True, fmt="d", cmap="Blues")
  plt.title(f"Confusion Matrix of {method_name}")
  plt.xlabel("Predicted Label")
  plt.ylabel("True Label")
  plt.show()

  cm_normalized = cm_np / np.sum(cm_np)
  plt.figure(figsize=(6, 5))
  sns.heatmap(cm_normalized, annot=True, fmt=".2%", cmap="Blues")
  plt.title(f"Normalized Confusion Matrix of {method_name}")
  plt.xlabel("Predicted Label")
  plt.ylabel("True Label")
  plt.show()

def summarize_and_plot(metrics, confusion_matrix_accumulated, method_name, all_fprs, all_tprs):
  print(f"\n===== Individual Results for {method_name} =====")
  for i in range(len(metrics["all"]["accuracy"])):
    print(f"Run {i+1:>2}:", end=' ')
    for metric in metrics["all"]:
      print(f"{metric} = {metrics['all'][metric][i]:.4f}", end=' | ')
    print()

  print(f"\n===== Summary for {method_name} =====")
  for metric in metrics["all"]:
    values = metrics["all"][metric]
    mean, std = np.mean(values), np.std(values)
    max_val, min_val = metrics["max"][metric], metrics["min"][metric]
    print(f"{metric:<20} Mean = {mean:.4f}, Std = {std:.4f}, Max = {max_val:.4f}, Min = {min_val:.4f}")

  # Draw ROC curve with AUC
  mean_fpr = np.linspace(0, 1, 100)
  interp_tprs = []
  for fpr_i, tpr_i in zip(all_fprs, all_tprs):
    interp_func = interp1d(fpr_i, tpr_i, bounds_error=False, fill_value=0)
    interp_tpr = interp_func(mean_fpr)
    interp_tprs.append(interp_tpr)

  mean_tpr = np.mean(interp_tprs, axis=0)
  auc_mean = np.mean(metrics["all"]["auc"])

  plt.figure(figsize=(6, 5))
  plt.plot(mean_fpr, mean_tpr, label=f'Mean ROC (AUC = {auc_mean:.4f})')
  plt.plot([0, 1], [0, 1], 'r--')
  plt.xlabel("False Positive Rate")
  plt.ylabel("True Positive Rate")
  plt.title(f'Average ROC Curve ({method_name})')
  plt.legend(loc="lower right")
  plt.grid(True)
  plt.show()

  # 累積混淆矩陣
  cm_acc = cp.asnumpy(confusion_matrix_accumulated).astype(int)
  plt.figure(figsize=(6, 5))
  sns.heatmap(cm_acc, annot=True, fmt="d", cmap="Blues")
  plt.title(f"Accumulated Confusion Matrix ({method_name})")
  plt.xlabel("Predicted Label")
  plt.ylabel("True label")
  plt.show()

  cm_normalized = cm_acc / np.sum(cm_acc)
  plt.figure(figsize=(6, 5))
  sns.heatmap(cm_normalized, annot=True, fmt=".2%", cmap="Blues")
  plt.title(f"Normalized Confusion Matrix ({method_name})")
  plt.xlabel("Predicted Label")
  plt.ylabel("True Label")
  plt.show()

def evaluate_and_accumulate(Y_true, Y_pred, Y_prob, accumulators, confusion_matrix_total):
  acc = accuracy_score(Y_true, Y_pred)
  acc_bal = balanced_accuracy_score(Y_true, Y_pred)
  prec = precision_score(Y_true, Y_pred)
  rec = recall_score(Y_true, Y_pred)
  f1 = f1_score(Y_true, Y_pred)
  auc = roc_auc_score(Y_true, Y_prob)
  cm = confusion_matrix(Y_true, Y_pred)

  for metric, value in zip(
    ["accuracy", "accuracy_balanced", "precision", "recall", "f1", "auc"],
    [acc, acc_bal, prec, rec, f1, auc]
  ):
    accumulators["max"][metric] = max(accumulators["max"][metric], value)
    accumulators["min"][metric] = min(accumulators["min"][metric], value)
    accumulators["all"][metric].append(value)

  confusion_matrix_total += cm
  return cm

def to_numpy(x):
  if isinstance(x, cp.ndarray):
      return cp.asnumpy(x)
  elif isinstance(x, cudf.Series) or isinstance(x, cudf.DataFrame):
      return x.to_pandas().values
  return x

def run_model_evaluation_loop(model_name, X_train, X_test, Y_train, Y_test, run, method_name):
  confusion_matrix_accumulated = cp.zeros((2, 2))
  metrics = initialize_metrics()
  all_fprs = []
  all_tprs = []

  # select_random_number = random.randint(0, run - 1)

  seeds_used = [random.randint(0, 99999) for _ in range(run)]

  for i in range(run):
    print(f"Run {i+1}/{run} with seed {seeds_used[i]}")
    # 使用 SMOTE 進行過抽樣
    print("Applying SMOTE to balance the training set...")
    smote = SMOTE(random_state=seeds_used[i])
    X_train_resampled, Y_train_resampled = smote.fit_resample(X_train, Y_train)

    X_train_cudf = cudf.DataFrame(X_train_resampled)
    Y_train_cudf = cudf.Series(Y_train_resampled)

    if model_name == "LR":
      print('Using Logistic Regression')
      model = cuLR()
    elif model_name == "SVM":
      print('Using cuML SVM (linear kernel on GPU)')
      model = cuSVC(kernel='linear', probability=True)
    elif model_name == "RF":
      print('Using Random Forest')
      model = cuRF()
    elif model_name == "LightGBM":
      print('Using LightGBM')
      model = LGBMClassifier()
    elif model_name == "MLP":
      print("Using PyTorch MLP (with early stopping)")
      model = MLPClassifier(
        activation='relu',
        solver='adam',
        learning_rate_init=0.001,
        max_iter=200,
        early_stopping=True, # Not default
        random_state=seeds_used[i]
      )
    else:
      raise ValueError("Unknown model name")

    print("Training...")
    if model_name == "LightGBM" or "MLP":
      model.fit(X_train_resampled, Y_train_resampled)
    elif model_name not in ["MLP"]:  # MLP don't need to fit
      model.fit(X_train_cudf, Y_train_cudf)

    print("Predicting...")
    Y_preds = model.predict(X_test)
    Y_pred_proba = model.predict_proba(X_test)[:, 1]

    Y_preds_cpu = to_numpy(Y_preds)
    Y_pred_proba_cpu = to_numpy(Y_pred_proba)
    print('PROB: ', Y_pred_proba_cpu)
    Y_test_cpu = to_numpy(Y_test)

    fpr_i, tpr_i, _ = roc_curve(Y_test_cpu, Y_pred_proba_cpu)
    all_fprs.append(fpr_i)
    all_tprs.append(tpr_i)

    cm = evaluate_and_accumulate(Y_test_cpu, Y_preds_cpu, Y_pred_proba_cpu, metrics, confusion_matrix_accumulated)

    # if i == select_random_number:
    #   print('====== One of running result ======')
    #   show_confusion_matrices(cm, method_name)
  print('-------------------- Running 10 times Done ---------------------')
  summarize_and_plot(metrics, confusion_matrix_accumulated, method_name, all_fprs, all_tprs)

def run_models(
    X_train_raw, X_test_raw, Y_train, Y_test,
    feature_func, run=10, method_name=""
):
  print('-' * 30)
  print("Extracting features using ", method_name)
  print(f'Sampling Rate: {SR}, with denoising method: {denoising_method}')
  
  with timed_block("Feature Extraction"):
    if feature_func == preprocess_riemann_features:
      X_train_feat, X_test_feat = feature_func(X_train_raw, X_test_raw, Y_train)
    else:
      X_train_feat = feature_func(X_train_raw, sampling_rate=SR)
      X_test_feat = feature_func(X_test_raw, sampling_rate=SR)
  
  scenarios = {
    "no_pca": {"X_train": X_train_feat, "X_test": X_test_feat}
  }

  for pca_mode, data_pair in scenarios.items():
    print('-' * 30)
    for model_name in ["SVM", "LR", "RF"]: # ["LR", "SVM", "RF", "LightGBM", "KNN", "MLP"]
      print(f"\n---- Running Model: {model_name} ----")
      with timed_block(f"{model_name} 模型 running time"):
        run_model_evaluation_loop(
          model_name=model_name,
          X_train=data_pair["X_train"],
          X_test=data_pair["X_test"],
          Y_train=Y_train,
          Y_test=Y_test,
          run=run,
          method_name=f"{model_name}_{denoising_method}_XCTS"
        )
        print('-' * 30)


### Main function

In [None]:
# For data shaping
epoch_len = 140 # (IC-U-Net: 136, Others: 140)

train_data_shaped = np.reshape(train_data_list,
  (train_subj_num * stimulus_per_subj, len(channels), epoch_len))
test_data_shaped = np.reshape(test_data_list,
  (test_subj_num * stimulus_per_subj, len(channels), epoch_len))

Y_train = pd.read_csv('data/TrainLabels.csv')['Prediction'].values

Y_test = np.reshape(pd.read_csv('./data/true_labels.csv', header=None).values, 3400)

run_models(
    X_train_raw=train_data_shaped,
    X_test_raw=test_data_shaped,
    Y_train=Y_train,
    Y_test=Y_test,
    feature_func=preprocess_riemann_features,
    run=10,
    method_name="XCTS"
)

print('-' * 50)