# Figure 1: Athletes Brain Study - Refactored

This notebook demonstrates the refactored, modular approach to generating Figure 1 for the athletes brain study.
The analysis has been organized into reusable modules within the `athletes_brain.fig1` package.

## Setup and Imports

Import the refactored modules from our package.

In [1]:
# Standard imports
from pathlib import Path
import pandas as pd
import numpy as np

# Import our refactored modules
from athletes_brain.fig1 import (
    Fig1Config,
    AthletesBrainDataLoader,
    GroupComparison,
    BrainPlotter,
    generate_figure1
)

# Import specific functions for focused analysis
from athletes_brain.fig1.main import (
    analyze_specific_comparison,
    get_most_significant_regions
)

[32m2025-08-12 15:04:38.485[0m | [1mINFO    [0m | [36mathletes_brain.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /home/groot/Projects/athletes_brain[0m


## Configuration

Set up the configuration and visualization settings.

In [2]:
# Initialize configuration
config = Fig1Config()

# Set up matplotlib configuration for consistent plots
config.setup_matplotlib_config()

# Display configuration
print(f"Atlas: {config.ATLAS}")
print(f"Metrics: {config.METRICS}")
print(f"Group labels: {config.CLIMBER_GROUP_LABEL}, {config.BJJ_GROUP_LABEL}, {config.CONTROL_GROUP_LABEL}")
print(f"P-value threshold: {config.P_THRESHOLD}")
print(f"Visualization range: [{config.VMIN}, {config.VMAX}]")

Atlas: schaefer2018tian2020_400_7
Metrics: ['gm_vol', 'adc']
Group labels: Climbing, Bjj, Control
P-value threshold: 0.05
Visualization range: [-5, 5]


## Quick Analysis: Generate Complete Figure 1

Use the main function to generate all comparisons and visualizations.

In [4]:
# Generate complete Figure 1
output_dir = Path.home() / "Projects" / "athletes_brain" / "figures" / "fig1"
# generate_figure1(output_dir=output_dir, config=config)

## Step-by-Step Analysis

For more control, we can run each step individually using the modular components.

### 1. Data Loading

In [5]:
# Initialize data loader
data_loader = AthletesBrainDataLoader(config)

# Load all data
metric_data, parcels, nifti_path, nifti_matlab_path = data_loader.load_all_data()

print(f"Loaded {len(metric_data)} metrics")
print(f"Loaded {len(parcels)} brain parcels")
print(f"Metrics: {list(metric_data.keys())}")

# Display sample data (demographics)
d = metric_data["gm_vol"].drop_duplicates(subset=["subject_code"], keep="first")
print(f"N participants: {d.shape[0]}")


[32m2025-08-12 15:04:51.924[0m | [1mINFO    [0m | [36mathletes_brain.fig1.data_loader[0m:[36mload_metric_data[0m:[36m48[0m - [1mLoading metric data...[0m
[32m2025-08-12 15:04:51.925[0m | [1mINFO    [0m | [36mathletes_brain.fig1.data_loader[0m:[36mload_metric_data[0m:[36m54[0m - [1mLoading gm_vol data[0m


  df = pd.read_csv(self.processed_dir / f"{metric}.csv", index_col=0).reset_index(


[32m2025-08-12 15:04:59.305[0m | [1mINFO    [0m | [36mathletes_brain.fig1.data_loader[0m:[36mload_metric_data[0m:[36m74[0m - [1mLoaded 472498 records for gm_vol[0m
[32m2025-08-12 15:04:59.305[0m | [1mINFO    [0m | [36mathletes_brain.fig1.data_loader[0m:[36mload_metric_data[0m:[36m54[0m - [1mLoading adc data[0m


  df = pd.read_csv(self.processed_dir / f"{metric}.csv", index_col=0).reset_index(


[32m2025-08-12 15:05:04.624[0m | [1mINFO    [0m | [36mathletes_brain.fig1.data_loader[0m:[36mload_metric_data[0m:[36m74[0m - [1mLoaded 494664 records for adc[0m
[32m2025-08-12 15:05:04.625[0m | [32m[1mSUCCESS [0m | [36mathletes_brain.fig1.data_loader[0m:[36mload_metric_data[0m:[36m76[0m - [32m[1mSuccessfully loaded 2 metrics[0m
[32m2025-08-12 15:05:04.625[0m | [1mINFO    [0m | [36mathletes_brain.fig1.data_loader[0m:[36mload_atlas_data[0m:[36m32[0m - [1mLoading atlas data for schaefer2018tian2020_400_7[0m
[32m2025-08-12 15:05:04.627[0m | [1mINFO    [0m | [36mathletes_brain.fig1.data_loader[0m:[36mload_atlas_data[0m:[36m39[0m - [1mLoaded 454 parcels from atlas[0m
Loaded 2 metrics
Loaded 454 brain parcels
Metrics: ['gm_vol', 'adc']
N participants: 1050


### 2. Statistical Analysis

In [6]:
import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.stats.anova import anova_lm


def prep(df):
    out = df.dropna(subset=["value", "target", "sex", "age_at_scan"]).copy()
    # Ensure categorical encoding is stable
    out["target"] = out["target"].astype(str)
    out["sex"] = out["sex"].astype(str)
    out["age_c"] = out["age_at_scan"] - out["age_at_scan"].mean()
    return out


def fit_primary(df, quad_age=True, interaction=False):
    terms = ["C(target, Treatment(reference='False'))", "C(sex)", "age_c"]
    if quad_age:
        terms.append("I(age_c**2)")
    if interaction:
        terms.append("C(target):age_c")
        if quad_age:
            terms.append("C(target):I(age_c**2)")
    formula = "value ~ " + " + ".join(terms)
    res = smf.ols(formula, data=df).fit()  # primary: nonrobust
    return res, formula


def partial_r2_for_term(full_model, df, term_label="C(target)"):
    """Compute partial R^2 for the term by refitting the reduced model."""
    # Build reduced formula by dropping any pieces that start with the term label
    rhs = full_model.model.formula.split("~", 1)[1]
    keep = [t.strip() for t in rhs.split("+") if term_label not in t]
    reduced = "value ~ " + " + ".join(keep) if keep else "value ~ 1"
    m0 = smf.ols(reduced, data=df).fit()
    # Using SSE difference
    sse0, sse1 = np.sum(m0.resid**2), np.sum(full_model.resid**2)
    return max(0.0, (sse0 - sse1) / sse0)


# Freedman–Lane permutation p-value for the target effect
def freedman_lane_p(full_formula, group_term, df, B=2000, seed=123):
    rng = np.random.default_rng(seed)
    # Reduced model: full minus the group term(s)
    rhs = full_formula.split("~", 1)[1]
    keep = [t.strip() for t in rhs.split("+") if group_term not in t]
    reduced = "value ~ " + " + ".join(keep) if keep else "value ~ 1"

    m_full = smf.ols(full_formula, data=df).fit()
    m_red = smf.ols(reduced, data=df).fit()

    aov = anova_lm(m_full, typ=2)
    # Find the actual row name for the group term (handles Treatment coding)
    row = next((r for r in aov.index if r.startswith(group_term)), group_term)
    F_obs = float(aov.loc[row, "F"])

    y_hat0 = m_red.fittedvalues
    resid = m_red.resid.values
    ge = 0
    for _ in range(B):
        y_star = y_hat0 + rng.permutation(resid)
        m_star = smf.ols(full_formula, data=df.assign(value=y_star)).fit()
        aov_s = anova_lm(m_star, typ=2)
        if row in aov_s.index and float(aov_s.loc[row, "F"]) >= F_obs:
            ge += 1
    return (ge + 1) / (B + 1)


# Example end-to-end
def analyze(df, quad_age=True, interaction=False, B_perm=2000):
    df = prep(df)
    res, formula = fit_primary(df, quad_age=quad_age, interaction=interaction)

    # Adjusted mean difference is the coefficient on the athlete indicator (additive model)
    coef_name = [
        n
        for n in res.params.index
        if n.startswith("C(target, Treatment(reference='False'))[T.True]")
    ]
    if not coef_name:  # fallback if patsy prints differently
        coef_name = [n for n in res.params.index if n.startswith("C(target)[T.True]")]
    coef_name = coef_name[0]

    delta = res.params[coef_name]
    ci_l, ci_u = res.conf_int().loc[coef_name].tolist()
    p_model = res.pvalues[coef_name]

    p_perm = freedman_lane_p(formula, "C(target", df, B=B_perm) if B_perm else np.nan
    pr2 = partial_r2_for_term(res, df, term_label="C(target")

    out = {
        "n_total": len(df),
        "n_athletes": int((df["target"] == "True").sum()),
        "n_controls": int((df["target"] == "False").sum()),
        "model": formula,
        "coef_delta": float(delta),
        "ci95": (float(ci_l), float(ci_u)),
        "p_model": float(p_model),
        "p_perm": float(p_perm) if B_perm else None,
        "partial_R2_target": float(pr2),
        "adj_R2": float(res.rsquared_adj),
    }
    return out, res

In [41]:
from statsmodels.stats.multitest import multipletests
from tqdm import tqdm

quad_age = False
interaction = True
results = {}
for metric, df in metric_data.items():
    print(f"Analyzing metric: {metric}")
    df_prep = prep(df)
    stats = parcels.copy()
    for i, row in tqdm(parcels.iterrows(), total=parcels.shape[0], desc="Processing parcels"):
        roi_df = df_prep[df_prep["index"] == row["index"]]
        out, res = analyze(roi_df, quad_age=quad_age, interaction=interaction, B_perm=False)
        for key, val in out.items():
            if key == "ci95":
                stats.loc[i, f"{key}_lower"] = val[0]
                stats.loc[i, f"{key}_upper"] = val[1]
            else:
                stats.loc[i, key] = val
    pvals = stats["p_model"].values
    _, corrected_pvals, _, _ = multipletests(pvals, method="fdr_bh")
    stats["p_model_corrected"] = corrected_pvals
    results[metric] = stats
    # break

Analyzing metric: gm_vol


Processing parcels: 100%|██████████| 454/454 [00:10<00:00, 41.79it/s]


Analyzing metric: adc


Processing parcels: 100%|██████████| 454/454 [00:10<00:00, 41.85it/s]


In [42]:
stats.sort_values("p_model_corrected").head(20)

Unnamed: 0,index,name,base_name,Label Name,network,component,hemisphere,n_total,n_athletes,n_controls,model,coef_delta,ci95_lower,ci95_upper,p_model,p_perm,partial_R2_target,adj_R2,p_model_corrected
191,192,7Networks_LH_Default_pCunPCC_3,7networks_lh_default_pcunpcc,7Networks_LH_Default_pCunPCC,default,precuneus posterior cingulate cortex,L,1089.0,134.0,955.0,"value ~ C(target, Treatment(reference='False')...",-1.6e-05,-2.4e-05,-9e-06,1.1e-05,,0.019599,0.086912,0.005012
426,427,aGP-rh,aGP,"Pallidum, anterior part",subcortex,Pallidum,R,1090.0,133.0,957.0,"value ~ C(target, Treatment(reference='False')...",3.2e-05,1.7e-05,4.7e-05,2.8e-05,,0.016041,0.036285,0.006463
362,363,7Networks_RH_Default_Par_2,7networks_rh_default_par,7Networks_RH_Default_Par,default,parietal,R,1089.0,133.0,956.0,"value ~ C(target, Treatment(reference='False')...",-1.7e-05,-2.6e-05,-8e-06,0.000127,,0.013942,0.050966,0.011541
391,392,7Networks_RH_Default_pCunPCC_1,7networks_rh_default_pcunpcc,7Networks_RH_Default_pCunPCC,default,precuneus posterior cingulate cortex,R,1089.0,133.0,956.0,"value ~ C(target, Treatment(reference='False')...",-1.5e-05,-2.3e-05,-8e-06,8.5e-05,,0.014717,0.159879,0.011541
189,190,7Networks_LH_Default_pCunPCC_1,7networks_lh_default_pcunpcc,7Networks_LH_Default_pCunPCC,default,precuneus posterior cingulate cortex,L,1089.0,134.0,955.0,"value ~ C(target, Treatment(reference='False')...",-1.8e-05,-2.7e-05,-9e-06,0.000115,,0.021351,0.097758,0.011541
325,326,7Networks_RH_Limbic_TempPole_2,7networks_rh_limbic_temppole,7Networks_RH_Limbic_TempPole,limbic,temporal pole,R,1089.0,133.0,956.0,"value ~ C(target, Treatment(reference='False')...",1.4e-05,7e-06,2.2e-05,0.000206,,0.012707,0.021642,0.015561
393,394,7Networks_RH_Default_pCunPCC_3,7networks_rh_default_pcunpcc,7Networks_RH_Default_pCunPCC,default,precuneus posterior cingulate cortex,R,1089.0,133.0,956.0,"value ~ C(target, Treatment(reference='False')...",-2.5e-05,-3.8e-05,-1.1e-05,0.000331,,0.013531,0.140638,0.019727
152,153,7Networks_LH_Default_Temp_5,7networks_lh_default_temp,7Networks_LH_Default_Temp,default,temporal,L,1090.0,134.0,956.0,"value ~ C(target, Treatment(reference='False')...",1.8e-05,8e-06,2.7e-05,0.000348,,0.016193,0.017985,0.019727
190,191,7Networks_LH_Default_pCunPCC_2,7networks_lh_default_pcunpcc,7Networks_LH_Default_pCunPCC,default,precuneus posterior cingulate cortex,L,1089.0,134.0,955.0,"value ~ C(target, Treatment(reference='False')...",-1.7e-05,-2.7e-05,-7e-06,0.000795,,0.012901,0.155141,0.040084
195,196,7Networks_LH_Default_pCunPCC_7,7networks_lh_default_pcunpcc,7Networks_LH_Default_pCunPCC,default,precuneus posterior cingulate cortex,L,1089.0,134.0,955.0,"value ~ C(target, Treatment(reference='False')...",-1.1e-05,-1.7e-05,-4e-06,0.000983,,0.010795,0.126086,0.043249
