In [None]:
import afqinsight as afqi
import matplotlib.pyplot as plt
import numpy as np
import os.path as op
import pandas as pd
import seaborn as sns

from afqinsight import AFQDataset
from afqinsight.plot import plot_tract_profiles
from afqinsight.match import mahalonobis_dist_match
from neurocombat_sklearn import CombatModel
from sklearn.impute import SimpleImputer, KNNImputer

In [None]:
workdir = "../nles"
fn_nodes=op.join(workdir, "combined_tract_profiles_merged.csv")
fn_subjects=op.join(workdir, "pheno_merged.csv")

unsupervised_dataset = AFQDataset.from_files(
    fn_nodes=fn_nodes,
    dwi_metrics=["dki_fa", "dki_md"],
    unsupervised=True,
    concat_subject_session=True,
)

dataset = AFQDataset.from_files(
    fn_nodes=fn_nodes,
    fn_subjects=fn_subjects,
    dwi_metrics=["dki_fa", "dki_md"],
    target_cols=["Barratt_Total", "Age", "Sex", "NLES_P_TotalEvents", "NLES_P_Upset_Total"],
)

In [None]:
subjects = [sub.split("HBNsite")[0] for sub in unsupervised_dataset.subjects]
sites = [sub.split("HBNsite")[1] for sub in unsupervised_dataset.subjects]

assert dataset.subjects == subjects
dataset.sessions = sites

In [None]:
df_y = pd.DataFrame(index=dataset.subjects, data=dataset.y, columns=dataset.target_cols)
df_y["Site"] = dataset.sessions
df_y

In [None]:
pd.unique(df_y["Site"])

In [None]:
df_y["site_index"] = df_y["Site"].map({
    "RU": 0.0,
    "SI": 1.0,
    "CBIC": 2.0,
    "CUNY": 3.0,
})

imputer = SimpleImputer(strategy="median")
X_imputed = imputer.fit_transform(dataset.X)

X_site_harmonized = CombatModel().fit_transform(
    X_imputed,
    df_y[["site_index"]],
    None,
    None,
)

X_site_and_pheno_harmonized = CombatModel().fit_transform(
    X_imputed,
    df_y[["site_index"]],
    None,
    df_y[["Age", "Sex"]],
)

In [None]:
sns.histplot(data=df_y, x="Barratt_Total")

In [None]:
df_y["Barratt_class"] = (df_y["Barratt_Total"] > df_y["Barratt_Total"].median()).astype(int)

In [None]:
matched = afqi.match.mahalonobis_dist_match(
    df_y, status_col="Barratt_class", feature_cols=["Age", "Sex"]
)

In [None]:
sns.pairplot(data=matched, vars=["Barratt_Total", "Age", "Sex"], hue="Barratt_class")

In [None]:
matched

In [None]:
df_nodes = {
    "imputed": pd.DataFrame(data=X_imputed, index=dataset.subjects),
    "site_harmonized": pd.DataFrame(data=X_site_harmonized, index=dataset.subjects),
    "site_pheno_harmonized": pd.DataFrame(data=X_site_and_pheno_harmonized, index=dataset.subjects),
}

In [None]:
df_nodes_matched = {
    key: pd.DataFrame(index=matched.index).merge(df, how="left", left_index=True, right_index=True)
    for key, df in df_nodes.items()
}

In [None]:
nles_sites = plot_tract_profiles(
    X=df_nodes_matched["site_pheno_harmonized"].to_numpy(),
    groups=dataset.groups,
    group_names=dataset.group_names,
    group_by=matched["Barratt_Total"],
    group_by_name="Barratt",
    palette="colorblind",
    figsize=(14, 14),
    quantiles=4,
)