In [1]:
from __future__ import annotations

import sys
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import precision_recall_curve, confusion_matrix, average_precision_score


here = Path.cwd().resolve()
project_root = None
for p in [here] + list(here.parents):
    if (p / "app").exists() and (p / "models").exists() and (p / "data").exists():
        project_root = p
        break
if project_root is None:
    raise FileNotFoundError(f"프로젝트 루트를 못 찾았어. 현재 위치={here}")
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

from app.utils.paths import DEFAULT_PATHS as P, ensure_runtime_dirs
from app.utils.metrics import evaluate_churn_metrics
from app.utils.save import save_model_and_artifacts
from models.model_definitions import MLP_enhance

try:
    from app.utils.plotting import configure_matplotlib_korean
    configure_matplotlib_korean()
except Exception:
    pass

ensure_runtime_dirs()


SEED = 42

REPORT_K_LIST = (5, 10, 15, 30)
BEST_K_PCT = 10

LR_LIST = [3e-4, 1e-3, 2e-3]
WD_LIST = [0.0, 1e-5, 1e-4]
BATCH_LIST = [256, 512]

MAX_EPOCHS = 30
USE_POS_WEIGHT = True
RETRAIN_ON_TRAIN_VAL = True
SAVE_VERSION = "tuned"


def seed_everything(seed: int = 42) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def get_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def plot_confusion_matrix_figure(y_true, y_pred, title: str, labels=("non_m2", "m2")):
    y_true_arr = np.asarray(y_true, dtype=int).reshape(-1)
    y_pred_arr = np.asarray(y_pred, dtype=int).reshape(-1)
    cm = confusion_matrix(y_true_arr, y_pred_arr)

    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(cm, interpolation="nearest", aspect="equal", cmap="Blues")
    fig.colorbar(im, ax=ax)

    ax.set_title(title)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("Actual")
    ax.set_xticks([0, 1])
    ax.set_yticks([0, 1])
    ax.set_xticklabels(list(labels))
    ax.set_yticklabels(list(labels))

    thresh = float(cm.max()) / 2.0 if cm.size else 0.0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(
                j, i, str(cm[i, j]),
                ha="center", va="center",
                color="white" if float(cm[i, j]) > thresh else "black",
                fontsize=12,
            )

    ax.set_xlim(-0.5, cm.shape[1] - 0.5)
    ax.set_ylim(cm.shape[0] - 0.5, -0.5)
    fig.tight_layout()
    return fig


def topk_threshold(y_prob: np.ndarray, k_pct: int) -> float:
    prob = np.asarray(y_prob, dtype=float).reshape(-1)
    order = np.argsort(-prob)
    n_sel = int(np.floor(len(prob) * (float(k_pct) / 100.0)))
    n_sel = max(n_sel, 1)
    return float(prob[order[n_sel - 1]])


def plot_confusion_topk(y_true, y_prob, k_pct: int, labels=("non_m2", "m2")):
    thr = topk_threshold(np.asarray(y_prob, dtype=float), int(k_pct))
    y_pred = (np.asarray(y_prob, dtype=float) >= thr).astype(int)
    return plot_confusion_matrix_figure(
        y_true, y_pred,
        title=f"Confusion Matrix (Top {int(k_pct)}%, thr={thr:.5f})",
        labels=labels,
    )


def recall_at_topk(y_true: np.ndarray, y_prob: np.ndarray, k_pct: int) -> float:
    y_true = np.asarray(y_true).astype(int).reshape(-1)
    y_prob = np.asarray(y_prob).astype(float).reshape(-1)
    order = np.argsort(-y_prob)
    n = max(int(np.floor(len(y_true) * (k_pct / 100.0))), 1)
    top_idx = order[:n]
    return float(y_true[top_idx].sum() / max(y_true.sum(), 1))


@torch.no_grad()
def predict_probs(model: nn.Module, loader: DataLoader, device: str):
    model.eval()
    probs, trues = [], []
    for xb, yb in loader:
        xb = xb.to(device)
        logits = model(xb).view(-1)
        prob = torch.sigmoid(logits).detach().cpu().numpy().reshape(-1)
        probs.append(prob)
        trues.append(yb.detach().cpu().numpy().reshape(-1))
    y_true = np.concatenate(trues).astype(int).reshape(-1)
    y_prob = np.concatenate(probs).astype(float).reshape(-1)
    return y_true, y_prob


