In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
import math
import random
import re
import time
import hashlib
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, brier_score_loss

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel

FILE_PATHS = [
    "/content/drive/MyDrive/CS441/Dallas_Animal_Shelter_Data_Fiscal_Year_2023_-_2026.csv",
    "/content/drive/MyDrive/CS441/dogs_intake_outcome_2021_2025.xlsx",
]

ADOPTION_ONLY = True
K_LIST = [7, 14, 30]
SEED = 42
BATCH_SIZE = 512
EPOCHS = 80
LR = 1e-3
WEIGHT_DECAY = 1e-4
LABEL_SMOOTHING = 0.05
N_RUNS = 10
NUM_WORKERS = min(8, os.cpu_count() or 4)
TOP_K_FS = 256
TOP_N_BREEDS = 20
USE_BERT = True
BERT_MODEL_NAME = "bert-base-uncased"
BERT_MAX_LENGTH = 64
BERT_BATCH_SIZE = 64
BERT_USE_AMP = True
BERT_CACHE_DIR = "_bert_cache"
TEXT_COLS_FOR_BERT = ["Animal_Breed", "Intake_Condition", "Reason", "Hold_Request"]
BERT_GPU_ID = 0
TRAIN_GPU_ID = 0
RESULTS_TXT_NAME = "surv_binary3_final_only.txt"
AGE_DAYS_MAP = {"puppy": 0.5 * 365.0, "adult": 5.0 * 365.0, "senior": 10.5 * 365.0}

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def pick_cuda_device(gpu_id: int) -> torch.device:
    if torch.cuda.is_available():
        n = torch.cuda.device_count()
        if n == 0:
            return torch.device("cpu")
        if gpu_id < 0 or gpu_id >= n:
            print(f"[WARN] gpu_id={gpu_id} out of range (0..{n-1}), fallback cuda:0")
            gpu_id = 0
        return torch.device(f"cuda:{gpu_id}")
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

def print_gpu_info():
    print("[HW] torch.cuda.is_available:", torch.cuda.is_available())
    if torch.cuda.is_available():
        print("[HW] torch.cuda.device_count:", torch.cuda.device_count())
        for i in range(torch.cuda.device_count()):
            print(f"[HW] GPU {i}: {torch.cuda.get_device_name(i)}")

def now_hms() -> str:
    return time.strftime("%H:%M:%S")

def status(msg: str) -> None:
    print(f"[{now_hms()}] {msg}")

def write_final_line(results_txt_path: str, msg: str) -> None:
    with open(results_txt_path, "a", encoding="utf-8") as f:
        f.write(msg + "\n")

def read_any_table(path: str) -> pd.DataFrame:
    ext = os.path.splitext(path)[1].lower()
    if ext in [".xlsx", ".xls"]:
        return pd.read_excel(path)
    if ext == ".csv":
        return pd.read_csv(path, low_memory=False)
    raise ValueError(f"Unsupported file type: {ext}, path={path}")

