In [None]:
# Age 2–3.5 ASD Screening Model Training

from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay,
    roc_curve,
    auc,
)

sns.set(style="whitegrid")

# ------------------------------------------------------------------
# Environment + paths (works locally and in Google Colab)
# ------------------------------------------------------------------

def is_colab() -> bool:
    try:
        import google.colab  # type: ignore
        return True
    except Exception:
        return False

IN_COLAB = is_colab()
print("Running in Colab:", IN_COLAB)

# If you run locally from: .../ML_TRAINING/age_specific_models
# PROJECT_ROOT will be: .../ML_TRAINING
# We then go one more up to reach Cognitive_Flexibility.
if not IN_COLAB:
    PROJECT_ROOT = Path.cwd().parent.parent
else:
    # In Colab you will upload files; this is just a workspace folder.
    PROJECT_ROOT = Path("/content")

SAMPLE_DATA_DIR = PROJECT_ROOT / "SAMPLE_DATASETS"
ONLINE_DATA_DIR = PROJECT_ROOT / "Online Datasets"

print("PROJECT_ROOT:", PROJECT_ROOT)
print("SAMPLE_DATA_DIR:", SAMPLE_DATA_DIR)
print("ONLINE_DATA_DIR:", ONLINE_DATA_DIR)

# Small helper for Colab uploads

def upload_files_if_needed(required_filenames: list[str]) -> dict[str, Path]:
    """Upload required files in Colab. Returns mapping: filename -> local path."""
    if not IN_COLAB:
        return {}

    from google.colab import files  # type: ignore

    print("\nPlease upload these files from your PC:")
    for f in required_filenames:
        print(" -", f)

    uploaded = files.upload()  # user selects files
    paths: dict[str, Path] = {}

    for name in uploaded.keys():
        paths[name] = Path("/content") / name

    missing = [f for f in required_filenames if f not in paths]
    if missing:
        raise FileNotFoundError(f"Missing uploads: {missing}. Please upload them and run this cell again.")

    return paths


In [None]:
# Load online datasets (Toddler + Autism Screening Combined)
# - Local: read from your repo folder `Online Datasets/`
# - Colab: upload the two CSVs from your PC

TODDLER_FILENAME = "Toddler Autism dataset July 2018.csv"
COMBINED_FILENAME = "Autism_Screening_Data_Combined.csv"

if IN_COLAB:
    uploaded_paths = upload_files_if_needed([
        TODDLER_FILENAME,
        COMBINED_FILENAME,
    ])

    TODDLER_PATH = uploaded_paths[TODDLER_FILENAME]
    COMBINED_PATH = uploaded_paths[COMBINED_FILENAME]
else:
    TODDLER_PATH = ONLINE_DATA_DIR / TODDLER_FILENAME
    COMBINED_PATH = ONLINE_DATA_DIR / "Autism screening data for toddlers" / COMBINED_FILENAME

print("TODDLER_PATH:", TODDLER_PATH)
print("COMBINED_PATH:", COMBINED_PATH)

df_toddler = pd.read_csv(TODDLER_PATH)
df_combined = pd.read_csv(COMBINED_PATH)

print("Toddler shape:", df_toddler.shape)
print("Combined shape:", df_combined.shape)

display(df_toddler.head())
display(df_combined.head())

In [None]:
# Feature engineering (ONLINE) for age 2–3.5
# This follows the same engineered features used by your pipeline:
# - `critical_items_failed`
# - age-normalized z-scores: `social_responsiveness_zscore`, `joint_attention_zscore`, `total_score_zscore`
# - risk flags: `low_attention_flag`, `high_critical_items_flag`, `low_social_flag`

from scipy import stats

LABEL_MAP = {"YES": 1, "Yes": 1, "Y": 1, 1: 1,
             "NO": 0, "No": 0, "N": 0, 0: 0}

AGE_BINS = [24, 30, 36, 42]


