In [1]:
# ============================================================

import os, glob, warnings, math, random
warnings.filterwarnings("ignore", category=RuntimeWarning)

# -----------------------------
# Backend selection (JAX → TPU/GPU; else NumPy CPU)
# -----------------------------
USE_JAX = True
try:
    import jax
    import jax.numpy as xp
    from jax import jit
    JAX_BACKEND = jax.devices()[0].platform
    print(f"[Backend] JAX on {JAX_BACKEND} ({jax.devices()[0]})")
except Exception:
    USE_JAX = False
    import numpy as xp
    print("[Backend] Using NumPy (CPU). Enable TPU for speed if desired.")

import numpy as np
import scipy.io as sio
import pandas as pd
from scipy import stats
import matplotlib.pyplot as plt

# Echo versions for reproducibility
print("[Versions] numpy", np.__version__)
if USE_JAX:
    import jaxlib
    print("[Versions] jax", jax.__version__, "| jaxlib", jaxlib.__version__)

# -----------------------------
# Config
# -----------------------------
DATA_ROOTS = [
    "/kaggle/input/ssvep-sandiego",
    "/kaggle/input",
]
FS_FALLBACK = 250            # Hz
BANDPASS = (5.0, 45.0)
NOTCH = None                 

FB_SUBBANDS = [(6,14), (14,22), (22,30), (30,38), (38,46)]
FB_WEIGHTS_POW = 1.25

TRCA_REG = 1e-4
EPS = 1e-8

RNG_SEED = 2025
np.random.seed(RNG_SEED); random.seed(RNG_SEED)

OUTDIR = "/kaggle/working"
os.makedirs(OUTDIR, exist_ok=True)


def _to_numpy_numeric(x):
    if isinstance(x, np.ndarray) and x.dtype != object and x.ndim >= 2 and x.size > 0:
        return x
    return None

def _maybe_cell_to_list(x):
    if isinstance(x, np.ndarray) and x.dtype == object:
        out = []
        for i in np.ndindex(x.shape):
            v = x[i]
            if isinstance(v, np.ndarray):
                out.append(v)
        return out
    return None

def _find_first_key(mat, candidates):
    for k in candidates:
        if k in mat:
            return k
    return None

def _extract_fs(mat, default=FS_FALLBACK):
    k = _find_first_key(mat, ["fs","Fs","srate","sfreq","fsamp","SamplingRate","sampling_rate"])
    if k and np.isscalar(mat[k]): return float(mat[k])
    return float(default)

def _extract_freqs(mat):
    k = _find_first_key(mat, ["freqs","FREQS","stimulus_frequencies","sfreqs","Freq","frequency","Frequencies"])
    if k is not None:
        try:
            v = np.array(mat[k], dtype=float).reshape(-1)
            if 2 <= v.size <= 128: return v
        except Exception:
            pass
    return None

def _extract_labels(mat):
    k = _find_first_key(mat, ["trainLabel","testLabel","label","labels","y","Y","stimulus","class","classes"])
    if k is not None:
        try:
            return np.array(mat[k]).reshape(-1)
        except Exception:
            return None
    return None

def _as_class_trial_ch_samp_from_3d(X3, labels=None, freqs=None):
    shp = X3.shape
    samp_ax = int(np.argmax(shp))
    ch_ax = None
    for ax, d in enumerate(shp):
        if d in (8,9,16,32,64) and ax != samp_ax:
            ch_ax = ax; break
    if ch_ax is None:
        order = np.argsort(shp)
        for ax in order[::-1]:
            if ax != samp_ax:
                ch_ax = int(ax); break
    axes = [0,1,2]; axes.remove(samp_ax); axes.remove(ch_ax)
    trial_ax = axes[0]
    X_tcs = np.moveaxis(X3, (trial_ax, ch_ax, samp_ax), (0,1,2))

    if labels is not None and labels.size == X_tcs.shape[0]:
        uniq = np.unique(labels)
        label_to_idx = {v:i for i,v in enumerate(uniq)}
        groups = {label_to_idx[v]: [] for v in uniq}
        for i in range(X_tcs.shape[0]):
            groups[label_to_idx[labels[i]]].append(X_tcs[i])
        per_class = [np.stack(groups[i], axis=0) for i in range(len(uniq))]
        minT = min(x.shape[0] for x in per_class)
        per_class = [x[:minT] for x in per_class]
        C = len(per_class); T = minT; Ch = X_tcs.shape[1]; S = X_tcs.shape[2]
        data4 = np.zeros((C, T, Ch, S), dtype=X_tcs.dtype)
        for c in range(C): data4[c] = per_class[c]
        return data4
    else:
        T_total = X_tcs.shape[0]
        if freqs is not None and freqs.size >= 2 and T_total % freqs.size == 0:
            C = freqs.size; T = T_total // C
            data4 = np.zeros((C, T, X_tcs.shape[1], X_tcs.shape[2]), dtype=X_tcs.dtype)
            idx = 0
            for c in range(C):
                data4[c] = X_tcs[idx:idx+T]; idx += T
            return data4
        for C in [4, 12, 40]:
            if T_total % C == 0:
                T = T_total // C
                data4 = np.zeros((C, T, X_tcs.shape[1], X_tcs.shape[2]), dtype=X_tcs.dtype)
                idx = 0
                for c in range(C):
                    data4[c] = X_tcs[idx:idx+T]; idx += T
                return data4
        raise RuntimeError("Cannot infer classes without labels/freqs for a 3-D tensor.")

def _from_cell_per_class(cell_list):
    class_tensors = []
    for arr in cell_list:
        if not isinstance(arr, np.ndarray) or arr.ndim < 3: continue
        perms = [(0,1,2),(0,2,1),(1,2,0),(1,0,2),(2,0,1),(2,1,0)]
        best = None
        for p in perms:
            A = np.transpose(arr, p)
            if A.shape[0] >= 1 and A.shape[1] in (8,9,16,32,64) and A.shape[2] >= 64:
                best = A; break
        if best is None: best = np.transpose(arr, (2,0,1))
        class_tensors.append(best)
    if not class_tensors: raise RuntimeError("Empty/invalid cell array.")
    minT = min(c.shape[0] for c in class_tensors)
    class_tensors = [c[:minT] for c in class_tensors]
    C = len(class_tensors); T = minT; Ch = class_tensors[0].shape[1]; S = class_tensors[0].shape[2]
    data4 = np.zeros((C, T, Ch, S), dtype=class_tensors[0].dtype)
    for c in range(C): data4[c] = class_tensors[c]
    return data4