def make_criterion(y_bin: np.ndarray, device: str):
    if not USE_POS_WEIGHT:
        return nn.BCEWithLogitsLoss()

    pos = int((y_bin == 1).sum())
    neg = int((y_bin == 0).sum())
    pw = torch.tensor([neg / max(pos, 1)], dtype=torch.float32).to(device)
    return nn.BCEWithLogitsLoss(pos_weight=pw)


def load_merged_dataset():
    print("Load parquet files")

    anchors = pd.read_parquet(P.parquet_path("anchors"))
    features = pd.read_parquet(P.parquet_path("features_ml_clean"))
    labels = pd.read_parquet(P.parquet_path("labels"))

    for df in (anchors, features, labels):
        df["user_id"] = df["user_id"].astype(str)

    if "split" in anchors.columns:
        anchors = anchors.drop(columns=["split"])

    need_cols = [c for c in ["user_id", "anchor_time", "label", "split"] if c in labels.columns]
    if "split" not in need_cols:
        raise KeyError(f"labels.parquet에 split 컬럼이 없습니다. labels columns head: {list(labels.columns)[:50]}")

    data = anchors.merge(features, on=["user_id", "anchor_time"], how="inner")
    data = data.merge(labels[need_cols], on=["user_id", "anchor_time"], how="inner")

    data["target"] = data["label"].astype(str).eq("m2").astype(int)
    split_col = data["split"].astype(str)

    feature_cols = [c for c in features.columns if c not in ("user_id", "anchor_time")]
    X_all = data.loc[:, feature_cols].fillna(0.0)
    y_all = data["target"].to_numpy(dtype=int)

    idx_train = split_col.eq("train").to_numpy()
    idx_val = split_col.eq("val").to_numpy()
    idx_test = split_col.eq("test").to_numpy()

    if int(idx_train.sum()) == 0 or int(idx_val.sum()) == 0 or int(idx_test.sum()) == 0:
        raise ValueError(
            f"split 분리가 이상해. train/val/test={int(idx_train.sum())}/{int(idx_val.sum())}/{int(idx_test.sum())}"
        )

    X_train = X_all.loc[idx_train].to_numpy(dtype=float)
    y_train = y_all[idx_train]
    X_val = X_all.loc[idx_val].to_numpy(dtype=float)
    y_val = y_all[idx_val]
    X_test = X_all.loc[idx_test].to_numpy(dtype=float)
    y_test = y_all[idx_test]

    print("rows:", len(data), "train/val/test:", int(idx_train.sum()), int(idx_val.sum()), int(idx_test.sum()))
    print("n_features:", int(X_train.shape[1]))

    return X_train, y_train, X_val, y_val, X_test, y_test, feature_cols


