In [1]:
# Standard Library
from abc import ABC, abstractmethod

# Third Party
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import average_precision_score, log_loss, roc_auc_score
from torch.utils.data import DataLoader, TensorDataset


# -----------------------------------------------------------
# Base Neural Network Model
# -----------------------------------------------------------
class BaseNNModel(nn.Module, ABC):
    @abstractmethod
    def __init__(self, **kwargs):
        super(BaseNNModel, self).__init__()

    @abstractmethod
    def build_network(self) -> nn.Module:
        pass

    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pass


# -----------------------------------------------------------
# MLP Model
# -----------------------------------------------------------
class MLPModel(BaseNNModel):
    def __init__(
        self,
        input_dim: int,
        hidden_dims: list[int],
        cat_features: list[int] | None = None,
        cat_dims: list[int] | None = None,
        emb_dim: int = 8,
    ):
        super(MLPModel, self).__init__()

        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.cat_features = cat_features or []
        self.cat_dims = cat_dims or []
        self.emb_dim = emb_dim

        self.embeddings = None
        self.network = self.build_network()

    def build_network(self) -> nn.Sequential:
        # categorical embedding layer
        if len(self.cat_dims) > 0:
            self.embeddings = nn.ModuleList(
                [nn.Embedding(cat_dim, self.emb_dim) for cat_dim in self.cat_dims]
            )

        combined_input_dim = (
            self.input_dim - len(self.cat_features)
            + len(self.cat_features) * self.emb_dim
        )

        layers = []
        dims = [combined_input_dim] + self.hidden_dims

        for i in range(len(dims) - 1):
            hidden_layer = nn.Linear(dims[i], dims[i + 1])
            nn.init.kaiming_normal_(
                hidden_layer.weight, mode="fan_in", nonlinearity="relu"
            )

            layers.append(hidden_layer)
            layers.append(nn.BatchNorm1d(dims[i + 1]))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.2))

        # output layer (logit)
        layers.append(nn.Linear(dims[-1], 1))

        return nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        continuous_features = []
        embedded_features = []
        cat_idx = 0

        for i in range(self.input_dim):
            if i in self.cat_features:
                cat_values = x[:, i].long()
                embedded = self.embeddings[cat_idx](cat_values)
                embedded_features.append(embedded)
                cat_idx += 1
            else:
                continuous_features.append(x[:, i : i + 1])

        if continuous_features:
            continuous_features = torch.cat(continuous_features, dim=1)
        else:
            continuous_features = None

        if embedded_features:
            embedded_features = torch.cat(embedded_features, dim=1)
        else:
            embedded_features = None

        if continuous_features is not None and embedded_features is not None:
            combined_features = torch.cat(
                [continuous_features, embedded_features], dim=1
            )
        elif embedded_features is not None:
            combined_features = embedded_features
        elif continuous_features is not None:
            combined_features = continuous_features
        else:
            raise ValueError("No features found for forward pass.")

        logits = self.network(combined_features)
        return logits  # (B, 1)


