In [None]:
import streamlit as st
import pandas as pd
import numpy as np
import joblib

st.set_page_config(page_title="Adult Autism Prediction", layout="centered")
st.title("🧠 Autism Spectrum Disorder (ASD) Prediction — Adult Screener")
st.markdown(
    "This tool estimates the probability of ASD **for adults** using clinical/demographic inputs. "
    "It is a screening aid, **not** a diagnosis. Always consult qualified clinicians."
)

@st.cache_resource(show_spinner=False)
def load_pipeline(path: str = "autism_pipeline.pkl"):
    return joblib.load(path)

pipeline = load_pipeline()

CATEGORY_HINTS: dict[str, list[str]] = {}
try:
    from sklearn.compose import ColumnTransformer
    from sklearn.preprocessing import OneHotEncoder
    if hasattr(pipeline, "named_steps"):
        for step in pipeline.named_steps.values():
            if isinstance(step, ColumnTransformer):
                for _, trans, cols in step.transformers_:
                    if isinstance(trans, OneHotEncoder) and hasattr(trans, "categories_"):
                        for col, cats in zip(cols, trans.categories_):
                            CATEGORY_HINTS[str(col)] = [str(c) for c in cats]
except Exception:
    pass

ETHNICITIES_FALLBACK = [
    'Asian', 'Black', 'Hispanic', 'Latino', 'Middle Eastern', 'Mixed',
    'Native Indian', 'Others', 'PaciFica', 'South Asian', 'White European'
]
RELATIONS_FALLBACK = ["Self", "Parent", "Health Care Professional", "Relative", "Others"]
GENDERS_FALLBACK = ["male", "female"]
AGEBANDS_FALLBACK = ["18 and more"]

def discover_raw_required_cols(pipe) -> list:
    required = []
    try:
        from sklearn.compose import ColumnTransformer
        if hasattr(pipe, "named_steps"):
            for step in pipe.named_steps.values():
                if isinstance(step, ColumnTransformer):
                    for _, trans, cols in step.transformers_:
                        # we only keep string-based selectors; ignore numeric indices
                        if isinstance(cols, str):
                            required.append(cols)
                        elif hasattr(cols, "__iter__") and not isinstance(cols, str):
                            for c in cols:
                                if isinstance(c, str):
                                    required.append(c)
    except Exception:
        pass
    # de-duplicate, preserve order
    seen = set()
    out = []
    for c in required:
        if c not in seen:
            out.append(c)
            seen.add(c)
    return out

RAW_REQUIRED = discover_raw_required_cols(pipeline)

EXPECTED_BASE = [
    "result",   
    "ethnicity",
    "austim",   
    "jundice",  
    "relation",
    "age",      
    "gender",
]

ALL_INPUT_COLS = list(dict.fromkeys([*EXPECTED_BASE, *RAW_REQUIRED]))

st.sidebar.header("Settings")
thr = st.sidebar.slider("Decision threshold (ASD if probability ≥ threshold)", 0.0, 1.0, 0.50, 0.01)
uncert_band = st.sidebar.slider("Inconclusive band (± around threshold)", 0.0, 0.20, 0.05, 0.01)
show_debug = st.sidebar.checkbox("Show debug info", value=False)

st.sidebar.markdown("---")
st.sidebar.subheader("Model notes")
st.sidebar.caption("Trained for **adults**. Inputs outside the training distribution may reduce reliability.")

def predict_proba_safe(pipe, X: pd.DataFrame) -> float:
    """Return the positive-class probability robustly. Falls back to label if needed."""
    try:
        proba = pipe.predict_proba(X)
    except Exception:
        try:
            pred = pipe.predict(X)
            return float(np.ravel(pred)[0])
        except Exception:
            return 0.0

    pos_idx = 1  
    try:
        est = None
        if hasattr(pipe, "named_steps") and pipe.named_steps:
            est = list(pipe.named_steps.values())[-1]
        elif hasattr(pipe, "steps") and pipe.steps:
            est = pipe.steps[-1][1]
        classes = getattr(est, "classes_", None)
        if classes is not None:
            classes_list = [str(c) for c in list(classes)]
            if "1" in classes_list:
                pos_idx = classes_list.index("1")
            else:
                for i, c in enumerate(classes_list):
                    if c.lower() in {"asd", "yes", "true", "positive", "pos"}:
                        pos_idx = i
                        break
    except Exception:
        pass

    try:
        return float(proba[0, pos_idx])
    except Exception:
        return float(np.ravel(proba)[pos_idx])


