In [None]:
from afqinsight.datasets import download_weston_havens

Welcome to this tractometry introduction tutorial! This is designed to introduce you to the analysis and interpretation of tractometry data, which involves quantifying the microstructural properties of white matter tracts in the brain using pre-processed diffusion MRI data. In this tutorial, we are not collecting or processing and MRI data; tractometry data is provided from the Weston Havens dataset, and we will analyze it.

Goals of the Notebook:
Introduction to Tractometry Data: Gain a basic understanding of what tractometry data entails, including the different metrics (like FA, MD, AD, RD) and the structures of the data (different tracts, different nodes).

Visualization: Visualize tractometry data using the python libraries matplotlib and seaborn.

Statistical Modeling: Show how to apply statistical models to tractometry data to find relationships between it and phenotypic information.

At the end of this notebook you must continue this analysis in whichever direction most interests you. Do not spend more than 1-2 hours extending this analysis, and we encourage you to ask questions at any point in the process! This can be done via email or by posting issues on this tutorial's github.


In [None]:
# The download_weston_havens() function download the data used in this example
# and places it in the ~/.cache/afq-insight/weston_havens directory.
workdir = download_weston_havens()

In [None]:
import pandas as pd
import os.path as op

In [None]:
tract_profiles = pd.read_csv(op.join(workdir, "nodes.csv"))
subject_information = pd.read_csv(op.join(workdir, "subjects.csv"))

In [None]:
tract_profiles

In [None]:
subject_information

### Data Exploration
The information in `nodes.csv` comes from diffusion MRI data analyzed according to tractometry. You can read more about diffusion MRI and tractometry from these two links: 

1. https://andysbrainbook.readthedocs.io/en/latest/MRtrix/MRtrix_Course/MRtrix_00_Diffusion_Overview.html

2. https://yeatmanlab.github.io/pyAFQ/explanations/index.html

3. https://yeatmanlab.github.io/pyAFQ/explanations/tractometry_pipeline.html

Below is some code which gets started exploring the data.


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

In [None]:
# Age is our target variable. Let's start by looking at its distribution.
plt.figure(figsize=(10, 6))
sns.histplot(subject_information['Age'], kde=True, bins=30, color='skyblue')
plt.title('Distribution of Age')
plt.xlabel('Age')
plt.ylabel('Frequency')
plt.show()

In [None]:
# We are going to compare age to 'tract profiles.' what do they look like? Let's look at an example tract profile
tract_profiles_sub000 = tract_profiles[tract_profiles.subjectID=="subject_000"]  # Look at only one subject, subject 000
tract_profile_cstl = tract_profiles_sub000[tract_profiles_sub000.tractID=="Left Corticospinal"]  # Look at one tract, the corticospinal tract
sns.lineplot(data=tract_profile_cstl, x='nodeID', y='fa')  # Look at one measure, fractional anisotropy (FA)

In [None]:
# Let's look at all tract profiles for a single subject now
sns.lineplot(data=tract_profiles_sub000, x='nodeID', y='fa', hue="tractID")
plt.legend(loc='upper left', bbox_to_anchor=(1, 1), title='Legend')

In [None]:
# What a mess! And that is just for one measure. There is also MD
sns.lineplot(data=tract_profiles_sub000, x='nodeID', y='md', hue="tractID")
plt.legend(loc='upper left', bbox_to_anchor=(1, 1), title='Legend')

In [None]:
# Let's return to using all subjects
# And let's simplify this by taking the mean across nodes, for now
# Let's see the distribution of FA for a given bundle
tract_profile_cstl = tract_profiles[tract_profiles.tractID=="Left Corticospinal"]
grouped_data = tract_profile_cstl.drop(columns="tractID").groupby('subjectID').mean()  

plt.figure(figsize=(10, 6))
sns.histplot(data=grouped_data, x='fa', kde=True, bins=30)
plt.legend()
plt.title('Distribution of FAs')
plt.xlabel('FA')
plt.ylabel('Frequency')
plt.show()

In [None]:
# Now let's see the distribution of FA for all bundles
grouped_data = tract_profiles.groupby(['subjectID', 'tractID']).mean()  

plt.figure(figsize=(10, 6))
sns.histplot(data=grouped_data, x='fa',  hue='tractID', kde=True, bins=30)
plt.legend()
plt.title('Distribution of FAs')
plt.xlabel('FA')
plt.ylabel('Frequency')
plt.show()

In [None]:
# So we have seen the distributions for our dMRI data, and for our age.
# Now, merge these sources of information to see how the distributions correlate