# -----------------------------------------------------------
# TabTransformer Model (단일 클래스, BaseNNModel 상속)
# -----------------------------------------------------------
class TabTransformerModel(BaseNNModel):
    """
    DeepLearningBinaryClassifier 에서 사용할 단일 클래스 TabTransformer.
    - input_dim: 전체 feature 개수
    - cat_features: 범주형 feature 인덱스 리스트 (X의 column index)
    - cat_dims: 각 범주형 feature의 cardinality (고유값 개수)
    - 나머지 파라미터는 TabTransformer 구조 하이퍼파라미터
    """

    def __init__(
        self,
        input_dim: int,
        cat_features: list[int] | None = None,
        cat_dims: list[int] | None = None,
        d_token: int = 32,
        n_heads: int = 4,
        n_layers: int = 2,
        dim_feedforward: int = 128,
        attn_dropout: float = 0.1,
        embedding_dropout: float = 0.1,
        add_cls: bool = False,
        pooling: str = "concat",  # "concat" or "cls"
        cont_proj: str = "linear",  # "none" or "linear"
        mlp_hidden_dims: tuple[int, ...] = (128, 64),
        mlp_dropout: float = 0.2,
        padding_idx: int = 0,  # 0 을 OOV/pad로 예약한 경우
    ):
        super(TabTransformerModel, self).__init__()

        self.input_dim = input_dim
        self.cat_features = cat_features or []
        self.cat_dims = cat_dims or []
        self.d_token = d_token
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.dim_feedforward = dim_feedforward
        self.attn_dropout = attn_dropout
        self.embedding_dropout = embedding_dropout
        self.add_cls = add_cls
        self.pooling = pooling
        self.cont_proj = cont_proj
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_dropout = mlp_dropout
        self.padding_idx = padding_idx

        assert len(self.cat_features) == len(
            self.cat_dims
        ), "cat_features 개수와 cat_dims 개수가 다릅니다."

        # 연속형 feature 개수
        self.n_cat = len(self.cat_features)
        self.n_cont = self.input_dim - self.n_cat

        # ====== 모듈 구성 (Categorical path) ======
        if self.n_cat == 0:
            self.cat_embeddings = nn.ModuleList()
            self.col_embedding = None
        else:
            self.cat_embeddings = nn.ModuleList(
                [
                    nn.Embedding(
                        num_embeddings=c + (1 if padding_idx is not None else 0),
                        embedding_dim=d_token,
                        padding_idx=padding_idx if padding_idx is not None else None,
                    )
                    for c in self.cat_dims
                ]
            )
            self.col_embedding = nn.Embedding(self.n_cat, d_token)

        if self.add_cls:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, d_token))
            nn.init.normal_(self.cls_token, std=0.02)

        self.embedding_dropout = nn.Dropout(self.embedding_dropout)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_token,
            nhead=self.n_heads,
            dim_feedforward=self.dim_feedforward,
            dropout=self.attn_dropout,
            batch_first=True,
            norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers=self.n_layers)

        # cat embedding 초기화
        for emb in self.cat_embeddings:
            nn.init.normal_(emb.weight, std=0.02)
        if self.col_embedding is not None:
            nn.init.normal_(self.col_embedding.weight, std=0.02)

        # ====== 연속형 path ======
        if self.n_cont > 0:
            self.cont_bn = nn.BatchNorm1d(self.n_cont)
            if self.cont_proj == "linear":
                self.cont_linear = nn.Linear(self.n_cont, d_token)
                nn.init.kaiming_uniform_(
                    self.cont_linear.weight, mode="fan_in", nonlinearity="relu"
                )
                cont_out_dim = d_token
            else:
                self.cont_linear = nn.Identity()
                cont_out_dim = self.n_cont
        else:
            self.cont_bn = None
            self.cont_linear = None
            cont_out_dim = 0

        # ====== Head (logit 출력) ======
        backbone_out = (
            self.d_token if self.pooling == "cls" else self.n_cat * self.d_token
        )
        in_dim = backbone_out + cont_out_dim

        layers = []
        prev = in_dim
        for h in self.mlp_hidden_dims:
            lin = nn.Linear(prev, h)
            nn.init.kaiming_uniform_(
                lin.weight, mode="fan_in", nonlinearity="relu"
            )
            layers.extend(
                [lin, nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(self.mlp_dropout)]
            )
            prev = h
        layers.append(nn.Linear(prev, 1))  # ★ Sigmoid 없음: BCEWithLogitsLoss와 맞춤
        self.head = nn.Sequential(*layers)

    # BaseNNModel 요구사항: build_network 구현 (여기서는 self 반환)
    def build_network(self) -> nn.Module:
        return self

    def _encode_categoricals(self, x_cat: torch.LongTensor) -> torch.Tensor:
        """
        x_cat: (B, n_cat) -> contextualized embedding
        """
        B = x_cat.size(0)
        if self.n_cat == 0:
            return torch.zeros(
                B,
                self.d_token if self.pooling == "cls" else 0,
                device=x_cat.device,
                dtype=torch.float32,
            )

        tok_list = []
        for j, emb in enumerate(self.cat_embeddings):
            tok = emb(x_cat[:, j])  # (B, d)
            if self.col_embedding is not None:
                tok = tok + self.col_embedding.weight[j]  # (d,)
            tok_list.append(tok.unsqueeze(1))  # (B, 1, d)

        x_tok = torch.cat(tok_list, dim=1)  # (B, n_cat, d)

        if self.add_cls:
            cls = self.cls_token.expand(B, -1, -1)  # (B, 1, d)
            x_tok = torch.cat([cls, x_tok], dim=1)  # (B, 1 + n_cat, d)

        x_tok = self.embedding_dropout(x_tok)
        z = self.transformer(x_tok)  # (B, T, d)

        if self.pooling == "cls" and self.add_cls:
            out = z[:, 0, :]  # (B, d)
        elif self.pooling == "cls":
            out = z.mean(dim=1)  # (B, d)
        else:
            if self.add_cls:
                z = z[:, 1:, :]  # CLS 제거
            out = z.reshape(B, -1)  # (B, n_cat*d)
        return out

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, input_dim)
        - self.cat_features 인덱스에 해당하는 column은 categorical, 나머지는 continuous로 사용.
        - 출력: (B, 1) logits (sigmoid 전)
        """
        B, D = x.shape
        device = x.device

        # ----- categorical split -----
        if self.n_cat > 0:
            cat_idx_tensor = torch.tensor(self.cat_features, device=device)
            x_cat = x[:, cat_idx_tensor].long()
        else:
            x_cat = torch.zeros(B, 0, dtype=torch.long, device=device)

        # ----- continuous split -----
        if self.n_cont > 0:
            cont_idx = [i for i in range(self.input_dim) if i not in self.cat_features]
            cont_idx_tensor = torch.tensor(cont_idx, device=device)
            x_cont = x[:, cont_idx_tensor].float()
        else:
            x_cont = None

        # categorical path
        z_cat = self._encode_categoricals(x_cat)

        # continuous path
        if (x_cont is not None) and (self.n_cont > 0):
            if x_cont.ndim == 1:
                x_cont = x_cont.unsqueeze(1)
            x_cont = self.cont_bn(x_cont)
            x_cont = self.cont_linear(x_cont)
            z = torch.cat([z_cat, x_cont], dim=1)
        else:
            z = z_cat

        logits = self.head(z)  # (B, 1)
        return logits


# -----------------------------------------------------------
# Deep Learning Binary Classifier
# -----------------------------------------------------------
class DeepLearningBinaryClassifier(BaseEstimator, ClassifierMixin):
    def __init__(
        self,
        model_type: str = "mlp",
        model_params: dict | None = None,
    ):
        self.model_type = model_type
        self.model_params = model_params or {}
        self.model = None

    @property
    def device(self) -> torch.device:
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _build_model(self) -> BaseNNModel:
        model_registry = {
            "mlp": MLPModel,
            "tabtransformer": TabTransformerModel,
        }

        if self.model_type not in model_registry:
            raise ValueError(f"Unknown model type: {self.model_type}")

        valid_params = {
            k: v for k, v in self.model_params.items() if k not in ["loss_fn", "lr"]
        }

        model_class = model_registry[self.model_type](**valid_params)
        return model_class

    def _get_loss_fn(self) -> nn.Module:
        loss_name = self.model_params.get("loss_fn", "logloss")
        if loss_name == "logloss":
            return nn.BCEWithLogitsLoss(reduction="none")
        else:
            raise ValueError(f"Unknown loss function: {loss_name}")

    def fit(
        self,
        X: np.ndarray,
        y: np.ndarray,
        sample_weight: np.ndarray | None = None,
        eval_set: list[tuple[np.ndarray, np.ndarray]] | None = None,
        eval_metric: list[str] | None = None,
        max_epochs: int = 10,
        patience: int | None = None,
        batch_size: int = 128,
        verbose: bool = True,
    ) -> "DeepLearningBinaryClassifier":

        lr = self.model_params.get("lr", 0.001)
        eval_metric = eval_metric or ["logloss"]

        X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
        y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1).to(self.device)

        if eval_set is not None:
            x_eval_tensor = torch.tensor(
                eval_set[0][0], dtype=torch.float32
            ).to(self.device)
            y_eval_true = eval_set[0][1]
        else:
            x_eval_tensor = None
            y_eval_true = None

        if sample_weight is not None:
            sample_weight_tensor = torch.tensor(
                sample_weight, dtype=torch.float32
            ).to(self.device)
        else:
            sample_weight_tensor = torch.ones_like(y_tensor, dtype=torch.float32)

        train_dataset = TensorDataset(X_tensor, y_tensor, sample_weight_tensor)
        train_dataloader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )

        if self.model is None:
            self.model = self._build_model().to(self.device)

        loss_fn = self._get_loss_fn()
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)

        patience_counter = 0
        best_metric = float("inf")
        best_model_weights = None

        for epoch in range(max_epochs):
            self.model.train()
            epoch_loss = 0.0
            n_batches = 0

            for x_batch, y_batch, weight_batch in train_dataloader:
                optimizer.zero_grad()

                y_pred_logits = self.model(x_batch)
                loss = loss_fn(y_pred_logits, y_batch)
                weighted_loss = (loss * weight_batch).sum() / weight_batch.sum()

                weighted_loss.backward()
                optimizer.step()

                epoch_loss += weighted_loss.item()
                n_batches += 1

            if verbose:
                print(
                    f"Epoch {epoch + 1}/{max_epochs} "
                    f"- [train] loss: {epoch_loss / max(1, n_batches):.6f}"
                )

            # evaluation
            if eval_set is not None:
                self.model.eval()
                with torch.no_grad():
                    y_eval_logits = self.model(x_eval_tensor)
                    y_eval_pred = torch.sigmoid(y_eval_logits).cpu().numpy().ravel()

                eval_metrics = {}
                for metric in eval_metric:
                    if metric == "logloss":
                        eval_metrics["logloss"] = log_loss(y_eval_true, y_eval_pred)
                    elif metric == "average_precision":
                        eval_metrics["average_precision"] = -average_precision_score(
                            y_eval_true, y_eval_pred
                        )
                    elif metric == "auc":
                        eval_metrics["auc"] = -roc_auc_score(y_eval_true, y_eval_pred)
                    else:
                        raise ValueError(f"Unknown metric: {metric}")

                if verbose:
                    metrics_str = ", ".join(
                        [f"{k}: {v:.4f}" for k, v in eval_metrics.items()]
                    )
                    print(f"  - [eval] {metrics_str}")

                # early stopping (기준 metric은 리스트의 첫 번째)
                main_metric_name = eval_metric[0]
                current_metric = eval_metrics.get(
                    main_metric_name, eval_metrics["logloss"]
                )

                if verbose:
                    print(
                        f"    -- (early_stopping) current_metric: {current_metric:.6f}, "
                        f"best_metric: {best_metric:.6f}"
                    )

                if current_metric < best_metric:
                    best_metric = current_metric
                    patience_counter = 0
                    best_model_weights = self.model.state_dict()
                else:
                    patience_counter += 1
                    if patience is not None and patience_counter >= patience:
                        if verbose:
                            print(f"Early stopping at epoch {epoch + 1}")
                        break

        if best_model_weights is not None:
            self.model.load_state_dict(best_model_weights)

        return self

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        self.model = self.model.to(self.device)
        X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)

        with torch.no_grad():
            self.model.eval()
            logits = self.model(X_tensor)
            probs1 = torch.sigmoid(logits).cpu().numpy()

        if probs1.shape[1] == 1:
            probs1 = probs1.reshape(-1, 1)

        probs0 = 1.0 - probs1
        probs = np.hstack((probs0, probs1))
        return probs.astype("float")

    def predict(self, X: np.ndarray) -> np.ndarray:
        probs = self.predict_proba(X)
        return probs.argmax(axis=1)


# -----------------------------------------------------------
# 한국형 피처 스키마(연령/성별/OS/점수형) 데모 데이터 생성
# -----------------------------------------------------------
def make_kor_feature_demo_data(
    n_samples: int = 5000,
    seed: int = 42,
):
    rng = np.random.RandomState(seed)
    n = n_samples

    # --- 1) 각 피처 생성 ---
    # 연령: 18~80 사이 normal 분포
    age = np.clip(rng.normal(loc=40, scale=12, size=n), 18, 80).astype("float32")

    # 성별: 3개 카테고리 (0: 미상, 1: 남, 2: 여)
    gender = rng.randint(0, 3, size=n).astype("int64")

    # 앱가입건수: 포아송
    app_join_cnt = rng.poisson(lam=2.0, size=n).astype("float32")

    # 앱마지막접속경과일: 지수분포 (최근일수록 작게)
    last_login_days = rng.exponential(scale=10.0, size=n).astype("float32")

    # OS: 4개 카테고리 (0: Android, 1: iOS, 2: Web, 3: 기타)
    os_cat = rng.randint(0, 4, size=n).astype("int64")

    # 앱가입경과일: 평균 200일 정도로
    app_join_days = rng.exponential(scale=200.0, size=n).astype("float32")

    # 점수형(0~1 근처의 score 형태로)
    biz_owner_score = (rng.binomial(1, 0.3, size=n) * rng.uniform(0.5, 1.0, size=n)).astype("float32")
    home_owner_score = (rng.binomial(1, 0.4, size=n) * rng.uniform(0.5, 1.0, size=n)).astype("float32")
    car_owner_score = (rng.binomial(1, 0.5, size=n) * rng.uniform(0.5, 1.0, size=n)).astype("float32")
    married_score   = (rng.binomial(1, 0.5, size=n) * rng.uniform(0.5, 1.0, size=n)).astype("float32")
    child_score     = (rng.binomial(1, 0.4, size=n) * rng.uniform(0.5, 1.0, size=n)).astype("float32")

    # --- 2) 라벨 생성용 latent score (대충 그럴듯하게) ---
    # 범주형 가중치
    w_gender = rng.uniform(-0.3, 0.3, size=3)   # gender 0/1/2
    w_os     = rng.uniform(-0.2, 0.2, size=4)   # os 0~3

    score = np.zeros(n, dtype="float32")

    # 연속형/점수형 효과
    score += 0.05 * (age - 40) / 10.0          # 나이가 40보다 많으면 +, 적으면 -
    score += 0.25 * app_join_cnt              # 가입건수 많을수록 +
    score -= 0.03 * last_login_days           # 마지막 접속이 오래되면 -
    score -= 0.002 * app_join_days            # 가입한 지 너무 오래돼도 -
    score += 0.8 * biz_owner_score            # 사업자일수록 +
    score += 0.6 * home_owner_score           # 자가주택 +
    score += 0.7 * car_owner_score            # 자동차 보유 +
    score += 0.9 * married_score              # 결혼 +
    score += 1.0 * child_score                # 자녀 보유 +

    # 범주형 효과
    score += w_gender[gender]
    score += w_os[os_cat]

    # 노이즈 + bias
    noise = rng.normal(scale=0.5, size=n).astype("float32")
    bias = -0.2
    logit = score + bias + noise
    prob = 1.0 / (1.0 + np.exp(-logit))
    y = (prob > 0.5).astype("int64")

    # --- 3) 최종 X 매트릭스 (TabTransformer는 float32로 받고, 범주형은 내부에서 long으로 캐스팅) ---
    X = np.column_stack(
        [
            age,                    # 0
            gender.astype("float32"),  # 1
            app_join_cnt,           # 2
            last_login_days,        # 3
            os_cat.astype("float32"),  # 4
            app_join_days,          # 5
            biz_owner_score,        # 6
            home_owner_score,       # 7
            car_owner_score,        # 8
            married_score,          # 9
            child_score,            # 10
        ]
    ).astype("float32")

    # TabTransformer용 범주형 인덱스 및 cardinality
    cat_feature_indices = [1, 4]   # 성별, OS
    cat_dims = [3, 4]              # gender:3, os:4

    return X, y, cat_feature_indices, cat_dims


def demo_train_tabtransformer():
    print("\n===== TabTransformer Demo (KOR Feature Schema) =====")

    X, y, cat_features, cat_dims = make_kor_feature_demo_data(
        n_samples=6000,
        seed=123,
    )

    print("X shape:", X.shape, " | y shape:", y.shape)
    print("Categorical feature indices:", cat_features)
    print("Categorical dims (cardinality):", cat_dims)

    # train / val / test split
    N = X.shape[0]
    idx = np.arange(N)
    np.random.shuffle(idx)

    tr_end = int(N * 0.7)
    va_end = int(N * 0.85)
    tr_idx, va_idx, te_idx = idx[:tr_end], idx[tr_end:va_end], idx[va_end:]

    X_tr, y_tr = X[tr_idx], y[tr_idx]
    X_va, y_va = X[va_idx], y[va_idx]
    X_te, y_te = X[te_idx], y[te_idx]

    # 간단히 균일한 sample_weight
    sample_weight = np.ones_like(y_tr, dtype="float32")

    input_dim = X.shape[1]

    model_params = {
        "input_dim": input_dim,
        "cat_features": cat_features,  # [1,4]  성별/OS
        "cat_dims": cat_dims,          # [3,4]
        "d_token": 32,
        "n_heads": 4,
        "n_layers": 2,
        "dim_feedforward": 128,
        "attn_dropout": 0.1,
        "embedding_dropout": 0.05,
        "add_cls": False,
        "pooling": "concat",
        "cont_proj": "linear",
        "mlp_hidden_dims": (128, 64),
        "mlp_dropout": 0.2,
        "padding_idx": 0,
        "lr": 1e-3,
        "loss_fn": "logloss",
    }

    clf = DeepLearningBinaryClassifier(
        model_type="tabtransformer",
        model_params=model_params,
    )

    clf.fit(
        X_tr,
        y_tr,
        sample_weight=sample_weight,
        eval_set=[(X_va, y_va)],
        eval_metric=["logloss"],
        max_epochs=15,
        patience=3,
        batch_size=256,
        verbose=True,
    )

    # 평가
    probs_te = clf.predict_proba(X_te)[:, 1]
    preds_te = (probs_te >= 0.5).astype("int64")

    acc = (preds_te == y_te).mean()
    auc = roc_auc_score(y_te, probs_te)
    ll = log_loss(y_te, probs_te)

    print("\n===== Test Metrics =====")
    print(f"Accuracy : {acc:.4f}")
    print(f"ROC-AUC  : {auc:.4f}")
    print(f"Logloss  : {ll:.4f}")
    print("Sample probs (first 10):", np.round(probs_te[:10], 4))


if __name__ == "__main__":
    demo_train_tabtransformer()


===== TabTransformer Demo (KOR Feature Schema) =====
X shape: (6000, 11)  | y shape: (6000,)
Categorical feature indices: [1, 4]
Categorical dims (cardinality): [3, 4]




Epoch 1/15 - [train] loss: 138.271694
  - [eval] logloss: 0.4680
    -- (early_stopping) current_metric: 0.467974, best_metric: inf
Epoch 2/15 - [train] loss: 102.804705
  - [eval] logloss: 0.3810
    -- (early_stopping) current_metric: 0.380984, best_metric: 0.467974
Epoch 3/15 - [train] loss: 87.605172
  - [eval] logloss: 0.3257
    -- (early_stopping) current_metric: 0.325656, best_metric: 0.380984
Epoch 4/15 - [train] loss: 78.494526
  - [eval] logloss: 0.3000
    -- (early_stopping) current_metric: 0.299957, best_metric: 0.325656
Epoch 5/15 - [train] loss: 74.516504
  - [eval] logloss: 0.2966
    -- (early_stopping) current_metric: 0.296625, best_metric: 0.299957
Epoch 6/15 - [train] loss: 71.521301
  - [eval] logloss: 0.2932
    -- (early_stopping) current_metric: 0.293182, best_metric: 0.296625
Epoch 7/15 - [train] loss: 71.191573
  - [eval] logloss: 0.2861
    -- (early_stopping) current_metric: 0.286077, best_metric: 0.293182
Epoch 8/15 - [train] loss: 70.055059
  - [eval] log

In [None]:
# -*- coding: utf-8 -*-
"""
단일 파일 테스트용 코드:
- BaseNNModel
- MLPModel
- TabTransformerModel (BaseNNModel 상속, 하나의 클래스에 TabTransformer 구조 포함)
- DeepLearningBinaryClassifier (MLP / TabTransformer 공통 학습기)
- 더미 데이터셋 생성 + 학습/평가 데모
"""

# Standard Library
from abc import ABC, abstractmethod

# Third Party
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import average_precision_score, log_loss, roc_auc_score
from torch.utils.data import DataLoader, TensorDataset


# -----------------------------------------------------------
# Base Neural Network Model
# -----------------------------------------------------------
class BaseNNModel(nn.Module, ABC):
    @abstractmethod
    def __init__(self, **kwargs):
        super(BaseNNModel, self).__init__()

    @abstractmethod
    def build_network(self) -> nn.Module:
        pass

    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pass


# -----------------------------------------------------------
# MLP Model
# -----------------------------------------------------------
class MLPModel(BaseNNModel):
    def __init__(
        self,
        input_dim: int,
        hidden_dims: list[int],
        cat_features: list[int] | None = None,
        cat_dims: list[int] | None = None,
        emb_dim: int = 8,
    ):
        super(MLPModel, self).__init__()

        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.cat_features = cat_features or []
        self.cat_dims = cat_dims or []
        self.emb_dim = emb_dim

        self.embeddings = None
        self.network = self.build_network()

    def build_network(self) -> nn.Sequential:
        # categorical embedding layer
        if len(self.cat_dims) > 0:
            self.embeddings = nn.ModuleList(
                [nn.Embedding(cat_dim, self.emb_dim) for cat_dim in self.cat_dims]
            )

        combined_input_dim = (
            self.input_dim - len(self.cat_features)
            + len(self.cat_features) * self.emb_dim
        )

        layers = []
        dims = [combined_input_dim] + self.hidden_dims

        for i in range(len(dims) - 1):
            hidden_layer = nn.Linear(dims[i], dims[i + 1])
            nn.init.kaiming_normal_(
                hidden_layer.weight, mode="fan_in", nonlinearity="relu"
            )

            layers.append(hidden_layer)
            layers.append(nn.BatchNorm1d(dims[i + 1]))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.2))

        # output layer (logit)
        layers.append(nn.Linear(dims[-1], 1))

        return nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        continuous_features = []
        embedded_features = []
        cat_idx = 0

        for i in range(self.input_dim):
            if i in self.cat_features:
                cat_values = x[:, i].long()
                embedded = self.embeddings[cat_idx](cat_values)
                embedded_features.append(embedded)
                cat_idx += 1
            else:
                continuous_features.append(x[:, i : i + 1])

        if continuous_features:
            continuous_features = torch.cat(continuous_features, dim=1)
        else:
            continuous_features = None

        if embedded_features:
            embedded_features = torch.cat(embedded_features, dim=1)
        else:
            embedded_features = None

        if continuous_features is not None and embedded_features is not None:
            combined_features = torch.cat(
                [continuous_features, embedded_features], dim=1
            )
        elif embedded_features is not None:
            combined_features = embedded_features
        elif continuous_features is not None:
            combined_features = continuous_features
        else:
            raise ValueError("No features found for forward pass.")

        logits = self.network(combined_features)
        return logits  # (B, 1)


# -----------------------------------------------------------
# TabTransformer Model (단일 클래스, BaseNNModel 상속)
# -----------------------------------------------------------
class TabTransformerModel(BaseNNModel):
    """
    DeepLearningBinaryClassifier 에서 사용할 단일 클래스 TabTransformer.
    - input_dim: 전체 feature 개수
    - cat_features: 범주형 feature 인덱스 리스트 (X의 column index)
    - cat_dims: 각 범주형 feature의 cardinality (고유값 개수)
    - 나머지 파라미터는 TabTransformer 구조 하이퍼파라미터
    """

    def __init__(
        self,
        input_dim: int,
        cat_features: list[int] | None = None,
        cat_dims: list[int] | None = None,
        d_token: int = 32,
        n_heads: int = 4,
        n_layers: int = 2,
        dim_feedforward: int = 128,
        attn_dropout: float = 0.1,
        embedding_dropout: float = 0.1,
        add_cls: bool = False,
        pooling: str = "concat",  # "concat" or "cls"
        cont_proj: str = "linear",  # "none" or "linear"
        mlp_hidden_dims: tuple[int, ...] = (128, 64),
        mlp_dropout: float = 0.2,
        padding_idx: int = 0,  # 0 을 OOV/pad로 예약한 경우
    ):
        super(TabTransformerModel, self).__init__()

        self.input_dim = input_dim
        self.cat_features = cat_features or []
        self.cat_dims = cat_dims or []
        self.d_token = d_token
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.dim_feedforward = dim_feedforward
        self.attn_dropout = attn_dropout
        self.embedding_dropout = embedding_dropout
        self.add_cls = add_cls
        self.pooling = pooling
        self.cont_proj = cont_proj
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_dropout = mlp_dropout
        self.padding_idx = padding_idx

        assert len(self.cat_features) == len(
            self.cat_dims
        ), "cat_features 개수와 cat_dims 개수가 다릅니다."

        # 연속형 feature 개수
        self.n_cat = len(self.cat_features)
        self.n_cont = self.input_dim - self.n_cat

        # ====== 모듈 구성 (Categorical path) ======
        if self.n_cat == 0:
            self.cat_embeddings = nn.ModuleList()
            self.col_embedding = None
        else:
            self.cat_embeddings = nn.ModuleList(
                [
                    nn.Embedding(
                        num_embeddings=c + (1 if padding_idx is not None else 0),
                        embedding_dim=d_token,
                        padding_idx=padding_idx if padding_idx is not None else None,
                    )
                    for c in self.cat_dims
                ]
            )
            self.col_embedding = nn.Embedding(self.n_cat, d_token)

        if self.add_cls:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, d_token))
            nn.init.normal_(self.cls_token, std=0.02)

        self.embedding_dropout = nn.Dropout(self.embedding_dropout)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_token,
            nhead=self.n_heads,
            dim_feedforward=self.dim_feedforward,
            dropout=self.attn_dropout,
            batch_first=True,
            norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers=self.n_layers)

        # cat embedding 초기화
        for emb in self.cat_embeddings:
            nn.init.normal_(emb.weight, std=0.02)
        if self.col_embedding is not None:
            nn.init.normal_(self.col_embedding.weight, std=0.02)

        # ====== 연속형 path ======
        if self.n_cont > 0:
            self.cont_bn = nn.BatchNorm1d(self.n_cont)
            if self.cont_proj == "linear":
                self.cont_linear = nn.Linear(self.n_cont, d_token)
                nn.init.kaiming_uniform_(
                    self.cont_linear.weight, mode="fan_in", nonlinearity="relu"
                )
                cont_out_dim = d_token
            else:
                self.cont_linear = nn.Identity()
                cont_out_dim = self.n_cont
        else:
            self.cont_bn = None
            self.cont_linear = None
            cont_out_dim = 0

        # ====== Head (logit 출력) ======
        backbone_out = (
            self.d_token if self.pooling == "cls" else self.n_cat * self.d_token
        )
        in_dim = backbone_out + cont_out_dim

        layers = []
        prev = in_dim
        for h in self.mlp_hidden_dims:
            lin = nn.Linear(prev, h)
            nn.init.kaiming_uniform_(
                lin.weight, mode="fan_in", nonlinearity="relu"
            )
            layers.extend(
                [lin, nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(self.mlp_dropout)]
            )
            prev = h
        layers.append(nn.Linear(prev, 1))  # ★ Sigmoid 없음: BCEWithLogitsLoss와 맞춤
        self.head = nn.Sequential(*layers)

    # BaseNNModel 요구사항: build_network 구현 (여기서는 self 반환)
    def build_network(self) -> nn.Module:
        return self

    def _encode_categoricals(self, x_cat: torch.LongTensor) -> torch.Tensor:
        """
        x_cat: (B, n_cat) -> contextualized embedding
        """
        B = x_cat.size(0)
        if self.n_cat == 0:
            return torch.zeros(
                B,
                self.d_token if self.pooling == "cls" else 0,
                device=x_cat.device,
                dtype=torch.float32,
            )

        tok_list = []
        for j, emb in enumerate(self.cat_embeddings):
            tok = emb(x_cat[:, j])  # (B, d)
            if self.col_embedding is not None:
                tok = tok + self.col_embedding.weight[j]  # (d,)
            tok_list.append(tok.unsqueeze(1))  # (B, 1, d)

        x_tok = torch.cat(tok_list, dim=1)  # (B, n_cat, d)

        if self.add_cls:
            cls = self.cls_token.expand(B, -1, -1)  # (B, 1, d)
            x_tok = torch.cat([cls, x_tok], dim=1)  # (B, 1 + n_cat, d)

        x_tok = self.embedding_dropout(x_tok)
        z = self.transformer(x_tok)  # (B, T, d)

        if self.pooling == "cls" and self.add_cls:
            out = z[:, 0, :]  # (B, d)
        elif self.pooling == "cls":
            out = z.mean(dim=1)  # (B, d)
        else:
            if self.add_cls:
                z = z[:, 1:, :]  # CLS 제거
            out = z.reshape(B, -1)  # (B, n_cat*d)
        return out

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, input_dim)
        - self.cat_features 인덱스에 해당하는 column은 categorical, 나머지는 continuous로 사용.
        - 출력: (B, 1) logits (sigmoid 전)
        """
        B, D = x.shape
        device = x.device

        # ----- categorical split -----
        if self.n_cat > 0:
            cat_idx_tensor = torch.tensor(self.cat_features, device=device)
            x_cat = x[:, cat_idx_tensor].long()
        else:
            x_cat = torch.zeros(B, 0, dtype=torch.long, device=device)

        # ----- continuous split -----
        if self.n_cont > 0:
            cont_idx = [i for i in range(self.input_dim) if i not in self.cat_features]
            cont_idx_tensor = torch.tensor(cont_idx, device=device)
            x_cont = x[:, cont_idx_tensor].float()
        else:
            x_cont = None

        # categorical path
        z_cat = self._encode_categoricals(x_cat)

        # continuous path
        if (x_cont is not None) and (self.n_cont > 0):
            if x_cont.ndim == 1:
                x_cont = x_cont.unsqueeze(1)
            x_cont = self.cont_bn(x_cont)
            x_cont = self.cont_linear(x_cont)
            z = torch.cat([z_cat, x_cont], dim=1)
        else:
            z = z_cat

        logits = self.head(z)  # (B, 1)
        return logits


