In [None]:
import afqinsight as afqi
import afqinsight.nn.tf_models as nn
import matplotlib.pyplot as plt
import numpy as np
import os.path as op
import pandas as pd
import scipy.stats as stats
import seaborn as sns
import statsmodels.api as sm
import statsmodels.formula.api as smf
import tensorflow as tf

from afqinsight import AFQDataset
from afqinsight.plot import plot_tract_profiles
from afqinsight.match import mahalonobis_dist_match
from neurocombat_sklearn import CombatModel
from sklearn.impute import SimpleImputer, KNNImputer
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.metrics import roc_auc_score, accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegressionCV
from sklearn.metrics import RocCurveDisplay

In [None]:
ses_dir = op.join("..", "ses")

df_ses = pd.read_csv(op.join(ses_dir, "pheno_merged.csv"))
df_ses.rename(columns={
    "FSQ_08": "family_size",
    "FSQ_04": "income_category",
}, inplace=True)

income_map = {
    idx: (idx + 1) * 10000 - 5000 for idx in range(10)
}
income_map[10] = 125000
income_map[11] = 175000
income_map[12] = np.nan

df_ses["income"] = df_ses["income_category"].map(income_map)
df_ses["assistance"] = df_ses.filter(like="FSQ_05").sum(axis="columns")
df_ses.drop([col for col in df_ses.columns if "FSQ" in col], axis="columns", inplace=True)

In [None]:
def get_poverty_limit_df(year):
    xls_rows = [8, 12] + list(range(15, 22))
    poverty = pd.read_excel(
        f"https://www2.census.gov/programs-surveys/cps/tables/time-series/"
        f"historical-poverty-thresholds/thresh{year[-2:]}.xls",
        header=1,
    ).iloc[xls_rows, 0:2]
    poverty.columns = ["family_size", "limit"]
    poverty["family_size"] = list(range(1, 10))
    poverty.set_index("family_size", inplace=True)
    return poverty

years = [str(yr) for yr in range(2016, 2021)]
poverty = {
    year: get_poverty_limit_df(year) for year in years
}

In [None]:
def get_inr(row):
    income = row.income
    family_size = row.family_size
    year = str(row.Enroll_Year)
    needs = poverty[year].loc[family_size].to_numpy()[0]
    return income / needs

In [None]:
df_ses["inr"] = df_ses.apply(get_inr, axis="columns")
df_ses.drop(["Enroll_Year", "income_category", "family_size", "income"], axis="columns", inplace=True)
df_ses

In [None]:
df_ses.to_csv(op.join(ses_dir, "inr.csv"))

In [None]:
sns.pairplot(df_ses, kind="kde")

In [None]:
fn_participants_tsv = "s3://fcp-indi/data/Projects/HBN/BIDS_curated/derivatives/qsiprep/participants.tsv"
df_participants = pd.read_csv(fn_participants_tsv, sep="\t", usecols=("subject_id", "dl_qc_score"))
df_participants

