In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from xgboost import XGBClassifier, DMatrix, train as xgb_train


import matplotlib.pyplot as plt
import seaborn as sns

import seaborn as sns
import tensorflow as tf
from tensorflow.keras.models import Sequential
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    precision_score, recall_score, f1_score, accuracy_score,
    confusion_matrix, classification_report, ConfusionMatrixDisplay
)

import xgboost as xgb
from xgboost import XGBClassifier

# LOAD DATA
X_fc = np.load("cc200_fc_X_children.npy")
X_fc = np.nan_to_num(X_fc, nan=0.0)
y_labels = np.load("cc200_fc_y_children.npy")

pheno = pd.read_csv("Phenotypic_V1_0b_preprocessed1.csv")
# filter for children and relabel
pheno = pheno[(pheno["AGE_AT_SCAN"] < 18) & pheno["DX_GROUP"].isin([1, 2])].copy()
pheno["label"] = pheno["DX_GROUP"] - 1
pheno["SEX"]   = pheno["SEX"].map({1: 1, 2: 0})
pheno["FIQ"]   = pheno["FIQ"].fillna(pheno["FIQ"].mean())
pheno["AGE_AT_SCAN"] = pheno["AGE_AT_SCAN"].fillna(pheno["AGE_AT_SCAN"].mean())
pheno["SITE_ID"]      = LabelEncoder().fit_transform(pheno["SITE_ID"])

# match subjects
rois_files = os.listdir("nyu_rois")
subject_ids = [f.split("_rois")[0] for f in rois_files if f.endswith(".1D")]
pheno = pheno[pheno["FILE_ID"].isin(subject_ids)].reset_index(drop=True)

# demographic features
X_demo = pheno[["AGE_AT_SCAN","SEX","FIQ","SITE_ID"]].values
X_demo = StandardScaler().fit_transform(X_demo)

# combine features
X_combined = np.hstack([X_fc[:len(X_demo)], X_demo])
y_combined = pheno["label"].values

# univariate feature selection
selector = SelectKBest(score_func=f_classif, k=2000)
X_selected = selector.fit_transform(X_combined, y_combined)

# cross-validated training
skf = StratifiedKFold(n_splits=12, shuffle=True, random_state=30)
all_y_true, all_y_pred = [], []
conf_matrix_total = np.zeros((2,2), dtype=int)

# storage for per-fold metrics
cv_scores = {'accuracy': [], 'precision': [], 'recall': [], 'f1': []}

for fold, (train_idx, test_idx) in enumerate(skf.split(X_selected, y_combined), 1):
    print(f"\n Fold {fold}")
    X_train, X_test = X_selected[train_idx], X_selected[test_idx]
    y_train, y_test = y_combined[train_idx], y_combined[test_idx]

    model = XGBClassifier(
        n_estimators=150,
        max_depth=5,
        learning_rate=0.05,
        subsample=0.8,
        colsample_bytree=0.8,
        objective="binary:logistic",
        use_label_encoder=False,
        eval_metric=["logloss","error"],
        random_state=42
    )
    model.fit(
        X_train, y_train,
        eval_set=[(X_train, y_train)],
        verbose=False
    )
  
    y_scores = model.predict_proba(X_train)[:, 1]  # probability of class 1 (ASD)

    # collect predictions and confusion
    y_pred = model.predict(X_test)
    all_y_true.extend(y_test)
    all_y_pred.extend(y_pred)
    conf_matrix_total += confusion_matrix(y_test, y_pred)

    # after computing y_pred for this fold:
    cv_scores['accuracy'].append(accuracy_score(y_test, y_pred) * 100)
    cv_scores['precision'].append(precision_score(y_test, y_pred) * 100)
    cv_scores['recall'].append(recall_score(y_test, y_pred) * 100)
    cv_scores['f1'].append(f1_score(y_test, y_pred) * 100)


        # pull out the results dict
    results = model.evals_result()
    
        # define x_axis as the list of epoch indices
    num_epochs = len(results["validation_0"]["logloss"])
    x_axis = list(range(1, num_epochs + 1))


# compute metrics
precision = precision_score(all_y_true, all_y_pred)
recall    = recall_score(all_y_true, all_y_pred)
f1        = f1_score(all_y_true, all_y_pred)
accuracy  = accuracy_score(all_y_true, all_y_pred)

print("\n XGBoost Metrics (FC + Clinical + SelectKBest):")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1-score:  {f1:.4f}")
print(f"Accuracy:  {accuracy:.4f}")

 #Boxplots of CV Metrics
metrics = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
data = [
    cv_scores['accuracy'],
    cv_scores['precision'],
    cv_scores['recall'],
    cv_scores['f1']
]

plt.figure(figsize=(8, 5))
plt.boxplot(data, labels=metrics, showmeans=True)
plt.title('Cross-Validation Performance Metrics')
plt.ylabel('Score (%)')
plt.xlabel('Metric')
plt.grid(axis='y', linestyle='--', linewidth=0.5)
plt.savefig("cv_metrics_boxplots.png", dpi=300, bbox_inches="tight")
plt.show()


#Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_train, y_scores)
roc_auc = auc(fpr, tpr)


 #Plot ROC Curve
plt.figure(figsize=(6, 6))
plt.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color="navy", lw=1, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate (Recall)")
plt.title("Receiver Operating Characteristic (ROC) Curve")
plt.legend(loc="lower right")
plt.tight_layout()
plt.savefig("roc_curve_auc.png", dpi=300, bbox_inches="tight")
plt.show()

# --- Percentage Confusion Matrix ---
conf_pct = conf_matrix_total / conf_matrix_total.sum(axis=1, keepdims=True) * 100

plt.figure(figsize=(6,5))
sns.heatmap(
    conf_pct, annot=True, fmt=".2f", cmap="Blues",
    xticklabels=["Control","ASD"], yticklabels=["Control","ASD"]
)
plt.title("Confusion Matrix (%)")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.tight_layout()
plt.savefig("confusion_matrix_percent.png", dpi=300, bbox_inches="tight")
plt.show()


# --- ConfusionMatrixDisplay for raw counts ---
disp = ConfusionMatrixDisplay(
    confusion_matrix=conf_matrix_total,
    display_labels=["Control","ASD"]
)
fig_cm, ax_cm = plt.subplots(figsize=(6,5))
disp.plot(ax=ax_cm, cmap="Oranges")
ax_cm.set_title("Confusion Matrix (Counts)")
fig_cm.tight_layout()
fig_cm.savefig("confusion_matrix_counts.png", dpi=300, bbox_inches="tight")
plt.show()


# --- Training Loss Curve ---
fig_loss, ax_loss = plt.subplots(figsize=(6,4))
ax_loss.plot(x_axis, results["validation_0"]["logloss"], label="Train Log Loss")
ax_loss.set_xlabel("Epoch")
ax_loss.set_ylabel("Log Loss")
ax_loss.set_title("Training Loss Curve")
ax_loss.legend()
fig_loss.tight_layout()
fig_loss.savefig("train_loss_curve.png", dpi=300, bbox_inches="tight")
plt.show()


# --- Training Accuracy Curve ---
fig_acc, ax_acc = plt.subplots(figsize=(6,4))
ax_acc.plot(x_axis, [1 - e for e in results["validation_0"]["error"]], label="Train Accuracy")
ax_acc.set_xlabel("Epoch")
ax_acc.set_ylabel("Accuracy")
ax_acc.set_title("Training Accuracy Curve")
ax_acc.legend()
fig_acc.tight_layout()
fig_acc.savefig("train_accuracy_curve.png", dpi=300, bbox_inches="tight")
plt.show()