In [None]:
import afqinsight as afqi
import afqinsight.nn.tf_models as nn
import matplotlib.pyplot as plt
import numpy as np
import os.path as op
import pandas as pd
import scipy.stats as stats
import seaborn as sns
import statsmodels.api as sm
import statsmodels.formula.api as smf
import tensorflow as tf
import tempfile

from afqinsight import AFQDataset
from afqinsight.plot import plot_tract_profiles
from afqinsight.match import mahalonobis_dist_match
from neurocombat_sklearn import CombatModel
from sklearn.impute import SimpleImputer, KNNImputer
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [None]:
fn_participants_tsv = "s3://fcp-indi/data/Projects/HBN/BIDS_curated/derivatives/qsiprep/participants.tsv"
df_participants = pd.read_csv(fn_participants_tsv, sep="\t", usecols=("subject_id", "dl_qc_score"))
df_participants

In [None]:
(df_participants["dl_qc_score"] == 0.0).sum()

In [None]:
def get_dataset(assessment, target_cols, matching_target, norm_high=True, qc_cutoff=0.0, median_split=False):
    print(assessment, matching_target)
    workdir = f"../{assessment}"
    fn_nodes=op.join(workdir, "combined_tract_profiles_merged.csv")
    fn_subjects=op.join(workdir, "pheno_merged.csv")

    unsupervised_dataset = AFQDataset.from_files(
        fn_nodes=fn_nodes,
        dwi_metrics=["dki_fa", "dki_md"],
        unsupervised=True,
        concat_subject_session=True,
    )

    dataset = AFQDataset.from_files(
        fn_nodes=fn_nodes,
        fn_subjects=fn_subjects,
        dwi_metrics=["dki_fa", "dki_md"],
        target_cols=["Barratt_Total", "Age", "Sex"] + target_cols
    )
    
    subjects = [sub.split("HBNsite")[0] for sub in unsupervised_dataset.subjects]
    sites = [sub.split("HBNsite")[1] for sub in unsupervised_dataset.subjects]

    assert dataset.subjects == subjects
    dataset.sessions = sites

    df_y = pd.DataFrame(index=dataset.subjects, data=dataset.y, columns=dataset.target_cols)
    df_y = pd.merge(df_y, df_participants, left_index=True, right_on="subject_id", how="left")
    df_y["Site"] = dataset.sessions
    
    # Filter based on QC cutoff
    qc_pass_mask = df_y["dl_qc_score"] > qc_cutoff
    df_y = df_y[qc_pass_mask]
    dataset.X = dataset.X[qc_pass_mask]
    # dataset.subjects = [sub for idx, sub in enumerate(dataset.subjects) if qc_pass_mask[idx]]

    df_y["site_index"] = df_y["Site"].map({
        "RU": 0.0,
        "SI": 1.0,
        "CBIC": 2.0,
        "CUNY": 3.0,
    })

    imputer = SimpleImputer(strategy="median")
    X_imputed = imputer.fit_transform(dataset.X)

    X_site_harmonized = CombatModel().fit_transform(
        X_imputed,
        df_y[["site_index"]],
        None,
        None,
    )
    
    df_nodes = pd.DataFrame(data=X_site_harmonized, index=df_y.index)

    sns.histplot(data=df_y, x=matching_target)
    
    if median_split:
        quantiles = [0.0, 0.5, 1.0]        
    else:
        if norm_high:
            quantiles = [0.0, 0.33, 0.5, 1.0]
        else:
            quantiles = [0.0, 0.5, 0.66, 1.0]
    
    df_y["status"] = pd.qcut(df_y[matching_target], quantiles, labels=False)
    
    if not median_split:
        df_y = df_y[df_y["status"] != 1]
        df_y["status"] /= 2

    feature_cols = ["Age", "Sex"]
    if assessment != "barratt":
        feature_cols += ["Barratt_Total"]
        
    matched = mahalonobis_dist_match(
        df_y, status_col="status",
        feature_cols=feature_cols
    )
    
    sns.pairplot(
        data=matched,
        vars=["Barratt_Total", "Age", "Sex"] + target_cols,
        hue="status"
    )
    
    df_nodes_matched = pd.DataFrame(index=matched.index).merge(
        df_nodes, how="left", left_index=True, right_index=True
    )
    df_nodes_unmatched = df_nodes[~df_nodes.index.isin(matched.index)]
    # unmatched = df_y.loc[df_nodes_unmatched.index]

    X_matched = afqi.datasets.bundles2channels(
        df_nodes_matched.to_numpy(),
        n_nodes=100,
        n_channels=48,
        channels_last=True,
    ).astype(np.float64)
    
    X_unmatched = afqi.datasets.bundles2channels(
        df_nodes_unmatched.to_numpy(),
        n_nodes=100,
        n_channels=48,
        channels_last=True,
    ).astype(np.float64)

    return {
        "X_matched": X_matched,
        "y_matched": matched,
        "X_unmatched": X_unmatched,
        # "y_unmatched": unmatched,
    }