In [None]:
def get_dataset(assessment, target_cols, matching_target, norm_high=True, qc_cutoff=0.0, median_split=False):
    print(assessment, matching_target)
    workdir = f"../{assessment}"
    fn_nodes=op.join(workdir, "combined_tract_profiles_merged.csv")
    fn_subjects=op.join(workdir, "pheno_merged.csv")

    unsupervised_dataset = AFQDataset.from_files(
        fn_nodes=fn_nodes,
        dwi_metrics=["dki_fa", "dki_md"],
        unsupervised=True,
        concat_subject_session=True,
    )

    dataset = AFQDataset.from_files(
        fn_nodes=fn_nodes,
        fn_subjects=fn_subjects,
        dwi_metrics=["dki_fa", "dki_md"],
        target_cols=["Barratt_Total", "Age", "Sex"] + target_cols
    )
    
    subjects = [sub.split("HBNsite")[0] for sub in unsupervised_dataset.subjects]
    sites = [sub.split("HBNsite")[1] for sub in unsupervised_dataset.subjects]

    assert dataset.subjects == subjects
    dataset.sessions = sites

    df_y = pd.DataFrame(index=dataset.subjects, data=dataset.y, columns=dataset.target_cols)
    df_y = pd.merge(df_y, df_participants, left_index=True, right_on="subject_id", how="left")
    df_y["Site"] = dataset.sessions
    
    # Filter based on QC cutoff
    qc_pass_mask = df_y["dl_qc_score"] > qc_cutoff
    df_y = df_y[qc_pass_mask]
    dataset.X = dataset.X[qc_pass_mask]
    # dataset.subjects = [sub for idx, sub in enumerate(dataset.subjects) if qc_pass_mask[idx]]

    df_y["site_index"] = df_y["Site"].map({
        "RU": 0.0,
        "SI": 1.0,
        "CBIC": 2.0,
        "CUNY": 3.0,
    })

    imputer = SimpleImputer(strategy="median")
    X_imputed = imputer.fit_transform(dataset.X)

    X_site_harmonized = CombatModel().fit_transform(
        X_imputed,
        df_y[["site_index"]],
        None,
        None,
    )
    
    df_nodes = pd.DataFrame(data=X_site_harmonized, index=df_y.index)

    sns.histplot(data=df_y, x=matching_target)
    
    if median_split:
        quantiles = [0.0, 0.5, 1.0]        
    else:
        if norm_high:
            quantiles = [0.0, 0.33, 0.5, 1.0]
        else:
            quantiles = [0.0, 0.5, 0.66, 1.0]
    
    df_y["status"] = pd.qcut(df_y[matching_target], quantiles, labels=False)
    
    if not median_split:
        df_y = df_y[df_y["status"] != 1]
        df_y["status"] /= 2

    feature_cols = ["Age", "Sex"]
    if assessment != "barratt":
        feature_cols += ["Barratt_Total"]
        
    matched = mahalonobis_dist_match(
        df_y, status_col="status",
        feature_cols=feature_cols
    )
    
    sns.pairplot(
        data=matched,
        vars=["Barratt_Total", "Age", "Sex"] + target_cols,
        hue="status"
    )
    
    df_nodes_matched = pd.DataFrame(index=matched.index).merge(
        df_nodes, how="left", left_index=True, right_index=True
    )
    df_nodes_unmatched = df_nodes[~df_nodes.index.isin(matched.index)]
    unmatched = df_y[~df_y.index.isin(matched.index)]

    X_matched = df_nodes_matched.to_numpy()    
    X_unmatched = df_nodes_unmatched.to_numpy()
    
    return {
        "X_matched": X_matched,
        "y_matched": matched,
        "X_unmatched": X_unmatched,
        "y_unmatched": unmatched,
    }

In [None]:
assessments = {
    "nles": ["NLES_P_TotalEvents", "NLES_P_Upset_Total"],
    "apq": ["APQ_P_CP", "APQ_SR_CP"],
    "cpic": ["CPIC_Perceived_Threat_Total"],
    "wisc": ["WISC_FSIQ"],
}

datasets = {
    measure: get_dataset(
        assessment=assessment,
        target_cols=targets,
        matching_target=measure,
        norm_high=assessment=="wisc",
        median_split=assessment=="cpic",
    )
    for assessment, targets in assessments.items()
    for measure in targets
}

datasets["Barratt_Total"] = get_dataset(
    assessment="barratt", target_cols=[], matching_target="Barratt_Total", norm_high=True
)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5))

df_unmatched = pd.concat([datasets["NLES_P_TotalEvents"]["y_unmatched"], datasets["NLES_P_TotalEvents"]["y_matched"]])
df_matched = datasets["NLES_P_TotalEvents"]["y_matched"].copy()

df_matched["age_bin"] = df_matched["Age"].round().astype(int)
df_unmatched["age_bin"] = df_unmatched["Age"].round().astype(int)

df_count_matched = df_matched.groupby(["age_bin", "status"])["NLES_P_TotalEvents"].count().reset_index()
df_count_unmatched = df_unmatched.groupby(["age_bin", "status"])["NLES_P_TotalEvents"].count().reset_index()

columns = ["Age", "status", "count"]
df_count_matched.columns = columns
df_count_unmatched.columns = columns

