# Metabolic Age Prediction: MLP Regression Model

This notebook trains multi-layer perceptron (MLP) regression models to predict
chronological age from plasma NMR metabolomics data (UK Biobank, UKBB) and
applies the trained models to the ADNI cohort. A polynomial bias-correction
procedure is applied to remove the systematic age-regression-to-the-mean effect.

**Datasets**
- Training / internal validation: UK Biobank (controls)
- External validation: ADNI (baseline controls)

**Models**
Three sex-stratified models are trained (whole population, male, female).
The final output is `delta_age = corrected_predicted_age - actual_age`.

## 1. Setup and Imports

In [None]:
import os
import random
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from matplotlib.patches import Rectangle
from scipy.stats import pearsonr

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import (
    mean_absolute_error,
    mean_squared_error,
    r2_score,
    explained_variance_score,
)

warnings.filterwarnings('ignore')
random.seed(42)

## 2. Helper Functions

In [None]:
def get_figure_prediction(actual_age, predicted_age, deg=3,
                          range_fig=[35, 75], range_txt_x=55,
                          range_txt_y=[40, 38],
                          title_add='UKBB Test Data'):
    """Scatter plot of actual vs. predicted age with a polynomial fit line."""
    fig, ax = plt.subplots(figsize=(9, 9))
    ax.scatter(actual_age, predicted_age, s=40, alpha=0.7, edgecolors='k')

    coeffs = np.polyfit(actual_age, predicted_age, deg=deg)
    poly   = np.poly1d(coeffs)
    xseq   = np.linspace(min(actual_age), max(actual_age), num=100)
    ax.plot(xseq, poly(xseq), color='k', lw=2.5)
    ax.plot(range_fig, range_fig, 'k-', alpha=0.75, zorder=0)

    plt.xlabel('Actual Age')
    plt.ylabel('Predicted Age')
    plt.text(range_txt_x, range_txt_y[0],
             f'MAE {mean_absolute_error(actual_age, predicted_age):.2f}',
             fontsize=14)
    plt.text(range_txt_x, range_txt_y[1],
             f'Pearson r {pearsonr(actual_age, predicted_age)[0]:.2f}',
             fontsize=14)
    plt.xlim(range_fig)
    plt.ylim(range_fig)
    plt.title(f'Actual Age vs. Predicted Age ({title_add})')
    plt.show()


def bias_correction(age_t, age_p1, degree, title_set='Training Subjects'):
    """Fit a polynomial bias-correction model on training data.

    Returns the bias-corrected predictions and the fitted polynomial object.
    """
    b_model     = np.polyfit(age_t, age_p1, deg=degree)
    b_model_out = np.poly1d(b_model)

    # age_pc = age_p1 - bias_fit(age_t) + age_t
    age_p2 = [b_model_out(a) for a in age_t]
    age_pc = [p1 - p2 + t for p1, p2, t in zip(age_p1, age_p2, age_t)]

    fig, ax = plt.subplots(figsize=(9, 9))
    ax.scatter(age_t, age_pc, s=40, alpha=0.7, edgecolors='k')
    ax.plot([38, 72], [38, 72], 'k-', alpha=0.75, zorder=0)
    plt.xlabel('Actual Age')
    plt.ylabel('Corrected Predicted Age')
    plt.text(55, 40, f'MAE {mean_absolute_error(age_t, age_pc):.3f}', fontsize=14)
    plt.text(55, 38, f'Correction degree {degree}', fontsize=14)
    plt.title(f'Actual vs. Corrected Predicted Age – {title_set}')
    plt.show()

    return age_pc, b_model_out


def correct_w_model(age_t, age_p1, b_model_out_poly):
    """Apply a pre-fitted bias-correction polynomial to new data."""
    age_p2 = [b_model_out_poly(a) for a in age_t]
    age_pc = [p1 - p2 + t for p1, p2, t in zip(age_p1, age_p2, age_t)]
    return age_pc