def extract_features_from_external(df: pd.DataFrame, age_col: str, qchat_col: str | None = None) -> pd.DataFrame:
    """Create the engineered feature set from an external dataset."""
    df = df.copy()
    features = pd.DataFrame(index=df.index)

    # Age
    features["age_months"] = df[age_col].astype(float)

    # Ensure A1-A10 exist (some datasets may not have them)
    a_cols = [f"A{i}" for i in range(1, 11)]
    for col in a_cols:
        if col not in df.columns:
            df[col] = 0

    # Critical items failed
    features["critical_items_failed"] = df[a_cols].sum(axis=1).astype(int)

    # Completion time (external datasets don't have it) -> constant placeholder
    features["completion_time_sec"] = 300.0

    # Domain raw scores
    social_items = [c for c in ["A1", "A4", "A5"] if c in df.columns]
    joint_items = [c for c in ["A5", "A9"] if c in df.columns]

    features["social_responsiveness_raw"] = (df[social_items].sum(axis=1) / max(len(social_items), 1) * 100) if social_items else 0
    features["joint_attention_raw"] = (df[joint_items].sum(axis=1) / max(len(joint_items), 1) * 100) if joint_items else 0

    # Total score raw (prefer QCHAT if available)
    if qchat_col and qchat_col in df.columns:
        features["total_score_raw"] = df[qchat_col].astype(float) * 10
    else:
        features["total_score_raw"] = df[a_cols].sum(axis=1).astype(float) * 10

    # Age bins + z-scores
    features["age_bin"] = pd.cut(
        features["age_months"],
        bins=AGE_BINS,
        labels=["24-30", "30-36", "36-42"],
        include_lowest=True,
    )

    for col in ["social_responsiveness_raw", "joint_attention_raw", "total_score_raw"]:
        zcol = col.replace("_raw", "_zscore")
        features[zcol] = features.groupby("age_bin")[col].transform(
            lambda x: stats.zscore(x.fillna(x.mean())) if len(x) > 1 and x.std() > 0 else 0
        ).fillna(0)

    # Binary flags
    features["low_attention_flag"] = ((df["A1"] == 1) | (df["A4"] == 1)).astype(int)
    features["high_critical_items_flag"] = (features["critical_items_failed"] >= 3).astype(int)
    features["low_social_flag"] = (features["social_responsiveness_raw"] < 50).astype(int)

    # Label
    if "Class/ASD Traits" in df.columns:
        features["asd_label"] = df["Class/ASD Traits"].map(LABEL_MAP)
    elif "Class/ASD Traits " in df.columns:
        features["asd_label"] = df["Class/ASD Traits "].map(LABEL_MAP)
    elif "Class" in df.columns:
        features["asd_label"] = df["Class"].map(LABEL_MAP)
    else:
        features["asd_label"] = 0

    final_cols = [
        "age_months",
        "critical_items_failed",
        "completion_time_sec",
        "social_responsiveness_zscore",
        "joint_attention_zscore",
        "total_score_zscore",
        "low_attention_flag",
        "high_critical_items_flag",
        "low_social_flag",
        "asd_label",
    ]
    return features[final_cols].dropna(subset=["asd_label"]).reset_index(drop=True)


# --- Apply to each ONLINE dataset ---

# Toddler Autism dataset (uses Age_Mons + Qchat-10-Score)
df_toddler_24_42 = df_toddler[(df_toddler["Age_Mons"] >= 24) & (df_toddler["Age_Mons"] < 42)].copy()
df_toddler_features = extract_features_from_external(df_toddler_24_42, age_col="Age_Mons", qchat_col="Qchat-10-Score")

# Autism Screening Data Combined (uses Age + A1-A10)
df_combined_24_42 = df_combined[(df_combined["Age"] >= 24) & (df_combined["Age"] < 42)].copy()
df_combined_features = extract_features_from_external(df_combined_24_42, age_col="Age", qchat_col=None)

# Merge

df_online_features = pd.concat([df_toddler_features, df_combined_features], ignore_index=True)

print("Online engineered dataset shape:", df_online_features.shape)
display(df_online_features.head())
display(df_online_features["asd_label"].value_counts().to_frame("count"))

In [None]:
# Feature engineering (HOSPITAL/SYSTEM) for age 2–3.5
# We transform your collected questionnaire + clinician reflection outputs into
# the SAME engineered feature columns as the online pipeline.

from scipy import stats

HOSP_ASD_FILENAME = "age_2_3_questionnaire_asd.csv"
HOSP_CTRL_FILENAME = "age_2_3_questionnaire_control.csv"

if IN_COLAB:
    uploaded_paths = upload_files_if_needed([
        HOSP_ASD_FILENAME,
        HOSP_CTRL_FILENAME,
    ])
    HOSP_ASD = uploaded_paths[HOSP_ASD_FILENAME]
    HOSP_CTRL = uploaded_paths[HOSP_CTRL_FILENAME]
else:
    HOSP_ASD = SAMPLE_DATA_DIR / HOSP_ASD_FILENAME
    HOSP_CTRL = SAMPLE_DATA_DIR / HOSP_CTRL_FILENAME

print("HOSP_ASD:", HOSP_ASD)
print("HOSP_CTRL:", HOSP_CTRL)

df_hosp_asd_raw = pd.read_csv(HOSP_ASD)
df_hosp_ctrl_raw = pd.read_csv(HOSP_CTRL)

# Add label column if not present
if "asd_label" not in df_hosp_asd_raw.columns:
    df_hosp_asd_raw["asd_label"] = 1
if "asd_label" not in df_hosp_ctrl_raw.columns:
    df_hosp_ctrl_raw["asd_label"] = 0