# -----------------------------------------------------------
# Deep Learning Binary Classifier
# -----------------------------------------------------------
class DeepLearningBinaryClassifier(BaseEstimator, ClassifierMixin):
    def __init__(
        self,
        model_type: str = "mlp",
        model_params: dict | None = None,
    ):
        self.model_type = model_type
        self.model_params = model_params or {}
        self.model = None

    @property
    def device(self) -> torch.device:
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _build_model(self) -> BaseNNModel:
        model_registry = {
            "mlp": MLPModel,
            "tabtransformer": TabTransformerModel,
        }

        if self.model_type not in model_registry:
            raise ValueError(f"Unknown model type: {self.model_type}")

        valid_params = {
            k: v for k, v in self.model_params.items() if k not in ["loss_fn", "lr"]
        }

        model_class = model_registry[self.model_type](**valid_params)
        return model_class

    def _get_loss_fn(self) -> nn.Module:
        loss_name = self.model_params.get("loss_fn", "logloss")
        if loss_name == "logloss":
            return nn.BCEWithLogitsLoss(reduction="none")
        else:
            raise ValueError(f"Unknown loss function: {loss_name}")

    def fit(
        self,
        X: np.ndarray,
        y: np.ndarray,
        sample_weight: np.ndarray | None = None,
        eval_set: list[tuple[np.ndarray, np.ndarray]] | None = None,
        eval_metric: list[str] | None = None,
        max_epochs: int = 10,
        patience: int | None = None,
        batch_size: int = 128,
        verbose: bool = True,
    ) -> "DeepLearningBinaryClassifier":

        lr = self.model_params.get("lr", 0.001)
        eval_metric = eval_metric or ["logloss"]

        X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
        y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1).to(self.device)

        if eval_set is not None:
            x_eval_tensor = torch.tensor(
                eval_set[0][0], dtype=torch.float32
            ).to(self.device)
            y_eval_true = eval_set[0][1]
        else:
            x_eval_tensor = None
            y_eval_true = None

        if sample_weight is not None:
            sample_weight_tensor = torch.tensor(
                sample_weight, dtype=torch.float32
            ).to(self.device)
        else:
            sample_weight_tensor = torch.ones_like(y_tensor, dtype=torch.float32)

        train_dataset = TensorDataset(X_tensor, y_tensor, sample_weight_tensor)
        train_dataloader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )

        if self.model is None:
            self.model = self._build_model().to(self.device)

        loss_fn = self._get_loss_fn()
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)

        patience_counter = 0
        best_metric = float("inf")
        best_model_weights = None

        for epoch in range(max_epochs):
            self.model.train()
            epoch_loss = 0.0
            n_batches = 0

            for x_batch, y_batch, weight_batch in train_dataloader:
                optimizer.zero_grad()

                y_pred_logits = self.model(x_batch)
                loss = loss_fn(y_pred_logits, y_batch)
                weighted_loss = (loss * weight_batch).sum() / weight_batch.sum()

                weighted_loss.backward()
                optimizer.step()

                epoch_loss += weighted_loss.item()
                n_batches += 1

            if verbose:
                print(
                    f"Epoch {epoch + 1}/{max_epochs} "
                    f"- [train] loss: {epoch_loss / max(1, n_batches):.6f}"
                )

            # evaluation
            if eval_set is not None:
                self.model.eval()
                with torch.no_grad():
                    y_eval_logits = self.model(x_eval_tensor)
                    y_eval_pred = torch.sigmoid(y_eval_logits).cpu().numpy().ravel()

                eval_metrics = {}
                for metric in eval_metric:
                    if metric == "logloss":
                        eval_metrics["logloss"] = log_loss(y_eval_true, y_eval_pred)
                    elif metric == "average_precision":
                        eval_metrics["average_precision"] = -average_precision_score(
                            y_eval_true, y_eval_pred
                        )
                    elif metric == "auc":
                        eval_metrics["auc"] = -roc_auc_score(y_eval_true, y_eval_pred)
                    else:
                        raise ValueError(f"Unknown metric: {metric}")

                if verbose:
                    metrics_str = ", ".join(
                        [f"{k}: {v:.4f}" for k, v in eval_metrics.items()]
                    )
                    print(f"  - [eval] {metrics_str}")

                # early stopping (기준 metric은 리스트의 첫 번째)
                main_metric_name = eval_metric[0]
                current_metric = eval_metrics.get(
                    main_metric_name, eval_metrics["logloss"]
                )

                if verbose:
                    print(
                        f"    -- (early_stopping) current_metric: {current_metric:.6f}, "
                        f"best_metric: {best_metric:.6f}"
                    )

                if current_metric < best_metric:
                    best_metric = current_metric
                    patience_counter = 0
                    best_model_weights = self.model.state_dict()
                else:
                    patience_counter += 1
                    if patience is not None and patience_counter >= patience:
                        if verbose:
                            print(f"Early stopping at epoch {epoch + 1}")
                        break

        if best_model_weights is not None:
            self.model.load_state_dict(best_model_weights)

        return self

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        self.model = self.model.to(self.device)
        X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)

        with torch.no_grad():
            self.model.eval()
            logits = self.model(X_tensor)
            probs1 = torch.sigmoid(logits).cpu().numpy()

        if probs1.shape[1] == 1:
            probs1 = probs1.reshape(-1, 1)

        probs0 = 1.0 - probs1
        probs = np.hstack((probs0, probs1))
        return probs.astype("float")

    def predict(self, X: np.ndarray) -> np.ndarray:
        probs = self.predict_proba(X)
        return probs.argmax(axis=1)


# -----------------------------------------------------------
# 더미 데이터셋 생성 (범주형 + 연속형 섞인 이진분류)
# -----------------------------------------------------------
def make_dummy_tabular_data(
    n_samples: int = 5000,
    n_cont: int = 10,
    seed: int = 42,
):
    rng = np.random.RandomState(seed)

    # 범주형 feature 3개
    cat_dims = [3, 5, 4]  # 각 컬럼의 cardinality
    n_cat = len(cat_dims)

    X_cat = np.zeros((n_samples, n_cat), dtype=np.int64)
    for j, k in enumerate(cat_dims):
        X_cat[:, j] = rng.randint(0, k, size=n_samples)

    # 연속형 feature
    X_cont = rng.randn(n_samples, n_cont).astype("float32")

    # latent score 생성 (cat + cont 섞어서)
    w_cat_cols = [rng.uniform(-1.0, 1.0, size=k) for k in cat_dims]
    w_cont = rng.randn(n_cont).astype("float32")

    score_cat = np.zeros(n_samples, dtype="float32")
    for j in range(n_cat):
        score_cat += w_cat_cols[j][X_cat[:, j]]

    score_cont = (X_cont * w_cont).sum(axis=1)

    bias = 0.1
    noise = rng.normal(scale=0.5, size=n_samples).astype("float32")
    logit = 0.7 * score_cat + 0.8 * score_cont + bias + noise
    prob = 1.0 / (1.0 + np.exp(-logit))
    y = (prob > 0.5).astype("int64")

    # 최종 feature: [cat | cont] 형태
    X = np.concatenate(
        [X_cat.astype("float32"), X_cont.astype("float32")], axis=1
    ).astype("float32")

    cat_feature_indices = list(range(n_cat))  # 0,1,2
    return X, y, cat_feature_indices, cat_dims