def print_metrics(y_true, y_pred, label=''):
    """Print MAE, RMSE, Pearson r, and R2 for a set of predictions."""
    mae  = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    r, _ = pearsonr(y_true, y_pred)
    r2   = r2_score(y_true, y_pred)
    print(f'[{label}]  MAE={mae:.3f}  RMSE={rmse:.3f}  r={r:.3f}  R2={r2:.3f}')

## 3. UKBB Data Preparation

Metabolomics features are loaded and log2-transformed. The dataset is split
80 / 10 / 10 (train / validation / test) using stratified sampling on age bins
to preserve the age distribution. Three separate scalers are fitted for the
whole-population, male-only, and female-only subsets.

In [None]:
# --- Whole population ---
x_raw = pd.read_csv('First_Time_Point_Metabolomics_Control.csv', index_col='Unnamed: 0')
x_raw = np.log2(np.exp(x_raw))   # convert from log-e to log2 scale

y_all = pd.read_csv('First_Time_Point_Demographics_Control.csv',
                    index_col='eid').Age_at_TestCenter

# Stratified 90/10 test split, then 90/10 train/val split
bins_all = np.digitize(y_all, np.linspace(0, len(y_all), 20))
scaler   = StandardScaler().fit(x_raw)

X_ukbb_train, X_ukbb_test, Y_ukbb_train, Y_ukbb_test = train_test_split(
    x_raw, y_all, test_size=0.1, shuffle=True, random_state=42, stratify=bins_all
)
bins_train = np.digitize(Y_ukbb_train, np.linspace(0, len(Y_ukbb_train), 20))
X_ukbb_train, X_ukbb_val, Y_ukbb_train, Y_ukbb_val = train_test_split(
    X_ukbb_train, Y_ukbb_train, test_size=0.1, shuffle=True,
    random_state=42, stratify=bins_train
)

# Standardise using the training-set scaler
norm = lambda X: pd.DataFrame(scaler.transform(X), columns=x_raw.columns)
normalized_X_ukbb_train = norm(X_ukbb_train)
normalized_X_ukbb_val   = norm(X_ukbb_val)
normalized_X_ukbb_test  = norm(X_ukbb_test)

# --- Sex-stratified subsets ---
demo   = pd.read_csv('First_Time_Point_Demographics_Control.csv', index_col='eid')
# Sex coding: 0 = Female, 1 = Male
x_male   = x_raw[x_raw.index.isin(demo[demo.Sex == 1].index)]
x_female = x_raw[x_raw.index.isin(demo[demo.Sex == 0].index)]
y_male   = demo[demo.Sex == 1].Age_at_TestCenter
y_female = demo[demo.Sex == 0].Age_at_TestCenter

def split_and_scale(x, y):
    """Stratified train/val/test split and StandardScaler for a subset."""
    sc  = StandardScaler().fit(x)
    b   = np.digitize(y, np.linspace(0, len(y), 20))
    Xtr, Xte, Ytr, Yte = train_test_split(
        x, y, test_size=0.1, shuffle=True, random_state=42, stratify=b)
    b2  = np.digitize(Ytr, np.linspace(0, len(Ytr), 20))
    Xtr, Xva, Ytr, Yva = train_test_split(
        Xtr, Ytr, test_size=0.1, shuffle=True, random_state=42, stratify=b2)
    ntr = pd.DataFrame(sc.transform(Xtr), columns=x.columns)
    nva = pd.DataFrame(sc.transform(Xva), columns=x.columns)
    nte = pd.DataFrame(sc.transform(Xte), columns=x.columns)
    return Xtr, Xva, Xte, Ytr, Yva, Yte, ntr, nva, nte, sc

(X_male_train, X_male_val, X_male_test,
 Y_male_train, Y_male_val, Y_male_test,
 norm_X_male_train, norm_X_male_val, norm_X_male_test, _) = split_and_scale(x_male, y_male)

(X_female_train, X_female_val, X_female_test,
 Y_female_train, Y_female_val, Y_female_test,
 norm_X_female_train, norm_X_female_val, norm_X_female_test, _) = split_and_scale(x_female, y_female)