def _parse_one_mat(mat):
    fs = _extract_fs(mat, default=FS_FALLBACK)
    freqs = _extract_freqs(mat)
    labels = _extract_labels(mat)

    for _, v in mat.items():
        vnp = _to_numpy_numeric(v)
        if vnp is not None and vnp.ndim == 4 and vnp.size > 10000:
            dims = list(vnp.shape)
            samp_ax = int(np.argmax(dims))
            ch_ax = None
            for ax, d in enumerate(dims):
                if d in (8,9,16,32,64) and ax != samp_ax: ch_ax = ax; break
            if ch_ax is None:
                rem = [0,1,2,3]; rem.remove(samp_ax)
                ch_ax = max(rem, key=lambda ax: dims[ax])
            rem = [0,1,2,3]; rem.remove(samp_ax); rem.remove(ch_ax)
            a1, a2 = rem
            d1, d2 = dims[a1], dims[a2]
            trial_ax, class_ax = (a1, a2) if d1 < d2 else (a2, a1)
            data4 = np.moveaxis(vnp, (class_ax, trial_ax, ch_ax, samp_ax), (0,1,2,3))
            return data4.astype(np.float64), fs, freqs

    for key in ["data","EEG","trainEEG","testEEG","X","eeg","signals","Data"]:
        if key in mat and isinstance(mat[key], np.ndarray) and mat[key].dtype == object:
            cell_list = _maybe_cell_to_list(mat[key])
            if cell_list:
                data4 = _from_cell_per_class(cell_list)
                return data4.astype(np.float64), fs, freqs

    for key in ["EEG","trainEEG","testEEG","data","X","eeg","signals","Data"]:
        if key in mat:
            vnp = _to_numpy_numeric(mat[key])
            if vnp is not None and vnp.ndim == 3:
                data4 = _as_class_trial_ch_samp_from_3d(vnp, labels=labels, freqs=freqs)
                return data4.astype(np.float64), fs, freqs

    class_arrays = []
    for _, v in mat.items():
        vnp = _to_numpy_numeric(v)
        if vnp is not None and vnp.ndim == 3:
            class_arrays.append(vnp)
    if class_arrays:
        normed = []
        for arr in class_arrays:
            if arr.shape[1] in (8,9,16,32,64): A = arr
            else: A = np.transpose(arr, (2,0,1))
            normed.append(A)
        minT = min(a.shape[0] for a in normed)
        normed = [a[:minT] for a in normed]
        C = len(normed); T = minT; Ch = normed[0].shape[1]; S = normed[0].shape[2]
        data4 = np.zeros((C, T, Ch, S), dtype=normed[0].dtype)
        for c in range(C): data4[c] = normed[c]
        return data4.astype(np.float64), fs, freqs

    return None, None, None

def load_subject_pair(train_path, test_path=None):
    mat_tr = sio.loadmat(train_path, squeeze_me=True, struct_as_record=False)
    Xtr, fs, freqs = _parse_one_mat(mat_tr)
    if Xtr is None: raise RuntimeError(f"No usable EEG tensor in {os.path.basename(train_path)}")

    if test_path is not None and os.path.exists(test_path):
        mat_te = sio.loadmat(test_path, squeeze_me=True, struct_as_record=False)
        Xte, fs2, freqs2 = _parse_one_mat(mat_te)
        if Xte is not None and Xte.shape[0:3] == Xtr.shape[0:3]:
            T_min = min(Xtr.shape[1], Xte.shape[1])
            X = np.concatenate([Xtr[:, :T_min], Xte[:, :T_min]], axis=1)
            fs = fs2 if fs2 is not None else fs
            if freqs is None: freqs = freqs2
            return X, fs, freqs
    return Xtr, fs, freqs

def load_all_subjects():
    root = None
    for r in DATA_ROOTS:
        if os.path.exists(r): root = r; break
    if root is None: raise FileNotFoundError("Attach the dataset under /kaggle/input first.")
    mats = sorted(glob.glob(os.path.join(root, "**", "*.mat"), recursive=True))
    if not mats: raise FileNotFoundError("No .mat files found.")

    train_files = [p for p in mats if "train" in os.path.basename(p).lower() and "eeg" in os.path.basename(p).lower()]
    test_files  = [p for p in mats if "test"  in os.path.basename(p).lower() and "eeg" in os.path.basename(p).lower()]
    def twin_for(p, pool):
        bn = os.path.basename(p)
        cand = bn.lower().replace("train", "test")
        for q in pool:
            if os.path.basename(q).lower() == cand: return q
        return None

    subjects, fs_list = [], []
    global_freqs = None

    if train_files:
        for tr in sorted(train_files):
            te = twin_for(tr, test_files)
            try:
                Xsub, fs, freqs = load_subject_pair(tr, te)
                subjects.append(Xsub); fs_list.append(fs)
                if global_freqs is None and freqs is not None: global_freqs = freqs
            except Exception as e:
                print(f"[Loader] Skipping {os.path.basename(tr)}: {e}")
    else:
        for m in mats:
            try:
                mat = sio.loadmat(m, squeeze_me=True, struct_as_record=False)
                Xsub, fs, freqs = _parse_one_mat(mat)
                if Xsub is not None:
                    subjects.append(Xsub); fs_list.append(fs)
                    if global_freqs is None and freqs is not None: global_freqs = freqs
                else:
                    print(f"[Loader] Skipping {os.path.basename(m)}: no usable tensor")
            except Exception as e:
                print(f"[Loader] Skipping {os.path.basename(m)}: {e}")

    if len(subjects) == 0:
        raise RuntimeError("No usable subjects parsed after robust loader.")

    C0, _, Ch0, S0 = subjects[0].shape
    minT = min(x.shape[1] for x in subjects)
    subjects = [x[:, :minT] for x in subjects]
    for x in subjects:
        assert x.shape[0] == C0 and x.shape[1] == minT and x.shape[2] == Ch0 and x.shape[3] == S0

    fs = int(round(float(np.median(fs_list)))) if fs_list else FS_FALLBACK
    X = np.stack(subjects, axis=0)  # (Subj, C, T, Ch, S)
    return X, fs, global_freqs