def extract_features_from_hospital(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    out = pd.DataFrame(index=df.index)

    # Age
    if "age_months" not in df.columns:
        if "Age_Mons" in df.columns:
            df["age_months"] = df["Age_Mons"]
        elif "Age" in df.columns:
            df["age_months"] = df["Age"]
    out["age_months"] = df["age_months"].astype(float)

    # Critical items failed
    if "critical_items_failed" in df.columns:
        out["critical_items_failed"] = df["critical_items_failed"].astype(float)
    else:
        out["critical_items_failed"] = 0

    # Completion time
    if "completion_time_sec" in df.columns:
        out["completion_time_sec"] = df["completion_time_sec"].astype(float)
    else:
        out["completion_time_sec"] = 300.0

    # Domain raw scores (prefer your collected domain scores if present)
    if "social_responsiveness_score" in df.columns:
        out["social_responsiveness_raw"] = df["social_responsiveness_score"].astype(float)
    else:
        out["social_responsiveness_raw"] = 0

    if "joint_attention_score" in df.columns:
        out["joint_attention_raw"] = df["joint_attention_score"].astype(float)
    else:
        out["joint_attention_raw"] = 0

    # Total score raw (use percentage_score to match 0–100 scale)
    if "percentage_score" in df.columns:
        out["total_score_raw"] = df["percentage_score"].astype(float)
    elif "total_score" in df.columns and "total_questions" in df.columns:
        out["total_score_raw"] = (df["total_score"].astype(float) / df["total_questions"].astype(float) * 100).fillna(0)
    else:
        out["total_score_raw"] = 0

    # Age bins + z-scores (same bins as online)
    out["age_bin"] = pd.cut(
        out["age_months"],
        bins=AGE_BINS,
        labels=["24-30", "30-36", "36-42"],
        include_lowest=True,
    )

    for col in ["social_responsiveness_raw", "joint_attention_raw", "total_score_raw"]:
        zcol = col.replace("_raw", "_zscore")
        out[zcol] = out.groupby("age_bin")[col].transform(
            lambda x: stats.zscore(x.fillna(x.mean())) if len(x) > 1 and x.std() > 0 else 0
        ).fillna(0)

    # Binary flags
    if "attention_level" in df.columns:
        out["low_attention_flag"] = (df["attention_level"].astype(float) <= 2).astype(int)
    else:
        out["low_attention_flag"] = 0

    out["high_critical_items_flag"] = (out["critical_items_failed"] >= 3).astype(int)

    out["low_social_flag"] = (out["social_responsiveness_raw"] < 50).astype(int)

    # Labels
    out["asd_label"] = df["asd_label"].astype(int)
    if "severity_label" in df.columns:
        out["severity_label"] = df["severity_label"]

    final_cols = [
        "age_months",
        "critical_items_failed",
        "completion_time_sec",
        "social_responsiveness_zscore",
        "joint_attention_zscore",
        "total_score_zscore",
        "low_attention_flag",
        "high_critical_items_flag",
        "low_social_flag",
        "asd_label",
    ]
    if "severity_label" in out.columns:
        final_cols.append("severity_label")

    return out[final_cols]


# Restrict ages to 24–42 months
for df_h in (df_hosp_asd_raw, df_hosp_ctrl_raw):
    if "age_months" in df_h.columns:
        df_h.query("24 <= age_months < 42", inplace=True)
    elif "Age_Mons" in df_h.columns:
        df_h.query("24 <= Age_Mons < 42", inplace=True)
    elif "Age" in df_h.columns:
        df_h.query("24 <= Age < 42", inplace=True)

# Engineer

df_hosp_asd_features = extract_features_from_hospital(df_hosp_asd_raw)
df_hosp_ctrl_features = extract_features_from_hospital(df_hosp_ctrl_raw)

df_hosp_features = pd.concat([df_hosp_asd_features, df_hosp_ctrl_features], ignore_index=True)

print("Hospital engineered dataset shape:", df_hosp_features.shape)
display(df_hosp_features.head())
display(df_hosp_features["asd_label"].value_counts().to_frame("count"))

In [None]:
# Combine ONLINE + HOSPITAL engineered features (same feature columns)

DF_ALL = pd.concat([df_online_features, df_hosp_features.drop(columns=["severity_label"], errors="ignore")], ignore_index=True)

print("Combined engineered dataset shape:", DF_ALL.shape)
display(DF_ALL.head())

# Basic descriptive statistics

display(DF_ALL.describe(include="all"))

# Class balance plot
plt.figure(figsize=(4, 4))
sns.countplot(x="asd_label", data=DF_ALL)
plt.title("ASD vs Non-ASD (Age 2–3.5)")
plt.xticks([0, 1], ["Non-ASD", "ASD"])
plt.show()

# Age distribution by class
plt.figure(figsize=(6, 4))
sns.kdeplot(data=DF_ALL, x="age_months", hue="asd_label", common_norm=False)
plt.title("Age Distribution by ASD Label (2–3.5)")
plt.show()

# Critical items failed by class
plt.figure(figsize=(6, 4))
sns.boxplot(x="asd_label", y="critical_items_failed", data=DF_ALL)
plt.title("Critical Items Failed by ASD Label")
plt.xticks([0, 1], ["Non-ASD", "ASD"])
plt.show()

# Total score (age-normalized zscore) by class
plt.figure(figsize=(6, 4))
sns.boxplot(x="asd_label", y="total_score_zscore", data=DF_ALL)
plt.title("Total Score (z-score) by ASD Label")
plt.xticks([0, 1], ["Non-ASD", "ASD"])
plt.show()

In [None]:
# Correlation heatmap (numeric features only)

numeric_cols = DF_ALL.select_dtypes(include=["number"]).columns
corr = DF_ALL[numeric_cols].corr()

plt.figure(figsize=(6, 5))
sns.heatmap(corr, annot=False, cmap="coolwarm", center=0)
plt.title("Correlation Heatmap (Numeric Features)")
plt.show()

In [None]:
# Train/test split and preprocessing (ENGINEERED FEATURES)

FEATURE_COLS = [
    "age_months",
    "critical_items_failed",
    "completion_time_sec",
    "social_responsiveness_zscore",
    "joint_attention_zscore",
    "total_score_zscore",
    "low_attention_flag",
    "high_critical_items_flag",
    "low_social_flag",
]

X = DF_ALL[FEATURE_COLS].copy()
y = DF_ALL["asd_label"].astype(int)

# Fill missing numeric values (if any)
X = X.fillna(X.median(numeric_only=True))

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, random_state=42, stratify=y
)

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