def train_one_config(
    X_train: np.ndarray, y_train: np.ndarray,
    X_val: np.ndarray, y_val: np.ndarray,
    lr: float, weight_decay: float, batch_size: int,
    device: str,
):
    scaler = StandardScaler()
    X_train_s = scaler.fit_transform(X_train)
    X_val_s = scaler.transform(X_val)

    train_ds = TensorDataset(torch.tensor(X_train_s, dtype=torch.float32),
                             torch.tensor(y_train, dtype=torch.float32))
    val_ds = TensorDataset(torch.tensor(X_val_s, dtype=torch.float32),
                           torch.tensor(y_val, dtype=torch.float32))

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)

    input_dim = int(X_train_s.shape[1])
    model = MLP_enhance(input_dim=input_dim).to(device)

    criterion = make_criterion(y_train, device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    def run_epoch(loader, train_mode: bool) -> float:
        model.train() if train_mode else model.eval()
        loss_sum, n_sum = 0.0, 0
        with torch.set_grad_enabled(train_mode):
            for xb, yb in loader:
                xb = xb.to(device)
                yb = yb.to(device).view(-1)

                if train_mode:
                    optimizer.zero_grad()

                logits = model(xb).view(-1)
                loss = criterion(logits, yb)

                if train_mode:
                    loss.backward()
                    optimizer.step()

                bs = int(xb.shape[0])
                loss_sum += float(loss.item()) * bs
                n_sum += bs

        return float(loss_sum / max(n_sum, 1))

    best = {
        "epoch": 0,
        "val_recall_k": -1.0,
        "val_pr_auc": -1.0,
        "state_dict": None,
    }

    for epoch in range(1, MAX_EPOCHS + 1):
        tr_loss = run_epoch(train_loader, True)
        va_loss = run_epoch(val_loader, False)

        yv_true, yv_prob = predict_probs(model, val_loader, device)
        val_pr_auc = float(average_precision_score(yv_true, yv_prob))
        val_recall_k = recall_at_topk(yv_true, yv_prob, BEST_K_PCT)

        print(
            f"[lr={lr:g} wd={weight_decay:g} bs={batch_size}] "
            f"epoch {epoch}/{MAX_EPOCHS} "
            f"train_loss={tr_loss:.5f} val_loss={va_loss:.5f} "
            f"val_pr_auc={val_pr_auc:.5f} val_recall_at_{BEST_K_PCT}pct={val_recall_k:.5f}"
        )

        better = (val_recall_k > best["val_recall_k"] + 1e-12) or (
            abs(val_recall_k - best["val_recall_k"]) <= 1e-12 and val_pr_auc > best["val_pr_auc"]
        )
        if better:
            best["epoch"] = epoch
            best["val_recall_k"] = val_recall_k
            best["val_pr_auc"] = val_pr_auc
            best["state_dict"] = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

    model.load_state_dict(best["state_dict"])
    return model, scaler, best


def retrain_train_val(
    X_train: np.ndarray, y_train: np.ndarray,
    X_val: np.ndarray, y_val: np.ndarray,
    lr: float, weight_decay: float, batch_size: int,
    best_epoch: int,
    device: str,
):
    X_tv = np.concatenate([X_train, X_val], axis=0)
    y_tv = np.concatenate([y_train, y_val], axis=0)

    scaler = StandardScaler()
    X_tv_s = scaler.fit_transform(X_tv)

    ds = TensorDataset(torch.tensor(X_tv_s, dtype=torch.float32),
                       torch.tensor(y_tv, dtype=torch.float32))
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=0)

    input_dim = int(X_tv_s.shape[1])
    model = MLP_enhance(input_dim=input_dim).to(device)

    criterion = make_criterion(y_tv, device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    model.train()
    for epoch in range(1, best_epoch + 1):
        loss_sum, n_sum = 0.0, 0
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device).view(-1)

            optimizer.zero_grad()
            logits = model(xb).view(-1)
            loss = criterion(logits, yb)
            loss.backward()
            optimizer.step()

            bs = int(xb.shape[0])
            loss_sum += float(loss.item()) * bs
            n_sum += bs

        print(f"retrain epoch {epoch}/{best_epoch} loss={loss_sum/max(n_sum,1):.5f}")

    return model, scaler


def evaluate_test_and_save(model, scaler, device, X_test, y_test):
    X_test_s = scaler.transform(X_test)
    test_ds = TensorDataset(torch.tensor(X_test_s, dtype=torch.float32),
                            torch.tensor(y_test, dtype=torch.float32))
    test_loader = DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=0)

    y_true, y_prob = predict_probs(model, test_loader, device)

    metrics = evaluate_churn_metrics(y_true, y_prob)
    print("test_pr_auc:", float(metrics.get("PR-AUC (Average Precision)", 0.0)))

    precision, recall, _ = precision_recall_curve(y_true, y_prob)
    pr_auc_val = metrics.get("PR-AUC (Average Precision)")
    pr_auc_val = float(average_precision_score(y_true, y_prob)) if pr_auc_val is None else float(pr_auc_val)

    fig_pr, ax_pr = plt.subplots(figsize=(6, 5))
    ax_pr.plot(recall, precision, lw=2, label=f"PR-AUC={pr_auc_val:.5f}")
    ax_pr.set_xlabel("Recall")
    ax_pr.set_ylabel("Precision")
    ax_pr.set_title("Precision-Recall Curve")
    ax_pr.grid(alpha=0.3)
    ax_pr.legend()
    fig_pr.tight_layout()

    figures = {"pr_curve": fig_pr}
    for k_pct in REPORT_K_LIST:
        figures[f"confusion_matrix_top{k_pct}"] = plot_confusion_topk(
            y_true=y_true,
            y_prob=y_prob,
            k_pct=int(k_pct),
            labels=("non_m2", "m2"),
        )

    saved = save_model_and_artifacts(
        model=model,
        model_name="mlp_enhance",
        model_type="dl",
        model_id="dl__mlp_enhance",
        split="test",
        metrics=metrics,
        y_true=y_true,
        y_prob=y_prob,
        version=SAVE_VERSION,
        scaler=scaler,
        figures=figures,
        config={
            "model_name": "mlp_enhance",
            "model_type": "dl",
            "version": SAVE_VERSION,
            "feature_source": "features_ml_clean.parquet",
            "best_selection": f"val Recall@{BEST_K_PCT}%",
            "use_pos_weight": bool(USE_POS_WEIGHT),
            "note": "hp tuned on train/val, test evaluated once",
        },
    )

    plt.close(fig_pr)
    for k_pct in REPORT_K_LIST:
        plt.close(figures[f"confusion_matrix_top{k_pct}"])

    print("saved keys:", list(saved.keys()))
    for k, v in saved.items():
        print(k, "->", v)