#  FFT bandpass (+ optional notch), TPU/CPU friendly

def _fft_bandpass_np(x, fs, lo, hi, notch=None):
    X = np.fft.rfft(x, axis=-1)
    freqs = np.fft.rfftfreq(x.shape[-1], d=1.0/fs)
    mask = (freqs >= lo) & (freqs <= hi)
    Y = np.zeros_like(X); Y[..., mask] = X[..., mask]
    if notch is not None:
        notch_bw = 0.5
        nmask = (freqs >= (notch - notch_bw)) & (freqs <= (notch + notch_bw))
        Y[..., nmask] = 0.0
    return np.fft.irfft(Y, n=x.shape[-1], axis=-1)

if USE_JAX:
    @jit
    def _fft_bandpass_jax(x, fs, lo, hi, notch):
        X = xp.fft.rfft(x, axis=-1)
        freqs = xp.fft.rfftfreq(x.shape[-1], d=1.0/fs)
        mask = (freqs >= lo) & (freqs <= hi)
        Y = xp.where(mask, X, 0.0 + 0.0j)
        if notch is not None and notch > 0.0:
            notch_bw = 0.5
            nmask = (freqs >= (notch - notch_bw)) & (freqs <= (notch + notch_bw))
            Y = xp.where(nmask, 0.0 + 0.0j, Y)
        return xp.fft.irfft(Y, n=x.shape[-1], axis=-1)

def fft_bandpass(x, fs, lo, hi, notch=None):
    return _fft_bandpass_jax(xp.asarray(x), fs, lo, hi, notch) if USE_JAX else _fft_bandpass_np(np.asarray(x), fs, lo, hi, notch)

# ============================================================
#  TRCA (spatial filter & scoring)
# ============================================================
def trca_spatial_filter(X_trials, reg=TRCA_REG, use_jax=USE_JAX):
    T, C, S = X_trials.shape
    if use_jax:
        XiXiT = xp.einsum("tcs,tks->ck", X_trials, X_trials)
        Q = XiXiT + reg * xp.eye(C)
        Xsum = xp.sum(X_trials, axis=0)
        Snum = xp.zeros((C, C))
        for t in range(T):
            Xi = X_trials[t]; Xrest = Xsum - Xi
            Snum = Snum + Xi @ Xrest.T
        Snum = 0.5 * (Snum + Snum.T)
        evals, evecs = xp.linalg.eigh(xp.linalg.pinv(Q) @ Snum)
        w = xp.real(evecs[:, xp.argmax(xp.real(evals))])
        w = w / (xp.linalg.norm(w) + EPS)
        return w
    else:
        XiXiT = np.einsum("tcs,tks->ck", X_trials, X_trials)
        Q = XiXiT + reg * np.eye(C)
        Xsum = np.sum(X_trials, axis=0)
        Snum = np.zeros((C, C))
        for t in range(T):
            Xi = X_trials[t]; Xrest = Xsum - Xi
            Snum += Xi @ Xrest.T
        Snum = 0.5 * (Snum + Snum.T)
        evals, evecs = np.linalg.eigh(np.linalg.pinv(Q) @ Snum)
        w = np.real(evecs[:, np.argmax(np.real(evals))])
        w = w / (np.linalg.norm(w) + EPS)
        return w

def trca_template(X_trials): return X_trials.mean(axis=0)

def corr_pearson(a, b):
    am = a - a.mean(); bm = b - b.mean()
    num = (am * bm).sum()
    den = xp.linalg.norm(am) * xp.linalg.norm(bm) + EPS
    return num / den

def trca_score_epoch(x_epoch, w, tpl):
    s  = w @ x_epoch
    st = w @ tpl
    return float(corr_pearson(s, st))

# ============================================================
#  CORAL (unsupervised domain alignment)
# ============================================================
def coral_fit(Xs, Xt, eps=1e-6, use_jax=USE_JAX):
    if use_jax:
        mu_s = xp.mean(Xs, axis=0); mu_t = xp.mean(Xt, axis=0)
        Xs0 = Xs - mu_s; Xt0 = Xt - mu_t
        Cs = (Xs0.T @ Xs0) / (Xs0.shape[0] - 1)
        Ct = (Xt0.T @ Xt0) / (Xt0.shape[0] - 1)
        evals_s, evecs_s = xp.linalg.eigh(Cs + eps*xp.eye(Cs.shape[0]))
        evals_t, evecs_t = xp.linalg.eigh(Ct + eps*xp.eye(Ct.shape[0]))
        Cs_inv_sqrt = evecs_s @ xp.diag(1.0/xp.sqrt(evals_s)) @ evecs_s.T
        Ct_sqrt     = evecs_t @ xp.diag(xp.sqrt(evals_t))     @ evecs_t.T
        A = Cs_inv_sqrt @ Ct_sqrt
        b = (mu_t - (mu_s @ A))
        return A, b
    else:
        mu_s = np.mean(Xs, axis=0); mu_t = np.mean(Xt, axis=0)
        Xs0 = Xs - mu_s; Xt0 = Xt - mu_t
        Cs = (Xs0.T @ Xs0) / (Xs0.shape[0] - 1)
        Ct = (Xt0.T @ Xt0) / (Xt0.shape[0] - 1)
        evals_s, evecs_s = np.linalg.eigh(Cs + eps*np.eye(Cs.shape[0]))
        evals_t, evecs_t = np.linalg.eigh(Ct + eps*np.eye(Ct.shape[0]))
        Cs_inv_sqrt = evecs_s @ np.diag(1.0/np.sqrt(evals_s)) @ evecs_s.T
        Ct_sqrt     = evecs_t @ np.diag(np.sqrt(evals_t))     @ evecs_t.T
        A = Cs_inv_sqrt @ Ct_sqrt
        b = (mu_t - (mu_s @ A))
        return A, b

def coral_apply(X, A, b): return X @ A + b

# -----------------------------------------------------------
#  FBCCA helpers (TPU-safe CCA: symmetric eig via eigh)
# -----------------------------------------------------------
def _build_ref_bank(class_freqs, fs, nsamp, harmonics=3, use_jax=USE_JAX):
    t = (xp.arange(nsamp)/fs) if use_jax else (np.arange(nsamp)/fs)
    bank = []
    for f in class_freqs:
        comps = []
        for h in range(1, harmonics+1):
            comps.append(xp.sin(2*xp.pi*h*f*t) if use_jax else np.sin(2*np.pi*h*f*t))
            comps.append(xp.cos(2*xp.pi*h*f*t) if use_jax else np.cos(2*np.pi*h*f*t))
        R = xp.stack(comps, axis=1) if use_jax else np.stack(comps, axis=1)
        bank.append(R)
    return bank

