# Predict age from white matter features

This example uses data from the [HBN POD2 dataset](https://www.nature.com/articles/s41597-022-01695-7), 
which includes 1867 subjects ages 5-21. We will use the sparse group lasso implemented in AFQ-Insight to fit a predictive model that uses tractometry features to predict each subject's age. Because white matter develops dramatically during childhood and adolescence, this model can be fit to account for a substantial proportion of variance in a held-out dataset.

In [None]:
import os.path as op

import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from groupyr.decomposition import GroupPCA

from afqinsight.neurocombat_sklearn import CombatModel
from afqinsight.plot import plot_tract_profiles
from afqinsight import make_afq_regressor_pipeline
from afqinsight import AFQDataset, load_afq_data

## Read the data

The `nodes.csv` file, which is the input here is the output of pyAFQ processing. The `subjects.tsv` file is a BIDS-compliant participants file, which includes subject identifiers that match those that are 
stored in the pyAFQ output. This allows AFQ-Insight to merge the data between the two files.

In [None]:
afqdata = AFQDataset.from_files(
    fn_nodes="/data/tractometry/afq-insight/hbn/nodes.csv",
    fn_subjects="/data/tractometry/afq-insight/hbn/subjects.tsv",
    dwi_metrics=["dki_md", "dki_fa"],
    target_cols=["age", "sex", "scan_site_id"],
    label_encode_cols = ["sex", "scan_site_id"],
    index_col="subject_id"
)

In [None]:
afqdata.drop_target_na()
print(afqdata)

## Train / test split

We can pass the `AFQDataset` class instance to scikit-learn's
`train_test_split` function, just as we would with an array.

In [None]:
dataset_train, dataset_test = train_test_split(afqdata, test_size=0.25)

## Impute missing values

Next we impute missing values using median imputation (some values are missing because of noisy MRI scans). We fit the imputer using the training set and then use it to transform both the training and test
sets.

In [None]:
imputer = dataset_train.model_fit(SimpleImputer(strategy="median"))
dataset_train = dataset_train.model_transform(imputer)
dataset_test = dataset_test.model_transform(imputer)

## Harmonize the sites and replot

The HBN dataset contains measurements from four different sites. 
and there are substantial scan site differences in both the
FA and MD profiles. We use neuroComBat ([Fortin et al., 2017](https://doi.org/10.1016/j.neuroimage.2017.08.047)) to harmonize 
the site differences and then replot the mean bundle profiles.

In [None]:
# Fit the ComBat transformer to the training set

combat = CombatModel()
combat.fit(
    dataset_train.X,
    dataset_train.y[:, 2][:, np.newaxis],
    dataset_train.y[:, 1][:, np.newaxis],
    dataset_train.y[:, 0][:, np.newaxis],
)


# And then transform a copy of the test set and a copy of the train set:
harmonized_test = dataset_test.copy()
harmonized_test.X = combat.transform(
    dataset_test.X,
    dataset_test.y[:, 2][:, np.newaxis],
    dataset_test.y[:, 1][:, np.newaxis],
    dataset_test.y[:, 0][:, np.newaxis],
)

harmonized_train = dataset_train.copy()
harmonized_train.X = combat.transform(
    dataset_train.X,
    dataset_train.y[:, 2][:, np.newaxis],
    dataset_train.y[:, 1][:, np.newaxis],
    dataset_train.y[:, 0][:, np.newaxis],
)

## Create an analysis pipeline
Finally, we can use the imputed and harmonized data. AFQ-Insight implements complex pipelines that include multiple analysis steps. Helper functions (such as `make_afq_regressor_pipeline`) create 
scikit-learn compatible pipelines that can then be used to fit, predict and score the model.


In [None]:
do_group_pca = True

if do_group_pca:
    n_components = 10

    # The next three lines retrieve the group structure of the group-wise PCA
    # and store it in ``groups_pca``. We do not use the GroupPCA transformer
    # for anything else
    imputer = SimpleImputer(strategy="median")
    gpca = GroupPCA(n_components=n_components, groups=afqdata.groups)
    groups_pca = gpca.fit(harmonized_test.X).groups_out_

    transformer = GroupPCA
    transformer_kwargs = {"groups": afqdata.groups, "n_components": n_components}
else:
    transformer = False
    transformer_kwargs = None

pipe = make_afq_regressor_pipeline(
    imputer_kwargs={"strategy": "median"},  # Use median imputation
    use_cv_estimator=True,  # Automatically determine the best hyperparameters
    scaler="standard",  # Standard scale the features before regression
    feature_transformer=transformer,  # See note above about group PCA
    feature_transformer_kwargs=transformer_kwargs,
    groups=(
        groups_pca if do_group_pca else afqdata.groups
    ),  # SGL will use the original feature groups or the PCA feature groups depending on the choice above # noqa E501
    verbose=0,  # Be quiet!
    pipeline_verbosity=False,  # No really, be quiet!
    tuning_strategy="bayes",  # Use BayesSearchCV to determine optimal hyperparameters
    n_bayes_iter=20,  # Consider this many points in hyperparameter space
    cv=3,  # Use three CV splits to evaluate each hyperparameter combination
    l1_ratio=[0.0, 1.0],  # Explore the entire range of ``l1_ratio``
    eps=5e-2,  # This is the ratio of the smallest to largest ``alpha`` value
)

In [None]:
pipe.fit(harmonized_train.X, harmonized_train.y[:, 0])

In [None]:
pred_age = pipe.predict(harmonized_test.X)

In [None]:
fig, ax = plt.subplots()
ax.scatter(harmonized_test.y[:, 0], pred_age)

In [None]:
pipe.score(harmonized_test.X, harmonized_test.y[:, 0])