def standardize_columns(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()

    if "Animal_Type" in out.columns:
        out["Animal_Type"] = out["Animal_Type"].astype(str).str.upper().str.strip()
    elif "Species" in out.columns:
        out["Animal_Type"] = out["Species"].astype(str).str.upper().str.strip()
    else:
        out["Animal_Type"] = "UNKNOWN"

    if "Animal_Breed" in out.columns:
        out["Animal_Breed"] = out["Animal_Breed"].astype(str)
    else:
        prim = out.get("Primary Breed", pd.Series([""] * len(out))).fillna("").astype(str).str.strip()
        sec = out.get("Secondary Breed", pd.Series([""] * len(out))).fillna("").astype(str).str.strip()
        breed = prim.copy()
        mask = sec.ne("") & prim.ne("") & sec.ne(prim)
        breed.loc[mask] = prim.loc[mask] + " / " + sec.loc[mask]
        mask2 = prim.eq("") & sec.ne("")
        breed.loc[mask2] = sec.loc[mask2]
        out["Animal_Breed"] = breed.replace("", "Unknown")

    if "Intake_Date" in out.columns:
        out["Intake_Date"] = pd.to_datetime(out["Intake_Date"], errors="coerce")
    elif "Intake Date" in out.columns:
        out["Intake_Date"] = pd.to_datetime(out["Intake Date"], errors="coerce")
    elif "Intake_DateTime" in out.columns:
        out["Intake_Date"] = pd.to_datetime(out["Intake_DateTime"], errors="coerce")
    else:
        out["Intake_Date"] = pd.NaT

    if "Outcome_Date" in out.columns:
        out["Outcome_Date"] = pd.to_datetime(out["Outcome_Date"], errors="coerce")
    elif "Outcome Date" in out.columns:
        out["Outcome_Date"] = pd.to_datetime(out["Outcome Date"], errors="coerce")
    else:
        out["Outcome_Date"] = pd.NaT

    if "Outcome_Type" in out.columns:
        out["Outcome_Type"] = out["Outcome_Type"].astype(str)
    elif "Outcome Type" in out.columns:
        out["Outcome_Type"] = out["Outcome Type"].astype(str)
    else:
        out["Outcome_Type"] = "UNKNOWN"
    out["Outcome_Type"] = out["Outcome_Type"].astype(str).str.upper().str.strip()

    if "Intake_Type" in out.columns:
        out["Intake_Type"] = out["Intake_Type"].astype(str)
    elif "Intake Type" in out.columns:
        out["Intake_Type"] = out["Intake Type"].astype(str)

    if "Intake_Condition" in out.columns:
        out["Intake_Condition"] = out["Intake_Condition"].fillna("").astype(str)
    else:
        out["Intake_Condition"] = ""

    out["Reason"] = out.get("Reason", "").fillna("").astype(str) if "Reason" in out.columns else ""
    out["Hold_Request"] = out.get("Hold_Request", "").fillna("").astype(str) if "Hold_Request" in out.columns else ""

    if "AgeGroup" in out.columns:
        out["AgeGroup"] = out["AgeGroup"].fillna("").astype(str)
    elif "Age Group at Intake" in out.columns:
        raw = out["Age Group at Intake"].fillna("").astype(str).str.upper()
        def _map_age(s: str) -> str:
            if "PUPPY" in s or "UNWEANED" in s:
                return "puppy"
            if "SENIOR" in s:
                return "senior"
            if "ADULT" in s:
                return "adult"
            return "adult"
        out["AgeGroup"] = raw.apply(_map_age)
    else:
        out["AgeGroup"] = ""

    if "Days in Custody" in out.columns:
        out["Days_in_Custody"] = pd.to_numeric(out["Days in Custody"], errors="coerce")
    elif "StayLength" in out.columns:
        out["Days_in_Custody"] = pd.to_numeric(out["StayLength"], errors="coerce")
    else:
        out["Days_in_Custody"] = np.nan

    return out

def load_and_merge_sources(paths: List[str]) -> Tuple[pd.DataFrame, str]:
    existing = [p for p in paths if p and os.path.isfile(p)]
    if not existing:
        raise FileNotFoundError("No valid file in FILE_PATHS. Please check paths.")
    out_dir = os.path.dirname(existing[0]) if os.path.dirname(existing[0]) else os.getcwd()
    dfs = []
    for p in existing:
        status(f"[LOAD] reading: {p}")
        raw = read_any_table(p)
        std = standardize_columns(raw)
        std["__source__"] = os.path.basename(p)
        dfs.append(std)
    df = pd.concat(dfs, axis=0, ignore_index=True)
    status(f"[LOAD] merged rows={len(df)} from {len(existing)} files")
    return df, out_dir

HERDING_KEYS = {"SHEPHERD","SHEEPDOG","MALINOIS","CATTLE","HEELER","BORDER COLLIE","KELPIE","COLLIE","CORGI","AUSSIE"}
WORKING_KEYS = {"ROTTWEILER","ROTT","MASTIFF","PYRENEES","ST BERNARD","SAINT BERNARD","BERNESE","BOXER","DOBERMAN","DOBERMANN",
                "NEWFOUNDLAND","GREAT DANE","AKITA","RIDGEBACK","CANE CORSO","PRESA","BULLMASTIFF","KOMONDOR"}
SPORTING_KEYS = {"LABRADOR","GOLDEN","RETRIEVER","SETTER","POINTER","SPANIEL","WEIMARANER","VIZSLA","BRITTANY","GRIFFON"}
NONSPORTING_KEYS = {"POODLE","BULLDOG","DALMATIAN","BICHON","BOSTON","FRENCH","CHOW","SHIBA","SHAR PEI","KEESHOND"}
TOY_KEYS = {"CHIHUAHUA","POMERANIAN","YORKSHIRE","YORKIE","SHIH","MALTESE","PUG","PAPILLON",
            "MIN PINSCHER","MINIATURE PINSCHER","MINI PIN","TOY POODLE","MINIATURE POODLE","PEKINGESE"}
TERRIER_KEYS = {"TERRIER","PIT","STAFFORDSHIRE","STAFFY","BULL TERRIER","AIREDALE","FOX TERRIER","SCOTTISH","WEST HIGHLAND","JACK RUSSELL"}
HOUND_KEYS = {"HOUND","BEAGLE","BASSET","GREYHOUND","DACHSHUND","AFGHAN","BLOODHOUND","COONHOUND","RHODESIAN","RIDGEBACK"}

FIERCE_KEYS = [
    "PIT BULL","PITBULL","AMERICAN PIT","AM PIT","AMERICAN STAFFORDSHIRE","STAFFORDSHIRE TERRIER","AM STAFF",
    "BULL TERRIER","MINIATURE BULL TERRIER","ROTTWEILER","ROTT","DOBERMAN","DOBERMANN","MASTIFF","BULLMASTIFF",
    "CANE CORSO","PRESA CANARIO","DOGO ARGENTINO","TOSA","FILA BRASILEIRO","KANGAL","RIDGEBACK","RHODESIAN RIDGEBACK","AKITA",
]

GIANT_KEYS = {"GREAT DANE","MASTIFF","PYRENEES","ST BERNARD","SAINT BERNARD","NEWFOUNDLAND","IRISH WOLFHOUND"}
LARGE_KEYS = {"ROTTWEILER","ROTT","GERMAN SHEPHERD","SHEPHERD","LABRADOR","GOLDEN","MALINOIS","BOXER","DOBERMAN","RIDGEBACK","AKITA","HUSKY","MALAMUTE"}
SMALL_KEYS = {"CHIHUAHUA","POMERANIAN","YORKSHIRE","YORKIE","SHIH","MALTESE","PUG","PAPILLON","DACHSHUND","JACK RUSSELL",
              "MINIATURE","MINI ","TOY","BOSTON","BEAGLE","CAVALIER","PEKINGESE"}

PUPPY_WORDS = {"UNDERAGE", "PUPPY", "JUVENILE", "BABY", "NEONATE"}
SENIOR_WORDS = {"GERIATRIC", "SENIOR", "OLD", "AGED"}

HEALTH_TOKENS = {
    "healthy": {"APP", "WNL", "NORMAL"},
    "injured": {"INJ", "INJURED"},
    "sick": {"SICK"},
    "critical": {"CRITICAL"},
    "dead": {"DECEASED", "FATAL"},
    "underage": {"UNDERAGE"},
    "geriatric": {"GERIATRIC"},
}

def map_breed_use(breed: str) -> str:
    if not isinstance(breed, str) or not breed:
        return "unknown"
    s = breed.upper()
    if any(k in s for k in HERDING_KEYS): return "herding"
    if any(k in s for k in WORKING_KEYS): return "working"
    if any(k in s for k in SPORTING_KEYS): return "sporting"
    if any(k in s for k in TOY_KEYS): return "toy"
    if any(k in s for k in TERRIER_KEYS): return "terrier"
    if any(k in s for k in HOUND_KEYS): return "hound"
    if any(k in s for k in NONSPORTING_KEYS): return "nonsporting"
    if "MIX" in s or "MIXED" in s: return "mixed_unknown"
    return "unknown"

def map_breed_size(breed: str) -> str:
    if not isinstance(breed, str) or not breed:
        return "unknown"
    s = breed.upper()
    if "GIANT" in s: return "giant"
    if "TOY" in s or "MINIATURE" in s or "MINI " in s: return "small"
    if any(k in s for k in GIANT_KEYS): return "giant"
    if any(k in s for k in LARGE_KEYS): return "large"
    if any(k in s for k in SMALL_KEYS): return "small"
    return "medium"

def map_breed_temper(breed: str) -> str:
    if not isinstance(breed, str) or not breed:
        return "unknown"
    s = breed.upper()
    return "fierce" if any(key in s for key in FIERCE_KEYS) else "normal"

def tokenize_upper(s: str) -> set:
    if not isinstance(s, str):
        return set()
    toks = re.split(r"[^A-Z]+", s.upper())
    return {t for t in toks if t}

def normalize_intake_condition(cond: str) -> str:
    toks = tokenize_upper(cond)
    if not toks: return "unknown"
    if HEALTH_TOKENS["dead"] & toks: return "dead"
    if HEALTH_TOKENS["critical"] & toks: return "critical"
    if HEALTH_TOKENS["injured"] & toks: return "injured"
    if HEALTH_TOKENS["sick"] & toks: return "sick"
    if HEALTH_TOKENS["underage"] & toks: return "underage"
    if HEALTH_TOKENS["geriatric"] & toks: return "geriatric"
    if HEALTH_TOKENS["healthy"] & toks: return "healthy"
    return "other"

def infer_age_group_from_condition(cond: str) -> str:
    toks = tokenize_upper(cond)
    if PUPPY_WORDS & toks: return "puppy"
    if SENIOR_WORDS & toks: return "senior"
    return "adult"

def build_bert_text_column(df: pd.DataFrame) -> pd.Series:
    parts = []
    for col in TEXT_COLS_FOR_BERT:
        if col in df.columns:
            parts.append(df[col].fillna("").astype(str))
    if not parts:
        return pd.Series([""] * len(df), index=df.index)
    text = parts[0]
    for extra in parts[1:]:
        text = text + " [SEP] " + extra
    return text

def compute_stay_length_days(df: pd.DataFrame) -> pd.Series:
    stay = (df["Outcome_Date"] - df["Intake_Date"]).dt.total_seconds() / 86400.0
    stay = stay.round().astype("float")
    stay = stay.fillna(df.get("Days_in_Custody", pd.Series([np.nan] * len(df), index=df.index)))
    return stay

def extract_intake_hour(df: pd.DataFrame) -> pd.Series:
    dt = df["Intake_Date"]
    if not pd.api.types.is_datetime64_any_dtype(dt):
        return pd.Series([np.nan] * len(df), index=df.index)
    hours = dt.dt.hour.astype("float")
    mins = dt.dt.minute.astype("float")
    secs = dt.dt.second.astype("float")
    if hours.nunique(dropna=True) <= 1 and mins.nunique(dropna=True) <= 1 and secs.nunique(dropna=True) <= 1:
        if len(hours.dropna()) > 0 and float(hours.dropna().iloc[0]) == 0.0 and float(mins.dropna().iloc[0]) == 0.0:
            return pd.Series([np.nan] * len(df), index=df.index)
    return hours

def _hash_texts_for_cache(texts: pd.Series, model_name: str, max_length: int) -> str:
    h = hashlib.md5()
    h.update(model_name.encode("utf-8"))
    h.update(str(max_length).encode("utf-8"))
    n = len(texts)
    sample = pd.concat([texts.iloc[: min(2000, n)], texts.iloc[max(0, n - 2000):]], axis=0)
    joined = "\n".join(sample.astype(str).tolist())
    h.update(joined.encode("utf-8", errors="ignore"))
    h.update(str(n).encode("utf-8"))
    return h.hexdigest()

def compute_bert_embeddings(
    texts: pd.Series,
    device: torch.device,
    cache_dir: str,
    model_name: str = BERT_MODEL_NAME,
    max_length: int = BERT_MAX_LENGTH,
    batch_size: int = BERT_BATCH_SIZE,
    use_amp: bool = True,
) -> np.ndarray:
    os.makedirs(cache_dir, exist_ok=True)
    key = _hash_texts_for_cache(texts, model_name=model_name, max_length=max_length)
    cache_path = os.path.join(cache_dir, f"bert_emb_{key}.npy")
    if os.path.isfile(cache_path):
        status(f"[BERT] cache hit: {os.path.basename(cache_path)}")
        return np.load(cache_path).astype(np.float32)
    status(f"[BERT] computing embeddings on {device} ...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    model.eval()
    all_embs = []
    n = len(texts)
    use_cuda_amp = (device.type == "cuda") and use_amp
    t0 = time.time()
    for start in range(0, n, batch_size):
        batch_texts = texts.iloc[start:start + batch_size].tolist()
        enc = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
        )
        enc = {k: v.to(device) for k, v in enc.items()}
        with torch.no_grad():
            if use_cuda_amp:
                with torch.cuda.amp.autocast():
                    outputs = model(**enc)
            else:
                outputs = model(**enc)
            cls_emb = outputs.last_hidden_state[:, 0, :]
        all_embs.append(cls_emb.detach().cpu().numpy())
        if (start // batch_size) % 20 == 0:
            done = min(start + batch_size, n)
            status(f"[BERT] {done}/{n} done, elapsed {int(time.time() - t0)}s")
    embs = np.concatenate(all_embs, axis=0).astype(np.float32)
    np.save(cache_path, embs)
    status(f"[BERT] saved cache: {os.path.basename(cache_path)}")
    return embs

def build_features_and_time(df: pd.DataFrame, bert_device: torch.device, out_dir: str):
    df = df[df["Animal_Type"].astype(str).str.upper().eq("DOG")].copy()
    status(f"[FILTER] Animal_Type==DOG rows={len(df)}")
    df["Intake_Date"] = pd.to_datetime(df["Intake_Date"], errors="coerce")
    df["Outcome_Date"] = pd.to_datetime(df["Outcome_Date"], errors="coerce")
    df["StayLength"] = compute_stay_length_days(df)
    df = df[df["StayLength"].notna() & (df["StayLength"] >= 0)].reset_index(drop=True)
    if ADOPTION_ONLY:
        before = len(df)
        df = df[df["Outcome_Type"].eq("ADOPTION")].copy()
        status(f"[FILTER] ADOPTION_ONLY keep={len(df)}/{before}")
    status(f"[STAY] StayLength describe:\n{df['StayLength'].describe()}")
    df["Breed_Use"] = df["Animal_Breed"].apply(map_breed_use)
    df["Breed_Size"] = df["Animal_Breed"].apply(map_breed_size)
    df["Breed_Temper"] = df["Animal_Breed"].apply(map_breed_temper)
    df["IsFierceBreed"] = (df["Breed_Temper"] == "fierce").astype(int)
    df["Intake_Condition_Group"] = df["Intake_Condition"].apply(normalize_intake_condition)
    if "AgeGroup" not in df.columns or df["AgeGroup"].astype(str).str.strip().eq("").all():
        df["AgeGroup"] = df["Intake_Condition"].apply(infer_age_group_from_condition)
    else:
        df["AgeGroup"] = df["AgeGroup"].fillna("adult").astype(str).str.lower()
    df["AgeDays"] = df["AgeGroup"].map(AGE_DAYS_MAP).fillna(AGE_DAYS_MAP["adult"])
    df["AgeLog"] = np.log1p(df["AgeDays"].astype(float))
    df["Intake_Month"] = df["Intake_Date"].dt.month.astype("Int64").fillna(0).astype(int)
    df["Intake_Year"] = df["Intake_Date"].dt.year.astype("Int64").fillna(0).astype(int)
    df["Intake_Weekday"] = df["Intake_Date"].dt.weekday.astype("Int64").fillna(0).astype(int)
    df["IsWeekend"] = (df["Intake_Weekday"] >= 5).astype(int)
    try:
        import holidays
        us_holidays = holidays.US()
        df["IsHoliday"] = df["Intake_Date"].dt.date.apply(lambda d: d in us_holidays if pd.notna(d) else False).astype(int)
    except Exception:
        df["IsHoliday"] = 0
    df["Intake_Hour"] = extract_intake_hour(df)
    def _time_of_day_bin(hour):
        if pd.isna(hour): return "unknown"
        h = int(hour)
        if h < 6: return "night"
        if h < 12: return "morning"
        if h < 18: return "afternoon"
        return "evening"
    df["TimeOfDayBin"] = df["Intake_Hour"].apply(_time_of_day_bin)
    df["Use__x__Health"] = df["Breed_Use"].astype(str) + "||" + df["Intake_Condition_Group"].astype(str)
    df["Size__x__Health"] = df["Breed_Size"].astype(str) + "||" + df["Intake_Condition_Group"].astype(str)
    df["Temper__x__Health"] = df["Breed_Temper"].astype(str) + "||" + df["Intake_Condition_Group"].astype(str)
    df["Weekend__x__TimeBin"] = df["IsWeekend"].astype(str) + "||" + df["TimeOfDayBin"].astype(str)
    df["Breed_Clean"] = df["Animal_Breed"].fillna("Unknown").astype(str).str.upper().str.strip()
    top_breeds = df["Breed_Clean"].value_counts().index[:TOP_N_BREEDS].tolist()
    breed_ohe_cols: List[str] = []
    for b in top_breeds:
        col = "Breed_" + re.sub(r"[^A-Z0-9]+", "_", b)[:30]
        df[col] = (df["Breed_Clean"] == b).astype(int)
        breed_ohe_cols.append(col)
    df["Breed_Other"] = (~df["Breed_Clean"].isin(top_breeds)).astype(int)
    bert_emb_all = None
    bert_feature_names: List[str] = []
    if USE_BERT:
        status("[BERT] building text column ...")
        df["text_for_bert"] = build_bert_text_column(df)
        cache_dir = os.path.join(out_dir, BERT_CACHE_DIR)
        bert_emb_all = compute_bert_embeddings(
            df["text_for_bert"],
            device=bert_device,
            cache_dir=cache_dir,
            model_name=BERT_MODEL_NAME,
            max_length=BERT_MAX_LENGTH,
            batch_size=BERT_BATCH_SIZE,
            use_amp=BERT_USE_AMP,
        )
        bert_feature_names = [f"BERT_{i}" for i in range(bert_emb_all.shape[1])]
    numeric_cols = [
        "Intake_Month","Intake_Year","Intake_Weekday","IsWeekend","IsHoliday",
        "AgeLog","IsFierceBreed",
    ] + breed_ohe_cols + ["Breed_Other"]
    base_cat_cols = [
        "Breed_Use","Breed_Size","Breed_Temper","AgeGroup","Intake_Condition_Group","TimeOfDayBin",
        "Use__x__Health","Size__x__Health","Temper__x__Health","Weekend__x__TimeBin",
    ]
    extra_cat_cols: List[str] = []
    if "Intake_Type" in df.columns:
        extra_cat_cols.append("Intake_Type")
    cat_cols = base_cat_cols + extra_cat_cols
    X_num_raw = df[numeric_cols].fillna(0.0).to_numpy().astype(np.float32)
    num_mean = X_num_raw.mean(axis=0)
    num_std = X_num_raw.std(axis=0)
    num_std[num_std == 0] = 1.0
    X_num = ((X_num_raw - num_mean) / num_std).astype(np.float32)
    cat_ohe = pd.get_dummies(df[cat_cols].fillna("unknown"), prefix=cat_cols, dummy_na=False)
    X_cat = cat_ohe.to_numpy().astype(np.float32)
    X_list = [X_num, X_cat]
    feature_names = numeric_cols + list(cat_ohe.columns)
    if USE_BERT and bert_emb_all is not None:
        X_list.append(bert_emb_all)
        feature_names += bert_feature_names
    X_full = np.concatenate(X_list, axis=1).astype(np.float32)
    input_dim = X_full.shape[1]
    if input_dim > TOP_K_FS:
        status("[FS] L1 Logistic feature selection (saga + n_jobs=-1) ...")
        lr = LogisticRegression(
            penalty="l1",
            C=0.1,
            solver="saga",
            max_iter=2000,
            multi_class="ovr",
            n_jobs=-1,
        )
        y_fs = pd.qcut(df["StayLength"], q=3, labels=False, duplicates="drop").astype(int).to_numpy()
        lr.fit(X_full, y_fs)
        coef_abs = np.abs(lr.coef_).max(axis=0)
        order = np.argsort(-coef_abs)
        selected_idx = np.sort(order[:TOP_K_FS])
        X_full = X_full[:, selected_idx]
        feature_names = [feature_names[i] for i in selected_idx]
        status(f"[FS] original_dim={input_dim}, selected_dim={X_full.shape[1]}")
    else:
        status(f"[FS] no need, dim={input_dim}")
    return df, X_full, feature_names

class BinaryDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    def __len__(self) -> int:
        return self.X.shape[0]
    def __getitem__(self, idx: int):
        return self.X[idx], self.y[idx]

class BinaryMLP(nn.Module):
    def __init__(self, input_dim: int, feat_dropout: float = 0.0):
        super().__init__()
        hidden1 = 256
        hidden2 = 128
        self.feat_dropout = nn.Dropout(feat_dropout) if feat_dropout > 0 else None
        self.trunk = nn.Sequential(
            nn.Linear(input_dim, hidden1),
            nn.ReLU(),
            nn.BatchNorm1d(hidden1),
            nn.Dropout(0.3),
            nn.Linear(hidden1, hidden2),
            nn.ReLU(),
            nn.BatchNorm1d(hidden2),
            nn.Dropout(0.2),
        )
        self.head = nn.Linear(hidden2, 1)
    def forward(self, x: torch.Tensor):
        if self.feat_dropout is not None:
            x = self.feat_dropout(x)
        h = self.trunk(x)
        logit = self.head(h).squeeze(1)
        return logit

def bce_with_label_smoothing(logits: torch.Tensor, y: torch.Tensor, smoothing: float) -> torch.Tensor:
    if smoothing and smoothing > 0:
        y = (1.0 - smoothing) * y + 0.5 * smoothing
    return nn.functional.binary_cross_entropy_with_logits(logits, y)

@torch.no_grad()
def eval_binary(model: nn.Module, loader: DataLoader, device: torch.device) -> Dict[str, float]:
    model.eval()
    ys = []
    ps = []
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        logits = model(xb)
        prob = torch.sigmoid(logits).detach().cpu().numpy()
        ys.append(yb.numpy())
        ps.append(prob)
    y_true = np.concatenate(ys).astype(int)
    p = np.concatenate(ps)
    y_pred = (p >= 0.5).astype(int)
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    try:
        auc = roc_auc_score(y_true, p)
    except ValueError:
        auc = float("nan")
    try:
        brier = brier_score_loss(y_true, p)
    except ValueError:
        brier = float("nan")
    return {"acc": float(acc), "f1": float(f1), "auc": float(auc), "brier": float(brier)}

def train_one_run_binary(
    X: np.ndarray,
    y: np.ndarray,
    device: torch.device,
    feat_dropout: float = 0.0,
    use_class_weight: bool = True,
) -> Dict[str, float]:
    idx_all = np.arange(len(y))
    idx_trainval, idx_test = train_test_split(idx_all, test_size=0.15, random_state=SEED, stratify=y)
    y_trainval = y[idx_trainval]
    idx_train, idx_val = train_test_split(
        idx_trainval, test_size=0.15 / 0.85, random_state=SEED, stratify=y_trainval
    )
    X_train, y_train = X[idx_train], y[idx_train]
    X_val, y_val = X[idx_val], y[idx_val]
    X_test, y_test = X[idx_test], y[idx_test]
    train_ds = BinaryDataset(X_train, y_train)
    val_ds = BinaryDataset(X_val, y_val)
    test_ds = BinaryDataset(X_test, y_test)
    pin = (device.type == "cuda")
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=pin)
    val_loader = DataLoader(val_ds, batch_size=2048, shuffle=False, num_workers=NUM_WORKERS, pin_memory=pin)
    test_loader = DataLoader(test_ds, batch_size=2048, shuffle=False, num_workers=NUM_WORKERS, pin_memory=pin)
    model = BinaryMLP(input_dim=X.shape[1], feat_dropout=feat_dropout).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    pos_weight = None
    if use_class_weight:
        pos = float(np.sum(y_train == 1))
        neg = float(np.sum(y_train == 0))
        if pos > 0:
            pos_weight = torch.tensor([neg / max(pos, 1.0)], device=device, dtype=torch.float32)
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
    best_val_auc = -1e9
    best_state = None
    for ep in range(EPOCHS):
        model.train()
        losses = []
        for xb, yb in train_loader:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
                logits = model(xb)
                if pos_weight is None:
                    loss = bce_with_label_smoothing(logits, yb, LABEL_SMOOTHING)
                else:
                    y_smooth = (1.0 - LABEL_SMOOTHING) * yb + 0.5 * LABEL_SMOOTHING
                    loss = nn.functional.binary_cross_entropy_with_logits(logits, y_smooth, pos_weight=pos_weight)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            losses.append(float(loss.detach().cpu().item()))
        if (ep + 1) % max(5, EPOCHS // 5) == 0:
            val_m = eval_binary(model, val_loader, device=device)
            status(f"[TRAIN] ep {ep+1}/{EPOCHS} loss={np.mean(losses):.4f} | val AUC={val_m['auc']:.4f}, F1={val_m['f1']:.4f}")
            if val_m["auc"] == val_m["auc"] and val_m["auc"] > best_val_auc:
                best_val_auc = val_m["auc"]
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
    if best_state is not None:
        model.load_state_dict(best_state)
    test_m = eval_binary(model, test_loader, device=device)
    val_m = eval_binary(model, val_loader, device=device)
    return {
        "val_acc": val_m["acc"], "val_f1": val_m["f1"], "val_auc": val_m["auc"], "val_brier": val_m["brier"],
        "test_acc": test_m["acc"], "test_f1": test_m["f1"], "test_auc": test_m["auc"], "test_brier": test_m["brier"],
    }

def run_binary_task_many_runs(
    X: np.ndarray,
    stay_days: np.ndarray,
    K: int,
    device: torch.device,
    n_runs: int,
) -> Dict[str, float]:
    y = (stay_days <= K).astype(int)
    if len(np.unique(y)) < 2:
        raise ValueError(f"K={K}: only one class present, cannot train.")
    status(f"[TASK] K={K} days binary, positive_rate={y.mean():.4f}")
    metrics = []
    for run in range(n_runs):
        status(f"[RUN] K={K} run {run+1}/{n_runs} start")
        set_seed(SEED + run)
        m = train_one_run_binary(X, y, device=device, feat_dropout=0.0, use_class_weight=True)
        status(f"[RUN] K={K} run {run+1}/{n_runs} DONE | test AUC={m['test_auc']:.4f}, F1={m['test_f1']:.4f}")
        metrics.append(m)
    def mean_std(key: str) -> Tuple[float, float]:
        arr = np.array([m[key] for m in metrics], dtype=float)
        return float(np.nanmean(arr)), float(np.nanstd(arr))
    out = {"K": K}
    for key in ["val_acc", "val_f1", "val_auc", "val_brier", "test_acc", "test_f1", "test_auc", "test_brier"]:
        mu, sd = mean_std(key)
        out[key + "_mean"] = mu
        out[key + "_std"] = sd
    return out

def main():
    set_seed(SEED)
    print_gpu_info()
    df, out_dir = load_and_merge_sources(FILE_PATHS)
    results_txt_path = os.path.join(out_dir, RESULTS_TXT_NAME)
    with open(results_txt_path, "w", encoding="utf-8") as f:
        f.write("Binary 3-task (K=7/14/30) Final Summary Only\n")
    bert_device = pick_cuda_device(BERT_GPU_ID)
    train_device = pick_cuda_device(TRAIN_GPU_ID)
    status(f"[DEV] bert_device={bert_device}, train_device={train_device}")
    df_used, X_full, feat_names = build_features_and_time(df, bert_device=bert_device, out_dir=out_dir)
    stay_days = df_used["StayLength"].to_numpy().astype(float)
    status(f"[DATA] rows_used={len(df_used)}, X_dim={X_full.shape[1]}")
    write_final_line(results_txt_path, "")
    write_final_line(results_txt_path, f"ADOPTION_ONLY={ADOPTION_ONLY}, USE_BERT={USE_BERT}, N_RUNS={N_RUNS}")
    write_final_line(results_txt_path, f"rows_used={len(df_used)}, dim={X_full.shape[1]}")
    write_final_line(results_txt_path, "")
    write_final_line(results_txt_path, "K\tvalAUC_mean\tvalAUC_std\tvalF1_mean\tvalF1_std\ttestAUC_mean\ttestAUC_std\ttestF1_mean\ttestF1_std\t(brier_test_mean)")
    all_task_summaries = []
    t0 = time.time()
    for K in K_LIST:
        status(f"[START] K={K} task begin")
        summ = run_binary_task_many_runs(X_full, stay_days, K=K, device=train_device, n_runs=N_RUNS)
        all_task_summaries.append(summ)
        write_final_line(
            results_txt_path,
            f"{K}\t"
            f"{summ['val_auc_mean']:.4f}\t{summ['val_auc_std']:.4f}\t"
            f"{summ['val_f1_mean']:.4f}\t{summ['val_f1_std']:.4f}\t"
            f"{summ['test_auc_mean']:.4f}\t{summ['test_auc_std']:.4f}\t"
            f"{summ['test_f1_mean']:.4f}\t{summ['test_f1_std']:.4f}\t"
            f"{summ['test_brier_mean']:.4f}"
        )
        status(f"[END] K={K} task done")
    test_auc_avg = float(np.mean([s["test_auc_mean"] for s in all_task_summaries]))
    test_f1_avg = float(np.mean([s["test_f1_mean"] for s in all_task_summaries]))
    write_final_line(results_txt_path, "")
    write_final_line(results_txt_path, f"OVERALL(avg over K): testAUC_mean={test_auc_avg:.4f}, testF1_mean={test_f1_avg:.4f}")
    status(f"[DONE] total elapsed {int(time.time() - t0)}s")
    status(f"[DONE] final summary saved to: {results_txt_path}")

if __name__ == "__main__":
    main()