def _max_cca_corr(Xs, Ys, eps=1e-8, use_jax=USE_JAX):
    # Xs: (S, Dx), Ys: (S, Dy)
    if use_jax:
        X = Xs - xp.mean(Xs, axis=0)
        Y = Ys - xp.mean(Ys, axis=0)
        Sxx = X.T @ X + eps*xp.eye(X.shape[1])
        Syy = Y.T @ Y + eps*xp.eye(Y.shape[1])
        Sxy = X.T @ Y
        invSxx = xp.linalg.pinv(Sxx)
        invSyy = xp.linalg.pinv(Syy)
        M = invSxx @ Sxy @ invSyy @ Sxy.T
        M = 0.5 * (M + M.T)  # TPU-safe
        ev = xp.linalg.eigh(M)[0]
        rho2 = float(xp.max(xp.real(ev)))
        return float(np.sqrt(max(rho2, 0.0)))
    else:
        X = Xs - Xs.mean(axis=0)
        Y = Ys - Ys.mean(axis=0)
        Sxx = X.T @ X + eps*np.eye(X.shape[1])
        Syy = Y.T @ Y + eps*np.eye(Y.shape[1])
        Sxy = X.T @ Y
        invSxx = np.linalg.pinv(Sxx)
        invSyy = np.linalg.pinv(Syy)
        M = invSxx @ Sxy @ invSyy @ Sxy.T
        M = 0.5 * (M + M.T)
        ev = np.linalg.eigh(M)[0]
        rho2 = float(np.max(np.real(ev)))
        return float(np.sqrt(max(rho2, 0.0)))

def _estimate_class_freqs_from_training(train_by_c, fs, fmin=6.0, fmax=20.0):
    est = []
    for c in range(len(train_by_c)):
        Xc = np.concatenate(train_by_c[c], axis=0)
        tpl = Xc.mean(axis=0).mean(axis=0)
        S = tpl.shape[0]
        freqs = np.fft.rfftfreq(S, d=1.0/fs)
        P = np.abs(np.fft.rfft(tpl))**2
        band = (freqs >= fmin) & (freqs <= fmax)
        est.append(float(freqs[band][np.argmax(P[band])]) if np.any(band) else 10.0)
    return np.array(est, dtype=float)

def _fbcca_score_epoch(x_epoch, cand_c, ref_bank, fs, subbands, use_jax=USE_JAX):
    scores = []
    for (lo, hi) in subbands:
        x_f = fft_bandpass(x_epoch, fs, lo, hi, NOTCH)
        Xs = (x_f.T) if use_jax else x_f.T
        R  = ref_bank[cand_c]
        rho = _max_cca_corr(Xs, R, use_jax=use_jax)
        scores.append(rho)
    return scores

# ----------------------------------------------------------
#  Utilities
# ----------------------------------------------------------
def itr_bits_per_min(P, N, Tsec):
    if N <= 1 or P <= 0.0 or P >= 1.0: return 0.0
    return float((math.log2(N) + P*math.log2(P) + (1-P)*math.log2((1-P)/(N-1))) * (60.0/Tsec))

def plot_confmat(cm, classes, title, path):
    fig = plt.figure(figsize=(6,5))
    plt.imshow(cm, interpolation='nearest')
    plt.title(title)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.xticks(ticks=np.arange(len(classes)), labels=classes, rotation=45, ha='right')
    plt.yticks(ticks=np.arange(len(classes)), labels=classes)
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, f"{cm[i,j]}", ha="center", va="center")
    plt.tight_layout()
    fig.savefig(path, bbox_inches='tight', dpi=150)
    plt.close(fig)

def bar_with_ci(values, labels, mean_label, title, path):
    vals = np.asarray(values, dtype=float)
    fig = plt.figure(figsize=(8,4))
    x = np.arange(len(vals))
    plt.bar(x, vals)
    plt.xticks(x, labels, rotation=45, ha='right')
    plt.ylabel("Accuracy (%)")
    plt.title(title)
    mu = vals.mean(); se = vals.std(ddof=1)/np.sqrt(len(vals)); ci95 = 1.96*se
    plt.axhline(mu, linestyle='--')
    plt.text(len(vals)-0.5, mu+0.5, f"{mean_label}: {mu:.2f} ± {ci95:.2f}")
    plt.tight_layout()
    fig.savefig(path, bbox_inches='tight', dpi=150)
    plt.close(fig)