## 4. ADNI Data Preparation

ADNI metabolomics data and demographic information are loaded and aligned.
Subjects are divided into whole / male / female subsets. Baseline controls
are identified for bias-correction and validation.

In [None]:
import itertools

adni_met  = pd.read_csv('ADNI_Rearranged_New_Input_data.csv').transpose()
subj_info = pd.read_csv('SubjectId_ADNI.csv')
adni_dem  = pd.read_csv('ADNI_Demographics.csv')

# Filter to diagnostic groups of interest
out_og = adni_dem[adni_dem.Diagnosis.isin(['Controls', 'LMCI', 'EMCI', 'AD'])]

# Populate age, sex, and diagnosis for each subject x visit
subj_info['Age']       = 0
subj_info['Sex']       = 'female'
subj_info['Diagnosis'] = 'Controls'

for i in adni_dem.RID:
    temp_idx  = subj_info.RID == i
    bl_mask   = subj_info.VISCODE2 == 'bl'
    bl_idx    = [a and b for a, b in zip(temp_idx, bl_mask)]
    fu_idx    = [a and b for a, b in zip(temp_idx, subj_info.VISCODE2 != 'bl')]

    age_bl = float(list(adni_dem.age[adni_dem.RID == i])[0])
    sex    = list(adni_dem.Sex[adni_dem.RID == i])[0]
    diag   = list(adni_dem.Diagnosis[adni_dem.RID == i])[0]

    subj_info.loc[bl_idx, 'Age']       = age_bl
    subj_info.loc[bl_idx, 'Sex']       = sex
    subj_info.loc[bl_idx, 'Diagnosis'] = diag

    for k in list(subj_info[fu_idx].index):
        months = int(subj_info.VISCODE2[k].split('m')[1])
        subj_info.loc[k, 'Age']       = age_bl + months / 12
        subj_info.loc[k, 'Sex']       = sex
        subj_info.loc[k, 'Diagnosis'] = diag

# Restrict to subjects with valid diagnosis
subj_info_2    = subj_info.loc[subj_info.RID.isin(out_og.RID), :]
bl_controls    = subj_info_2[(subj_info_2.VISCODE2 == 'bl') &
                             (subj_info_2.Diagnosis == 'Controls')]
adni_all       = adni_met.loc[list(subj_info_2.Sample_id), :]

# Separate male / female subsets (all visits)
male_adni_x   = adni_all[adni_all.index.isin(
    subj_info_2.Sample_id[subj_info_2.Sex == 'Male'])]
female_adni_x = adni_all[adni_all.index.isin(
    subj_info_2.Sample_id[subj_info_2.Sex == 'Female'])]

male_adni_y   = subj_info_2[subj_info_2.Sample_id.isin(male_adni_x.index)].set_index('Sample_id').Age
female_adni_y = subj_info_2[subj_info_2.Sample_id.isin(female_adni_x.index)].set_index('Sample_id').Age

# Baseline controls only
male_adni_x_controls   = male_adni_x[male_adni_x.index.isin(
    subj_info_2.Sample_id[(subj_info_2.Sex == 'Male') &
                          (subj_info_2.Diagnosis == 'Controls')])]
female_adni_x_controls = female_adni_x[female_adni_x.index.isin(
    subj_info_2.Sample_id[(subj_info_2.Sex == 'Female') &
                          (subj_info_2.Diagnosis == 'Controls')])]

male_adni_y_controls   = subj_info_2[subj_info_2.Sample_id.isin(
    male_adni_x_controls.index)].set_index('Sample_id').Age
female_adni_y_controls = subj_info_2[subj_info_2.Sample_id.isin(
    female_adni_x_controls.index)].set_index('Sample_id').Age

# Standardise each ADNI subset independently
def scale_adni(X):
    sc  = StandardScaler().fit(X)
    out = pd.DataFrame(sc.transform(X), index=X.index, columns=X.columns)
    return out, sc

