In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from functools import reduce

from athletes_brain.fig2.config import (
    OUTPUT_DIR,
    GROUP_NAMES,
    REGION_COL,
    METRICS,
    MODEL_NAME,
    DEMOGRAPHIC_COLS,
)
from athletes_brain.fig2.data_loader import (
    load_parcels,
    load_and_preprocess_metric_data,
    find_common_sessions,
    filter_by_common_sessions,
)
from athletes_brain.fig2.preprocessing import long_to_wide
from athletes_brain.fig2.model_training import (
    train_base_models,
    train_stacked_base_models,
    train_final_stacked_model,
)
from athletes_brain.fig2.utils import save_results, save_predictions, save_model

[32m2025-06-30 11:13:15.295[0m | [1mINFO    [0m | [36mathletes_brain.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /home/galkepler/Projects/athletes_brain[0m


In [2]:
# ── Global visualisation configuration ──────────────────────────────────────

# 1.  General Matplotlib defaults
# ── Global visualisation configuration ──────────────────────────────────────
import matplotlib as mpl
import seaborn as sns

mpl.rcParams.update(
    {
        # ── Canvas size & resolution ───────────────────────────────────────────
        # Default figure size: 12×8 inches  →  4800×3200 px when exported at 400 dpi
        "figure.figsize": (12, 8),
        "figure.dpi": 200,  # crisp in-notebook / retina preview
        "savefig.dpi": 400,  # print-quality PNG/PDF
        # ── Fonts ──────────────────────────────────────────────────────────────
        "font.family": "sans-serif",
        "font.sans-serif": ["Roboto", "DejaVu Sans", "Arial"],
        "axes.titlesize": 24,
        # "axes.titleweight": "bold",
        "axes.labelsize": 24,
        "xtick.labelsize": 14,
        "ytick.labelsize": 14,
        "legend.fontsize": 20,
        # ── Axis & spine aesthetics ────────────────────────────────────────────
        "axes.spines.top": False,
        "axes.spines.right": False,
        "axes.spines.left": True,
        "axes.spines.bottom": True,
        "axes.linewidth": 1,
        "axes.grid": True,
        "grid.color": "#E6E6E6",
        "grid.linewidth": 0.4,
        "grid.alpha": 0.8,
        # ── Colour cycle (colour-blind-safe) ───────────────────────────────────
        "axes.prop_cycle": mpl.cycler(color=sns.color_palette("Set2")),
        # ── Figure background ─────────────────────────────────────────────────
        "figure.facecolor": "white",
    }
)

# Seaborn theme inherits the rcParams above
sns.set_theme(context="talk", style="whitegrid", palette="Set2")


# 2.  Seaborn theme (inherits Matplotlib rcParams)
sns.set_theme(
    context="talk",  # slightly larger fonts for presentations / papers
    style="whitegrid",  # grid only on y-axis (good for histograms)
    palette="Set2",  # matches the rcParams colour cycle
)


# 3.  Helper function for consistent figure export
def savefig_nice(fig, filename, *, tight=True, dpi=300, **savefig_kwargs):
    """Save figure with tight layout and correct DPI."""
    if tight:
        fig.tight_layout()
    fig.savefig(filename, dpi=dpi, bbox_inches="tight", transparent=True, **savefig_kwargs)


# 4.  Colour constants for this project (optional convenience)
COL_RAW = "#1f77b4"  # e.g. unweighted sample
COL_WEIGHTED = "#d62728"  # weighted sample
COL_REF = "0.35"  # census reference (neutral grey)

In [4]:
ATLAS = "schaefer2018tian2020_400_7"
region_col = "index"
# Load important files
DATA_DIR = Path("/home/galkepler/Projects/athletes_brain/data")

# Output directory for figures
OUTPUT_DIR = Path("/home/galkepler/Projects/athletes_brain/figures/fig2")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

parcels = load_parcels()
raw_metric_data = load_and_preprocess_metric_data()

  df = pd.read_csv(DATA_DIR / "processed" / f"{metric}.csv", index_col=0).reset_index(
  df = pd.read_csv(DATA_DIR / "processed" / f"{metric}.csv", index_col=0).reset_index(
  df = pd.read_csv(DATA_DIR / "processed" / f"{metric}.csv", index_col=0).reset_index(
  df = pd.read_csv(DATA_DIR / "processed" / f"{metric}.csv", index_col=0).reset_index(
  df = pd.read_csv(DATA_DIR / "processed" / f"{metric}.csv", index_col=0).reset_index(
  df = pd.read_csv(DATA_DIR / "processed" / f"{metric}.csv", index_col=0).reset_index(
  df = pd.read_csv(DATA_DIR / "processed" / f"{metric}.csv", index_col=0).reset_index(


In [5]:
# 2. Convert to Wide Format and Find Common Sessions
data_wide = {}
for metric, df in raw_metric_data.items():
    data_wide[metric] = long_to_wide(
        df, columns_to_pivot=REGION_COL, demographic_cols=DEMOGRAPHIC_COLS
    )
common_sessions = find_common_sessions(data_wide)
data_wide = filter_by_common_sessions(data_wide, common_sessions)

In [22]:
import numpy as np
import pandas as pd
from typing import Union
from sklearn.base import BaseEstimator
from sklearn.model_selection import BaseCrossValidator
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline  # Make sure this is imported for PolynomialFeatures usage


def cross_val_predict_with_bias_correction(
    model: BaseEstimator,
    X: Union[np.ndarray, pd.DataFrame],
    y_chronological: np.ndarray,  # y in paper's notation
    cv: BaseCrossValidator,
    *,
    post_hoc_degree: int = 1,  # 0 = no correction, 1 = linear de Lange/Beheshti, >1 = polynomial residual correction
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Cross-validated predictions with optional post-hoc bias-correction.

    Parameters
    ----------
    model            : any scikit-learn regressor or pipeline
    X                : shape (n_samples, n_features)
    y_chronological  : shape (n_samples,) – chronological age (y in paper's notation)
    w                : sample weights, shape (n_samples,)
    cv               : cross-validator providing (train_idx, test_idx)
    use_weights      : if False, sample weights are ignored
    post_hoc_degree  : int
                       0 = no bias correction applied.
                       1 = linear de Lange et al. / Beheshti-style correction (fit x=a*y+b, correct (x-b)/a).
                       >1 = polynomial residual correction (fit residuals as poly of predicted_age, add to prediction).

    Returns
    -------
    y_pred_corr      : bias-corrected out-of-fold predictions
    original_residuals : uncorrected residuals (chronological_age - original_predicted_age)
    corrected_residuals : corrected residuals (chronological_age - corrected_predicted_age)
    """

    # Ensure array-like indexing works
    X_arr = X.values if isinstance(X, pd.DataFrame) else X
    y_chronological = np.asarray(y_chronological)

    y_oof_corrected = np.full_like(y_chronological, np.nan, dtype=float)
    original_residuals = np.full_like(y_chronological, np.nan, dtype=float)
    corrected_residuals = np.full_like(y_chronological, np.nan, dtype=float)

    for all_train_idx, test_idx in cv.split(X_arr, y_chronological):
        # Nested split for training base model and fitting bias correction model
        train_idx, val_idx = train_test_split(all_train_idx, test_size=0.2, random_state=42)

        X_tr, X_te, X_val = X_arr[train_idx], X_arr[test_idx], X_arr[val_idx]
        y_tr, y_te, y_val = (
            y_chronological[train_idx],
            y_chronological[test_idx],
            y_chronological[val_idx],
        )
        X_tr = pd.DataFrame(X_tr, columns=X.columns) if isinstance(X, pd.DataFrame) else X_tr
        X_te = pd.DataFrame(X_te, columns=X.columns) if isinstance(X, pd.DataFrame) else X_te
        X_val = pd.DataFrame(X_val, columns=X.columns) if isinstance(X, pd.DataFrame) else X_val
        # ---------------- Fit base model -----------------------------
        model.fit(X_tr, y_tr)

        # ---------------- Predict on test and validation folds -----------------------
        y_pred_te_original = model.predict(X_te)
        y_pred_val = model.predict(X_val)

        # ---------------- Post-hoc bias correction (conditional on post_hoc_degree) ---------------
        if post_hoc_degree > 0:
            if post_hoc_degree == 1:
                # --- Linear de Lange et al. / Beheshti-style correction ---
                # Fit a linear model: x = a*y + b
                # Where x is predicted_age (y_pred_val) and y is chronological_age (y_val)
                bias_corrector = LinearRegression()
                # fit_kwargs_bias = {"sample_weight": w_val} if use_weights else {}
                fit_kwargs_bias = {}
                # Fit: y_pred_val (dependent) on y_val (independent)
                bias_corrector.fit(y_val.reshape(-1, 1), y_pred_val, **fit_kwargs_bias)

                a_coeff = bias_corrector.coef_[0]  # This is 'a' from x = a*y + b
                b_intercept = bias_corrector.intercept_  # This is 'b' from x = a*y + b

                # Apply correction: x_corrected = (x - b) / a
                if a_coeff != 0:  # Avoid division by zero
                    y_pred_te_corrected = (y_pred_te_original - b_intercept) / a_coeff
                else:
                    y_pred_te_corrected = (
                        y_pred_te_original  # No effective correction if no linear relationship
                    )
            else:  # post_hoc_degree > 1
                # --- Polynomial Residual Correction (your original approach) ---
                # Fit residuals (y_val - y_pred_val) as a polynomial function of predicted_age (y_pred_val)
                poly = PolynomialFeatures(degree=post_hoc_degree)
                lin = LinearRegression()

                # Create a pipeline for the residual model: poly features of predicted age -> linear regression
                resid_model_pipeline = Pipeline([("poly", poly), ("lin", lin)])

                # Fit the residual model on validation data
                # Independent variable: y_pred_val (predicted age on validation set)
                # Dependent variable: y_val - y_pred_val (residuals on validation set)
                # fit_kwargs_resid_model = {"lin__sample_weight": w_val} if use_weights else {}
                fit_kwargs_resid_model = {}
                resid_model_pipeline.fit(
                    y_pred_val.reshape(-1, 1), y_val - y_pred_val, **fit_kwargs_resid_model
                )

                # Predict the bias component for the test set using its original predictions
                bias_component_te = resid_model_pipeline.predict(y_pred_te_original.reshape(-1, 1))

                # Add the predicted bias component to the original predictions
                y_pred_te_corrected = y_pred_te_original + bias_component_te
        else:  # post_hoc_degree == 0, no correction
            y_pred_te_corrected = y_pred_te_original

        # Store results
        y_oof_corrected[test_idx] = y_pred_te_corrected
        original_residuals[test_idx] = y_te - y_pred_te_original
        corrected_residuals[test_idx] = y_te - y_pred_te_corrected

    return y_oof_corrected, original_residuals, corrected_residuals

In [12]:
# polynomial features
from xgboost import XGBRegressor

# ---------------------------------------------------------------------
# 2. Parcel-wise base-learner loop
# ---------------------------------------------------------------------
stacked_models = parcels.copy()

predictions = {}

stacked_estimators = {}
predictions["base_stacked"] = {}

In [13]:
tmp_df_template = pd.DataFrame(
    index=list(common_sessions),
    columns=["subject_code", "age_at_scan", "sex", "group", "target"] + list(data_wide.keys()),
)

# Populate the template with demographic/target info from the first metric
first_metric_df = data_wide[list(data_wide.keys())[0]].set_index("session_id")
# drop duplicate indexes to avoid issues with multiple sessions
first_metric_df = first_metric_df[~first_metric_df.index.duplicated(keep="first")]
tmp_df_template[["subject_code", "age_at_scan", "sex", "group", "target"]] = first_metric_df.loc[
    list(common_sessions), ["subject_code", "age_at_scan", "sex", "group", "target"]
]

In [29]:
from sklearn.linear_model import RidgeCV
from athletes_brain.fig2.preprocessing import init_preprocessor
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from athletes_brain.fig2.config import (
    AVAILABLE_MODELS,
    AVAILABLE_PARAMS,
    N_PERMUTATIONS,
    MODEL_NAME,
)
from sklearn.model_selection import KFold


cv = KFold(
    n_splits=5, shuffle=True, random_state=42
)  # Ensure groups are preserved in cross-validation


subjects = tmp_df_template["subject_code"].astype(str)

stacked_models_results = parcels.copy()
predictions_base_stacked = {}

# Initiate CV for subjects to ensure groups are preserved

for i, row in parcels.iterrows():
    roi = row[REGION_COL]
    vals_data = {}
    for metric, df in data_wide.items():
        m_v = df.set_index("session_id")
        m_v = m_v[~m_v.index.duplicated(keep="first")]
        vals_data[metric] = m_v[roi].loc[list(common_sessions)]

    X_roi = pd.DataFrame(vals_data)
    # X_roi["age_at_scan"] = tmp_df_template["age_at_scan"]
    X_roi["sex"] = tmp_df_template["sex"]
    X_roi = X_roi.loc[tmp_df_template.index]

    y = tmp_df_template["age_at_scan"]

    X_roi.columns = X_roi.columns.astype(str)

    preprocessor = init_preprocessor(X_roi)
    pipe = Pipeline(
        [
            ("preprocessor", preprocessor),
            ("imputer", SimpleImputer(strategy="mean")),
            # ("classifier", AVAILABLE_MODELS[MODEL_NAME]),
            ("estimator", RidgeCV(alphas=np.logspace(-3, 3, 7))),
        ]
    )
    print(f"\n--- Training stacked base model for Parcel {roi} ---")
    # Fit the base model and get predictions
    y_oof_corrected, original_residuals, corrected_residuals = (
        cross_val_predict_with_bias_correction(
            model=pipe,
            X=X_roi,
            y_chronological=y,
            cv=cv,
            post_hoc_degree=1,  # Linear de Lange/Beheshti-style correction
        )
    )
    predictions_base_stacked[roi] = y_oof_corrected
    break


--- Training stacked base model for Parcel 1 ---


In [31]:
X_roi.shape

(1169, 8)

In [30]:
from sklearn.metrics import mean_absolute_error, r2_score

# Calculate metrics for the base model
mae_base = mean_absolute_error(y, predictions_base_stacked[roi])
r2_base = r2_score(y, predictions_base_stacked[roi])
print(f"MAE (base model): {mae_base:.2f}")
print(f"R² (base model): {r2_base:.2f}")

MAE (base model): 22.66
R² (base model): -12.18


In [None]:
for i, row in parcels.iterrows():  # i == parcel index (0..453)
    # ------------- build design matrix for parcel i -----------------
    # X_roi : (n_subjects , 5 metrics)
    # X_cov = cov[cov_names["gm_vol"]].to_numpy()
    X_roi = np.hstack([X_dict[m][:, [i]] for m in metrics])
    # X_roi = np.hstack([X_roi, X_cov])  # add covariates

    pipe = Pipeline(
        steps=[
            ("scaler", StandardScaler()),
            # ("poly", PolynomialFeatures(degree=2)),
            ("estimator", RidgeCV(alphas=alphas)),
            # ("estimator", RandomForestRegressor())
        ]
    )

    # ---------------- fit model & predict --------------------------
    y_pred, original_residuals, corrected_residuals = cross_val_predict_with_bias_correction(
        model=pipe,
        X=X_roi,
        y_chronological=y,
        w=w,
        cv=outer_cv,
        use_weights=use_weights,
        post_hoc_degree=0,
        # residual_orthog_degree=post_hoc_degree,
    )

    # ---------------- store predictions & metrics -------------------
    pred_df = cov.copy()
    pred_df["True"] = y
    pred_df["Predicted"] = y_pred
    pred_df["raw_residuals"] = original_residuals
    pred_df["corrected_residuals"] = corrected_residuals
    predictions["base_stacked"][i] = pred_df
    r2 = r2_score(y, y_pred)
    mae = mean_absolute_error(y, y_pred)
    rmse = root_mean_squared_error(y, y_pred)
    r2_weighted = r2_score(y, y_pred, sample_weight=w)
    mae_weighted = mean_absolute_error(y, y_pred, sample_weight=w)
    rmse_weighted = root_mean_squared_error(y, y_pred, sample_weight=w)

    stacked_models.loc[
        i, ["R2", "MAE", "RMSE", "R2_weighted", "MAE_weighted", "RMSE_weighted"]
    ] = [r2, mae, rmse, r2_weighted, mae_weighted, rmse_weighted]
    stacked_estimators[i] = pipe

In [None]:
predictions_base_stacked, stacked_models_results, common_data_template = train_stacked_base_models(
    common_sessions, data_wide, parcels, group_name, FORCE_STACKING
)