# -----------------------------------------------------------
#  Core evaluation (LOSO) with extras
# -----------------------------------------------------------
def evaluate_loso(
    X_all, fs, freqs=None, harmonics=3,
    use_coral=True,
    window_seconds=None,           # None = full length
    adaptive_fusion=True,          # learn TRCA/FBCCA weight on training subjects
    csv_prefix="ssvep_loso_pub"
):
    """
    Returns: dict with summary paths and per-subject arrays
    Saves:
      - summary CSV
      - per-subject confusion matrices (png)
      - accuracy bar plots + CI (png)
    """
    Subj, C, T, Ch, S_full = X_all.shape
    if window_seconds is None:
        S_use = S_full
    else:
        S_use = int(round(window_seconds * fs))
        S_use = max(64, min(S_use, S_full))

    M = len(FB_SUBBANDS)
    fb_w = xp.asarray([1.0/((m+1)**FB_WEIGHTS_POW) for m in range(M)])

    subjects = [f"S{i+1:02d}" for i in range(Subj)]
    classes  = [f"{i}" for i in range(C)]

    rows_summary = []
    acc_tr_list, acc_fb_list, acc_en_list = [], [], []
    confmats = {"TRCA": [], "FBCCA": [], "ENS": []}

    for s_te in range(Subj):
        # ---- train/test split
        idx_tr_np = np.array([i for i in range(Subj) if i != s_te], dtype=int)
        X_tr = xp.take(X_all, xp.asarray(idx_tr_np), axis=0) if USE_JAX else X_all[idx_tr_np]
        X_te = X_all[s_te]

        # ---- SAFE window cropping by slicing last axis (no in-place broadcasting)
        if S_use != S_full:
            X_tr_np = np.array(X_tr)[..., :S_use]   # shape: (Subj-1, C, T, Ch, S_use)
            X_te_np = np.array(X_te)[..., :S_use]   # shape: (C, T, Ch, S_use)
            X_tr = xp.asarray(X_tr_np) if USE_JAX else X_tr_np
            X_te = X_te_np
        else:
            # ensure host array for test subject for downstream ops
            X_te = np.array(X_te)

        S = S_use

        # group training by class (host arrays for templates)
        train_by_c = {c: [] for c in range(C)}
        for s in range(X_tr.shape[0]):
            for c in range(C):
                train_by_c[c].append(np.array(X_tr[s, c]))

        # TRCA filter-bank
        W = [[None for _ in range(M)] for __ in range(C)]
        TPL = [[None for _ in range(M)] for __ in range(C)]
        for c in range(C):
            Xc_all = np.concatenate(train_by_c[c], axis=0)   # (T_all, Ch, S)
            for m, (lo, hi) in enumerate(FB_SUBBANDS):
                Xc_f = fft_bandpass(Xc_all, fs, lo, hi, NOTCH)
                if USE_JAX: Xc_f = xp.asarray(Xc_f)
                w = trca_spatial_filter(Xc_f, reg=TRCA_REG, use_jax=USE_JAX)
                tpl = trca_template(Xc_f)
                W[c][m] = w; TPL[c][m] = tpl

        # Build CORAL feats for TRCA head
        src_X = []
        for c_true in range(C):
            Xc_src = np.concatenate(train_by_c[c_true], axis=0)
            for t_idx in range(Xc_src.shape[0]):
                x = Xc_src[t_idx]
                vec = []
                for m in range(M):
                    w, tpl = W[c_true][m], TPL[c_true][m]
                    x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)
                    if USE_JAX: x_f = xp.asarray(x_f)
                    vec.append(trca_score_epoch(x_f, w, tpl))
                for c_other in range(C):
                    if c_other == c_true: continue
                    for m in range(M):
                        w, tpl = W[c_other][m], TPL[c_other][m]
                        x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)
                        if USE_JAX: x_f = xp.asarray(x_f)
                        vec.append(trca_score_epoch(x_f, w, tpl))
                src_X.append(vec)
        src_X = xp.asarray(src_X)

        # Unlabeled target features (for CORAL fit)
        te_pairs = [(c, t) for c in range(C) for t in range(T)]
        te_feats = []
        for (c_label, t_idx) in te_pairs:
            x = np.array(X_te[c_label, t_idx])
            vec = []
            for m in range(M):
                w, tpl = W[c_label][m], TPL[c_label][m]
                x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)
                if USE_JAX: x_f = xp.asarray(x_f)
                vec.append(trca_score_epoch(x_f, w, tpl))
            for c_other in range(C):
                if c_other == c_label: continue
                for m in range(M):
                    w, tpl = W[c_other][m], TPL[c_other][m]
                    x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)
                    if USE_JAX: x_f = xp.asarray(x_f)
                    vec.append(trca_score_epoch(x_f, w, tpl))
            te_feats.append(vec)
        te_feats = xp.asarray(te_feats)

        if use_coral:
            A, b = coral_fit(src_X, te_feats, eps=1e-6, use_jax=USE_JAX)
        else:
            d = src_X.shape[1]
            A = xp.eye(d); b = xp.zeros((d,))

        if freqs is not None and len(freqs) == C:
            class_freqs = np.array(freqs, dtype=float)
        else:
            class_freqs = _estimate_class_freqs_from_training(train_by_c, fs)
        ref_bank = _build_ref_bank(class_freqs, fs, S, harmonics=harmonics, use_jax=USE_JAX)

        # ---- Adaptive fusion weight (on training subjects only)
        w_trca, w_fb = 1.0, 1.0
        if adaptive_fusion:
            acc_pairs = []
            for s_val_i in range(X_tr.shape[0]):
                X_val = np.array(X_tr[s_val_i])
                y_true_v, y_tr_v, y_fb_v = [], [], []
                for c in range(C):
                    for t_idx in range(T):
                        x = X_val[c, t_idx]
                        # TRCA(CORAL)
                        tr_scores = []
                        for c_hat in range(C):
                            vec = []
                            for m in range(M):
                                w, tpl = W[c_hat][m], TPL[c_hat][m]
                                x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)
                                if USE_JAX: x_f = xp.asarray(x_f)
                                vec.append(trca_score_epoch(x_f, w, tpl))
                            for c_other in range(C):
                                if c_other == c_hat: continue
                                for m in range(M):
                                    w, tpl = W[c_other][m], TPL[c_other][m]
                                    x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)
                                    if USE_JAX: x_f = xp.asarray(x_f)
                                    vec.append(trca_score_epoch(x_f, w, tpl))
                            v = xp.asarray(vec)[None,:]
                            v_al = coral_apply(v, A, b)[0]
                            v_block = v_al[:len(FB_SUBBANDS)]
                            tr_scores.append(float((v_block * fb_w).sum()))
                        # FBCCA
                        fb_scores = []
                        weights_np = 1.0/((np.arange(M)+1)**FB_WEIGHTS_POW)
                        for c_hat in range(C):
                            sb_scores = _fbcca_score_epoch(x, c_hat, ref_bank, fs, FB_SUBBANDS, use_jax=USE_JAX)
                            fb_scores.append(float(np.sum(np.asarray(sb_scores)*weights_np)))
                        y_true_v.append(c)
                        y_tr_v.append(int(np.argmax(tr_scores)))
                        y_fb_v.append(int(np.argmax(fb_scores)))
                acc_pairs.append( (np.mean(np.array(y_tr_v)==np.array(y_true_v)),
                                   np.mean(np.array(y_fb_v)==np.array(y_true_v))) )
            mean_tr, mean_fb = np.mean([a for a,b in acc_pairs]), np.mean([b for a,b in acc_pairs])
            w_trca = float(mean_tr + 1e-3); w_fb = float(mean_fb + 1e-3)

        # ---- Decode test subject
        y_true, y_trca, y_fbcca, y_ens = [], [], [], []
        cm_tr = np.zeros((C,C), dtype=int)
        cm_fb = np.zeros((C,C), dtype=int)
        cm_en = np.zeros((C,C), dtype=int)

        for c in range(C):
            for t_idx in range(T):
                x = np.array(X_te[c, t_idx])
                # TRCA(CORAL)
                trca_scores = []
                for c_hat in range(C):
                    vec = []
                    for m in range(M):
                        w, tpl = W[c_hat][m], TPL[c_hat][m]
                        x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)
                        if USE_JAX: x_f = xp.asarray(x_f)
                        vec.append(trca_score_epoch(x_f, w, tpl))
                    for c_other in range(C):
                        if c_other == c_hat: continue
                        for m in range(M):
                            w, tpl = W[c_other][m], TPL[c_other][m]
                            x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)
                            if USE_JAX: x_f = xp.asarray(x_f)
                            vec.append(trca_score_epoch(x_f, w, tpl))
                    v = xp.asarray(vec)[None,:]
                    v_al = coral_apply(v, A, b)[0]
                    v_block = v_al[:len(FB_SUBBANDS)]
                    trca_scores.append(float((v_block * fb_w).sum()))
                # FBCCA
                fbcca_scores = []
                weights_np = 1.0/((np.arange(M)+1)**FB_WEIGHTS_POW)
                for c_hat in range(C):
                    sb_scores = _fbcca_score_epoch(x, c_hat, ref_bank, fs, FB_SUBBANDS, use_jax=USE_JAX)
                    fbcca_scores.append(float(np.sum(np.asarray(sb_scores)*weights_np)))
                # Ensemble (z-norm per head across candidates; adaptive weight)
                tr = np.asarray(trca_scores); fb = np.asarray(fbcca_scores)
                tr_z = (tr - tr.mean()) / (tr.std() + 1e-6)
                fb_z = (fb - fb.mean()) / (fb.std() + 1e-6)
                ens = w_trca*tr_z + w_fb*fb_z

                yt = c; pr_tr = int(np.argmax(tr)); pr_fb = int(np.argmax(fb)); pr_en = int(np.argmax(ens))
                y_true.append(yt); y_trca.append(pr_tr); y_fbcca.append(pr_fb); y_ens.append(pr_en)
                cm_tr[yt, pr_tr] += 1; cm_fb[yt, pr_fb] += 1; cm_en[yt, pr_en] += 1

        acc_tr = 100.0 * float((np.asarray(y_trca) == np.asarray(y_true)).mean())
        acc_fb = 100.0 * float((np.asarray(y_fbcca) == np.asarray(y_true)).mean())
        acc_en = 100.0 * float((np.asarray(y_ens)  == np.asarray(y_true)).mean())
        itr_en = itr_bits_per_min(acc_en/100.0, C, S/fs)

        print(f"[LOSO] {subjects[s_te]}: TRCA(CORAL={use_coral})={acc_tr:.2f}% | FBCCA={acc_fb:.2f}% | ENS(adapt)={acc_en:.2f}% | ITR={itr_en:.2f}")

        rows_summary.append({
            "subject": subjects[s_te],
            "use_coral": use_coral,
            "window_sec": S/fs,
            "acc_trca": round(acc_tr,2),
            "acc_fbcca": round(acc_fb,2),
            "acc_ensemble": round(acc_en,2),
            "itr_ensemble": round(itr_en,2),
            "w_trca": round(w_trca,4),
            "w_fbcca": round(w_fb,4),
        })

        confmats["TRCA"].append(cm_tr)
        confmats["FBCCA"].append(cm_fb)
        confmats["ENS"].append(cm_en)

        # save confusion matrices for this subject
        plot_confmat(cm_tr, classes, f"{subjects[s_te]} TRCA (CORAL={use_coral})", os.path.join(OUTDIR, f"{csv_prefix}_{subjects[s_te]}_cm_trca.png"))
        plot_confmat(cm_fb, classes, f"{subjects[s_te]} FBCCA", os.path.join(OUTDIR, f"{csv_prefix}_{subjects[s_te]}_cm_fbcca.png"))
        plot_confmat(cm_en, classes, f"{subjects[s_te]} Ensemble", os.path.join(OUTDIR, f"{csv_prefix}_{subjects[s_te]}_cm_ens.png"))

        acc_tr_list.append(acc_tr); acc_fb_list.append(acc_fb); acc_en_list.append(acc_en)

    # summary CSV
    df_sum = pd.DataFrame(rows_summary)
    sum_path = os.path.join(OUTDIR, f"{csv_prefix}_summary.csv")
    df_sum.to_csv(sum_path, index=False)
    print("[Saved]", sum_path)

    # per-subject accuracy bars
    bar_with_ci(acc_tr_list, subjects, "TRCA mean±95%CI",
                f"TRCA Accuracies (CORAL={use_coral}, win={S/fs:.2f}s)",
                os.path.join(OUTDIR, f"{csv_prefix}_bars_trca.png"))
    bar_with_ci(acc_fb_list, subjects, "FBCCA mean±95%CI",
                f"FBCCA Accuracies (win={S/fs:.2f}s)",
                os.path.join(OUTDIR, f"{csv_prefix}_bars_fbcca.png"))
    bar_with_ci(acc_en_list, subjects, "Ensemble mean±95%CI",
                f"Ensemble Accuracies (adapt={adaptive_fusion}, win={S/fs:.2f}s)",
                os.path.join(OUTDIR, f"{csv_prefix}_bars_ensemble.png"))

    return {
        "summary_path": sum_path,
        "acc_tr": np.array(acc_tr_list),
        "acc_fb": np.array(acc_fb_list),
        "acc_en": np.array(acc_en_list),
        "confmats": confmats,
        "subjects": subjects,
        "C": C,
        "T": T,
        "fs": fs,
        "S": S,
    }

