
# Predict age from white matter features

Predict subject age from white matter features. This example fetches the
Weston-Havens dataset described in Yeatman et al [1]_. This dataset contains
tractometry features from 77 subjects ages 6-50. The plots display the absolute
value of the mean regression coefficients (averaged across cross-validation
splits) for the mean diffusivity (MD) features.

Predictive performance for this example is quite poor. In a research setting,
one might have to ensemble a number of SGL estimators together and conduct a
more thorough search of the hyperparameter space. For more details, please see
[2]_.

.. [1]  Jason D. Yeatman, Brian A. Wandell, & Aviv A. Mezer, "Lifespan
    maturation and degeneration of human brain white matter" Nature
    Communications, vol. 5:1, pp. 4932, 2014 DOI: 10.1038/ncomms5932

.. [2]  Adam Richie-Halford, Jason Yeatman, Noah Simon, and Ariel Rokem
   "Multidimensional analysis and detection of informative features in human
   brain white matter" PLOS Computational Biology, 2021 DOI:
   10.1371/journal.pcbi.1009136


In [1]:
import os.path as op

import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from groupyr.decomposition import GroupPCA

from afqinsight.neurocombat_sklearn import CombatModel
from afqinsight.plot import plot_tract_profiles
from afqinsight import make_afq_regressor_pipeline
from afqinsight import AFQDataset, load_afq_data

## Fetch example data