norm_adni_all,           _  = scale_adni(adni_all)
norm_male_adni_x,        _  = scale_adni(male_adni_x)
norm_female_adni_x,      _  = scale_adni(female_adni_x)
norm_male_adni_controls, _  = scale_adni(male_adni_x_controls)
norm_female_adni_controls, _ = scale_adni(female_adni_x_controls)

adni_y_all = subj_info_2.Age
adni_y_all.index = list(subj_info_2.Sample_id)

## 5. Model Training on UK Biobank

MLP hyperparameters were selected via Optuna (separate tuning run, not shown).
Bias correction is fitted on training data and applied to held-out sets.

### 5.1 Whole Population

In [None]:
DEG_BIAS = 3   # polynomial degree for bias correction

mlp_whole = MLPRegressor(
    hidden_layer_sizes=(25, 75, 100),
    activation='relu',
    learning_rate='invscaling',
    momentum=0.8570072003082716,
    learning_rate_init=0.007430075847532571,
    random_state=42
)
mlp_whole.fit(normalized_X_ukbb_train, Y_ukbb_train)

# Visualise raw (uncorrected) predictions
get_figure_prediction(Y_ukbb_test,
                      mlp_whole.predict(normalized_X_ukbb_test),
                      title_add='Whole UKBB – Test')

# Fit bias-correction model on training data
_, b_poly_whole = bias_correction(
    list(Y_ukbb_train),
    mlp_whole.predict(normalized_X_ukbb_train),
    degree=DEG_BIAS,
    title_set='Whole UKBB Training'
)

# Apply bias correction to test and validation sets
age_pc_test_whole = correct_w_model(list(Y_ukbb_test),
                                    mlp_whole.predict(normalized_X_ukbb_test),
                                    b_poly_whole)
age_pc_val_whole  = correct_w_model(list(Y_ukbb_val),
                                    mlp_whole.predict(normalized_X_ukbb_val),
                                    b_poly_whole)

get_figure_prediction(list(Y_ukbb_test), age_pc_test_whole,
                      title_add='Whole UKBB – Test (bias corrected)')

print_metrics(list(Y_ukbb_test), age_pc_test_whole, 'Whole – Test')
print_metrics(list(Y_ukbb_val),  age_pc_val_whole,  'Whole – Validation')

# Save results
df_whole = pd.DataFrame(
    list(zip(mlp_whole.predict(normalized_X_ukbb_test),
             list(Y_ukbb_test), age_pc_test_whole)),
    index=Y_ukbb_test.index,
    columns=['Predicted_Age', 'Actual_Age', 'Corrected_Predicted_Age']
)
df_whole['Difference'] = df_whole['Predicted_Age'] - df_whole['Actual_Age']

### 5.2 Male Population

In [None]:
mlp_male = MLPRegressor(
    hidden_layer_sizes=(25, 50, 75, 100),
    activation='relu',
    learning_rate='adaptive',
    momentum=0.320464912514306,
    learning_rate_init=0.012975784121150543,
    random_state=42
)
mlp_male.fit(norm_X_male_train, Y_male_train)

get_figure_prediction(Y_male_test,
                      mlp_male.predict(norm_X_male_test),
                      title_add='Male UKBB – Test')

_, b_poly_male = bias_correction(
    list(Y_male_train),
    mlp_male.predict(norm_X_male_train),
    degree=DEG_BIAS,
    title_set='Male UKBB Training'
)

age_pc_test_male = correct_w_model(list(Y_male_test),
                                   mlp_male.predict(norm_X_male_test),
                                   b_poly_male)
age_pc_val_male  = correct_w_model(list(Y_male_val),
                                   mlp_male.predict(norm_X_male_val),
                                   b_poly_male)

get_figure_prediction(list(Y_male_test), age_pc_test_male,
                      title_add='Male UKBB – Test (bias corrected)')

print_metrics(list(Y_male_test), age_pc_test_male, 'Male – Test')
print_metrics(list(Y_male_val),  age_pc_val_male,  'Male – Validation')