print("Train size:", X_train.shape[0], " Test size:", X_test.shape[0])
print("Features used:", FEATURE_COLS)

In [None]:
# Train logistic regression model (engineered features)

log_reg = LogisticRegression(max_iter=2000, class_weight="balanced")
log_reg.fit(X_train_scaled, y_train)

y_pred = log_reg.predict(X_test_scaled)
y_proba = log_reg.predict_proba(X_test_scaled)[:, 1]

print("Classification report (threshold 0.5):")
print(classification_report(y_test, y_pred, target_names=["Non-ASD", "ASD"]))

cm = confusion_matrix(y_test, y_pred)
ConfusionMatrixDisplay(cm, display_labels=["Non-ASD", "ASD"]).plot(cmap="Blues")
plt.title("Confusion Matrix (Age 2–3.5)")
plt.show()

# ROC curve
fpr, tpr, thr = roc_curve(y_test, y_proba)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(5, 5))
plt.plot(fpr, tpr, label=f"ROC AUC = {roc_auc:.3f}")
plt.plot([0, 1], [0, 1], "k--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate (Sensitivity)")
plt.title("ROC Curve (Age 2–3.5)")
plt.legend(loc="lower right")
plt.show()

# Show feature weights (helps explain the model)
coef = pd.Series(log_reg.coef_[0], index=FEATURE_COLS).sort_values(key=np.abs, ascending=False)
plt.figure(figsize=(7, 4))
sns.barplot(x=coef.values, y=coef.index)
plt.title("Logistic Regression Coefficients (absolute importance)")
plt.xlabel("Coefficient")
plt.ylabel("Feature")
plt.show()

display(coef.to_frame("coef"))

In [None]:
# Save trained model + scaler (and download if running in Colab)

import joblib

# Where to save artifacts
# - Local: save into your repo so backend can use them
# - Colab: save into /content and download

MODEL_NAME = "model_age_2_3_5_questionnaire"

if IN_COLAB:
    out_dir = Path("/content") / "model_artifacts"
else:
    out_dir = PROJECT_ROOT / "ML_TRAINING" / "models"

out_dir.mkdir(parents=True, exist_ok=True)

model_path = out_dir / f"{MODEL_NAME}.pkl"
scaler_path = out_dir / f"scaler_{MODEL_NAME}.pkl"

joblib.dump(log_reg, model_path)
joblib.dump(scaler, scaler_path)

print("Saved model:", model_path)
print("Saved scaler:", scaler_path)

# Optional: also copy into backend production folder (local only)
if not IN_COLAB:
    backend_model_dir = PROJECT_ROOT / "senseai_backend" / "ml_engine" / "models"
    backend_model_dir.mkdir(parents=True, exist_ok=True)

    import shutil

    shutil.copy2(model_path, backend_model_dir / model_path.name)
    shutil.copy2(scaler_path, backend_model_dir / scaler_path.name)

    print("Copied into backend:", backend_model_dir)

# Colab: download the files to your PC
if IN_COLAB:
    from google.colab import files  # type: ignore

    files.download(str(model_path))
    files.download(str(scaler_path))