def main():
    seed_everything(SEED)
    device = get_device()
    print("device:", device)

    X_train, y_train, X_val, y_val, X_test, y_test, feature_cols = load_merged_dataset()

    best_overall = {
        "score": (-1.0, -1.0),
        "hp": None,
        "epoch": 0,
        "model_state": None,
        "scaler": None,
    }

    print("tuning start")
    for lr in LR_LIST:
        for wd in WD_LIST:
            for bs in BATCH_LIST:
                model, scaler, best = train_one_config(
                    X_train, y_train, X_val, y_val,
                    lr=lr, weight_decay=wd, batch_size=bs,
                    device=device,
                )

                cand_score = (best["val_recall_k"], best["val_pr_auc"])
                if (cand_score[0] > best_overall["score"][0] + 1e-12) or (
                    abs(cand_score[0] - best_overall["score"][0]) <= 1e-12 and cand_score[1] > best_overall["score"][1]
                ):
                    best_overall["score"] = cand_score
                    best_overall["hp"] = {"lr": lr, "weight_decay": wd, "batch_size": bs}
                    best_overall["epoch"] = int(best["epoch"])
                    best_overall["model_state"] = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                    best_overall["scaler"] = scaler

    print("best_score:", best_overall["score"])
    print("best_hp:", best_overall["hp"])
    print("best_epoch:", best_overall["epoch"])

    hp = best_overall["hp"]
    if hp is None:
        raise RuntimeError("No best hyperparameters found")

    if RETRAIN_ON_TRAIN_VAL:
        final_model, final_scaler = retrain_train_val(
            X_train, y_train, X_val, y_val,
            lr=float(hp["lr"]),
            weight_decay=float(hp["weight_decay"]),
            batch_size=int(hp["batch_size"]),
            best_epoch=int(best_overall["epoch"]),
            device=device,
        )
    else:
        input_dim = int(X_train.shape[1])
        final_model = MLP_enhance(input_dim=input_dim).to(device)
        final_model.load_state_dict(best_overall["model_state"])
        final_model.eval()
        final_scaler = best_overall["scaler"]

    print("test_eval_start")
    evaluate_test_and_save(final_model, final_scaler, device, X_test, y_test)
    print("Done")


if __name__ == "__main__":
    main()

device: mps
Load parquet files
rows: 813540 train/val/test: 574092 137615 101833
n_features: 14
tuning start
[lr=0.0003 wd=0 bs=256] epoch 1/30 train_loss=0.25078 val_loss=0.22558 val_pr_auc=0.91385 val_recall_at_10pct=0.11190
[lr=0.0003 wd=0 bs=256] epoch 2/30 train_loss=0.24858 val_loss=0.22514 val_pr_auc=0.91400 val_recall_at_10pct=0.11197
[lr=0.0003 wd=0 bs=256] epoch 3/30 train_loss=0.24824 val_loss=0.22539 val_pr_auc=0.91399 val_recall_at_10pct=0.11194
[lr=0.0003 wd=0 bs=256] epoch 4/30 train_loss=0.24810 val_loss=0.22513 val_pr_auc=0.91433 val_recall_at_10pct=0.11201
[lr=0.0003 wd=0 bs=256] epoch 5/30 train_loss=0.24796 val_loss=0.22620 val_pr_auc=0.91419 val_recall_at_10pct=0.11199
[lr=0.0003 wd=0 bs=256] epoch 6/30 train_loss=0.24782 val_loss=0.22493 val_pr_auc=0.91444 val_recall_at_10pct=0.11209
[lr=0.0003 wd=0 bs=256] epoch 7/30 train_loss=0.24779 val_loss=0.22452 val_pr_auc=0.91449 val_recall_at_10pct=0.11212
[lr=0.0003 wd=0 bs=256] epoch 8/30 train_loss=0.24774 val_loss=0.