In [None]:
import afqinsight as afqi
import afqinsight.nn.tf_models as nn
import matplotlib.pyplot as plt
import numpy as np
import os.path as op
import pandas as pd
import scipy.stats as stats
import seaborn as sns
import statsmodels.api as sm
import statsmodels.formula.api as smf
import tensorflow as tf
import tempfile

from afqinsight import AFQDataset
from afqinsight.plot import plot_tract_profiles
from afqinsight.match import mahalonobis_dist_match
from neurocombat_sklearn import CombatModel
from sklearn.impute import SimpleImputer, KNNImputer
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

In [None]:
fn_participants_tsv = "s3://fcp-indi/data/Projects/HBN/BIDS_curated/derivatives/qsiprep/participants.tsv"
df_participants = pd.read_csv(fn_participants_tsv, sep="\t", usecols=("subject_id", "dl_qc_score"))
df_participants["dl_qc_score"].isna().sum()

In [None]:
df_participants

In [None]:
def get_dataset(assessment, target_cols, matching_target, norm_high=True, qc_cutoff=0.0, median_split=False):
    print(assessment, matching_target)
    workdir = f"../{assessment}"
    fn_nodes=op.join(workdir, "combined_tract_profiles_merged.csv")
    fn_subjects=op.join(workdir, "pheno_merged.csv")

    unsupervised_dataset = AFQDataset.from_files(
        fn_nodes=fn_nodes,
        dwi_metrics=["dki_fa", "dki_md"],
        unsupervised=True,
        concat_subject_session=True,
    )

    dataset = AFQDataset.from_files(
        fn_nodes=fn_nodes,
        fn_subjects=fn_subjects,
        dwi_metrics=["dki_fa", "dki_md"],
        target_cols=["Barratt_Total", "Age", "Sex"] + target_cols
    )
    
    subjects = [sub.split("HBNsite")[0] for sub in unsupervised_dataset.subjects]
    sites = [sub.split("HBNsite")[1] for sub in unsupervised_dataset.subjects]

    assert dataset.subjects == subjects
    dataset.sessions = sites

    df_y = pd.DataFrame(index=dataset.subjects, data=dataset.y, columns=dataset.target_cols)
    df_y = pd.merge(df_y, df_participants, left_index=True, right_on="subject_id", how="left")
    df_y["Site"] = dataset.sessions
    
    # Filter based on QC cutoff
    qc_pass_mask = df_y["dl_qc_score"] > qc_cutoff
    df_y = df_y[qc_pass_mask]
    dataset.X = dataset.X[qc_pass_mask]
    # dataset.subjects = [sub for idx, sub in enumerate(dataset.subjects) if qc_pass_mask[idx]]

    df_y["site_index"] = df_y["Site"].map({
        "RU": 0.0,
        "SI": 1.0,
        "CBIC": 2.0,
        "CUNY": 3.0,
    })

    imputer = SimpleImputer(strategy="median")
    X_imputed = imputer.fit_transform(dataset.X)

    X_site_harmonized = CombatModel().fit_transform(
        X_imputed,
        df_y[["site_index"]],
        None,
        None,
    )
    
    df_nodes = pd.DataFrame(data=X_site_harmonized, index=df_y.index)

    sns.histplot(data=df_y, x=matching_target)
    
    df_y_orig = df_y.copy()

    if median_split:
        quantiles = [0.0, 0.5, 1.0]        
    else:
        if norm_high:
            quantiles = [0.0, 0.33, 0.5, 1.0]
        else:
            quantiles = [0.0, 0.5, 0.66, 1.0]
    
    df_y["status"] = pd.qcut(df_y[matching_target], quantiles, labels=False)
    
    if not median_split:
        df_y = df_y[df_y["status"] != 1]
        df_y["status"] /= 2

    feature_cols = ["Age", "Sex"]
    if assessment != "barratt":
        feature_cols += ["Barratt_Total"]
        
    matched = mahalonobis_dist_match(
        df_y, status_col="status",
        feature_cols=feature_cols
    )
    
    sns.pairplot(
        data=matched,
        vars=["Barratt_Total", "Age", "Sex"] + target_cols,
        hue="status"
    )
    
    df_nodes_matched = pd.DataFrame(index=matched.index).merge(
        df_nodes, how="left", left_index=True, right_index=True
    )
    df_nodes_unmatched = df_nodes[~df_y_orig["subject_id"].isin(matched["subject_id"])]
    unmatched = df_y_orig[~df_y_orig["subject_id"].isin(matched["subject_id"])]

    X_matched = afqi.datasets.bundles2channels(
        df_nodes_matched.to_numpy(),
        n_nodes=100,
        n_channels=48,
        channels_last=True,
    ).astype(np.float64)
    
    X_unmatched = afqi.datasets.bundles2channels(
        df_nodes_unmatched.to_numpy(),
        n_nodes=100,
        n_channels=48,
        channels_last=True,
    ).astype(np.float64)

    return {
        "X_matched": X_matched,
        "y_matched": matched,
        "X_unmatched": X_unmatched,
        "y_unmatched": unmatched,
    }