# -----------------------------------------------------------
#  Ablations and Window-Length Study + Paired Stats
# -----------------------------------------------------------
if __name__ == "__main__":
    # 1) Load & prefilter once (broad)
    X_np, fs, freqs = load_all_subjects()
    print(f"[Load] X shape = {X_np.shape} | fs={fs} | freqs={None if freqs is None else freqs.tolist()}")

    for s in range(X_np.shape[0]):
        for c in range(X_np.shape[1]):
            X_np[s, c] = fft_bandpass(X_np[s, c], fs, *BANDPASS, NOTCH)

    X_dev = xp.asarray(X_np) if USE_JAX else X_np

    # ---- Baseline (full window, CORAL on, adaptive fusion)
    res_full = evaluate_loso(
        X_dev, fs, freqs, harmonics=3,
        use_coral=True, window_seconds=None, adaptive_fusion=True,
        csv_prefix="ssvep_pub_full_coral_adapt"
    )

    # ---- Ablation: CORAL off
    res_nocoral = evaluate_loso(
        X_dev, fs, freqs, harmonics=3,
        use_coral=False, window_seconds=None, adaptive_fusion=True,
        csv_prefix="ssvep_pub_full_nocoral_adapt"
    )

    # ---- Window length sweep (1.0s, 2.0s, 3.0s, full)
    window_list = [1.0, 2.0, 3.0, X_np.shape[-1]/fs]
    sweep_rows = []
    for wsec in window_list:
        res_w = evaluate_loso(
            X_dev, fs, freqs, harmonics=3,
            use_coral=True, window_seconds=wsec, adaptive_fusion=True,
            csv_prefix=f"ssvep_pub_win{wsec:.1f}s"
        )
        sweep_rows.append({
            "window_sec": wsec,
            "acc_tr_mean": float(np.mean(res_w["acc_tr"])),
            "acc_fb_mean": float(np.mean(res_w["acc_fb"])),
            "acc_en_mean": float(np.mean(res_w["acc_en"])),
            "itr_en_mean": float(np.mean([itr_bits_per_min(a/100.0, res_w["C"], wsec) for a in res_w["acc_en"]]))
        })
    df_sweep = pd.DataFrame(sweep_rows)
    sweep_path = os.path.join(OUTDIR, "ssvep_pub_window_sweep.csv")
    df_sweep.to_csv(sweep_path, index=False)
    print("[Saved]", sweep_path)

    # Plot Acc/ITR vs window
    fig1 = plt.figure(figsize=(6,4))
    plt.plot(df_sweep["window_sec"], df_sweep["acc_en_mean"], marker='o')
    plt.xlabel("Window length (s)"); plt.ylabel("Ensemble Accuracy (%)")
    plt.title("Ensemble Accuracy vs Window")
    fig1.savefig(os.path.join(OUTDIR, "ssvep_pub_acc_vs_window.png"), bbox_inches='tight', dpi=150)
    plt.close(fig1)

    fig2 = plt.figure(figsize=(6,4))
    plt.plot(df_sweep["window_sec"], df_sweep["itr_en_mean"], marker='o')
    plt.xlabel("Window length (s)"); plt.ylabel("ITR (bits/min)")
    plt.title("ITR vs Window")
    fig2.savefig(os.path.join(OUTDIR, "ssvep_pub_itr_vs_window.png"), bbox_inches='tight', dpi=150)
    plt.close(fig2)

    # ---- Paired stats across subjects
    def paired_t(a, b):
        t, p = stats.ttest_rel(a, b)
        return float(t), float(p)

    # Heads comparisons (full window, CORAL on)
    t_tr_fb, p_tr_fb = paired_t(res_full["acc_tr"], res_full["acc_fb"])
    t_fb_en, p_fb_en = paired_t(res_full["acc_fb"], res_full["acc_en"])
    t_tr_en, p_tr_en = paired_t(res_full["acc_tr"], res_full["acc_en"])

    # CORAL vs no CORAL (ensemble)
    t_en_coral, p_en_coral = paired_t(res_full["acc_en"], res_nocoral["acc_en"])

    stats_rows = [
        {"comparison": "TRCA vs FBCCA (full, CORAL)", "t": t_tr_fb, "p": p_tr_fb},
        {"comparison": "FBCCA vs ENS (full, CORAL)", "t": t_fb_en, "p": p_fb_en},
        {"comparison": "TRCA vs ENS (full, CORAL)", "t": t_tr_en, "p": p_tr_en},
        {"comparison": "ENS (CORAL) vs ENS (noCORAL)", "t": t_en_coral, "p": p_en_coral},
    ]
    df_stats = pd.DataFrame(stats_rows)
    stats_path = os.path.join(OUTDIR, "ssvep_pub_paired_stats.csv")
    df_stats.to_csv(stats_path, index=False)
    print("[Saved]", stats_path)

    # ---- Final short report to console
    print("\n=== FINAL SUMMARY (for manuscript) ===")
    print("Full window (baseline):")
    print(f"  Mean TRCA  = {np.mean(res_full['acc_tr']):.2f}%")
    print(f"  Mean FBCCA = {np.mean(res_full['acc_fb']):.2f}%")
    print(f"  Mean ENS   = {np.mean(res_full['acc_en']):.2f}%")
    print("Paired t-tests (p-values):")
    for r in stats_rows:
        print(f"  {r['comparison']}: t={r['t']:.3f}, p={r['p']:.3g}")
    print("\nWindow sweep (CSV & PNGs saved).")
    print("Figures/CSVs in:", OUTDIR)