In [None]:
assessments = {
    "nles": ["NLES_P_TotalEvents", "NLES_P_Upset_Total"],
    "apq": ["APQ_P_CP", "APQ_SR_CP"],
    "cpic": ["CPIC_Perceived_Threat_Total"],
    "wisc": ["WISC_FSIQ"],
}

datasets = {
    measure: get_dataset(
        assessment=assessment,
        target_cols=targets,
        matching_target=measure,
        norm_high=assessment=="wisc",
        median_split=assessment=="cpic",
    )
    for assessment, targets in assessments.items()
    for measure in targets
}

datasets["Barratt_Total"] = get_dataset(
    assessment="barratt", target_cols=[], matching_target="Barratt_Total", norm_high=True
)

In [None]:
for measure, dset in datasets.items():
    print(measure, len(dset["X_unmatched"]), len(dset["X_matched"]))

In [None]:
def train_classification_model(dataset, epochs=1000, lr=0.0001, output_activation="softmax", n_classes=2):
    model = nn.cnn_resnet(
        input_shape=(100, 48),
        n_classes=n_classes,
        output_activation=output_activation,
    )

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
        loss="binary_crossentropy",
        metrics=[
            'accuracy', 
            tf.keras.metrics.AUC(name="roc_auc"),
        ],
    )
    
    # ModelCheckpoint
    ckpt_filepath = tempfile.NamedTemporaryFile().name + '.h5'
    ckpt = tf.keras.callbacks.ModelCheckpoint(
        filepath = ckpt_filepath,
        monitor="val_loss",
        verbose=0,
        save_best_only=True,
        save_weights_only=True,
        mode="auto",
    )
    
    # EarlyStopping
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        min_delta=0.001,
        mode="min",
        patience=100
    )

    # ReduceLROnPlateau
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=20,
        verbose=1
    )
    
    callbacks = [early_stopping, reduce_lr, ckpt]
    
    X_train, X_test, y_train, y_test = train_test_split(
        dataset["X_matched"],
        pd.get_dummies(dataset["y_matched"]["status"]).to_numpy().astype(np.float64),
        random_state=0,
    )
        
    model.fit(X_train, y_train, epochs=epochs, validation_split=0.33, callbacks=callbacks)
    model.load_weights(ckpt_filepath)
    
    df = pd.concat([
        pd.DataFrame(dict(y_true=y_test.flatten(),
                          y_pred=model.predict(X_test).flatten(),
                          split="test")),
        pd.DataFrame(dict(y_true=y_train.flatten(),
                          y_pred=model.predict(X_train).flatten(),
                          split="train")),
    ])
        
    return dict(
        model=model,
        y_pred=model.predict(X_test).flatten(),
        y_test=y_test.flatten(),
        X_test=X_test,
        df=df,
    )

In [None]:
cpic_results = train_classification_model(datasets["CPIC_Perceived_Threat_Total"])
apq_p_results = train_classification_model(datasets["APQ_P_CP"])
apq_sr_results = train_classification_model(datasets["APQ_SR_CP"])
nles_tot_results = train_classification_model(datasets["NLES_P_TotalEvents"])
nles_upset_results = train_classification_model(datasets["NLES_P_Upset_Total"])
barratt_results = train_classification_model(datasets["Barratt_Total"])
wisc_results = train_classification_model(datasets["WISC_FSIQ"])

In [None]:
def report(covariate, results):
    y_true = results["y_test"].flatten()
    y_pred = results["y_pred"].flatten()
    y_prob = results["y_prob"].flatten()
    y_delta = y_true - y_pred
    
    df = results["df"].copy()
    # sns.swarmplot(x="y_true", y="y_prob", data=df, hue="split")

    accuracy = accuracy_score(y_true, y_pred)
    roc_auc = roc_auc_score(y_true, y_prob[:len(y_pred)])
    
    print(covariate)
    print("=" * len(covariate))
    print(f"Accuracy: {accuracy}")
    print(f"ROC AUC: {roc_auc}")
    print(f"model.classes_: {results['model'].classes_}")
    print()

In [None]:
results = {
    "CPIC_Perceived_Threat_Total": cpic_results,
    "APQ_P_CP": apq_p_results,
    "APQ_SR_CP": apq_sr_results,
    "NLES_P_TotalEvents": nles_tot_results,
    "NLES_P_Upset_Total": nles_upset_results,
    "Barratt_Total": barratt_results,
    "WISC_FSIQ": wisc_results,
}

In [None]:
for key, res in results.items():
    report(key, res)