def warn_on_unseen_categories(pipe, X: pd.DataFrame):
    try:
        from sklearn.compose import ColumnTransformer
        from sklearn.preprocessing import OneHotEncoder
        if hasattr(pipe, "named_steps"):
            for step in pipe.named_steps.values():
                if isinstance(step, ColumnTransformer):
                    for _, trans, cols in step.transformers_:
                        if isinstance(trans, OneHotEncoder) and hasattr(trans, "categories_"):
                            for col, cats in zip(cols, trans.categories_):
                                c = str(col)
                                if c in X.columns:
                                    val = X[c].iloc[0]
                                    if (val not in cats) and (not pd.isna(val)):
                                        st.warning(f"Value '{val}' for **{c}** was not seen during training; predictions may be less reliable.")
    except Exception:
        return

def get_user_input() -> tuple[pd.DataFrame, bool]:
    with st.form("user_form"):
        payload: dict[str, object] = {}
        for col in ALL_INPUT_COLS:
            key = str(col)
            lkey = key.lower()

            if lkey == "result":
                payload[key] = st.slider("Screening Result (AQ-10 score)", min_value=0, max_value=10, value=1, step=1)
            elif lkey in {"austim", "jundice"}:
                yn = st.selectbox(key, ["no", "yes"])
                payload[key] = 1 if yn == "yes" else 0
            elif lkey == "ethnicity":
                opts = CATEGORY_HINTS.get(key, sorted(ETHNICITIES_FALLBACK))
                payload[key] = st.selectbox("Ethnicity", opts)
            elif lkey == "relation":
                opts = CATEGORY_HINTS.get(key, RELATIONS_FALLBACK)
                payload[key] = st.selectbox("Relation to the person being tested", opts)
            elif lkey == "age":
               
                payload[key] = st.number_input("Age (years)", min_value=0, max_value=120, value=30, step=1)
            elif lkey in {"gender", "sex"}:
                opts = CATEGORY_HINTS.get(key, GENDERS_FALLBACK)
                payload[key] = st.selectbox(key.capitalize(), opts)
            elif lkey == "age_desc":
                opts = CATEGORY_HINTS.get(key, AGEBANDS_FALLBACK)
                payload[key] = st.selectbox("Age Group", opts)
            else:
                opts = CATEGORY_HINTS.get(key)
                if opts:
                    payload[key] = st.selectbox(key, opts)
                else:
                    payload[key] = st.text_input(key)

        submitted = st.form_submit_button("Predict")

    if ("age_desc" in RAW_REQUIRED) and ("age_desc" not in payload):
        payload["age_desc"] = "18 and more"
    
    if ("sex" in RAW_REQUIRED) and ("sex" not in payload) and ("gender" in payload):
        payload["sex"] = str(payload["gender"])
    if ("gender" in RAW_REQUIRED) and ("gender" not in payload) and ("sex" in payload):
        payload["gender"] = str(payload["sex"])

    for k in list(payload.keys()):
        lk = k.lower()
        if lk in {"result", "austim", "jundice"}:
            try:
                payload[k] = int(payload[k])
            except Exception:
                payload[k] = 0
        elif lk == "age":
            try:
                payload[k] = float(payload[k])
            except Exception:
                payload[k] = np.nan
        else:
            payload[k] = str(payload[k])

    df = pd.DataFrame([payload], columns=ALL_INPUT_COLS)
    return df, submitted