The :func:`download_weston_havens` function download the data used in this
example and places it in the `~/.cache/afq-insight/weston_havens` directory.
If the directory does not exist, it is created. The data follows the format
expected by the :func:`load_afq_data` function: a file called `nodes.csv` that
contains AFQ tract profiles and a file called `subjects.csv` that contains
information about the subjects. The two files are linked through the
`subjectID` column that should exist in both of them. For more information
about this format, see also the [AFQ-Browser documentation](https://yeatmanlab.github.io/AFQ-Browser/dataformat.html) (items 2 and 3).



In [None]:
afqdata = AFQDataset.from_files(
    fn_nodes="/data/afq-insight/hbn/nodes.csv",
    fn_subjects="/data/afq-insight/hbn/subjects.tsv",
    dwi_metrics=["dki_md", "dki_fa"],
    target_cols=["age", "sex", "scan_site_id"],
    label_encode_cols = ["sex", "scan_site_id"],
    index_col="subject_id"
)

In [None]:
afqdata.drop_target_na()
print(afqdata)

In [None]:
"""
===============================
Harmonize HBN data using ComBat
===============================

This example loads AFQ data from the Healthy Brain Network (HBN) preprocessed
diffusion derivatives [1]_. The HBN is a landmark pediatric mental health study.
Over the course of the study, it will collect diffusion MRI data from
approximately 5,000 children and adolescents. We recently processed the
available data from over 2,000 of these subjects, and provide the tract profiles
from this dataset, which can be downloaded from AWS thanks to
[INDI](http://fcon_1000.projects.nitrc.org/).

We first load the data by using the :func:`AFQDataset.from_files` static method
and supplying AWS S3 URIs instead of local file names. We then impute missing
values and plot the mean bundle profiles by scanning site, noting that there are
substantial site differences. Lastly, we harmonize the site differences using
NeuroComBat [2]_ and plot the harmonized bundle profiles to verify that the site
differences have been removed.

.. [1]  Adam Richie-Halford, Matthew Cieslak, Lei Ai, Sendy Caffarra, Sydney
   Covitz, Alexandre R. Franco, Iliana I. Karipidis, John Kruper, Michael
   Milham, Bárbara Avelar-Pereira, Ethan Roy, Valerie J. Sydnor, Jason Yeatman,
   The Fibr Community Science Consortium, Theodore D. Satterthwaite, and Ariel
   Rokem,
   "An open, analysis-ready, and quality controlled resource for pediatric brain
   white-matter research"
   bioRxiv 2022.02.24.481303;
   doi: https://doi.org/10.1101/2022.02.24.481303

.. [2] Jean-Philippe Fortin, Drew Parker, Birkan Tunc, Takanori Watanabe, Mark A
   Elliott, Kosha Ruparel, David R Roalf, Theodore D Satterthwaite, Ruben C Gur,
   Raquel E Gur, Robert T Schultz, Ragini Verma, Russell T Shinohara.
   "Harmonization Of Multi-Site Diffusion Tensor Imaging Data"
   NeuroImage, 161, 149-170, 2017;
   doi: https://doi.org/10.1016/j.neuroimage.2017.08.047

"""

#############################################################################
# Train / test split
# ------------------
#
# We can pass the :class:`AFQDataset` class instance to scikit-learn's
# :func:`train_test_split` function, just as we would with an array.

dataset_train, dataset_test = train_test_split(afqdata, test_size=0.25)

##########################################################################
# Impute missing values
# ---------------------
#
# Next we impute missing values using median imputation. We fit the imputer
# using the training set and then use it to transform both the training and test
# sets.

imputer = dataset_train.model_fit(SimpleImputer(strategy="median"))
dataset_train = dataset_train.model_transform(imputer)
dataset_test = dataset_test.model_transform(imputer)

##########################################################################
# Harmonize the sites and replot
# ------------------------------
#
# We can see that there are substantial scan site differences in both the
# FA and MD profiles. Let's use neuroComBat to harmonize the site differences
# and then replot the mean bundle profiles.
#

# N.B. We use the excellent `neurocombat_sklearn
# <https://github.com/Warvito/neurocombat_sklearn>`_ package, which we have
# ported and updated to support recent versions of scikit learn,
# to apply ComBat to
# our data. We love this library, however, it is not fully compliant with the
# scikit-learn transformer API, so we cannot use the
# :func:`AFQDataset.model_fit_transform` method to apply this transformer to our
# dataset. No problem! We can simply copy the unharmonized dataset into a new
# variable and then overwrite the features of the new dataset with the ComBat
# output.
#
# Lastly, we replot the mean bundle profiles and confirm that ComBat did its
# job.

# Fit the ComBat transformer to the training set

combat = CombatModel()
combat.fit(
    dataset_train.X,
    dataset_train.y[:, 2][:, np.newaxis],
    dataset_train.y[:, 1][:, np.newaxis],
    dataset_train.y[:, 0][:, np.newaxis],
)


# And then transform a copy of the test set and a copy of the train set:
harmonized_test = dataset_test.copy()
harmonized_test.X = combat.transform(
    dataset_test.X,
    dataset_test.y[:, 2][:, np.newaxis],
    dataset_test.y[:, 1][:, np.newaxis],
    dataset_test.y[:, 0][:, np.newaxis],
)

harmonized_train = dataset_train.copy()
harmonized_train.X = combat.transform(
    dataset_train.X,
    dataset_train.y[:, 2][:, np.newaxis],
    dataset_train.y[:, 1][:, np.newaxis],
    dataset_train.y[:, 0][:, np.newaxis],
)


## Create an analysis pipeline




In [None]:
do_group_pca = True

if do_group_pca:
    n_components = 10

    # The next three lines retrieve the group structure of the group-wise PCA
    # and store it in ``groups_pca``. We do not use the GroupPCA transformer
    # for anything else
    imputer = SimpleImputer(strategy="median")
    gpca = GroupPCA(n_components=n_components, groups=afqdata.groups)
    groups_pca = gpca.fit(harmonized_test.X).groups_out_

    transformer = GroupPCA
    transformer_kwargs = {"groups": afqdata.groups, "n_components": n_components}
else:
    transformer = False
    transformer_kwargs = None

pipe = make_afq_regressor_pipeline(
    imputer_kwargs={"strategy": "median"},  # Use median imputation
    use_cv_estimator=True,  # Automatically determine the best hyperparameters
    scaler="standard",  # Standard scale the features before regression
    feature_transformer=transformer,  # See note above about group PCA
    feature_transformer_kwargs=transformer_kwargs,
    groups=(
        groups_pca if do_group_pca else afqdata.groups
    ),  # SGL will use the original feature groups or the PCA feature groups depending on the choice above # noqa E501
    verbose=0,  # Be quiet!
    pipeline_verbosity=False,  # No really, be quiet!
    tuning_strategy="bayes",  # Use BayesSearchCV to determine optimal hyperparameters
    n_bayes_iter=20,  # Consider this many points in hyperparameter space
    cv=3,  # Use three CV splits to evaluate each hyperparameter combination
    l1_ratio=[0.0, 1.0],  # Explore the entire range of ``l1_ratio``
    eps=5e-2,  # This is the ratio of the smallest to largest ``alpha`` value
)

In [None]:
pipe.fit(harmonized_train.X, harmonized_train.y[:, 0])

In [None]:
pred_age = pipe.predict(harmonized_test.X)

In [None]:
fig, ax = plt.subplots()
ax.scatter(harmonized_test.y[:, 0], pred_age)

In [None]:
pipe.score(harmonized_test.X, harmonized_test.y[:, 0])