In [None]:
import os
import numpy as np
from scipy.stats import pearsonr
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import check_cv
from voxelwise_tutorials.io import load_hdf5_array
from voxelwise_tutorials.delayer import Delayer
from voxelwise_tutorials.utils import generate_leave_one_run_out
from himalaya.backend import set_backend
from himalaya.kernel_ridge import MultipleKernelRidgeCV, Kernelizer, ColumnKernelizer
from himalaya.scoring import r2_score_split

# Set backend
backend = set_backend("torch_cuda", on_error="warn")

# Set up paths and feature names
directory = '/Users/mariazimmermann/dropbox/encoding_model'
feature_names = ["AROUSAL", "VALENCE", "SOC", "TOM", "NO", "FACES"]

# Load and concatenate all features
Xs_train, Xs_test, n_features_list = [], [], []
for feature in feature_names:
    path = os.path.join(directory, "features_normalized", f"{feature}.hdf5")
    Xs_train.append(load_hdf5_array(path, key="x_train").astype("float32"))
    Xs_test.append(load_hdf5_array(path, key="x_test").astype("float32"))
    n_features_list.append(Xs_train[-1].shape[1])

X_train = np.concatenate(Xs_train, axis=1)
X_test = np.concatenate(Xs_test, axis=1)

# Define feature slices
start_and_end = np.concatenate([[0], np.cumsum(n_features_list)])
slices = [slice(start, end) for start, end in zip(start_and_end[:-1], start_and_end[1:])]

# Run for each subject
for sub in range(1, 22):
    subject = f"sub-Htriplet{sub:02d}"
    print(f"\n=== Running subject {subject} ===")

    # Load responses
    resp_path = os.path.join(directory, "responses", f"{subject}_responses.hdf")
    Y_train = load_hdf5_array(resp_path, key="Y_train")
    Y_test = load_hdf5_array(resp_path, key="Y_test")
    run_onsets = load_hdf5_array(resp_path, key="run_onsets")

    # Cross-validation
    cv = check_cv(generate_leave_one_run_out(X_train.shape[0], run_onsets))

    # Model configuration
    alphas = np.logspace(1, 20, 20)
    solver_params = dict(
        n_iter=20,
        alphas=alphas,
        n_targets_batch=200,
        n_alphas_batch=5,
        n_targets_batch_refit=200
    )

    preprocess = make_pipeline(
        StandardScaler(with_mean=True, with_std=False),
        Delayer(delays=[1, 2, 3, 4, 5]),
        Kernelizer(kernel="linear")
    )

    kernelizers = [(name, preprocess, slc) for name, slc in zip(feature_names, slices)]
    column_kernelizer = ColumnKernelizer(kernelizers)

    model = MultipleKernelRidgeCV(kernels="precomputed", solver="random_search",
                                   solver_params=solver_params, cv=cv)
    pipeline = make_pipeline(column_kernelizer, model)

    # Fit model
    pipeline.fit(X_train, Y_train)

    # === Full model prediction ===
    Y_pred = backend.to_numpy(pipeline.predict(X_test))
    Y_test_np = backend.to_numpy(Y_test)

    r_vals = np.zeros(Y_test_np.shape[1])
    r2_vals = np.zeros(Y_test_np.shape[1])
    for v in range(Y_test_np.shape[1]):
        r, _ = pearsonr(Y_pred[:, v], Y_test_np[:, v])
        r_vals[v] = r
        r2_vals[v] = r ** 2

    # Save full model scores
    np.savez(
        os.path.join(directory, f"full_model_scores_{subject}.npz"),
        r=r_vals,
        r2=r2_vals
    )

    # === Feature-specific prediction ===
    Y_pred_split = backend.to_numpy(pipeline.predict(X_test, split=True))  # (n_features, n_samples, n_voxels)

    split_r = {}
    split_r2 = {}
    for i, name in enumerate(feature_names):
        r_per_voxel = np.zeros(Y_test_np.shape[1])
        r2_per_voxel = np.zeros(Y_test_np.shape[1])
        for v in range(Y_test_np.shape[1]):
            r, _ = pearsonr(Y_pred_split[i, :, v], Y_test_np[:, v])
            r_per_voxel[v] = r
            r2_per_voxel[v] = r ** 2
        split_r[name] = r_per_voxel
        split_r2[name] = r2_per_voxel

    # Save per-feature r and r²
    np.savez(os.path.join(directory, f"split_r_{subject}.npz"), **split_r)
    np.savez(os.path.join(directory, f"split_r2_{subject}.npz"), **split_r2)

    print(f"✓ Saved: full_model_scores_{subject}.npz, split_r_*, split_r2_*")