df_count_matched.loc[df_count_matched["status"] == 0.0, "count"] *= -1
df_count_unmatched.loc[df_count_unmatched["status"] == 0.0, "count"] *= -1

df_count_matched["NLES"] = df_count_matched["status"].map({0.0: "Matched Low", 1.0: "Matched High"})
df_count_unmatched["NLES"] = df_count_unmatched["status"].map({0.0: "Original Low", 1.0: "Original High"})

df_count = pd.concat([df_count_matched, df_count_unmatched])

for status in [0.0, 1.0]:
    _ = sns.barplot(
        x="count", y='Age',
        data=df_count_unmatched[df_count_unmatched["status"] == status],
        hue_order=["Matched Low", "Matched High"],
        orient='horizontal', 
        dodge=False,
        ax=ax,
        lw=0,
        color="Gray",
        alpha=0.7,
    )
    
_ = sns.barplot(
    x="count", y='Age',
    data=df_count_matched,
    hue='NLES',
    hue_order=["Matched Low", "Matched High"],
    orient='horizontal', 
    dodge=False,
    ax=ax,
    lw=0,
)

ax.invert_yaxis()
ax.set_xlabel("Count", fontsize=18)
ax.tick_params(axis='both', which='major', labelsize=14)
    
ax.set_ylabel("Age", fontsize=18)

handles, labels = ax.get_legend_handles_labels()
labels = ["Original data"] + labels
handles = [Rectangle((0, 0), 1, 1, color="Gray", alpha=0.7)] + handles

_ = ax.legend(handles, labels, fontsize=16, title="NLES", title_fontsize=18)

xticklabels = [str(int(tick)).replace("-", "") for tick in ax.get_xticks()]
_ = ax.set_xticklabels(xticklabels)
_ = ax.set_title("Original and Matched Cohorts", fontsize=20)

fig.savefig("matching.pdf", bbox_inches="tight")

In [None]:
with plt.xkcd():
    fig0, ax0 = plt.subplots(1, 1, figsize=(8, 5))
    fig2, ax2 = plt.subplots(1, 1, figsize=(8, 5))
    
    ax2.plot([0, 1], [0, 1], ls="-", marker="", color="black")
    ax2.vlines([0.2, 0.8], [0.55, 0.45], [0.2, 0.8], color="black", ls="--")
    ax2.plot([0.2], [0.55], marker="o", mec="black", mfc="white", ms=10, mew=3)
    ax2.plot([0.5], [0.50], marker="o", mec="black", mfc="white", ms=10, mew=3)
    ax2.plot([0.8], [0.45], marker="o", mec="black", mfc="white", ms=10, mew=3)
    ax2.text(s="Positive BAG", x=0.2, y=0.6, transform=ax2.transAxes, ha="center", va="bottom", fontsize=18)
    ax2.text(s="Negative BAG", x=0.8, y=0.4, transform=ax2.transAxes, ha="center", va="top", fontsize=18)

    for ax in [ax0, ax2]:
        ax.tick_params(
            axis='both',
            which='both',
            bottom=False,
            top=False,
            left=False,
            right=False,
            labelbottom=False,
            labelleft=False,
        )
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

    ax2.set_xlabel("Chronological Age", fontsize=18)
    ax2.set_ylabel("Brain Age", fontsize=18)
    ax0.set_xlabel("Tract Profile", fontsize=18)

    df_bundles = pd.read_csv("../nles/combined_tract_profiles_merged.csv")
    df_plot_bundles = df_bundles[df_bundles["tractID"].isin(["CST_L", "CST_R"])]
    sns.lineplot(data=df_plot_bundles, x="nodeID", y="dki_fa", hue="tractID", ax=ax0)
    ax0.legend().set_visible(False)
    
    ax0.set_title("Potential Group Difference", fontsize=18)
    ax2.set_title("Brain Age Gap (BAG) Analysis", fontsize=18)
    
    fig0.savefig("tract_profile_diff.pdf", bbox_inches="tight")
    fig2.savefig("bag_analysis_cartoon.pdf", bbox_inches="tight")

In [None]:
for measure, dset in datasets.items():
    print(measure, len(dset["X_unmatched"]), len(dset["X_matched"]))

In [None]:
datasets["Barratt_Total"]["y_matched"]