# -----------------------------------------------------------
# 간단한 학습/평가 데모
# -----------------------------------------------------------
def demo_train_tabtransformer():
    print("\n===== TabTransformer Demo =====")

    X, y, cat_features, cat_dims = make_dummy_tabular_data(
        n_samples=6000, n_cont=10, seed=123
    )

    # train / val / test split
    N = X.shape[0]
    idx = np.arange(N)
    np.random.shuffle(idx)

    tr_end = int(N * 0.7)
    va_end = int(N * 0.85)
    tr_idx, va_idx, te_idx = idx[:tr_end], idx[tr_end:va_end], idx[va_end:]

    X_tr, y_tr = X[tr_idx], y[tr_idx]
    X_va, y_va = X[va_idx], y[va_idx]
    X_te, y_te = X[te_idx], y[te_idx]

    # 간단히 균일한 sample_weight
    sample_weight = np.ones_like(y_tr, dtype="float32")

    input_dim = X.shape[1]

    model_params = {
        "input_dim": input_dim,
        "cat_features": cat_features,  # [0,1,2]
        "cat_dims": cat_dims,          # [3,5,4]
        "d_token": 32,
        "n_heads": 4,
        "n_layers": 2,
        "dim_feedforward": 128,
        "attn_dropout": 0.1,
        "embedding_dropout": 0.05,
        "add_cls": False,
        "pooling": "concat",
        "cont_proj": "linear",
        "mlp_hidden_dims": (128, 64),
        "mlp_dropout": 0.2,
        "padding_idx": 0,
        "lr": 1e-3,
        "loss_fn": "logloss",
    }

    clf = DeepLearningBinaryClassifier(
        model_type="tabtransformer",
        model_params=model_params,
    )

    clf.fit(
        X_tr,
        y_tr,
        sample_weight=sample_weight,
        eval_set=[(X_va, y_va)],
        eval_metric=["logloss"],
        max_epochs=15,
        patience=3,
        batch_size=256,
        verbose=True,
    )

    # 평가
    probs_te = clf.predict_proba(X_te)[:, 1]
    preds_te = (probs_te >= 0.5).astype("int64")

    acc = (preds_te == y_te).mean()
    auc = roc_auc_score(y_te, probs_te)
    ll = log_loss(y_te, probs_te)

    print("\n===== Test Metrics =====")
    print(f"Accuracy : {acc:.4f}")
    print(f"ROC-AUC  : {auc:.4f}")
    print(f"Logloss  : {ll:.4f}")
    print("Sample probs (first 10):", np.round(probs_te[:10], 4))


if __name__ == "__main__":
    demo_train_tabtransformer()


===== TabTransformer Demo =====




Epoch 1/15 - [train] loss: 156.340005
  - [eval] logloss: 0.5136
    -- (early_stopping) current_metric: 0.513568, best_metric: inf
Epoch 2/15 - [train] loss: 115.444524
  - [eval] logloss: 0.3385
    -- (early_stopping) current_metric: 0.338491, best_metric: 0.513568
Epoch 3/15 - [train] loss: 83.781380
  - [eval] logloss: 0.2459
    -- (early_stopping) current_metric: 0.245882, best_metric: 0.338491
Epoch 4/15 - [train] loss: 66.346304
  - [eval] logloss: 0.2065
    -- (early_stopping) current_metric: 0.206536, best_metric: 0.245882
Epoch 5/15 - [train] loss: 56.908177
  - [eval] logloss: 0.1844
    -- (early_stopping) current_metric: 0.184432, best_metric: 0.206536
Epoch 6/15 - [train] loss: 49.811581
  - [eval] logloss: 0.1709
    -- (early_stopping) current_metric: 0.170867, best_metric: 0.184432
Epoch 7/15 - [train] loss: 46.979477
  - [eval] logloss: 0.1658
    -- (early_stopping) current_metric: 0.165841, best_metric: 0.170867
Epoch 8/15 - [train] loss: 43.672781
  - [eval] log

In [None]:
# 1. TabTransformerNet 추가 (logit 출력 버전)

import torch
import torch.nn as nn
import numpy as np
from abc import ABC, abstractmethod

# ... BaseNNModel, MLPModel 는 그대로 두고, 그 아래에 추가 ...


class TabTransformerNet(nn.Module):
    """
    TabTransformer backbone + head (binary, logit 출력).
    - cat_cardinalities: 각 범주형 컬럼의 unique 개수 (OOV 제외)
    - n_continuous: 연속형 컬럼 개수
    """
    def __init__(
        self,
        cat_cardinalities,          # List[int] (exclude OOV)
        n_continuous=0,
        d_token=32,
        n_heads=4,
        n_layers=2,
        dim_feedforward=128,
        attn_dropout=0.1,
        embedding_dropout=0.1,
        add_cls=False,
        pooling="concat",           # "concat" or "cls"
        cont_proj="linear",         # "none" or "linear"
        mlp_hidden_dims=(128, 64),
        mlp_dropout=0.2,
        padding_idx=0,              # reserve 0 for OOV/pad if not None
        norm_first=True,
    ):
        super().__init__()
        assert pooling in ("concat", "cls")
        self.n_cat = len(cat_cardinalities)
        self.n_cont = n_continuous
        self.d_token = d_token
        self.add_cls = add_cls
        self.pooling = pooling
        self.cont_proj = cont_proj

        # ---- Categorical path ----
        if self.n_cat == 0:
            self.cat_embeddings = nn.ModuleList()
            self.col_embedding = None
        else:
            # +1 slot if padding_idx is used (OOV/pad)
            self.cat_embeddings = nn.ModuleList([
                nn.Embedding(
                    num_embeddings=c + (1 if (padding_idx is not None) else 0),
                    embedding_dim=d_token,
                    padding_idx=padding_idx,
                )
                for c in cat_cardinalities
            ])
            self.col_embedding = nn.Embedding(self.n_cat, d_token)

        if self.add_cls:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, d_token))
            nn.init.normal_(self.cls_token, std=0.02)

        self.embedding_dropout = nn.Dropout(embedding_dropout)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_token,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=attn_dropout,
            batch_first=True,
            norm_first=norm_first,
        )
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers=n_layers)

        # init categorical embeddings
        for emb in self.cat_embeddings:
            nn.init.normal_(emb.weight, std=0.02)
        if self.col_embedding is not None:
            nn.init.normal_(self.col_embedding.weight, std=0.02)

        # ---- Continuous path ----
        if n_continuous > 0:
            self.cont_bn = nn.BatchNorm1d(n_continuous)
            if cont_proj == "linear":
                self.cont_linear = nn.Linear(n_continuous, d_token)
                nn.init.kaiming_uniform_(
                    self.cont_linear.weight, mode="fan_in", nonlinearity="relu"
                )
                cont_out_dim = d_token
            else:
                self.cont_linear = nn.Identity()
                cont_out_dim = n_continuous
        else:
            self.cont_bn = None
            self.cont_linear = None
            cont_out_dim = 0

        # ---- Head (마지막에 Sigmoid 없음: logit 출력) ----
        backbone_out = (d_token if pooling == "cls" else self.n_cat * d_token)
        in_dim = backbone_out + cont_out_dim

        layers = []
        prev = in_dim
        for h in mlp_hidden_dims:
            lin = nn.Linear(prev, h)
            nn.init.kaiming_uniform_(lin.weight, mode="fan_in", nonlinearity="relu")
            layers.extend([lin, nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(mlp_dropout)])
            prev = h
        layers.append(nn.Linear(prev, 1))  # logit
        self.head = nn.Sequential(*layers)

    def _encode_categoricals(self, x_cat: torch.LongTensor):
        """
        x_cat: (B, n_cat) -> contextualized representation
        returns:
          pooling='concat' -> (B, n_cat*d)
          pooling='cls'    -> (B, d)
        """
        B = x_cat.size(0)
        if self.n_cat == 0:
            return torch.zeros(
                B,
                self.d_token if self.pooling == "cls" else 0,
                device=x_cat.device,
                dtype=torch.float32,
            )

        tok_list = []
        for j, emb in enumerate(self.cat_embeddings):
            tok = emb(x_cat[:, j])  # (B, d)
            if self.col_embedding is not None:
                tok = tok + self.col_embedding.weight[j]  # (d,)
            tok_list.append(tok.unsqueeze(1))  # (B, 1, d)

        x_tok = torch.cat(tok_list, dim=1)  # (B, n_cat, d)

        if self.add_cls:
            cls = self.cls_token.expand(B, -1, -1)  # (B, 1, d)
            x_tok = torch.cat([cls, x_tok], dim=1)  # (B, 1+n_cat, d)

        x_tok = self.embedding_dropout(x_tok)
        z = self.transformer(x_tok)  # (B, T, d)

        if self.pooling == "cls" and self.add_cls:
            out = z[:, 0, :]  # (B, d)
        elif self.pooling == "cls":
            out = z.mean(dim=1)  # (B, d)
        else:
            if self.add_cls:
                z = z[:, 1:, :]  # drop CLS
            out = z.reshape(B, -1)  # (B, n_cat*d)
        return out

    def forward(self, x_cat: torch.LongTensor, x_cont: torch.FloatTensor | None = None):
        z_cat = self._encode_categoricals(x_cat)
        if (x_cont is not None) and (self.n_cont > 0):
            if x_cont.ndim == 1:
                x_cont = x_cont.unsqueeze(1)
            x_cont = self.cont_bn(x_cont)
            x_cont = self.cont_linear(x_cont)
            z = torch.cat([z_cat, x_cont], dim=1)
        else:
            z = z_cat
        out = self.head(z)  # (B, 1) logits
        return out


# 2. BaseNNModel 용 TabTransformer 래퍼 추가