input_df, submitted = get_user_input()

if submitted:
    try:
        input_df = input_df.reindex(columns=ALL_INPUT_COLS)
        warn_on_unseen_categories(pipeline, input_df)
        proba_pos = predict_proba_safe(pipeline, input_df)
        label = int(proba_pos >= thr)

        st.subheader("🔍 Prediction")
        st.metric("Predicted probability (ASD)", f"{proba_pos:.2f}")

        if abs(proba_pos - thr) < uncert_band:
            st.warning("Result **inconclusive** near threshold — consider clinical follow-up.")

        if label == 1:
            st.error("🔴 Likely Autism Spectrum Disorder")
        else:
            st.success("🟢 Not Likely Autism Spectrum Disorder")

        st.caption(f"Decision threshold: {thr:.2f} | Inconclusive band: ±{uncert_band:.2f}")

        if show_debug:
            with st.expander("See input & dtypes sent to the model"):
                st.write(input_df)
                st.write(input_df.dtypes)
    except Exception as e:
        st.error(f"Prediction Error: {e}")
        if show_debug:
            with st.expander("Debug payload"):
                st.write(input_df)
                st.write(input_df.dtypes)

st.markdown("---")
st.subheader("📦 Batch Predictions (CSV)")
st.caption(f"Upload a CSV with the columns the **model expects**: {', '.join(map(str, ALL_INPUT_COLS))}")

_template = pd.DataFrame(columns=ALL_INPUT_COLS).to_csv(index=False)
st.download_button(
    label="Download CSV template",
    data=_template,
    file_name="asd_template.csv",
    mime="text/csv",
)

uploaded = st.file_uploader("Upload CSV for batch prediction", type=["csv"], accept_multiple_files=False)
if uploaded is not None:
    try:
        df_in = pd.read_csv(uploaded)
        missing = [c for c in ALL_INPUT_COLS if c not in df_in.columns]
        if missing:
            st.error(f"Your file is missing required columns: {missing}")
        else:
            df_in = df_in[ALL_INPUT_COLS].copy()

           
            if ("sex" in df_in.columns) and ("gender" not in df_in.columns) and ("gender" in ALL_INPUT_COLS):
                df_in["gender"] = df_in["sex"].astype(str)
            if ("gender" in df_in.columns) and ("sex" not in df_in.columns) and ("sex" in ALL_INPUT_COLS):
                df_in["sex"] = df_in["gender"].astype(str)

            
            if ("age_desc" in ALL_INPUT_COLS) and ("age_desc" not in df_in.columns):
                df_in["age_desc"] = "18 and more"

            
            for c in ALL_INPUT_COLS:
                lc = str(c).lower()
                if lc in {"result", "austim", "jundice"}:
                    df_in[c] = pd.to_numeric(df_in[c], errors="coerce").fillna(0).astype(int)
                elif lc == "age":
                    df_in[c] = pd.to_numeric(df_in[c], errors="coerce")
                else:
                    df_in[c] = df_in[c].astype(str)

        
            try:
                proba = pipeline.predict_proba(df_in)
                if proba.ndim == 2 and proba.shape[1] >= 2:
                    probas = proba[:, 1]
                else:
                    probas = np.ravel(proba)
            except Exception:
                probas = pipeline.predict(df_in).astype(float)

            labels = (probas >= thr).astype(int)
            out = df_in.copy()
            out["prob_asd"] = probas
            out["pred_label"] = labels
            out["inconclusive"] = (np.abs(out["prob_asd"] - thr) < uncert_band)

            st.success("Batch prediction complete.")
            st.dataframe(out.head(100))

            csv_out = out.to_csv(index=False)
            st.download_button(
                label="Download results.csv",
                data=csv_out,
                file_name="asd_batch_predictions.csv",
                mime="text/csv",
            )

    except Exception as e:
        st.error(f"Batch prediction error: {e}")

st.markdown("---")
st.caption("© 2025 ASD Screening Aid — Educational use only.")