In [None]:
def train_classification_model(dataset, standard_scale=True):
    model = afqi.pipeline.make_base_afq_pipeline(
        imputer_kwargs={"strategy": "median"},
        feature_transformer=PCA,
        scaler="standard",
        estimator=LogisticRegressionCV,
        estimator_kwargs={
            "verbose": 0,
            "Cs": 50,
            "penalty": "l1",
            "cv": 3,
            "n_jobs": 8,
            "solver": "saga",
            "max_iter": 5000,
        },
        verbose=0,
    )
    
    X_train, X_test, y_train, y_test = train_test_split(
        dataset["X_matched"],
        dataset["y_matched"]["status"].to_numpy().astype(np.float64),
        random_state=0,
        stratify=dataset["y_matched"]["status"].to_numpy().astype(np.float64),
    )
        
    model.fit(X_train, y_train)
    
    df = pd.concat([
        pd.DataFrame(dict(y_true=y_test.flatten(),
                          y_pred=model.predict(X_test).flatten(),
                          y_prob=model.predict_proba(X_test)[:, 1],
                          split="test")),
        pd.DataFrame(dict(y_true=y_train.flatten(),
                          y_pred=model.predict(X_train).flatten(),
                          y_prob=model.predict_proba(X_train)[:, 1],
                          split="train")),
    ])
        
    return dict(
        model=model,
        y_pred=model.predict(X_test).flatten(),
        y_prob=model.predict_proba(X_test)[:, 1],
        y_test=y_test.flatten(),
        X_test=X_test,
        df=df,
    )

In [None]:
cpic_results = train_classification_model(datasets["CPIC_Perceived_Threat_Total"])
apq_p_results = train_classification_model(datasets["APQ_P_CP"])
apq_sr_results = train_classification_model(datasets["APQ_SR_CP"])
nles_tot_results = train_classification_model(datasets["NLES_P_TotalEvents"])
nles_upset_results = train_classification_model(datasets["NLES_P_Upset_Total"])
barratt_results = train_classification_model(datasets["Barratt_Total"])
wisc_results = train_classification_model(datasets["WISC_FSIQ"])

In [None]:
def report(covariate, results):
    y_true = results["y_test"].flatten()
    y_pred = results["y_pred"].flatten()
    y_prob = results["y_prob"].flatten()
    y_delta = y_true - y_pred
    
    df = results["df"].copy()
    # sns.swarmplot(x="y_true", y="y_prob", data=df, hue="split")

    accuracy = accuracy_score(y_true, y_pred)
    roc_auc = roc_auc_score(y_true, y_prob[:len(y_pred)])
    
    print(covariate)
    print("=" * len(covariate))
    print(f"Accuracy: {accuracy}")
    print(f"ROC AUC: {roc_auc}")
    print(f"model.classes_: {results['model'].classes_}")
    print()

In [None]:
results = {
    "CPIC_Perceived_Threat_Total": cpic_results,
    "APQ_P_CP": apq_p_results,
    "APQ_SR_CP": apq_sr_results,
    "NLES_P_TotalEvents": nles_tot_results,
    "NLES_P_Upset_Total": nles_upset_results,
    "Barratt_Total": barratt_results,
    # "WISC_FSIQ": wisc_results,
}

labels = {
    "CPIC_Perceived_Threat_Total": "CPIC-threat",
    "APQ_P_CP": "APQ-CP (parent)",
    "APQ_SR_CP": "APQ-CP (child)",
    "NLES_P_TotalEvents": "NLES-events",
    "NLES_P_Upset_Total": "NLES-upset",
    "Barratt_Total": "BSMSS",
    # "WISC_FSIQ": wisc_results,
}

In [None]:
for key, res in results.items():
    report(key, res)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(9, 5))

ax.plot([0, 1], [0, 1], ls="--", lw=3, color="black", label="Chance")

for key, res in results.items():
    pos_label = res["model"].classes_[0] if key in ["Barratt_Total", "WISC_FSIQ"] else res["model"].classes_[1]
    _ = RocCurveDisplay.from_estimator(res["model"], res["X_test"], res["y_test"], ax=ax, pos_label=pos_label, name=labels[key])