class TabTransformerModel(BaseNNModel):
    """
    DeepLearningBinaryClassifier 에서 사용할 TabTransformer 래퍼.
    - input_dim: 전체 feature 개수
    - cat_features: 범주형 feature 인덱스 (X의 column index)
    - cat_dims: 각 범주형 feature의 cardinality (고유 값 개수)
    - 나머지 파라미터는 TabTransformerNet 에 그대로 전달
    """
    def __init__(
        self,
        input_dim: int,
        cat_features: list[int] | None = None,
        cat_dims: list[int] | None = None,
        d_token: int = 32,
        n_heads: int = 4,
        n_layers: int = 2,
        dim_feedforward: int = 128,
        attn_dropout: float = 0.1,
        embedding_dropout: float = 0.1,
        add_cls: bool = False,
        pooling: str = "concat",
        cont_proj: str = "linear",
        mlp_hidden_dims: tuple[int, ...] = (128, 64),
        mlp_dropout: float = 0.2,
        padding_idx: int = 0,
    ):
        super(TabTransformerModel, self).__init__()

        self.input_dim = input_dim
        self.cat_features = cat_features or []
        self.cat_dims = cat_dims or []
        self.n_cont = self.input_dim - len(self.cat_features)

        assert len(self.cat_features) == len(
            self.cat_dims
        ), "cat_features 개수와 cat_dims 개수가 다릅니다."

        self.network = TabTransformerNet(
            cat_cardinalities=self.cat_dims,
            n_continuous=self.n_cont,
            d_token=d_token,
            n_heads=n_heads,
            n_layers=n_layers,
            dim_feedforward=dim_feedforward,
            attn_dropout=attn_dropout,
            embedding_dropout=embedding_dropout,
            add_cls=add_cls,
            pooling=pooling,
            cont_proj=cont_proj,
            mlp_hidden_dims=mlp_hidden_dims,
            mlp_dropout=mlp_dropout,
            padding_idx=padding_idx,
            norm_first=True,
        )

    def build_network(self) -> nn.Module:
        # DeepLearningBinaryClassifier 에서는 __init__ 시점에 network 를 이미 만든 상태이므로
        return self.network

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, input_dim)
        - cat_features 인덱스에 해당하는 column 은 LongTensor 로 cast해서 categorical로 사용
        - 나머지는 continuous 로 사용
        """
        B, D = x.shape
        device = x.device

        if len(self.cat_features) > 0:
            cat_idx_tensor = torch.tensor(self.cat_features, device=device)
            x_cat = x[:, cat_idx_tensor].long()
        else:
            x_cat = torch.zeros(B, 0, dtype=torch.long, device=device)

        if self.n_cont > 0:
            cont_idx = [i for i in range(self.input_dim) if i not in self.cat_features]
            cont_idx_tensor = torch.tensor(cont_idx, device=device)
            x_cont = x[:, cont_idx_tensor].float()
        else:
            x_cont = None

        logits = self.network(x_cat, x_cont)  # (B, 1)
        return logits


# 3. DeepLearningBinaryClassifier 에 TabTransformer 등록
class DeepLearningBinaryClassifier(BaseEstimator, ClassifierMixin):
    def __init__(
        self,
        model_type: str = "mlp",
        model_params: dict | None = None,
    ):
        self.model_type = model_type
        self.model_params = model_params or {}
        self.model = None

    @property
    def device(self) -> torch.device:
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _build_model(self) -> BaseNNModel:
        model_registry = {
            "mlp": MLPModel,
            "tabtransformer": TabTransformerModel,   # ★ 여기 추가
        }

        if self.model_type not in model_registry:
            raise ValueError(f"Unknown model type: {self.model_type}")

        valid_params = {
            k: v
            for k, v in self.model_params.items()
            if k not in ["loss_fn", "lr"]
        }

        model_class = model_registry[self.model_type](**valid_params)
        return model_class


# 사용
from custom_model import DeepLearningBinaryClassifier

# 예: 전체 feature 100개, 그 중 10, 11, 15번 column이 범주형이라고 가정
input_dim = X_train.shape[1]
cat_features = [10, 11, 15]
cat_dims = [20, 5, 12]  # 각 컬럼의 cardinality

tabtrans_params = {
    "input_dim": input_dim,
    "cat_features": cat_features,
    "cat_dims": cat_dims,
    "d_token": 32,
    "n_heads": 4,
    "n_layers": 2,
    "dim_feedforward": 128,
    "attn_dropout": 0.1,
    "embedding_dropout": 0.05,
    "pooling": "concat",
    "add_cls": False,
    "cont_proj": "linear",
    "mlp_hidden_dims": (128, 64),
    "mlp_dropout": 0.2,
    "padding_idx": 0,
    "lr": 1e-3,           # ★ DeepLearningBinaryClassifier 의 학습용
    "loss_fn": "logloss", # ★ BCEWithLogitsLoss 사용
}

clf = DeepLearningBinaryClassifier(
    model_type="tabtransformer",
    model_params=tabtrans_params,
)

clf.fit(
    X_train,
    y_train,
    sample_weight=sample_weight_train,             # 있으면
    eval_set=[(X_valid, y_valid)],
    eval_metric=["logloss"],
    max_epochs=10,
    patience=2,
    batch_size=1024,
)

proba_test = clf.predict_proba(X_test)[:, 1]
pred_test = clf.predict(X_test)


In [None]:
from custom_model import DeepLearningBinaryClassifier

# 예: 전체 feature 100개, 그 중 10, 11, 15번 column이 범주형이라고 가정
input_dim = X_train.shape[1]
cat_features = [10, 11, 15]
cat_dims = [20, 5, 12]  # 각 컬럼의 cardinality

tabtrans_params = {
    "input_dim": input_dim,
    "cat_features": cat_features,
    "cat_dims": cat_dims,
    "d_token": 32,
    "n_heads": 4,
    "n_layers": 2,
    "dim_feedforward": 128,
    "attn_dropout": 0.1,
    "embedding_dropout": 0.05,
    "pooling": "concat",
    "add_cls": False,
    "cont_proj": "linear",
    "mlp_hidden_dims": (128, 64),
    "mlp_dropout": 0.2,
    "padding_idx": 0,
    "lr": 1e-3,           # ★ DeepLearningBinaryClassifier 의 학습용
    "loss_fn": "logloss", # ★ BCEWithLogitsLoss 사용
}

clf = DeepLearningBinaryClassifier(
    model_type="tabtransformer",
    model_params=tabtrans_params,
)

clf.fit(
    X_train,
    y_train,
    sample_weight=sample_weight_train,             # 있으면
    eval_set=[(X_valid, y_valid)],
    eval_metric=["logloss"],
    max_epochs=10,
    patience=2,
    batch_size=1024,
)

proba_test = clf.predict_proba(X_test)[:, 1]
pred_test = clf.predict(X_test)


In [None]:
# deep_tab_transformer_demo.py
# -*- coding: utf-8 -*-

# Standard Library
from abc import ABC, abstractmethod

# Third Party
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import (
    average_precision_score,
    log_loss,
    roc_auc_score,
)
from torch.utils.data import DataLoader, TensorDataset


# ===========================================================
# Base Neural Network Model
# ===========================================================
class BaseNNModel(nn.Module, ABC):
    @abstractmethod
    def __init__(self, **kwargs):
        super(BaseNNModel, self).__init__()

    @abstractmethod
    def build_network(self) -> nn.Module:
        pass

    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pass


# ===========================================================
# MLP Model (기존 MLP)
# ===========================================================
class MLPModel(BaseNNModel):
    def __init__(
        self,
        input_dim: int,
        hidden_dims: list[int],
        cat_features: list[int] | None = None,
        cat_dims: list[int] | None = None,
        emb_dim: int = 8,
    ):
        super(MLPModel, self).__init__()

        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.cat_features = cat_features or []
        self.cat_dims = cat_dims or []
        self.emb_dim = emb_dim

        self.embeddings: nn.ModuleList | None = None
        self.network = self.build_network()

    def build_network(self) -> nn.Sequential:
        # categorical embedding layer
        if len(self.cat_dims) > 0:
            self.embeddings = nn.ModuleList(
                [nn.Embedding(cat_dim, self.emb_dim) for cat_dim in self.cat_dims]
            )
        else:
            self.embeddings = None

        combined_input_dim = (
            self.input_dim - len(self.cat_features)
            + len(self.cat_features) * self.emb_dim
        )

        layers: list[nn.Module] = []
        dims = [combined_input_dim] + self.hidden_dims

        for i in range(len(dims) - 1):
            hidden_layer = nn.Linear(dims[i], dims[i + 1])
            nn.init.kaiming_normal_(
                hidden_layer.weight, mode="fan_in", nonlinearity="relu"
            )
            layers.append(hidden_layer)
            layers.append(nn.BatchNorm1d(dims[i + 1]))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.2))

        # output layer: logit 1개
        layers.append(nn.Linear(dims[-1], 1))

        return nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        continuous_features = []
        embedded_features = []
        cat_idx = 0

        for i in range(self.input_dim):
            if i in self.cat_features and self.embeddings is not None:
                cat_values = x[:, i].long()
                embedded = self.embeddings[cat_idx](cat_values)
                embedded_features.append(embedded)
                cat_idx += 1
            else:
                continuous_features.append(x[:, i : i + 1])

        if continuous_features:
            continuous_features = torch.cat(continuous_features, dim=1)
        else:
            continuous_features = None

        if embedded_features:
            embedded_features = torch.cat(embedded_features, dim=1)
        else:
            embedded_features = None

        if continuous_features is not None and embedded_features is not None:
            combined_features = torch.cat(
                [continuous_features, embedded_features], dim=1
            )
        elif embedded_features is not None:
            combined_features = embedded_features
        elif continuous_features is not None:
            combined_features = continuous_features
        else:
            raise ValueError("No features found for forward pass.")

        return self.network(combined_features)  # logits


# ===========================================================
# TabTransformer Model (BaseNNModel 하나로 통합)
# ===========================================================
class TabTransformerModel(BaseNNModel):
    """
    BaseNNModel을 직접 상속하는 단일 TabTransformer 모델.

    - 입력: x (batch, input_dim)
      * cat_features 위치: 정수 인덱스(0,1,2,...) — StringIndexer 등으로 변환된 값
      * 나머지: 연속형 특징 (float)
    - 출력: (batch, 1) 로짓 (sigmoid 전 값)
    """

    def __init__(
        self,
        input_dim: int,
        cat_features: list[int] | None = None,
        cat_dims: list[int] | None = None,
        d_token: int = 32,
        n_heads: int = 4,
        n_layers: int = 2,
        dim_feedforward: int = 128,
        attn_dropout: float = 0.1,
        embedding_dropout: float = 0.1,
        add_cls: bool = False,
        pooling: str = "concat",      # "concat" or "cls"
        cont_proj: str = "linear",    # "none" or "linear"
        hidden_dims: list[int] | tuple[int, ...] = (128, 64),
        mlp_dropout: float = 0.2,
        padding_idx: int | None = 0,
    ):
        super(TabTransformerModel, self).__init__()

        # 기본 설정
        self.input_dim = input_dim
        self.cat_features = cat_features or []
        self.cat_dims = cat_dims or []
        self.cont_features = [
            i for i in range(self.input_dim) if i not in self.cat_features
        ]

        self.d_token = d_token
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.dim_feedforward = dim_feedforward
        self.attn_dropout = attn_dropout
        self.embedding_dropout = embedding_dropout
        self.add_cls = add_cls
        self.pooling = pooling
        self.cont_proj = cont_proj
        self.hidden_dims = list(hidden_dims)
        self.mlp_dropout = mlp_dropout
        self.padding_idx = padding_idx

        # ---- Categorical path ----
        self.n_cat = len(self.cat_features)
        self.n_cont = len(self.cont_features)

        if self.n_cat > 0 and len(self.cat_dims) != self.n_cat:
            raise ValueError(
                f"len(cat_dims) ({len(self.cat_dims)}) must match len(cat_features) ({self.n_cat})"
            )

        if self.n_cat > 0:
            # 각 범주형 컬럼별 embedding
            self.cat_embeddings = nn.ModuleList(
                [
                    nn.Embedding(
                        num_embeddings=c + (1 if padding_idx is not None else 0),
                        embedding_dim=self.d_token,
                        padding_idx=(
                            padding_idx if padding_idx is not None else None
                        ),
                    )
                    for c in self.cat_dims
                ]
            )
            # column embedding
            self.col_embedding = nn.Embedding(self.n_cat, self.d_token)
        else:
            self.cat_embeddings = nn.ModuleList()
            self.col_embedding = None

        if self.add_cls:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.d_token))
            nn.init.normal_(self.cls_token, std=0.02)
        else:
            self.cls_token = None

        self.embedding_dropout_layer = nn.Dropout(self.embedding_dropout)

        # Transformer encoder
        enc_layer = nn.TransformerEncoderLayer(
            d_model=self.d_token,
            nhead=self.n_heads,
            dim_feedforward=self.dim_feedforward,
            dropout=self.attn_dropout,
            batch_first=True,
            norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers=self.n_layers)

        # init categorical embeddings
        for emb in self.cat_embeddings:
            nn.init.normal_(emb.weight, std=0.02)
        if self.col_embedding is not None:
            nn.init.normal_(self.col_embedding.weight, std=0.02)

        # ---- Continuous path ----
        if self.n_cont > 0:
            self.cont_bn = nn.BatchNorm1d(self.n_cont)
            if self.cont_proj == "linear":
                self.cont_linear = nn.Linear(self.n_cont, self.d_token)
                nn.init.kaiming_uniform_(
                    self.cont_linear.weight, mode="fan_in", nonlinearity="relu"
                )
                cont_out_dim = self.d_token
            else:
                self.cont_linear = nn.Identity()
                cont_out_dim = self.n_cont
        else:
            self.cont_bn = None
            self.cont_linear = None
            cont_out_dim = 0

        # ---- Head (MLP, 마지막엔 sigmoid 없이 logit만 출력) ----
        if self.pooling == "cls":
            backbone_out_dim = self.d_token
        else:  # concat
            backbone_out_dim = self.n_cat * self.d_token if self.n_cat > 0 else 0

        in_dim = backbone_out_dim + cont_out_dim

        layers: list[nn.Module] = []
        prev_dim = in_dim if in_dim > 0 else self.d_token
        if in_dim == 0:
            # 극단적으로 cat/cont 둘 다 없는 경우 방어 (실제론 안 쓰이겠지만)
            layers.append(nn.Linear(self.d_token, self.d_token))
            prev_dim = self.d_token

        for h in self.hidden_dims:
            lin = nn.Linear(prev_dim, h)
            nn.init.kaiming_uniform_(lin.weight, mode="fan_in", nonlinearity="relu")
            layers.extend(
                [
                    lin,
                    nn.BatchNorm1d(h),
                    nn.ReLU(),
                    nn.Dropout(self.mlp_dropout),
                ]
            )
            prev_dim = h

        # 출력: 로짓 1개
        layers.append(nn.Linear(prev_dim, 1))
        self.head = nn.Sequential(*layers)

    def build_network(self) -> nn.Module:
        # 네트워크 전체가 self 이므로 별도 모듈은 필요 없음
        return self

    # 내부: categorical 인코딩
    def _encode_categoricals(self, x_cat: torch.LongTensor) -> torch.Tensor:
        """
        x_cat: (B, n_cat) -> (B, n_cat*d) or (B, d) depending on pooling
        """
        B = x_cat.size(0)

        if self.n_cat == 0:
            if self.pooling == "cls":
                return torch.zeros(
                    B, self.d_token, device=x_cat.device, dtype=torch.float32
                )
            else:
                return torch.zeros(
                    B, 0, device=x_cat.device, dtype=torch.float32
                )

        tok_list = []
        for j, emb in enumerate(self.cat_embeddings):
            tok = emb(x_cat[:, j])  # (B, d)
            if self.col_embedding is not None:
                tok = tok + self.col_embedding.weight[j]  # column embedding 더하기
            tok_list.append(tok.unsqueeze(1))  # (B,1,d)

        x_tok = torch.cat(tok_list, dim=1)  # (B, n_cat, d)

        if self.add_cls and self.cls_token is not None:
            cls = self.cls_token.expand(B, 1, -1)  # (B,1,d)
            x_tok = torch.cat([cls, x_tok], dim=1)  # (B,1+n_cat,d)

        x_tok = self.embedding_dropout_layer(x_tok)
        z = self.transformer(x_tok)  # (B,T,d)

        if self.pooling == "cls":
            if self.add_cls and self.cls_token is not None:
                out = z[:, 0, :]
            else:
                out = z.mean(dim=1)
        else:  # "concat"
            if self.add_cls and self.cls_token is not None:
                z = z[:, 1:, :]  # CLS 제외
            out = z.reshape(B, -1)

        return out

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, input_dim)
        - self.cat_features 위치: 정수 인덱스 (Long)
        - self.cont_features 위치: float
        반환: (B, 1) — 로짓
        """
        device = x.device

        # 범주형 피처
        if self.cat_features:
            x_cat = x[:, self.cat_features].long().to(device)  # (B, n_cat)
        else:
            x_cat = torch.zeros(
                (x.size(0), 0), dtype=torch.long, device=device
            )

        # 연속형 피처
        if self.cont_features:
            x_cont = x[:, self.cont_features].float().to(device)  # (B, n_cont)
        else:
            x_cont = None

        # categorical encode
        z_cat = self._encode_categoricals(x_cat)  # (B, d*)

        # continuous path
        if x_cont is not None and self.n_cont > 0:
            if x_cont.ndim == 1:
                x_cont = x_cont.unsqueeze(1)
            x_cont = self.cont_bn(x_cont)
            x_cont = self.cont_linear(x_cont)
            z = torch.cat([z_cat, x_cont], dim=1)
        else:
            z = z_cat

        logit = self.head(z)
        return logit