In [None]:
assessments = {
    "nles": ["NLES_P_TotalEvents", "NLES_P_Upset_Total"],
    "apq": ["APQ_P_CP", "APQ_SR_CP"],
    "cpic": ["CPIC_Perceived_Threat_Total"],
    "wisc": ["WISC_FSIQ"],
}

datasets = {
    measure: get_dataset(
        assessment=assessment,
        target_cols=targets,
        matching_target=measure,
        norm_high=assessment=="wisc",
        median_split=assessment=="cpic",
    )
    for assessment, targets in assessments.items()
    for measure in targets
}

datasets["Barratt_Total"] = get_dataset(
    assessment="barratt", target_cols=[], matching_target="Barratt_Total", norm_high=True
)

In [None]:
for measure, dset in datasets.items():
    print(measure, len(dset["X_unmatched"]), len(dset["X_matched"]))

In [None]:
datasets["NLES_P_TotalEvents"]["y_matched"]

In [None]:
def nrmse(y_true, y_pred):
    y_range = np.ptp(y_true)
    return mean_squared_error(y_true, y_pred, squared=False) / y_range

In [None]:
datasets["Barratt_Total"]["y_matched"]["dl_qc_score"].min()

In [None]:
def train_model(dataset, epochs=1000, lr=0.01, output_activation="linear", n_classes=1, normative_status=1, use_unmatched_for_training=True):
    model = nn.cnn_resnet(
        input_shape=(100, 48),
        n_classes=n_classes,
        output_activation=output_activation,
    )

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
        loss="mean_squared_error",
        metrics=[
            'mean_squared_error', 
            tf.keras.metrics.RootMeanSquaredError(name='rmse'), 
            'mean_absolute_error'
        ],
    )
    
    # ModelCheckpoint
    ckpt_filepath = tempfile.NamedTemporaryFile().name + '.h5'
    ckpt = tf.keras.callbacks.ModelCheckpoint(
        filepath = ckpt_filepath,
        monitor="val_loss",
        verbose=0,
        save_best_only=True,
        save_weights_only=True,
        mode="auto",
    )
    
    # EarlyStopping
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        min_delta=0.001,
        mode="min",
        patience=100
    )

    # ReduceLROnPlateau
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=20,
        verbose=1
    )
    
    callbacks = [early_stopping, reduce_lr, ckpt]
    
    if use_unmatched_for_training:
        X_train = dataset["X_unmatched"]
        X_test = dataset["X_matched"]
        y_train = dataset["y_unmatched"]["Age"].to_numpy().astype(np.float64)
        y_test = dataset["y_matched"]["Age"].to_numpy().astype(np.float64)

        model.fit(X_train, y_train, epochs=epochs, validation_split=0.2, callbacks=callbacks)
        model.load_weights(ckpt_filepath)
        
        df = dataset["y_matched"].copy()
        y_pred = model.predict(X_test)
        y_delta_test = y_pred.flatten() - y_test.flatten()
        df["bag"] = y_delta_test
        df["y_pred"] = y_pred.flatten()
        df["split"] = "test"
        
        df_train = dataset["y_unmatched"].copy()
        y_pred_train = model.predict(X_train)
        y_delta_train = y_pred_train.flatten() - y_train.flatten()
        df_train["bag"] = y_delta_train
        df_train["y_pred"] = y_pred_train.flatten()
        df_train["split"] = "train"
        
        df = pd.concat([df, df_train])
    else:
        train_norm_mask = (dataset["y_matched"]["status"] == normative_status).to_numpy()
        test_mask = np.logical_not(train_norm_mask)

        train_idx, norm_idx = train_test_split(np.where(train_norm_mask)[0], test_size=0.3)
        train_mask = np.zeros_like(test_mask).astype(bool)
        train_mask[train_idx] = True

        norm_mask = np.zeros_like(test_mask).astype(bool)
        norm_mask[norm_idx] = True
    
        y_train = dataset["y_matched"]["Age"][train_mask].to_numpy().astype(np.float64)
        y_test = dataset["y_matched"]["Age"][test_mask].to_numpy().astype(np.float64)
        y_norm = dataset["y_matched"]["Age"][norm_mask].to_numpy().astype(np.float64)
        X_train = dataset["X_matched"][train_mask]
        X_test = dataset["X_matched"][test_mask]
        X_norm = dataset["X_matched"][norm_mask]
            
        model.fit(X_train, y_train, epochs=epochs, validation_split=0.2, callbacks=callbacks)
        model.load_weights(ckpt_filepath)
    
        df = dataset["y_matched"].copy()
        df["bag"] = 0
        df["y_pred"] = 0

        y_pred = model.predict(X_test)
        y_delta_test = y_pred.flatten() - y_test.flatten()
        df.loc[test_mask, "bag"] = y_delta_test
        df.loc[test_mask, "y_pred"] = y_pred.flatten()

        y_pred_train = model.predict(X_train)
        y_delta_train = y_pred_train.flatten() - y_train.flatten()
        df.loc[train_mask, "bag"] = y_delta_train
        df.loc[train_mask, "y_pred"] = y_pred_train.flatten()

        y_pred_norm = model.predict(X_norm)
        y_delta_norm = y_pred_norm.flatten() - y_norm.flatten()
        df.loc[norm_mask, "bag"] = y_delta_norm
        df.loc[norm_mask, "y_pred"] = y_pred_norm.flatten()

        df.loc[train_mask, "split"] = "train"
        df.loc[test_mask, "split"] = "test"
        df.loc[norm_mask, "split"] = "norm"
    
    return dict(
        model=model,
        y_pred=y_pred,
        y_test=y_test,
        y_delta=y_delta_test,
        df=df,
    )

