In [3]:
# Core libraries
import numpy as np
import pandas as pd
import tensorflow as tf
from copy import deepcopy

# SciPy for statistical analysis
from scipy import stats
from scipy.stats import t

# Scikit-learn utilities
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
    confusion_matrix, balanced_accuracy_score, ConfusionMatrixDisplay, classification_report
)
from sklearn.metrics import precision_recall_curve


import matplotlib.pyplot as plt


from sklearn.utils import resample, class_weight
from sklearn.experimental import enable_iterative_imputer  # noqa
from sklearn.impute import IterativeImputer
from sklearn.ensemble import RandomForestRegressor
from sklearn.impute import KNNImputer
from sklearn.model_selection import train_test_split


# TensorFlow and Keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import LSTM, Dense, Dropout, Input, Concatenate
from tensorflow.keras.callbacks import EarlyStopping
import keras.backend as K



import csv
import os


In [4]:
# ===========================
# Load and preprocess dataset
# ===========================

from google.colab import drive
drive.mount('/content/drive')

file_path = '/content/drive/Shared drives/ABCD_Model/for_JS_final_withgroup.csv'
df = pd.read_csv(file_path)


# Extract numeric site ID from 'site_id_l.baseline_year_1_arm_1'
df['site'] = df['site_id_l.baseline_year_1_arm_1'].str.extract(r'(\d+)$').astype(int)

# Define features for each timepoint
features_baseline = [
    #'interview_age.baseline_year_1_arm_1',
    'KSADSintern.baseline_year_1_arm_1',
    'nihtbx_cryst_agecorrected.baseline_year_1_arm_1',
    'ACEs.baseline_year_1_arm_1',
    'avgPFCthick_QA.baseline_year_1_arm_1',
    'rsfmri_c_ngd_cgc_ngd_cgc_QA.baseline_year_1_arm_1',
    'rsfmri_c_ngd_dt_ngd_dt_QA.baseline_year_1_arm_1'
]

features_followup = [
    #'interview_age.2_year_follow_up_y_arm_1',
    'KSADSintern.2_year_follow_up_y_arm_1',
    'nihtbx_cryst_agecorrected.2_year_follow_up_y_arm_1',
    'ACEs.2_year_follow_up_y_arm_1',
    'avgPFCthick_QA.2_year_follow_up_y_arm_1',
    'rsfmri_c_ngd_cgc_ngd_cgc_QA.2_year_follow_up_y_arm_1',
    'rsfmri_c_ngd_dt_ngd_dt_QA.2_year_follow_up_y_arm_1',
]

features_all_time = features_baseline + features_followup

# Define cross-sectional features: all other numeric columns not in features_all_time or site/group columns
cross_sectional_features = [
    'rel_family_id',
    #'demo_sex_v2',
    #'race_ethnicity',
    'acs_raked_propensity_score',
    'speechdelays',
    'motordelays',
    'fam_history_8_yes_no',
    ]
# ========================
# Clean data
# ========================



for col in features_all_time + cross_sectional_features:
    df[col] = df[col].astype(str).str.strip()
    df.loc[df[col] == '', col] = np.nan
    df[col] = pd.to_numeric(df[col], errors='coerce')

# Step 2: Drop rows with missing values in key column(s)
df.dropna(subset=features_all_time + ['group_PDvLP_3timepoint'], inplace=True)

# Step 3: KNN Imputation on features (assumes all are numeric now)
#imputer = KNNImputer(n_neighbors=5) #comment these two lines for non data imputation
#df[features_all_time + cross_sectional_features] = imputer.fit_transform(df[features_all_time + cross_sectional_features])

print(f"Total samples after cleaning: {len(df)}")
print(f"1s: {len(df[df['group_PDvLP_3timepoint'] == 1])} 0s: {len(df[df['group_PDvLP_3timepoint'] == 0])}")
site_ids = df['site'].unique()



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Total samples after cleaning: 4153
1s: 160 0s: 3993


In [5]:

def build_LSTM_model(
    timesteps,
    ts_features,
    cross_features,
    lstm_units=64,
    dropout_rate=0.3,
    dense_units=32,
    alpha=0.75  # for focal loss if you want to pass it here
):
    # Time series input
    input_ts = Input(shape=(timesteps, ts_features), name='time_series_input')
    lstm_out = LSTM(lstm_units, return_sequences=False)(input_ts)
    lstm_out = Dropout(dropout_rate)(lstm_out)

    # Cross-sectional input
    input_cross = Input(shape=(cross_features,), name='cross_sectional_input')

    # Concatenate LSTM output and cross-sectional data
    concatenated = Concatenate()([lstm_out, input_cross])

    # Dense layers
    dense1 = Dense(dense_units, activation='relu')(concatenated)
    output = Dense(1, activation='sigmoid')(dense1)

    model = Model(inputs=[input_ts, input_cross], outputs=output)
    model.compile(optimizer='adam', loss=focal_loss(alpha=alpha), metrics=['accuracy']) #focal loss
    return model