### 5.3 Female Population

In [None]:
mlp_female = MLPRegressor(
    hidden_layer_sizes=(25, 50, 75, 100),
    activation='relu',
    learning_rate='adaptive',
    momentum=0.20283295369885226,
    learning_rate_init=0.008211544802815416,
    random_state=42
)
mlp_female.fit(norm_X_female_train, Y_female_train)

get_figure_prediction(Y_female_test,
                      mlp_female.predict(norm_X_female_test),
                      title_add='Female UKBB – Test')

_, b_poly_female = bias_correction(
    list(Y_female_train),
    mlp_female.predict(norm_X_female_train),
    degree=DEG_BIAS,
    title_set='Female UKBB Training'
)

age_pc_test_female = correct_w_model(list(Y_female_test),
                                     mlp_female.predict(norm_X_female_test),
                                     b_poly_female)
age_pc_val_female  = correct_w_model(list(Y_female_val),
                                     mlp_female.predict(norm_X_female_val),
                                     b_poly_female)

get_figure_prediction(list(Y_female_test), age_pc_test_female,
                      title_add='Female UKBB – Test (bias corrected)')

print_metrics(list(Y_female_test), age_pc_test_female, 'Female – Test')
print_metrics(list(Y_female_val),  age_pc_val_female,  'Female – Validation')

## 6. Model Application on ADNI Baseline Controls

The UKBB-trained bias-correction polynomials are applied directly to the ADNI
baseline control predictions to obtain `corrected_predicted_age`.

In [None]:
# --- Whole population ---
norm_adni_bl_controls = norm_adni_all.loc[list(bl_controls.Sample_id)]
adni_y_bl_controls    = subj_info_2[subj_info_2.Sample_id.isin(
    bl_controls.Sample_id)].set_index('Sample_id').Age

age_pc_adni_whole_bl = correct_w_model(
    list(adni_y_bl_controls),
    mlp_whole.predict(norm_adni_bl_controls),
    b_poly_whole
)
get_figure_prediction(list(adni_y_bl_controls), age_pc_adni_whole_bl,
                      range_fig=[40, 95], range_txt_x=45, range_txt_y=[90, 85],
                      title_add='ADNI Whole – Baseline Controls')
print_metrics(list(adni_y_bl_controls), age_pc_adni_whole_bl,
              'ADNI Whole BL Controls')

In [None]:
# --- Male population ---
norm_adni_male_bl = norm_male_adni_x.loc[
    norm_male_adni_x.index.isin(bl_controls.Sample_id)]
adni_y_male_bl = male_adni_y[male_adni_y.index.isin(bl_controls.Sample_id)]

age_pc_adni_male_bl = correct_w_model(
    list(adni_y_male_bl),
    mlp_male.predict(norm_adni_male_bl),
    b_poly_male
)
get_figure_prediction(list(adni_y_male_bl), age_pc_adni_male_bl,
                      range_fig=[40, 95], range_txt_x=45, range_txt_y=[90, 85],
                      title_add='ADNI Male – Baseline Controls')
print_metrics(list(adni_y_male_bl), age_pc_adni_male_bl,
              'ADNI Male BL Controls')

In [None]:
# --- Female population ---
norm_adni_female_bl = norm_female_adni_x.loc[
    norm_female_adni_x.index.isin(bl_controls.Sample_id)]
adni_y_female_bl = female_adni_y[female_adni_y.index.isin(bl_controls.Sample_id)]

age_pc_adni_female_bl = correct_w_model(
    list(adni_y_female_bl),
    mlp_female.predict(norm_adni_female_bl),
    b_poly_female
)
get_figure_prediction(list(adni_y_female_bl), age_pc_adni_female_bl,
                      range_fig=[40, 95], range_txt_x=45, range_txt_y=[90, 85],
                      title_add='ADNI Female – Baseline Controls')
print_metrics(list(adni_y_female_bl), age_pc_adni_female_bl,
              'ADNI Female BL Controls')