In [None]:
cpic_results = train_model(datasets["CPIC_Perceived_Threat_Total"], normative_status=0)
apq_p_results = train_model(datasets["APQ_P_CP"], normative_status=0)
apq_sr_results = train_model(datasets["APQ_SR_CP"], normative_status=0)
nles_tot_results = train_model(datasets["NLES_P_TotalEvents"], normative_status=0)
nles_upset_results = train_model(datasets["NLES_P_Upset_Total"], normative_status=0)
barratt_results = train_model(datasets["Barratt_Total"], normative_status=1)
wisc_results = train_model(datasets["WISC_FSIQ"], normative_status=1)

In [None]:
min(2, 3)

In [None]:
def report(covariate, results, label, use_unmatched_as_norm=True, lmplot=True):
    y_true = results["y_test"].flatten()
    y_pred = results["y_pred"].flatten()
    y_delta = y_true - y_pred

    df = results["df"].copy()
    
    if not lmplot:
        fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    lm = True
    
    if use_unmatched_as_norm:
        df["Class"] = df["status"].map({0.0: "Low", 1.0: "High"})

        if lmplot:
            lm = sns.lmplot(x="Age",
                            y="y_pred",
                            data=df[df["split"] != "train"],
                            hue="Class",
                            legend=False,
                            scatter_kws={"s": 60, "alpha": 0.7})
            ax = lm.axes[0, 0]
            _ = sns.scatterplot(
                x="Age",
                y="y_pred",
                data=df[df["split"] == "train"],
                ax=ax,
                color="gray",
                alpha=0.5,
                zorder=-100
            )
            xmin, xmax = ax.get_xlim()
            ymin, ymax = ax.get_ylim()
            global_min = min(xmin, ymin)
            global_max = max(xmax, ymax)
            _ = ax.plot([global_min, global_max], [global_min, global_max], ls="--", lw=2, zorder=-300, color="black", label="BAG=0")
            _ = ax.set_xlim(global_min, global_max)
            _ = ax.set_ylim(global_min, global_max)
            _ = ax.set_ylabel("Brain Age", fontsize=18)
            _ = ax.set_xlabel("Chronological Age", fontsize=18)
            _ = ax.tick_params(axis='both', which='major', labelsize=14)
            _ = ax.legend(fontsize=14, title=label, title_fontsize=14)
        else:
            sns.scatterplot(x="Age", y="y_pred", data=df[df["split"] == "train"], ax=axes[0], color="gray", alpha=0.7)
            sns.scatterplot(x="Age", y="y_pred", data=df[df["split"] != "train"], hue="Class", ax=axes[0])
        
        if not lmplot:
            ymin, ymax = axes[0].get_ylim()
            xmin, xmax = axes[0].get_xlim()
            axes[0].plot([0, 100], [0, 100], color="black", ls="--", lw=3)
            axes[0].set_xlim(xmin, xmax)
            axes[0].set_ylim(ymin, ymax)

        df["bag"] = df["Age"] - df["y_pred"]
        
        if not lmplot:
            sns.violinplot(data=df, x="Class", y="bag")
    else:
        df = df[df["split"] != "train"]

        sns.scatterplot(x="Age", y="y_pred", data=df, hue="split", ax=axes[0])
        ymin, ymax = axes[0].get_ylim()
        xmin, xmax = axes[0].get_xlim()
        axes[0].plot([0, 100], [0, 100], color="black", ls="--", lw=3)
        axes[0].set_xlim(xmin, xmax)
        axes[0].set_ylim(ymin, ymax)

        df["bag"] = df["Age"] - df["y_pred"]
        sns.violinplot(data=df, x="split", y="bag", ax=axes[1])
    
    scaler = StandardScaler()
    df["Scaled_Age"] = scaler.fit_transform(df["Age"].to_numpy()[:, np.newaxis]).flatten()
    df[covariate] = scaler.fit_transform(df[covariate].to_numpy()[:, np.newaxis]).flatten()
    
    formula = f"bag ~ Scaled_Age + Sex + {covariate}"
    mod = sm.GLM.from_formula(formula, df).fit()
    
    print(covariate)
    print("=" * len(covariate))
    print(f"Subject: RMSE={mean_squared_error(y_true, y_pred, squared=False)}")
    print(f"Subject: nRMSE={nrmse(y_true, y_pred)}")
    print(f"Subject: mean(bag)={np.mean(y_delta)}")
    print()
    print(mod.summary())
    print()
    print()
    
    return lm

In [None]:
results = {
    "CPIC_Perceived_Threat_Total": cpic_results,
    "APQ_P_CP": apq_p_results,
    "APQ_SR_CP": apq_sr_results,
    "NLES_P_TotalEvents": nles_tot_results,
    "NLES_P_Upset_Total": nles_upset_results,
    "Barratt_Total": barratt_results,
    "WISC_FSIQ": wisc_results,
}

labels = {
    "CPIC_Perceived_Threat_Total": "CPIC-threat",
    "APQ_P_CP": "APQ-CP (parent)",
    "APQ_SR_CP": "APQ-CP (child)",
    "NLES_P_TotalEvents": "NLES-events",
    "NLES_P_Upset_Total": "NLES-upset",
    "Barratt_Total": "BSMSS",
    "WISC_FSIQ": "FSIQ",
}

In [None]:
for key, res in results.items():
    score = r2_score(res["y_test"], res["y_pred"])
    print(key, score)

In [None]:
for key, res in results.items():
    res["lm"] = report(key, res, label=labels[key])

In [None]:
for key, res in results.items():
    res["lm"].fig.savefig(f"bag-analysis-{labels[key]}.pdf", bbox_inches="tight")