# this resets the grouped columns
grouped_data = tract_profiles.groupby(['subjectID', 'tractID']).mean()  
grouped_data.columns = [''.join(col).strip() for col in grouped_data.columns.values]
grouped_data.reset_index(inplace=True)

merged_data = pd.merge(grouped_data, subject_information, on='subjectID')

plt.figure(figsize=(10, 6))
sns.scatterplot(data=merged_data, x='Age', y='fa',  hue='tractID', alpha=0.6)
plt.legend(loc='upper left', bbox_to_anchor=(1, 1), title='Legend')
plt.title('Age vs Fractional Anisotropy (FA)')
plt.xlabel('Age')
plt.ylabel('FA (mean across tracts)')
plt.show()

In [None]:
# You might notice a slight drop off in FA after around age 40.
# There is also lower FA below the age of 10.
# Let's see if something similar occurs in MD
plt.figure(figsize=(10, 6))
sns.scatterplot(data=merged_data, x='Age', y='md', hue='tractID', alpha=0.6)
plt.legend(loc='upper left', bbox_to_anchor=(1, 1), title='Legend')
plt.title('Age vs Mean Diffusivity (MD)')
plt.xlabel('Age')
plt.ylabel('MD (mean across tracts)')
plt.show()

### Modelling the data
Below is some code to model the relationship between the dMRI measures and age (age from `subjects.csv`).

To start, we do some data wrangling. We simplify the data by taking the mean across nodeID. This might not be the right decision, but is a good place to start. There are currently 100 nodes per tract, and adjacent nodes are highly correlated. These large number of correlated measurements may cause problems for fitting models. However, parts of the tract may be more predictive than other parts of the tract, and taking an average throws away that information. Instead of taking an average, one can also sample every tenth (or fifth) node to reduce the number of correlated measurements.

In [None]:
print(tract_profiles.head(5))
grouped_data = tract_profiles.groupby(['subjectID', 'tractID']).mean()  
grouped_data.columns = [''.join(col).strip() for col in grouped_data.columns.values]
grouped_data.reset_index(inplace=True)
print(grouped_data.head(5))

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer

# For simplicity, let's focus on FA and MD as our features to predict Age.
X = merged_data[['fa', 'md']]
y = merged_data['Age']

# Splitting the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Let's start with imputation followed by linear regression
# Imputation is necessary as not all bundles are found in all subjects, meaning there are NaNs in X
pipeline_steps = [
    ('imputer', SimpleImputer(strategy='mean')),  # Impute missing values using the mean
    ('linear_reg', LinearRegression())            # Then, apply linear regression
]
linear_reg_pipeline = Pipeline(steps=pipeline_steps)
linear_reg_pipeline.fit(X_train, y_train)
y_pred = linear_reg_pipeline.predict(X_test)

# There are a variety of ways to evaluate a model, here are two
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"Linear Regression Model MSE: {mse:.2f}")
print(f"Linear Regression Model R² score: {r2:.2f}")

plt.figure(figsize=(10, 6))
sns.scatterplot(x=y_test, y=y_pred)
plt.title('Linear Regression predictions')
plt.xlabel('Age')
plt.ylabel('Predicted Age')
plt.show()

In [None]:
# we can try different models
from sklearn.ensemble import RandomForestRegressor

pipeline_steps = [
    ('imputer', SimpleImputer(strategy='mean')),
    ('forest_reg', RandomForestRegressor(n_estimators=100, random_state=42))
]
forest_reg_pipeline = Pipeline(steps=pipeline_steps)
forest_reg_pipeline.fit(X_train, y_train)
y_pred_rf = forest_reg_pipeline.predict(X_test)

mse_rf = mean_squared_error(y_test, y_pred_rf)
r2_rf = r2_score(y_test, y_pred_rf)

print(f"Random Forest Regressor MSE: {mse_rf:.2f}")
print(f"Random Forest Regressor R² score: {r2_rf:.2f}")
plt.figure(figsize=(10, 6))
sns.scatterplot(x=y_test, y=y_pred)
plt.title('Random Forest predictions')
plt.xlabel('Age')
plt.ylabel('Predicted Age')
plt.show()

It looks like linear regression works better of these two. But some next steps include (in no particular order):

1. Investigating which tract(s) are most predictive of age? Least? You can use the linear regression coefficients for this.
2. You could change the hyperparameters of random forest, or try different models to see if something can beat linear regression.
3. You could try not simplifying the data so much, by using other tissue properties (rd, cl, ad) or not meaning across the nodes.

# Tackle one of these next steps, or continue to explore the data in other creative ways!