## 7. Apply Whole-Population Model to All ADNI Visits

Predictions are generated for all ADNI subjects and visits. The output CSV
is used as input for the downstream survival and longitudinal analyses.

In [None]:
# Obtain a bias-correction model fitted on ADNI baseline controls
# (accounts for domain shift between UKBB and ADNI)
_, b_poly_adni_whole = bias_correction(
    list(adni_y_bl_controls),
    mlp_whole.predict(norm_adni_bl_controls),
    degree=3,
    title_set='ADNI Whole – Baseline Controls (re-fit)'
)

# Apply to all ADNI visits
age_pc_adni_all = correct_w_model(
    list(adni_y_all),
    mlp_whole.predict(norm_adni_all),
    b_poly_adni_whole
)

# Assemble output data frame
df_adni_all = pd.DataFrame(
    list(zip(mlp_whole.predict(norm_adni_all),
             list(adni_y_all),
             age_pc_adni_all)),
    index=norm_adni_all.index,
    columns=['Predicted_Age', 'Actual_Age', 'Corrected_Predicted_Age']
)
df_adni_all['Difference'] = df_adni_all['Predicted_Age'] - df_adni_all['Actual_Age']
df_adni_all['RID']        = list(subj_info_2[
    subj_info_2.Sample_id.isin(df_adni_all.index)].RID)
df_adni_all['Diagnosis']  = list(subj_info_2[
    subj_info_2.Sample_id.isin(df_adni_all.index)].Diagnosis)
df_adni_all['VISCODE2']   = list(subj_info_2[
    subj_info_2.Sample_id.isin(df_adni_all.index)].VISCODE2)

df_adni_all.to_csv('Predicted_Age_Whole_ADNI_Population_ADNIPolyBias.csv')
print('Saved:', 'Predicted_Age_Whole_ADNI_Population_ADNIPolyBias.csv')

## 8. Performance Summary Table (ADNI Baseline Controls)

MAE, RMSE, and Pearson correlation are reported for three population strata
and two age bands (55–71 and 71–91 years).

In [None]:
def get_performance_metrics(y_true, y_pred):
    """Return MAE, RMSE, and Pearson r for overall and two age sub-bands."""
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    mask_55_71 = (y_true >= 55) & (y_true <= 71)
    mask_71_91 = (y_true >  71) & (y_true <= 91)

    def calc(t, p):
        if len(t) < 2:
            return [len(t), np.nan, np.nan, np.nan, np.nan]
        mae  = mean_absolute_error(t, p)
        rmse = np.sqrt(mean_squared_error(t, p))
        r, pval = pearsonr(t, p)
        return [len(t), mae, rmse, r, pval]

    return {
        'Overall': calc(y_true, y_pred),
        '55-71':   calc(y_true[mask_55_71], y_pred[mask_55_71]),
        '71-91':   calc(y_true[mask_71_91], y_pred[mask_71_91]),
    }


rows = []
groups = {
    'Whole Population': (np.array(adni_y_bl_controls), np.array(age_pc_adni_whole_bl)),
    'Male':             (np.array(adni_y_male_bl),      np.array(age_pc_adni_male_bl)),
    'Female':           (np.array(adni_y_female_bl),    np.array(age_pc_adni_female_bl)),
}

for group_name, (y_t, y_p) in groups.items():
    for seg, vals in get_performance_metrics(y_t, y_p).items():
        rows.append([group_name, seg] + vals)

perf_table = pd.DataFrame(
    rows,
    columns=['Population', 'Age Segment', 'N', 'MAE', 'RMSE', 'Pearson r', 'P-value']
)

# Format for display
fmt = perf_table.copy()
for col in ['MAE', 'RMSE', 'Pearson r']:
    fmt[col] = fmt[col].map('{:.3f}'.format)
fmt['P-value'] = fmt['P-value'].map('{:.2e}'.format)

print('### ADNI Baseline Controls – Performance Summary ###')
display(fmt)

fmt.to_csv('adni_performance_summary.csv', index=False)