In [None]:
! pip install cebra 

In [None]:
from pathlib import Path
import mne
import numpy as np

RAW = Path("/kaggle/input/no-processing-run-nt9-10")
OUTPUT_DIR = Path("/kaggle/working/cebra_outputs")
OUTPUT_DIR.mkdir(exist_ok=True)

PAIRINGS = {
    "lst9-spk10": ("nt9_listen.edf", "nt10_speak.edf"),
    "spk9-lst10": ("nt9_speak.edf", "nt10_listen.edf"),
}

def load_eeg(path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    return raw.get_data(picks="eeg")

def align_lengths(a, b):
    T = min(a.shape[1], b.shape[1])
    return a[:, :T], b[:, :T]

def minmax_per_channel(x):
    xmin = x.min(axis=1, keepdims=True)
    xmax = x.max(axis=1, keepdims=True)
    rng  = np.where((xmax - xmin) == 0, 1, xmax - xmin)
    return (x - xmin) / rng

for pair_name, (file1, file2) in PAIRINGS.items():
    # Load EEG from both subjects
    A = load_eeg(RAW / file1)
    B = load_eeg(RAW / file2)
    A, B = align_lengths(A, B)
    stacked = np.vstack([A, B])  # shape: (channels, timepoints)

    # Normalize channels independently
    normalized = minmax_per_channel(stacked)

    # Save normalized EEG data
    npy_path = OUTPUT_DIR / f"{pair_name}_normalized.npy"
    np.save(npy_path, normalized)
    print(f"✓ Saved: {npy_path.name}  {normalized.shape}")

    n_ch = A.shape[0]  # channels per subject
    n_time = stacked.shape[1]  # number of timepoints

    # Generate labels per channel (1D array length = channels)
    if pair_name == "lst9-spk10":
        # nt9 listens → label 0, nt10 speaks → label 1
        channel_labels = np.array([0]*n_ch + [1]*n_ch)
    elif pair_name == "spk9-lst10":
        # nt9 speaks → label 1, nt10 listens → label 0
        channel_labels = np.array([1]*n_ch + [0]*n_ch)
    else:
        raise ValueError(f"Unknown pair_name {pair_name}")

    # Expand channel-wise labels to all timepoints: shape (channels, timepoints)
    labels_2d = np.repeat(channel_labels[:, np.newaxis], n_time, axis=1)

    # Save labels as 2D array (channels, timepoints)
    np.save(OUTPUT_DIR / f"{pair_name}_activity_labels.npy", labels_2d)
    print(f"✓ Saved expanded activity labels for {pair_name} with shape {labels_2d.shape}")



In [None]:
#todo 
#think about what portion of data take for decoding: pos/neg. split into halfs, 
#grab second third of each part 

#number of the neighboors 
#see if that helps 

In [None]:
from cebra import CEBRA, KNNDecoder, plot_embedding, plot_loss
from cebra.integrations.sklearn import metrics as cmetrics
import torch
import numpy as np
import matplotlib.pyplot as plt
import json
from pathlib import Path
import pandas as pd
from sklearn.model_selection import train_test_split

# Output directory
OUTPUT_DIR = Path("/kaggle/working/cebra_outputs")
OUTPUT_DIR.mkdir(exist_ok=True)

# CEBRA configuration
MODEL_KWARGS = dict(
    model_architecture="offset10-model",
    batch_size=512,
    learning_rate=3e-4,
    temperature=1.12,
    max_iterations=5000,
    conditional="time_delta",
    output_dimension=3,
    distance="cosine",
    device="cuda:0",
    verbose=True,
    time_offsets=10,
)

# Max samples for faster decoding
MAX_TRAIN_SAMPLES = 2000000


def run_cebra(pair_name, scale, run_id):
    torch.manual_seed(run_id)
    model = CEBRA(**MODEL_KWARGS)

    # Load data
    data_path = OUTPUT_DIR / f"{pair_name}_normalized.npy"
    labels_path = OUTPUT_DIR / f"{pair_name}_activity_labels.npy"

    normalized = np.load(data_path)  # (channels, timepoints)
    activity_labels_2d = np.load(labels_path)  # (channels, timepoints)

    channels, timepoints = normalized.shape

    # Reshape to (timepoints * channels, 1)
    X = normalized.T.reshape(-1, 1)
    Y = activity_labels_2d.T.flatten()

    # Train/test split
    train_tp, test_tp = train_test_split(
        np.arange(timepoints), test_size=0.2, random_state=42, shuffle=False
    )
    train_idx = np.hstack([np.arange(channels) + t * channels for t in train_tp])
    test_idx = np.hstack([np.arange(channels) + t * channels for t in test_tp])

    X_train, X_test = X[train_idx], X[test_idx]
    Y_train, Y_test = Y[train_idx], Y[test_idx]

        # Before training
    print("💡 [Before Training]")
    print("  X shape:", X_train.shape)  # should be (T, features)
    print("  y shape:", Y_train.shape)  # should be (T,)
    print("  X dtype:", X_train.dtype)
    print("  y dtype:", Y_train.dtype)
    print("  Unique y labels:", np.unique(Y_train))
    print()


    # Train CEBRA
    model.fit(X_train, Y_train)
    emb_train = model.transform(X_train)
    emb_val = model.transform(X_test)
    print('cebra trained')
    print("  Embedding shape:", emb_train.shape)  # (T, output_dim)


    # Save model + embeddings
    model.save(OUTPUT_DIR / f"{pair_name}_{scale}_run{run_id}.pt")
    np.save(OUTPUT_DIR / f"{pair_name}_{scale}_emb_train_run{run_id}.npy", emb_train)
    np.save(OUTPUT_DIR / f"{pair_name}_{scale}_emb_val_run{run_id}.npy", emb_val)

    # Save plots
    ax = plot_loss(model)
    fig = ax.get_figure()
    fig.savefig(OUTPUT_DIR / f"{pair_name}_{scale}_loss_run{run_id}.png")
    plt.close(fig)

    ax = plot_embedding(emb_val, embedding_labels=Y_test)
    fig = ax.get_figure()
    fig.savefig(OUTPUT_DIR / f"{pair_name}_{scale}_embedding_run{run_id}.png")
    plt.close(fig)
    print('plots done')

    # Custom slicing: divide training embeddings into 6 equal parts
    #total = emb_train.shape[0]
    #sixth = total // 6

    # Get 2nd and 4th parts
    #part2_idx = slice(1 * sixth, 2 * sixth)
    #part4_idx = slice(3 * sixth, 4 * sixth)

    #emb_train_small = np.concatenate([emb_train[part2_idx], emb_train[part4_idx]], axis=0)
    #Y_train_small = np.concatenate([Y_train[part2_idx], Y_train[part4_idx]], axis=0)

    #print(f"Using {emb_train_small.shape[0]} samples for decoding (2nd and 4th sixths)")

    print('decoding started')
    decoder = KNNDecoder(n_neighbors=50000) # ask about the params 
    decoder.fit(emb_train, Y_train)
    accuracy = decoder.score(emb_val, Y_test)
    print('decoding finished')

    # Goodness of fit
    gof = cmetrics.goodness_of_fit_history(model)

    # Save metrics
    results = {
        "pair": pair_name,
        "scale": scale,
        "run_id": run_id,
        "train_samples": len(X_train),
        "val_samples": len(X_test),
        "decoding_accuracy": accuracy,
        "goodness_of_fit": gof.tolist(),
    }
    with open(OUTPUT_DIR / f"{pair_name}_{scale}_metrics_run{run_id}.json", "w") as f:
        json.dump(results, f, indent=2)

    print(f"✓ Done: {pair_name} [{scale}] run {run_id}")


# Run multiple experiments
TO_RUN = {
    "lst9-spk10": {"normalized": list(range(5))},
    "spk9-lst10": {"normalized": list(range(5))},
}

summary_rows = []

for pair_name, scales in TO_RUN.items():
    for scale, run_ids in scales.items():
        embeddings = []
        labels = []

        for run_id in run_ids:
            run_cebra(pair_name, scale, run_id)

            # Load embeddings and labels
            emb_val = np.load(OUTPUT_DIR / f"{pair_name}_{scale}_emb_val_run{run_id}.npy")
            activity_labels = np.load(OUTPUT_DIR / f"{pair_name}_activity_labels.npy").astype(np.int64)
            n_samples = emb_val.shape[0]
            repeated_labels = np.tile(activity_labels.T.flatten(), int(n_samples / activity_labels.size))
            labels.append(repeated_labels)
            embeddings.append(emb_val)

            # Load accuracy
            with open(OUTPUT_DIR / f"{pair_name}_{scale}_metrics_run{run_id}.json") as f:
                acc_data = json.load(f)
            summary_rows.append({
                "pair": pair_name,
                "scale": scale,
                "run": run_id,
                "acc_train": acc_data.get("train_accuracy", np.nan),
                "acc_val": acc_data.get("decoding_accuracy", np.nan),
            })

        # Metrics across runs
        consistency = cmetrics.consistency_score(embeddings, between="runs")
        dist_scores = cmetrics.intra_inter_class_distance(embeddings, labels)

        with open(OUTPUT_DIR / f"{pair_name}_{scale}_consistency.json", "w") as f:
            json.dump({"consistency_score": consistency}, f)

        with open(OUTPUT_DIR / f"{pair_name}_{scale}_class_distance.json", "w") as f:
            json.dump(dist_scores, f)

# Save CSV summary
pd.DataFrame(summary_rows).to_csv(OUTPUT_DIR / "decoder_accuracy_summary.csv", index=False)

In [None]:
# 1620 neighboors as a square root of the 2626000 

# later implement this in code 
# if it's fast enough try to do the whole dataset like this without cutting 