INFO:2025-08-21 11:32:22,309:jax._src.xla_bridge:924: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


INFO:2025-08-21 11:32:22,325:jax._src.xla_bridge:924: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


[Backend] JAX on cpu (TFRT_CPU_0)


[Versions] numpy 1.26.4
[Versions] jax 0.5.2 | jaxlib 0.5.1


[Load] X shape = (10, 12, 10, 8, 1025) | fs=250 | freqs=None


[LOSO] S01: TRCA(CORAL=True)=60.83% | FBCCA=69.17% | ENS(adapt)=74.17% | ITR=27.32


[LOSO] S02: TRCA(CORAL=True)=38.33% | FBCCA=40.00% | ENS(adapt)=44.17% | ITR=9.71


[LOSO] S03: TRCA(CORAL=True)=72.50% | FBCCA=80.83% | ENS(adapt)=81.67% | ITR=33.12


[LOSO] S04: TRCA(CORAL=True)=35.00% | FBCCA=32.50% | ENS(adapt)=38.33% | ITR=7.19


[LOSO] S05: TRCA(CORAL=True)=78.33% | FBCCA=81.67% | ENS(adapt)=83.33% | ITR=34.51


[LOSO] S06: TRCA(CORAL=True)=87.50% | FBCCA=89.17% | ENS(adapt)=90.00% | ITR=40.54