# ===========================================================
# Deep Learning Binary Classifier (공통 Wrapper)
# ===========================================================
class DeepLearningBinaryClassifier(BaseEstimator, ClassifierMixin):
    def __init__(
        self,
        model_type: str = "mlp",
        model_params: dict | None = None,
    ):
        self.model_type = model_type
        self.model_params = model_params or {}
        self.model: BaseNNModel | None = None

    @property
    def device(self) -> torch.device:
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _build_model(self) -> BaseNNModel:
        model_registry: dict[str, type[BaseNNModel]] = {
            "mlp": MLPModel,
            "tabtransformer": TabTransformerModel,
        }

        if self.model_type not in model_registry:
            raise ValueError(f"Unknown model type: {self.model_type}")

        valid_params = {
            k: v
            for k, v in self.model_params.items()
            if k not in ["loss_fn", "lr"]
        }

        model_class = model_registry[self.model_type](**valid_params)
        return model_class

    def _get_loss_fn(self) -> nn.Module:
        loss_name = self.model_params.get("loss_fn", "logloss")
        if loss_name == "logloss":
            # logits + BCEWithLogitsLoss
            return nn.BCEWithLogitsLoss(reduction="none")
        else:
            raise ValueError(f"Unknown loss function: {loss_name}")

    def fit(
        self,
        X: np.ndarray,
        y: np.ndarray,
        sample_weight: np.ndarray | None = None,
        eval_set: list[tuple[np.ndarray, np.ndarray]] | None = None,
        eval_metric: list[str] | None = None,
        max_epochs: int = 10,
        patience: int | None = None,
        batch_size: int = 128,
    ) -> "DeepLearningBinaryClassifier":

        lr = self.model_params.get("lr", 0.001)
        eval_metric = eval_metric or ["logloss"]

        X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
        y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1).to(self.device)

        if eval_set is not None:
            x_eval_tensor = torch.tensor(
                eval_set[0][0], dtype=torch.float32
            ).to(self.device)
            y_eval_true = eval_set[0][1]
        else:
            x_eval_tensor = None
            y_eval_true = None

        if sample_weight is not None:
            sample_weight_tensor = torch.tensor(
                sample_weight, dtype=torch.float32
            ).view(-1, 1).to(self.device)
        else:
            sample_weight_tensor = torch.ones_like(y_tensor, dtype=torch.float32)

        train_dataset = TensorDataset(X_tensor, y_tensor, sample_weight_tensor)
        train_dataloader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )

        if self.model is None:
            self.model = self._build_model().to(self.device)

        loss_fn = self._get_loss_fn()
        optimizer = optim.Adam(
            self.model.parameters(), lr=lr, weight_decay=1e-4
        )

        patience_counter = 0
        best_metric = float("inf")
        best_model_weights = None

        for epoch in range(max_epochs):
            self.model.train()
            epoch_loss = 0.0
            n_batches = 0

            for x_batch, y_batch, weight_batch in train_dataloader:
                optimizer.zero_grad()

                y_pred = self.model(x_batch)  # logits
                loss = loss_fn(y_pred, y_batch)
                weighted_loss = (loss * weight_batch).sum() / weight_batch.sum()

                weighted_loss.backward()
                optimizer.step()

                epoch_loss += weighted_loss.item()
                n_batches += 1

            avg_train_loss = epoch_loss / max(1, n_batches)
            print(f"Epoch {epoch + 1}/{max_epochs}")
            print(f"- [train] loss: {avg_train_loss:.6f}")

            # evaluation
            if eval_set is not None and x_eval_tensor is not None:
                self.model.eval()
                with torch.no_grad():
                    logits_eval = self.model(x_eval_tensor)
                    y_eval_pred = (
                        torch.sigmoid(logits_eval).cpu().numpy().ravel()
                    )

                eval_metrics: dict[str, float] = {}
                for metric in eval_metric:
                    if metric == "logloss":
                        eval_metrics["logloss"] = log_loss(
                            y_eval_true, y_eval_pred, eps=1e-7
                        )
                    elif metric == "average_precision":
                        # 최대화 → 부호를 바꿔 최소화 기준에 맞추기
                        eval_metrics["average_precision"] = (
                            -average_precision_score(y_eval_true, y_eval_pred)
                        )
                    elif metric == "auc":
                        eval_metrics["auc"] = -roc_auc_score(
                            y_eval_true, y_eval_pred
                        )
                    else:
                        raise ValueError(f"Unknown metric: {metric}")

                metrics_str = ", ".join(
                    [f"{k}: {v:.4f}" for k, v in eval_metrics.items()]
                )
                print(f"- [eval] {metrics_str}")

                # early stopping: 첫 번째 metric 기준
                primary_metric_name = eval_metric[0]
                current_metric = eval_metrics.get(
                    primary_metric_name, eval_metrics.get("logloss")
                )
                print(
                    f"  -- (early_stopping) current_metric: {current_metric:.6f}, best_metric: {best_metric:.6f}"
                )

                if current_metric < best_metric:
                    best_metric = current_metric
                    patience_counter = 0
                    best_model_weights = {
                        k: v.detach().cpu().clone()
                        for k, v in self.model.state_dict().items()
                    }
                else:
                    patience_counter += 1
                    if patience is not None and patience_counter >= patience:
                        print(f"early stopping at epoch {epoch + 1}")
                        break

        if best_model_weights is not None:
            self.model.load_state_dict(best_model_weights)

        return self

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        if self.model is None:
            raise RuntimeError("Model is not trained yet.")

        self.model = self.model.to(self.device)
        X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)

        with torch.no_grad():
            self.model.eval()
            logits = self.model(X_tensor)
            probs1 = torch.sigmoid(logits).cpu().numpy()

        if probs1.shape[1] == 1:
            probs = np.hstack((1 - probs1, probs1))
        else:
            probs = probs1

        return probs.astype("float")

    def predict(self, X: np.ndarray) -> np.ndarray:
        probs = self.predict_proba(X)
        return probs.argmax(axis=1)


# ===========================================================
# 간단 테스트용 데이터 생성 & 실행 예제
# ===========================================================
def make_dummy_tabular(
    n_samples: int = 5000,
    n_cont: int = 10,
    random_state: int = 42,
):
    rng = np.random.RandomState(random_state)

    # 세 개의 범주형 피처 (0,1,2열)
    cat1 = rng.randint(0, 5, size=n_samples)   # cardinality 5
    cat2 = rng.randint(0, 4, size=n_samples)   # cardinality 4
    cat3 = rng.randint(0, 3, size=n_samples)   # cardinality 3

    # 연속형 피처 (3~)
    X_cont = rng.randn(n_samples, n_cont).astype("float32")

    # 간단한 로짓 생성 (cat + cont 조합)
    w_cat1 = rng.uniform(-0.5, 0.8, size=5)
    w_cat2 = rng.uniform(-0.3, 0.6, size=4)
    w_cat3 = rng.uniform(-1.0, 0.4, size=3)
    w_cont = rng.randn(n_cont)

    score_cat = (
        w_cat1[cat1] + w_cat2[cat2] + w_cat3[cat3]
    ).astype("float32")
    score_cont = (X_cont * w_cont).sum(axis=1).astype("float32")

    logit = 0.7 * score_cat + 0.5 * score_cont + rng.normal(
        scale=0.5, size=n_samples
    ).astype("float32")
    prob = 1.0 / (1.0 + np.exp(-logit))
    y = (prob > 0.5).astype("int64")

    # 최종 X: (cat1, cat2, cat3, cont...)
    X = np.zeros((n_samples, 3 + n_cont), dtype="float32")
    X[:, 0] = cat1
    X[:, 1] = cat2
    X[:, 2] = cat3
    X[:, 3:] = X_cont

    cat_features = [0, 1, 2]
    cat_dims = [5, 4, 3]

    return X, y, cat_features, cat_dims


def train_demo():
    X, y, cat_features, cat_dims = make_dummy_tabular(
        n_samples=8000, n_cont=10, random_state=123
    )

    # train / valid split
    N = X.shape[0]
    idx = np.arange(N)
    np.random.shuffle(idx)
    tr_end = int(N * 0.8)

    tr_idx = idx[:tr_end]
    va_idx = idx[tr_end:]

    X_tr, y_tr = X[tr_idx], y[tr_idx]
    X_va, y_va = X[va_idx], y[va_idx]

    print("=== Shape ===")
    print("X_tr:", X_tr.shape, "y_tr:", y_tr.shape)
    print("X_va:", X_va.shape, "y_va:", y_va.shape)
    print("cat_features:", cat_features)
    print("cat_dims:", cat_dims)

    # --------------------------------------------------
    # 1) TabTransformer
    # --------------------------------------------------
    print("\n==============================")
    print("Training TabTransformer model")
    print("==============================")

    tab_clf = DeepLearningBinaryClassifier(
        model_type="tabtransformer",
        model_params={
            "input_dim": X_tr.shape[1],
            "cat_features": cat_features,
            "cat_dims": cat_dims,
            "hidden_dims": [128, 64],
            "d_token": 32,
            "n_heads": 4,
            "n_layers": 2,
            "dim_feedforward": 128,
            "attn_dropout": 0.1,
            "embedding_dropout": 0.05,
            "pooling": "concat",
            "add_cls": False,
            "cont_proj": "linear",
            "mlp_dropout": 0.2,
            "lr": 1e-3,
            "loss_fn": "logloss",
        },
    )

    tab_clf.fit(
        X_tr,
        y_tr,
        eval_set=[(X_va, y_va)],
        eval_metric=["logloss"],
        max_epochs=10,
        patience=3,
        batch_size=512,
    )

    proba_tab = tab_clf.predict_proba(X_va)[:, 1]
    pred_tab = (proba_tab >= 0.5).astype(int)

    auc_tab = roc_auc_score(y_va, proba_tab)
    ap_tab = average_precision_score(y_va, proba_tab)
    ll_tab = log_loss(y_va, proba_tab)

    print("\n[TabTransformer] Metrics on valid:")
    print(f"AUC: {auc_tab:.4f} | AP: {ap_tab:.4f} | Logloss: {ll_tab:.4f}")

    # --------------------------------------------------
    # 2) MLP (비교용)
    # --------------------------------------------------
    print("\n==============================")
    print("Training MLP model")
    print("==============================")

    mlp_clf = DeepLearningBinaryClassifier(
        model_type="mlp",
        model_params={
            "input_dim": X_tr.shape[1],
            "cat_features": cat_features,
            "cat_dims": cat_dims,
            "hidden_dims": [128, 64],
            "emb_dim": 8,
            "lr": 1e-3,
            "loss_fn": "logloss",
        },
    )

    mlp_clf.fit(
        X_tr,
        y_tr,
        eval_set=[(X_va, y_va)],
        eval_metric=["logloss"],
        max_epochs=10,
        patience=3,
        batch_size=512,
    )

    proba_mlp = mlp_clf.predict_proba(X_va)[:, 1]
    pred_mlp = (proba_mlp >= 0.5).astype(int)

    auc_mlp = roc_auc_score(y_va, proba_mlp)
    ap_mlp = average_precision_score(y_va, proba_mlp)
    ll_mlp = log_loss(y_va, proba_mlp)

    print("\n[MLP] Metrics on valid:")
    print(f"AUC: {auc_mlp:.4f} | AP: {ap_mlp:.4f} | Logloss: {ll_mlp:.4f}")


if __name__ == "__main__":
    train_demo()


=== Shape ===
X_tr: (6400, 13) y_tr: (6400,)
X_va: (1600, 13) y_va: (1600,)
cat_features: [0, 1, 2]
cat_dims: [5, 4, 3]

Training TabTransformer model




Epoch 1/10
- [train] loss: 0.610459


TypeError: got an unexpected keyword argument 'eps'

In [None]:
# -*- coding: utf-8 -*-
# =========================================================
# 통합 코드:
#  - TabularPreprocessor + TabTransformerNet
#  - BaseNNModel / MLPModel / TabTransformerModel
#  - DeepLearningBinaryClassifier (mlp + tabtransformer 지원)
#  - 샘플 데이터 생성 + TabTransformer 실행 데모
# =========================================================

# Standard Library
from abc import ABC, abstractmethod

# Third Party
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim

from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import (
    accuracy_score,
    average_precision_score,
    log_loss,
    roc_auc_score,
)
from torch.utils.data import DataLoader, TensorDataset