_ = ax.set_xlabel("False Positive Rate", fontsize=18)
_ = ax.set_ylabel("True Positive Rate", fontsize=18)
_ = ax.set_title("ROC Curve", fontsize=20)
_ = ax.tick_params(axis='both', which='major', labelsize=14)
_ = ax.legend(fontsize=14, loc="lower right", bbox_to_anchor=(1.14, 0.0))
_ = fig.savefig("roc_curve_pcr_lasso.pdf", bbox_inches="tight")

In [None]:
from scipy import stats

In [None]:
df_y_matched = datasets["NLES_P_TotalEvents"]["y_matched"].copy()
nodes_matched = datasets["NLES_P_TotalEvents"]["X_matched"].copy()

high_mask = df_y_matched["status"].astype(bool).to_numpy()
low_mask = ~high_mask

nodes_high = nodes_matched[high_mask]
nodes_low = nodes_matched[low_mask]

node_tests = []
for col_num in range(nodes_matched.shape[1]):
    levene = stats.levene(nodes_high[:, col_num], nodes_low[:, col_num])
    shapiro_low = stats.shapiro(nodes_low[:, col_num])
    shapiro_high = stats.shapiro(nodes_high[:, col_num])
    ttest = stats.ttest_ind(nodes_low[:, col_num], nodes_high[:, col_num])
    wilcoxon = stats.wilcoxon(nodes_low[:, col_num], nodes_high[:, col_num])
    
    node_tests.append({
        "node": col_num,
        "levene_statistic": levene.statistic,
        "levene_pvalue": levene.pvalue,
        "low_shapiro_statistic": shapiro_low.statistic,
        "low_shapiro_pvalue": shapiro_low.pvalue,
        "high_shapiro_statistic": shapiro_high.statistic,
        "high_shapiro_pvalue": shapiro_high.pvalue,
        "ttest_statistic": ttest.statistic,
        "ttest_pvalue": ttest.pvalue,
        "wilcoxon_statistic": wilcoxon.statistic,
        "wilcoxon_pvalue": wilcoxon.pvalue,        
    })

In [None]:
df_tests = pd.DataFrame(node_tests)

In [None]:
len(df_tests[df_tests["ttest_pvalue"] < 0.05 / df_tests.shape[0]])

In [None]:
len(df_tests[df_tests["wilcoxon_pvalue"] < 0.05 / df_tests.shape[0]])

In [None]:
len(df_tests[df_tests["ttest_pvalue"] < 0.05]) / len(df_tests)

In [None]:
len(df_tests[df_tests["wilcoxon_pvalue"] < 0.05]) / len(df_tests)

In [None]:
from statsmodels.stats.multitest import multipletests

In [None]:
fdr = multipletests(df_tests["ttest_pvalue"], method="fdr_bh")

In [None]:
from statsmodels.multivariate.manova import MANOVA

maov_dicts = []
for bundle_idx in range(nodes_matched.shape[1]//100):
    col_low = bundle_idx * 100
    col_high = col_low + 100
    col_names = [f"node{idx}" for idx in range(100)]
    df_bundle_high = pd.DataFrame(columns=col_names, data=nodes_high[:, col_low:col_high])
    df_bundle_low = pd.DataFrame(columns=col_names, data=nodes_low[:, col_low:col_high])
    df_bundle_high["ela"] = 1
    df_bundle_low["ela"] = 0
    df_bundle = pd.concat([df_bundle_low, df_bundle_high])
    df_bundle["ela"] = df_bundle["ela"].astype(int)
    
    formula = "node0 "
    for idx in range(1, 100):
        formula += f" + node{idx} "
    
    formula += "~ ela"
    
    maov = MANOVA.from_formula(formula, data=df_bundle)
    maov_test_dict = maov.mv_test().results["ela"]["stat"]["Pr > F"].to_dict()
    maov_test_dict["bundle_idx"] = bundle_idx
    maov_dicts.append(maov_test_dict)

In [None]:
df_manova = pd.DataFrame(maov_dicts)

In [None]:
for test_type in df_manova.columns[:4]:
    print(len(df_manova[df_manova[test_type] < 0.05]))