[LOSO] S07: TRCA(CORAL=True)=86.67% | FBCCA=65.00% | ENS(adapt)=89.17% | ITR=39.74


[LOSO] S08: TRCA(CORAL=True)=93.33% | FBCCA=90.83% | ENS(adapt)=92.50% | ITR=43.04


[LOSO] S09: TRCA(CORAL=True)=89.17% | FBCCA=81.67% | ENS(adapt)=89.17% | ITR=39.74


[LOSO] S10: TRCA(CORAL=True)=90.83% | FBCCA=93.33% | ENS(adapt)=91.67% | ITR=42.19


[Saved] /kaggle/working/ssvep_pub_full_coral_adapt_summary.csv


[LOSO] S01: TRCA(CORAL=False)=87.50% | FBCCA=69.17% | ENS(adapt)=91.67% | ITR=42.19


[LOSO] S02: TRCA(CORAL=False)=40.00% | FBCCA=40.00% | ENS(adapt)=46.67% | ITR=10.88


[LOSO] S03: TRCA(CORAL=False)=85.00% | FBCCA=80.83% | ENS(adapt)=85.00% | ITR=35.94


[LOSO] S04: TRCA(CORAL=False)=35.83% | FBCCA=32.50% | ENS(adapt)=40.83% | ITR=8.23


[LOSO] S05: TRCA(CORAL=False)=85.00% | FBCCA=81.67% | ENS(adapt)=85.83% | ITR=36.68


[LOSO] S06: TRCA(CORAL=False)=97.50% | FBCCA=89.17% | ENS(adapt)=95.83% | ITR=46.70


[LOSO] S07: TRCA(CORAL=False)=95.00% | FBCCA=65.00% | ENS(adapt)=93.33% | ITR=43.92


[LOSO] S08: TRCA(CORAL=False)=100.00% | FBCCA=90.83% | ENS(adapt)=99.17% | ITR=51.02


[LOSO] S09: TRCA(CORAL=False)=91.67% | FBCCA=81.67% | ENS(adapt)=90.00% | ITR=40.54


[LOSO] S10: TRCA(CORAL=False)=95.00% | FBCCA=93.33% | ENS(adapt)=95.83% | ITR=46.70


[Saved] /kaggle/working/ssvep_pub_full_nocoral_adapt_summary.csv


[LOSO] S01: TRCA(CORAL=True)=45.83% | FBCCA=11.67% | ENS(adapt)=46.67% | ITR=44.59


[LOSO] S02: TRCA(CORAL=True)=17.50% | FBCCA=5.83% | ENS(adapt)=17.50% | ITR=3.71


[LOSO] S03: TRCA(CORAL=True)=50.83% | FBCCA=17.50% | ENS(adapt)=52.50% | ITR=56.61


[LOSO] S04: TRCA(CORAL=True)=18.33% | FBCCA=8.33% | ENS(adapt)=16.67% | ITR=3.12


[LOSO] S05: TRCA(CORAL=True)=50.00% | FBCCA=15.83% | ENS(adapt)=49.17% | ITR=49.60


[LOSO] S06: TRCA(CORAL=True)=66.67% | FBCCA=17.50% | ENS(adapt)=64.17% | ITR=84.24


[LOSO] S07: TRCA(CORAL=True)=43.33% | FBCCA=13.33% | ENS(adapt)=42.50% | ITR=36.72


[LOSO] S08: TRCA(CORAL=True)=84.17% | FBCCA=13.33% | ENS(adapt)=80.00% | ITR=130.27


[LOSO] S09: TRCA(CORAL=True)=80.00% | FBCCA=15.00% | ENS(adapt)=76.67% | ITR=119.64


[LOSO] S10: TRCA(CORAL=True)=80.83% | FBCCA=25.83% | ENS(adapt)=83.33% | ITR=141.50


[Saved] /kaggle/working/ssvep_pub_win1.0s_summary.csv


[LOSO] S01: TRCA(CORAL=True)=55.00% | FBCCA=37.50% | ENS(adapt)=60.83% | ITR=37.92


[LOSO] S02: TRCA(CORAL=True)=24.17% | FBCCA=20.00% | ENS(adapt)=25.83% | ITR=5.85


[LOSO] S03: TRCA(CORAL=True)=73.33% | FBCCA=40.00% | ENS(adapt)=75.83% | ITR=58.53


[LOSO] S04: TRCA(CORAL=True)=25.83% | FBCCA=15.00% | ENS(adapt)=25.83% | ITR=5.85


[LOSO] S05: TRCA(CORAL=True)=74.17% | FBCCA=45.83% | ENS(adapt)=78.33% | ITR=62.44


[LOSO] S06: TRCA(CORAL=True)=80.83% | FBCCA=50.83% | ENS(adapt)=80.83% | ITR=66.51


[LOSO] S07: TRCA(CORAL=True)=66.67% | FBCCA=35.83% | ENS(adapt)=65.00% | ITR=43.20


[LOSO] S08: TRCA(CORAL=True)=89.17% | FBCCA=58.33% | ENS(adapt)=82.50% | ITR=69.32


[LOSO] S09: TRCA(CORAL=True)=85.00% | FBCCA=41.67% | ENS(adapt)=80.00% | ITR=65.13


[LOSO] S10: TRCA(CORAL=True)=84.17% | FBCCA=83.33% | ENS(adapt)=89.17% | ITR=81.46


[Saved] /kaggle/working/ssvep_pub_win2.0s_summary.csv


[LOSO] S01: TRCA(CORAL=True)=52.50% | FBCCA=46.67% | ENS(adapt)=63.33% | ITR=27.37


[LOSO] S02: TRCA(CORAL=True)=29.17% | FBCCA=29.17% | ENS(adapt)=31.67% | ITR=6.41


[LOSO] S03: TRCA(CORAL=True)=71.67% | FBCCA=60.83% | ENS(adapt)=75.00% | ITR=38.18


[LOSO] S04: TRCA(CORAL=True)=30.83% | FBCCA=15.83% | ENS(adapt)=27.50% | ITR=4.57


[LOSO] S05: TRCA(CORAL=True)=74.17% | FBCCA=61.67% | ENS(adapt)=80.83% | ITR=44.34


[LOSO] S06: TRCA(CORAL=True)=85.00% | FBCCA=61.67% | ENS(adapt)=86.67% | ITR=51.14