# ---------------------------------------------------------
# 0) Preprocessor: string categorical -> integer IDs (with OOV=0)
#    + auto cat_idx/cont_idx + cardinalities
# ---------------------------------------------------------
class TabularPreprocessor:
    def __init__(self, categorical_indices=None, use_oov=True, oov_token=0, add_na_token=True):
        self.categorical_indices = None if categorical_indices is None else list(categorical_indices)
        self.use_oov = use_oov
        self.oov_token = int(oov_token)  # 0 recommended
        self.add_na_token = add_na_token
        self.cat_maps = {}          # col_idx -> {category_value: int_id}
        self.cardinalities = []     # per-categorical-column unique count (K, excluding OOV)
        self.cat_idx = []
        self.cont_idx = []
        self.fitted_ = False

    def _ensure_ndarray(self, X):
        if isinstance(X, pd.DataFrame):
            X = X.values
        return np.asarray(X, dtype=object)  # safe for mixed types

    def fit(self, X, categorical_indices=None):
        X = self._ensure_ndarray(X)
        n_cols = X.shape[1]

        if categorical_indices is not None:
            self.categorical_indices = list(categorical_indices)

        if self.categorical_indices is None:
            # fallback: infer by dtype==object
            self.categorical_indices = [j for j in range(n_cols) if X[:, j].dtype == object]

        cat_set = set(self.categorical_indices)
        self.cat_idx = sorted(list(cat_set))
        self.cont_idx = [j for j in range(n_cols) if j not in cat_set]

        # build per-column maps: real categories 1..K (0 reserved for OOV)
        self.cat_maps = {}
        self.cardinalities = []
        for j in self.cat_idx:
            col = X[:, j]
            if self.add_na_token:
                col = np.where(pd.isna(col), "<NA>", col)
            uniques = pd.unique(col)
            id_map = {val: i + 1 for i, val in enumerate(uniques)}  # 1..K
            self.cat_maps[j] = id_map
            self.cardinalities.append(len(uniques))  # exclude OOV

        self.fitted_ = True
        return self

    def transform(self, X):
        assert self.fitted_, "Call fit() before transform()."
        X = self._ensure_ndarray(X)

        # categorical
        if len(self.cat_idx) > 0:
            x_cat = np.zeros((X.shape[0], len(self.cat_idx)), dtype="int64")
            for ti, j in enumerate(self.cat_idx):
                col = X[:, j]
                if self.add_na_token:
                    col = np.where(pd.isna(col), "<NA>", col)
                id_map = self.cat_maps[j]
                x_cat[:, ti] = np.array([id_map.get(v, self.oov_token) for v in col], dtype="int64")
        else:
            x_cat = np.zeros((X.shape[0], 0), dtype="int64")

        # continuous
        if len(self.cont_idx) > 0:
            x_cont = X[:, self.cont_idx].astype("float32")
        else:
            x_cont = None

        return x_cat, x_cont

    def fit_transform(self, X, categorical_indices=None):
        self.fit(X, categorical_indices)
        return self.transform(X)


# ---------------------------------------------------------
# 1) Single-class TabTransformer (Backbone + Continuous + Head)
#    (원본과 동일, Sigmoid까지 포함)
# ---------------------------------------------------------
class TabTransformerNet(nn.Module):
    """
    Single-class TabTransformer:
    - Categorical: per-column embedding (+ OOV/pad), column embedding, Transformer encoder, pooling
    - Continuous: BatchNorm + (optional) Linear projection to d_token
    - Head: MLP -> sigmoid
    """
    def __init__(
        self,
        cat_cardinalities,          # List[int] (exclude OOV)
        n_continuous=0,
        d_token=32,
        n_heads=4,
        n_layers=2,
        dim_feedforward=128,
        attn_dropout=0.1,
        embedding_dropout=0.1,
        add_cls=False,
        pooling="concat",           # "concat" or "cls"
        cont_proj="linear",         # "none" or "linear"
        mlp_hidden_dims=(128, 64),
        mlp_dropout=0.2,
        padding_idx=0,              # reserve 0 for OOV/pad if not None
        norm_first=True
    ):
        super().__init__()
        assert pooling in ("concat", "cls")
        self.n_cat = len(cat_cardinalities)
        self.n_cont = n_continuous
        self.d_token = d_token
        self.add_cls = add_cls
        self.pooling = pooling
        self.cont_proj = cont_proj

        # ---- Categorical path ----
        if self.n_cat == 0:
            self.cat_embeddings = nn.ModuleList()
            self.col_embedding = None
        else:
            # +1 slot if padding_idx is used (OOV/pad)
            self.cat_embeddings = nn.ModuleList([
                nn.Embedding(
                    num_embeddings=c + (1 if (padding_idx is not None) else 0),
                    embedding_dim=d_token,
                    padding_idx=padding_idx
                )
                for c in cat_cardinalities
            ])
            self.col_embedding = nn.Embedding(self.n_cat, d_token)

        if self.add_cls:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, d_token))
            nn.init.normal_(self.cls_token, std=0.02)

        self.embedding_dropout = nn.Dropout(embedding_dropout)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_token,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=attn_dropout,
            batch_first=True,
            norm_first=norm_first
        )
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers=n_layers)

        # init categorical embeddings
        for emb in self.cat_embeddings:
            nn.init.normal_(emb.weight, std=0.02)
        if self.col_embedding is not None:
            nn.init.normal_(self.col_embedding.weight, std=0.02)

        # ---- Continuous path ----
        if n_continuous > 0:
            self.cont_bn = nn.BatchNorm1d(n_continuous)
            if cont_proj == "linear":
                self.cont_linear = nn.Linear(n_continuous, d_token)
                nn.init.kaiming_uniform_(self.cont_linear.weight, mode="fan_in", nonlinearity="relu")
                cont_out_dim = d_token
            else:
                self.cont_linear = nn.Identity()
                cont_out_dim = n_continuous
        else:
            self.cont_bn = None
            self.cont_linear = None
            cont_out_dim = 0

        # ---- Head ----
        backbone_out = (d_token if pooling == "cls" else self.n_cat * d_token)
        in_dim = backbone_out + cont_out_dim

        layers = []
        prev = in_dim
        for h in mlp_hidden_dims:
            lin = nn.Linear(prev, h)
            nn.init.kaiming_uniform_(lin.weight, mode="fan_in", nonlinearity="relu")
            layers.extend([lin, nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(mlp_dropout)])
            prev = h
        layers.append(nn.Linear(prev, 1))
        layers.append(nn.Sigmoid())
        self.head = nn.Sequential(*layers)

    def _encode_categoricals(self, x_cat: torch.LongTensor):
        """
        x_cat: (B, n_cat) -> contextualized representation
        returns:
          pooling='concat' -> (B, n_cat*d)
          pooling='cls'    -> (B, d)
        """
        B = x_cat.size(0)
        if self.n_cat == 0:
            return torch.zeros(B, self.d_token if self.pooling == "cls" else 0,
                               device=x_cat.device, dtype=torch.float32)

        tok_list = []
        for j, emb in enumerate(self.cat_embeddings):
            tok = emb(x_cat[:, j])                         # (B, d)
            if self.col_embedding is not None:
                tok = tok + self.col_embedding.weight[j]   # (d,)
            tok_list.append(tok.unsqueeze(1))              # (B, 1, d)

        x_tok = torch.cat(tok_list, dim=1)                 # (B, n_cat, d)

        if self.add_cls:
            cls = self.cls_token.expand(B, -1, -1)         # (B, 1, d)
            x_tok = torch.cat([cls, x_tok], dim=1)         # (B, 1+n_cat, d)

        x_tok = self.embedding_dropout(x_tok)
        z = self.transformer(x_tok)                        # (B, T, d)

        if self.pooling == "cls" and self.add_cls:
            out = z[:, 0, :]                               # (B, d)
        elif self.pooling == "cls":
            out = z.mean(dim=1)                            # (B, d)
        else:
            if self.add_cls:
                z = z[:, 1:, :]                            # drop CLS
            out = z.reshape(B, -1)                         # (B, n_cat*d)
        return out

    def forward(self, x_cat: torch.LongTensor, x_cont: torch.FloatTensor = None):
        z_cat = self._encode_categoricals(x_cat)
        if (x_cont is not None) and (self.n_cont > 0):
            if x_cont.ndim == 1:
                x_cont = x_cont.unsqueeze(1)
            x_cont = self.cont_bn(x_cont)
            x_cont = self.cont_linear(x_cont)
            z = torch.cat([z_cat, x_cont], dim=1)
        else:
            z = z_cat
        out = self.head(z)
        return out  # (B, 1), sigmoid된 확률


# ---------------------------------------------------------
# 2) 샘플 mixed-type dataset generator (문자열 포함)
# ---------------------------------------------------------
def make_mixed_sample(
    n_samples=40000,
    seed=7
):
    rng = np.random.RandomState(seed)

    genders = np.array(["남성", "여성", "기타"], dtype=object)
    cities = np.array(["서울", "부산", "대구", "인천", "수원", "고양"], dtype=object)
    devices = np.array(["ios", "android", "web"], dtype=object)

    # categorical columns
    gender_col = rng.choice(genders, size=n_samples, p=[0.48, 0.48, 0.04])
    city_col   = rng.choice(cities, size=n_samples)
    device_col = rng.choice(devices, size=n_samples, p=[0.35, 0.55, 0.10])

    # continuous columns
    n_cont = 10
    X_cont = rng.randn(n_samples, n_cont).astype("float32")

    # latent score
    w_gender = {g: w for g, w in zip(genders, rng.uniform(-0.8, 0.8, size=len(genders)))}
    w_city   = {c: w for c, w in zip(cities, rng.uniform(-0.6, 1.0, size=len(cities)))}
    w_device = {d: w for d, w in zip(devices, rng.uniform(-0.5, 0.9, size=len(devices)))}
    w_cont   = rng.randn(n_cont).astype("float32")

    score_cat = (np.vectorize(lambda v: w_gender[v])(gender_col) +
                 np.vectorize(lambda v: w_city[v])(city_col) +
                 np.vectorize(lambda v: w_device[v])(device_col)).astype("float32")
    score_cont = (X_cont * w_cont).sum(axis=1).astype("float32")

    bias = 0.1
    noise = rng.normal(scale=0.5, size=n_samples).astype("float32")
    logit = 0.7 * score_cat + 0.8 * score_cont + bias + noise
    prob = 1.0 / (1.0 + np.exp(-logit))
    y = (prob > 0.5).astype("int64")

    # combine into feature array (object dtype)
    X = np.empty((n_samples, 3 + n_cont), dtype=object)
    X[:, 0] = gender_col
    X[:, 1] = city_col
    X[:, 2] = device_col
    X[:, 3:] = X_cont

    categorical_feature_indices = [0, 1, 2]
    return X, y, categorical_feature_indices


# -----------------------------------------------------------
# Base Neural Network Model
# -----------------------------------------------------------
class BaseNNModel(nn.Module, ABC):
    @abstractmethod
    def __init__(self, **kwargs):
        super(BaseNNModel, self).__init__()

    @abstractmethod
    def build_network(self) -> nn.Module:
        pass

    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pass


# -----------------------------------------------------------
# MLP Model
# -----------------------------------------------------------
class MLPModel(BaseNNModel):
    def __init__(
        self,
        input_dim: int,
        hidden_dims: list[int],
        cat_features: list[int] | None = None,
        cat_dims: list[int] | None = None,
        emb_dim: int = 8,
    ):
        super(MLPModel, self).__init__()

        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.cat_features = cat_features or []
        self.cat_dims = cat_dims or []
        self.emb_dim = emb_dim

        self.embeddings = None
        self.network = self.build_network()

    def build_network(self) -> nn.Sequential:
        # categorical embedding layer
        if len(self.cat_dims) > 0:
            self.embeddings = nn.ModuleList([
                nn.Embedding(cat_dim, self.emb_dim) for cat_dim in self.cat_dims
            ])

        combined_input_dim = (
            self.input_dim - len(self.cat_features)
            + len(self.cat_features) * self.emb_dim
        )

        layers = []
        dims = [combined_input_dim] + self.hidden_dims

        for i in range(len(dims) - 1):
            hidden_layer = nn.Linear(dims[i], dims[i + 1])
            nn.init.kaiming_normal_(hidden_layer.weight, mode="fan_in", nonlinearity="relu")

            layers.append(hidden_layer)
            layers.append(nn.BatchNorm1d(dims[i + 1]))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.2))

        # output layer (logit)
        layers.append(nn.Linear(dims[-1], 1))

        return nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        continuous_features = []
        embedded_features = []
        cat_idx = 0

        for i in range(self.input_dim):
            if i in self.cat_features:
                cat_values = x[:, i].long()
                embedded = self.embeddings[cat_idx](cat_values)
                embedded_features.append(embedded)
                cat_idx += 1
            else:
                continuous_features.append(x[:, i:i+1])

        if continuous_features:
            continuous_features = torch.cat(continuous_features, dim=1)
        else:
            continuous_features = None

        if embedded_features:
            embedded_features = torch.cat(embedded_features, dim=1)
        else:
            embedded_features = None

        if continuous_features is not None and embedded_features is not None:
            combined_features = torch.cat([continuous_features, embedded_features], dim=1)
        elif embedded_features is not None:
            combined_features = embedded_features
        elif continuous_features is not None:
            combined_features = continuous_features
        else:
            raise ValueError("No features found for forward pass.")

        # logit 출력
        return self.network(combined_features)