In [6]:
def focal_loss(gamma=2., alpha=0.75):
    def focal_loss_fixed(y_true, y_pred):
        epsilon = K.epsilon()
        y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
        cross_entropy = -y_true * tf.math.log(y_pred) - (1 - y_true) * tf.math.log(1 - y_pred)
        weight = alpha * tf.pow(1 - y_pred, gamma) * y_true + (1 - alpha) * tf.pow(y_pred, gamma) * (1 - y_true)
        loss = weight * cross_entropy

        return tf.reduce_mean(loss)

    return focal_loss_fixed

In [7]:
def run_losocv_lstm(
    df,
    features_baseline,
    features_followup,
    cross_sectional_features,
    site_column='site',
    label_column='group_PDvLP_3timepoint',
    lstm_units=64,
    dropout_rate=0.3,
    dense_units=32,
    alpha=0.75,
    batch_size=16,
    verbose=0
):
    site_ids = df[site_column].unique()
    site_metrics = []

    for test_site in site_ids:
        print(f"\n==== Testing on site {test_site} ====")

        df_train = df[df[site_column] != test_site]
        df_test = df[df[site_column] == test_site]

        # === Time series features ===
        X_train_ts, y_train = [], []
        for _, row in df_train.iterrows():
            baseline = row[features_baseline].values.astype(np.float32)
            followup = row[features_followup].values.astype(np.float32)


            if baseline.shape != followup.shape:
                continue
            seq = np.stack([baseline, followup])
            X_train_ts.append(seq)
            y_train.append(row[label_column])

        X_test_ts, y_test = [], []
        for _, row in df_test.iterrows():
            baseline = row[features_baseline].values.astype(np.float32)
            followup = row[features_followup].values.astype(np.float32)
            if baseline.shape != followup.shape:
                continue
            seq = np.stack([baseline, followup])
            X_test_ts.append(seq)
            y_test.append(row[label_column])

        X_train_ts = np.array(X_train_ts)
        X_test_ts = np.array(X_test_ts)
        y_train = np.array(y_train).astype(int)
        y_test = np.array(y_test).astype(int)



        if len(X_train_ts) == 0 or len(X_test_ts) == 0:
            print(f"Skipping site {test_site} due to no valid samples.")
            continue

        # === Cross-sectional features ===
        X_train_cross = df_train[cross_sectional_features].copy()
        X_test_cross = df_test[cross_sectional_features].copy()

        imputer = KNNImputer(n_neighbors=5)
        X_train_cross = imputer.fit_transform(X_train_cross)
        X_test_cross = imputer.transform(X_test_cross)

        # === Normalize time series ===
        scaler_ts = StandardScaler()
        X_train_ts_flat = X_train_ts.reshape(-1, X_train_ts.shape[2])
        X_test_ts_flat = X_test_ts.reshape(-1, X_test_ts.shape[2])

        X_train_ts_scaled = scaler_ts.fit_transform(X_train_ts_flat).reshape(X_train_ts.shape)
        X_test_ts_scaled = scaler_ts.transform(X_test_ts_flat).reshape(X_test_ts.shape)

        # === Normalize cross-sectional ===
        scaler_cross = StandardScaler()
        X_train_cross_scaled = scaler_cross.fit_transform(X_train_cross)
        X_test_cross_scaled = scaler_cross.transform(X_test_cross)

        # === Build and train model ===
        model = build_LSTM_model(
            timesteps=X_train_ts_scaled.shape[1],
            ts_features=X_train_ts_scaled.shape[2],
            cross_features=X_train_cross_scaled.shape[1],
            lstm_units=lstm_units,
            dropout_rate=dropout_rate,
            dense_units=dense_units,
            alpha=alpha
        )

        early_stop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

        model.fit(
            [X_train_ts_scaled, X_train_cross_scaled], y_train,
            epochs=20,
            batch_size=batch_size,
            validation_split=0.1,
            callbacks=[early_stop],
            verbose=verbose
        )

        # === Evaluate ===
        y_pred_probs = model.predict([X_test_ts_scaled, X_test_cross_scaled]).flatten()
        y_pred = (y_pred_probs > 0.5).astype(int)

        if np.isnan(y_pred_probs).any():
            print("NaNs detected in prediction probabilities, skipping AUC calculation for this site.")
            auc = float('nan')
        elif len(np.unique(y_test)) == 2:
            auc = roc_auc_score(y_test, y_pred_probs)
        else:
            auc = float('nan')

        acc = accuracy_score(y_test, y_pred)
        prec = precision_score(y_test, y_pred, zero_division=0)
        rec = recall_score(y_test, y_pred, zero_division=0)
        f1 = f1_score(y_test, y_pred, zero_division=0)

        cm = confusion_matrix(y_test, y_pred, labels=[0, 1])
        if cm.shape == (2, 2):
            tn, fp, fn, tp = cm.ravel()
        else:
            print(f"Confusion matrix is not 2x2 for site {test_site}. It is:\n{cm}")
            tn = fp = fn = tp = 0

        npv = tn / (tn + fn) if (tn + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

        print(f"Accuracy:  {acc:.3f}")
        print(f"Precision: {prec:.3f}")
        print(f"Recall:    {rec:.3f}")
        print(f"F1 Score:  {f1:.3f}")
        print(f"AUC:       {auc:.3f}")
        print(f"Negative Predictive Value (NPV): {npv:.2f}")
        print(f"Specificity: {specificity:.3f}")

        site_metrics.append((test_site, acc, prec, rec, npv, auc, specificity, f1))

    return site_metrics


In [8]:
# Write header once, if file doesn't exist

csv_file = '/content/drive/Shared drives/ABCD_Model/Explainability/no_demos_feature_ablation_results.csv'
fieldnames = ['ablated_feature', 'feature_type',  'Accuracy', 'Precision', 'Recall', 'NPV', 'AUC', 'Specificity', 'F1 Score']





if not os.path.isfile(csv_file):
    with open(csv_file, mode='w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()

def write_result_to_csv(result_dict):
    with open(csv_file, mode='a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writerow(result_dict)

def get_existing_ablated_features(csv_file):
    if not os.path.isfile(csv_file):
        return set()

    existing = set()
    with open(csv_file, mode='r', newline='') as f:
        reader = csv.DictReader(f)
        for row in reader:
            existing.add(row['ablated_feature'])
    return existing

def core_name(feat):
    return feat.split('.')[0]


In [10]:

results = []


baseline_core_to_full = {core_name(f): f for f in features_baseline}
followup_core_to_full = {core_name(f): f for f in features_followup}

# All core features (assuming same set in baseline and follow-up)
core_features = list(baseline_core_to_full.keys())

existing_ablated = get_existing_ablated_features(csv_file)

for core_feat in core_features:
    if core_feat in existing_ablated:
        print(f"Skipping {core_feat} — already in results.")
        continue

    print(f"\n==== Ablating time-series feature: {core_feat} ====")

    features_baseline_ablate = [f for f in features_baseline if core_name(f) != core_feat]
    features_followup_ablate = [f for f in features_followup if core_name(f) != core_feat]

    print(f"Baseline features: {features_baseline_ablate}")
    print(f"Follow-up features: {features_followup_ablate}")

    metrics = run_losocv_lstm(
        df,
        features_baseline=features_baseline_ablate,
        features_followup=features_followup_ablate,
        cross_sectional_features=cross_sectional_features,
        verbose=0
    )

    mean_metrics = {
        'ablated_feature': core_feat,
        'feature_type': 'time_series',
        'Accuracy': np.nanmean([m[1] for m in metrics]),
        'Precision': np.nanmean([m[2] for m in metrics]),
        'Recall': np.nanmean([m[3] for m in metrics]),
        'NPV': np.nanmean([m[4] for m in metrics]),
        'AUC': np.nanmean([m[5] for m in metrics]),
        'Specificity': np.nanmean([m[6] for m in metrics]),
        'F1 Score': np.nanmean([m[7] for m in metrics])
    }

    results.append(mean_metrics)
    write_result_to_csv(mean_metrics)
    print(f"AUC after ablating {core_feat}: {mean_metrics['AUC']:.4f}")



==== Ablating time-series feature: KSADSintern ====
Baseline features: ['nihtbx_cryst_agecorrected.baseline_year_1_arm_1', 'ACEs.baseline_year_1_arm_1', 'avgPFCthick_QA.baseline_year_1_arm_1', 'rsfmri_c_ngd_cgc_ngd_cgc_QA.baseline_year_1_arm_1', 'rsfmri_c_ngd_dt_ngd_dt_QA.baseline_year_1_arm_1']
Follow-up features: ['nihtbx_cryst_agecorrected.2_year_follow_up_y_arm_1', 'ACEs.2_year_follow_up_y_arm_1', 'avgPFCthick_QA.2_year_follow_up_y_arm_1', 'rsfmri_c_ngd_cgc_ngd_cgc_QA.2_year_follow_up_y_arm_1', 'rsfmri_c_ngd_dt_ngd_dt_QA.2_year_follow_up_y_arm_1']

==== Testing on site 21 ====
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
Accuracy:  0.931
Precision: 0.000
Recall:    0.000
F1 Score:  0.000
AUC:       0.706
Negative Predictive Value (NPV): 0.94
Specificity: 0.995

==== Testing on site 11 ====
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 44ms/step
Accuracy:  0.979
Precision: 0.000
Recall:    0.000
F1 Score:  0.000
AUC:       0.572
Negative P

In [11]:
print("\nAblating cross-sectional features...\n")

# Load existing ablations from CSV
existing_ablated = get_existing_ablated_features(csv_file)

for cross_feat in cross_sectional_features:
    if cross_feat in existing_ablated:
        print(f"Skipping {cross_feat} — already in results.")
        continue

    print(f"\n==== Ablating cross-sectional feature: {cross_feat} ====")

    cross_sectional_ablate = [f for f in cross_sectional_features if f != cross_feat]

    metrics = run_losocv_lstm(
        df,
        features_baseline=features_baseline,
        features_followup=features_followup,
        cross_sectional_features=cross_sectional_ablate,
        verbose=0
    )

    mean_metrics = {
        'ablated_feature': cross_feat,
        'feature_type': 'cross_feat',
        'Accuracy': np.nanmean([m[1] for m in metrics]),
        'Precision': np.nanmean([m[2] for m in metrics]),
        'Recall': np.nanmean([m[3] for m in metrics]),
        'NPV': np.nanmean([m[4] for m in metrics]),
        'AUC': np.nanmean([m[5] for m in metrics]),
        'Specificity': np.nanmean([m[6] for m in metrics]),
        'F1 Score': np.nanmean([m[7] for m in metrics])
    }

    results.append(mean_metrics)
    write_result_to_csv(mean_metrics)
    print(f"AUC after ablating {cross_feat}: {mean_metrics['AUC']:.4f}")



Ablating cross-sectional features...


==== Ablating cross-sectional feature: rel_family_id ====

==== Testing on site 21 ====
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step
Accuracy:  0.936
Precision: 0.500
Recall:    0.286
F1 Score:  0.364
AUC:       0.701
Negative Predictive Value (NPV): 0.95
Specificity: 0.980

==== Testing on site 11 ====
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 58ms/step
Accuracy:  0.979
Precision: 0.500
Recall:    0.667
F1 Score:  0.571
AUC:       0.993
Negative Predictive Value (NPV): 0.99
Specificity: 0.985

==== Testing on site 4 ====
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 48ms/step
Accuracy:  0.940
Precision: 0.333
Recall:    0.214
F1 Score:  0.261
AUC:       0.673
Negative Predictive Value (NPV): 0.96
Specificity: 0.978

==== Testing on site 5 ====
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 48ms/step
Accuracy:  0.942
Precision: 0.000
Recall:    0.000
F1 Score:  0.000
AU

In [12]:
df_results = pd.DataFrame(results)
df_results_sorted = df_results.sort_values(by='AUC', ascending=True)
print(df_results_sorted)

                ablated_feature feature_type  Accuracy  Precision    Recall  \
0                   KSADSintern  time_series  0.963207   0.060606  0.020455   
7    acs_raked_propensity_score   cross_feat  0.954017   0.417911  0.329212   
4   rsfmri_c_ngd_cgc_ngd_cgc_QA  time_series  0.959282   0.409055  0.377785   
3                avgPFCthick_QA  time_series  0.957616   0.358658  0.330728   
6                 rel_family_id   cross_feat  0.960457   0.424729  0.331252   
2                          ACEs  time_series  0.955818   0.405195  0.359282   
10         fam_history_8_yes_no   cross_feat  0.960892   0.436147  0.330349   
8                  speechdelays   cross_feat  0.960726   0.427381  0.358904   
1     nihtbx_cryst_agecorrected  time_series  0.961975   0.484524  0.366188   
5     rsfmri_c_ngd_dt_ngd_dt_QA  time_series  0.960125   0.457558  0.356036   
9                   motordelays   cross_feat  0.954282   0.418001  0.344468   

         NPV       AUC  Specificity  F1 Score  
0  