# -----------------------------------------------------------
# TabTransformer Model (DeepLearningBinaryClassifier용 래퍼)
# -----------------------------------------------------------
class TabTransformerModel(BaseNNModel):
    def __init__(
        self,
        input_dim: int,
        cat_features: list[int],
        cat_cardinalities: list[int],
        d_token: int = 32,
        n_heads: int = 4,
        n_layers: int = 2,
        dim_feedforward: int = 128,
        attn_dropout: float = 0.1,
        embedding_dropout: float = 0.1,
        add_cls: bool = False,
        pooling: str = "concat",
        cont_proj: str = "linear",
        hidden_dims: tuple[int, ...] = (128, 64),
        mlp_dropout: float = 0.2,
    ):
        """
        input_dim        : 전체 feature 개수 (cat + cont)
        cat_features     : categorical feature의 컬럼 인덱스 리스트
        cat_cardinalities: 각 categorical 컬럼의 cardinality (OOV 제외, TabTransformerNet 정의와 동일)
        """
        super(TabTransformerModel, self).__init__()

        self.input_dim = input_dim
        self.cat_features = list(cat_features)
        self.cont_features = [i for i in range(input_dim) if i not in self.cat_features]
        self.cat_cardinalities = cat_cardinalities

        # TabTransformerNet 하이퍼파라미터
        self.d_token = d_token
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.dim_feedforward = dim_feedforward
        self.attn_dropout = attn_dropout
        self.embedding_dropout = embedding_dropout
        self.add_cls = add_cls
        self.pooling = pooling
        self.cont_proj = cont_proj
        self.hidden_dims = hidden_dims
        self.mlp_dropout = mlp_dropout

        self.network = self.build_network()

    def build_network(self) -> nn.Module:
        # n_continuous = 전체 - categorical 개수
        n_cont = len(self.cont_features)

        model = TabTransformerNet(
            cat_cardinalities=self.cat_cardinalities,
            n_continuous=n_cont,
            d_token=self.d_token,
            n_heads=self.n_heads,
            n_layers=self.n_layers,
            dim_feedforward=self.dim_feedforward,
            attn_dropout=self.attn_dropout,
            embedding_dropout=self.embedding_dropout,
            add_cls=self.add_cls,
            pooling=self.pooling,
            cont_proj=self.cont_proj,
            mlp_hidden_dims=self.hidden_dims,
            mlp_dropout=self.mlp_dropout,
            padding_idx=0,       # OOV=0
            norm_first=True,
        )
        return model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, input_dim)
          - cat_features 위치의 값들은 이미 integer ID라고 가정
          - 나머지는 float형 연속 변수
        """
        # categorical
        if len(self.cat_features) > 0:
            x_cat = x[:, self.cat_features].long()
        else:
            x_cat = torch.zeros(x.size(0), 0, dtype=torch.long, device=x.device)

        # continuous
        if len(self.cont_features) > 0:
            x_cont = x[:, self.cont_features].float()
        else:
            x_cont = None

        # TabTransformerNet은 Sigmoid까지 포함된 확률(p) (B,1) 반환
        out = self.network(x_cat, x_cont)
        return out  # 확률


# -----------------------------------------------------------
# Deep Learning Binary Classifier
#   - model_type: "mlp" / "tabtransformer"
# -----------------------------------------------------------
class DeepLearningBinaryClassifier(BaseEstimator, ClassifierMixin):
    def __init__(
        self,
        model_type: str = "mlp",
        model_params: dict | None = None,
    ):
        self.model_type = model_type
        self.model_params = model_params or {}
        self.model = None

    @property
    def device(self) -> torch.device:
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _build_model(self) -> BaseNNModel:
        model_registry = {
            "mlp": MLPModel,
            "tabtransformer": TabTransformerModel,
        }

        if self.model_type not in model_registry:
            raise ValueError(f"Unknown model type: {self.model_type}")

        # lr, loss_fn 은 모델 __init__에 넘기지 않음
        valid_params = {
            k: v for k, v in self.model_params.items()
            if k not in ["loss_fn", "lr"]
        }

        model_class = model_registry[self.model_type](**valid_params)
        return model_class

    def _get_loss_fn(self) -> nn.Module:
        loss_name = self.model_params.get("loss_fn", "logloss")

        if loss_name == "logloss":
            # TabTransformer는 이미 Sigmoid 된 확률을 출력하므로 BCELoss,
            # MLP는 logit 출력이므로 BCEWithLogitsLoss
            if self.model_type == "tabtransformer":
                return nn.BCELoss(reduction="none")
            else:
                return nn.BCEWithLogitsLoss(reduction="none")
        else:
            raise ValueError(f"Unknown loss function: {loss_name}")

    def fit(
        self,
        X: np.ndarray,
        y: np.ndarray,
        sample_weight: np.ndarray | None = None,
        eval_set: list[tuple[np.ndarray, np.ndarray]] | None = None,
        eval_metric: list[str] | None = None,
        max_epochs: int = 10,
        patience: int | None = None,
        batch_size: int = 128,
    ) -> None:

        lr = self.model_params.get("lr", 0.001)
        eval_metric = eval_metric or ["logloss"]

        X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
        y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1).to(self.device)

        if eval_set is not None:
            x_eval_tensor = torch.tensor(eval_set[0][0], dtype=torch.float32).to(self.device)
            y_eval_true = eval_set[0][1]
        else:
            x_eval_tensor = None
            y_eval_true = None

        if sample_weight is not None:
            sample_weight_tensor = torch.tensor(sample_weight, dtype=torch.float32).view(-1, 1).to(self.device)
        else:
            sample_weight_tensor = torch.ones_like(y_tensor, dtype=torch.float32)

        train_dataset = TensorDataset(X_tensor, y_tensor, sample_weight_tensor)
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        if self.model is None:
            self.model = self._build_model().to(self.device)

        loss_fn = self._get_loss_fn()
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)

        patience_counter = 0
        best_metric = float("inf")
        best_model_weights = None

        for epoch in range(max_epochs):
            epoch_loss = 0.0

            for x_batch, y_batch, weight_batch in train_dataloader:
                self.model.train()
                optimizer.zero_grad()

                y_pred = self.model(x_batch)

                # MLP: y_pred = logits, loss_fn = BCEWithLogitsLoss
                # TabTransformer: y_pred = prob,   loss_fn = BCELoss
                loss = loss_fn(y_pred, y_batch)
                weighted_loss = (loss * weight_batch).sum() / weight_batch.sum()

                weighted_loss.backward()
                optimizer.step()

                epoch_loss += weighted_loss.item()

            print(f"Epoch {epoch + 1}/{max_epochs}")
            print(f"- [train] loss: {epoch_loss / len(train_dataloader):.6f}")

            # evaluation
            if eval_set is not None and x_eval_tensor is not None:
                with torch.no_grad():
                    self.model.eval()
                    raw_eval_pred = self.model(x_eval_tensor)

                    if self.model_type == "tabtransformer":
                        y_eval_pred = raw_eval_pred.cpu().numpy().ravel()
                    else:
                        y_eval_pred = torch.sigmoid(raw_eval_pred).cpu().numpy().ravel()

                eval_metrics = {}
                for metric in eval_metric:
                    if metric == "logloss":
                        eval_metrics["logloss"] = log_loss(y_eval_true, y_eval_pred)
                    elif metric == "average_precision":
                        eval_metrics["average_precision"] = -average_precision_score(
                            y_eval_true, y_eval_pred
                        )  # early stopping을 "minimize" 기준으로 쓰기 위해 -
                    elif metric == "auc":
                        eval_metrics["auc"] = -roc_auc_score(y_eval_true, y_eval_pred)
                    else:
                        raise ValueError(f"Unknown metric: {metric}")

                metrics_str = ", ".join(
                    [f"{k}: {v:.4f}" for k, v in eval_metrics.items()]
                )
                print(f"- [eval] {metrics_str}")

                # early stopping 기준 metric
                key_for_es = eval_metric[0] if eval_metric[0] in eval_metrics else "logloss"
                current_metric = eval_metrics[key_for_es]
                print(f" -- (early_stopping) current_metric: {current_metric:.6f}, best_metric: {best_metric:.6f}")

                if current_metric < best_metric:
                    best_metric = current_metric
                    patience_counter = 0
                    best_model_weights = self.model.state_dict()
                else:
                    patience_counter += 1
                    if patience is not None and patience_counter >= patience:
                        print(f"early stopping at epoch {epoch + 1}")
                        break

        if best_model_weights is not None:
            self.model.load_state_dict(best_model_weights)

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        self.model = self.model.to(self.device)
        X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)

        with torch.no_grad():
            self.model.eval()
            raw_out = self.model(X_tensor)

            if self.model_type == "tabtransformer":
                # 이미 확률
                probs1 = raw_out
            else:
                # logits -> sigmoid
                probs1 = torch.sigmoid(raw_out)

            probs = probs1.cpu().numpy()

        if probs.shape[1] == 1:
            probs = np.hstack((1 - probs, probs))

        return probs.astype("float")

    def predict(self, X: np.ndarray) -> np.ndarray:
        probs = self.predict_proba(X)
        return probs.argmax(axis=1)


# -----------------------------------------------------------
# (옵션) 나머지 Rule Classifier들 – 원래 인프라 유지용
# -----------------------------------------------------------
class WeightedRuleClassifier:
    def __init__(self, rule_weights: list[float]):
        self.rule_weights = rule_weights

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        probs = np.zeros((X.shape[0], 2))
        for i, rule_weight in enumerate(self.rule_weights):
            probs[:, 1] += rule_weight * X[:, i]
        probs[:, 0] = (np.ones(X.shape[0],) - probs[:, 1])
        return probs.astype("float")

    def predict(self, X: np.ndarray) -> np.ndarray:
        return self.predict_proba(X).argmax(axis=1)


class CustomRuleClassifier:
    def __init__(self, rule: str):
        self.rule = rule


# -----------------------------------------------------------
# 4) TabTransformer + DeepLearningBinaryClassifier 실행 데모
# -----------------------------------------------------------
def tabtransformer_demo():
    np.random.seed(0)
    torch.manual_seed(0)

    # 1) 문자열 섞인 샘플 데이터 생성
    X_raw, y, categorical_feature_indices = make_mixed_sample(
        n_samples=30000,
        seed=123
    )

    # 2) 전처리: 문자열 카테고리 -> 정수 ID
    preproc = TabularPreprocessor(
        categorical_indices=categorical_feature_indices,
        use_oov=True,
        oov_token=0,
        add_na_token=True,
    )
    x_cat, x_cont = preproc.fit_transform(X_raw)

    # TabTransformerModel은 X 전체(np.float32)를 받되,
    # cat_features 인덱스에 카테고리 ID가 들어있다고 가정
    if x_cont is not None:
        X_all = np.concatenate([x_cat, x_cont], axis=1)
    else:
        X_all = x_cat

    # categorical 위치: 앞쪽부터 x_cat.shape[1]개
    cat_features = list(range(x_cat.shape[1]))
    cat_cardinalities = preproc.cardinalities  # 각 cat column의 cardinality (OOV 제외)

    # 3) train/val/test split
    N = X_all.shape[0]
    idx = np.arange(N)
    np.random.shuffle(idx)

    tr_end = int(N * 0.7)
    va_end = int(N * 0.85)
    tr_idx, va_idx, te_idx = idx[:tr_end], idx[tr_end:va_end], idx[va_end:]

    X_tr, y_tr = X_all[tr_idx], y[tr_idx]
    X_va, y_va = X_all[va_idx], y[va_idx]
    X_te, y_te = X_all[te_idx], y[te_idx]

    # 4) DeepLearningBinaryClassifier + TabTransformer 설정
    clf = DeepLearningBinaryClassifier(
        model_type="tabtransformer",
        model_params={
            "input_dim": X_all.shape[1],
            "cat_features": cat_features,
            "cat_cardinalities": cat_cardinalities,
            # TabTransformerNet hyperparams
            "d_token": 32,
            "n_heads": 4,
            "n_layers": 2,
            "dim_feedforward": 128,
            "attn_dropout": 0.1,
            "embedding_dropout": 0.05,
            "pooling": "concat",
            "add_cls": False,
            "cont_proj": "linear",
            "hidden_dims": (128, 64),
            "mlp_dropout": 0.2,
            # DL classifier options
            "lr": 1e-3,
            "loss_fn": "logloss",
        },
    )

    clf.fit(
        X_tr,
        y_tr,
        eval_set=[(X_va, y_va)],
        eval_metric=["logloss"],
        max_epochs=10,
        patience=2,
        batch_size=1024,
    )

    # 5) 평가
    proba_te = clf.predict_proba(X_te)[:, 1]
    pred_te = (proba_te >= 0.5).astype(int)

    acc = accuracy_score(y_te, pred_te)
    auc = roc_auc_score(y_te, proba_te)
    ll = log_loss(y_te, proba_te)

    print("\n===== Test Metrics (TabTransformer via DeepLearningBinaryClassifier) =====")
    print(f"Accuracy : {acc:.4f}")
    print(f"ROC-AUC  : {auc:.4f}")
    print(f"Logloss  : {ll:.4f}")

    print("\nSample predictions (first 10):")
    print(np.round(proba_te[:10], 4))

    print("\n[Info] cat_features indices:", cat_features)
    print("[Info] cat cardinalities   :", cat_cardinalities)


if __name__ == "__main__":
    tabtransformer_demo()




Epoch 1/10
- [train] loss: 0.553935
- [eval] logloss: 0.4053
 -- (early_stopping) current_metric: 0.405295, best_metric: inf
Epoch 2/10
- [train] loss: 0.333160
- [eval] logloss: 0.2482
 -- (early_stopping) current_metric: 0.248219, best_metric: 0.405295
Epoch 3/10
- [train] loss: 0.230129
- [eval] logloss: 0.1967
 -- (early_stopping) current_metric: 0.196677, best_metric: 0.248219
Epoch 4/10
- [train] loss: 0.181525
- [eval] logloss: 0.1572
 -- (early_stopping) current_metric: 0.157211, best_metric: 0.196677
Epoch 5/10
- [train] loss: 0.157481
- [eval] logloss: 0.1420
 -- (early_stopping) current_metric: 0.141978, best_metric: 0.157211
Epoch 6/10
- [train] loss: 0.144002
- [eval] logloss: 0.1331
 -- (early_stopping) current_metric: 0.133109, best_metric: 0.141978
Epoch 7/10
- [train] loss: 0.135624
- [eval] logloss: 0.1270
 -- (early_stopping) current_metric: 0.127006, best_metric: 0.133109
Epoch 8/10
- [train] loss: 0.131817
- [eval] logloss: 0.1242
 -- (early_stopping) current_metri