In [1]:
# -*- coding: utf-8 -*-
# FT-Transformer 구현 (TabTransformerModel 스타일과 통일)

from abc import ABC, abstractmethod
import torch
import torch.nn as nn


# -----------------------------------------------------------
# BaseNNModel (이미 있으시면 아래 클래스만 가져가시면 됩니다)
# -----------------------------------------------------------
class BaseNNModel(nn.Module, ABC):
    @abstractmethod
    def __init__(self, **kwargs):
        super(BaseNNModel, self).__init__()

    @abstractmethod
    def build_network(self) -> nn.Module:
        pass

    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pass


# -----------------------------------------------------------
# FT-Transformer Model
# -----------------------------------------------------------
class FTTransformerModel(BaseNNModel):
    """
    FT-Transformer (Feature Tokenizer + Transformer) 구현.

    - input_dim: 전체 feature 개수
    - cat_features: 범주형 feature 인덱스 리스트 (X의 column index)
    - cat_dims: 각 범주형 feature의 cardinality (고유값 개수)
        * 각 범주형 값은 0 ~ (cardinality-1)의 int 인덱스로 들어온다고 가정.
    - 나머지 feature는 모두 numerical feature로 처리.
    """

    def __init__(
        self,
        input_dim: int,
        cat_features: list[int] | None = None,
        cat_dims: list[int] | None = None,
        d_token: int = 32,
        n_heads: int = 4,
        n_layers: int = 3,
        dim_feedforward: int | None = None,
        attn_dropout: float = 0.1,
        token_dropout: float = 0.0,
        mlp_dropout: float = 0.2,
    ):
        super().__init__()

        self.input_dim = input_dim
        self.cat_features = cat_features or []
        self.cat_dims = cat_dims or []
        self.d_token = d_token
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.dim_feedforward = dim_feedforward or (4 * d_token)  # 논문 스타일: 4 * d
        self.attn_dropout = attn_dropout
        self.token_dropout_p = token_dropout
        self.mlp_dropout = mlp_dropout

        self.build_network()

    def build_network(self) -> nn.Module:
        # --- feature 개수 세기 ---
        assert len(self.cat_features) == len(
            self.cat_dims
        ), "cat_features 개수와 cat_dims 개수가 다릅니다."

        self.n_cat = len(self.cat_features)
        self.n_num = self.input_dim - self.n_cat

        # ===== feature 인덱스를 buffer로 저장 (forward에서 재사용) =====
        if self.n_cat > 0:
            cat_idx = torch.tensor(self.cat_features, dtype=torch.long)
            self.register_buffer("cat_idx_tensor", cat_idx, persistent=False)
        else:
            self.register_buffer(
                "cat_idx_tensor",
                torch.zeros(0, dtype=torch.long),
                persistent=False,
            )

        num_idx = [i for i in range(self.input_dim) if i not in self.cat_features]
        if len(num_idx) != self.n_num:
            raise ValueError("numerical feature 개수 계산이 잘못되었습니다.")
        num_idx = torch.tensor(num_idx, dtype=torch.long)
        self.register_buffer("num_idx_tensor", num_idx, persistent=False)

        # ===== Feature Tokenizer =====
        # 1) Numerical features: T_j = b_j + x_j * W_j
        if self.n_num > 0:
            # shape: (1, n_num, d_token)로 두고 broadcasting 사용
            self.num_weight = nn.Parameter(
                torch.empty(1, self.n_num, self.d_token)
            )
            self.num_bias = nn.Parameter(
                torch.empty(1, self.n_num, self.d_token)
            )
            nn.init.normal_(self.num_weight, std=0.02)
            nn.init.normal_(self.num_bias, std=0.02)
        else:
            self.num_weight = None
            self.num_bias = None

        # 2) Categorical features: T_j = b_j + Embedding(x_j)
        if self.n_cat > 0:
            self.cat_embeddings = nn.ModuleList()
            for c in self.cat_dims:
                emb = nn.Embedding(num_embeddings=c, embedding_dim=self.d_token)
                nn.init.normal_(emb.weight, std=0.02)
                self.cat_embeddings.append(emb)

            # 각 범주형 feature별 bias (1, n_cat, d)
            self.cat_bias = nn.Parameter(
                torch.empty(1, self.n_cat, self.d_token)
            )
            nn.init.normal_(self.cat_bias, std=0.02)
        else:
            self.cat_embeddings = nn.ModuleList()
            self.cat_bias = None

        # CLS 토큰
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.d_token))
        nn.init.normal_(self.cls_token, std=0.02)

        # 토큰 dropout (feature-level)
        self.token_dropout = nn.Dropout(self.token_dropout_p)

        # ===== Transformer Encoder (PreNorm) =====
        # FT-Transformer는 PreNorm 구조를 사용 → norm_first=True
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_token,
            nhead=self.n_heads,
            dim_feedforward=self.dim_feedforward,
            dropout=self.attn_dropout,
            batch_first=True,
            norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=self.n_layers,
        )

        # ===== Head (Prediction) =====
        # 최종: ŷ = Linear(ReLU(LayerNorm(T_CLS^L)))  (binary classification logit)
        self.final_norm = nn.LayerNorm(self.d_token)
        self.final_relu = nn.ReLU()
        self.final_dropout = nn.Dropout(self.mlp_dropout)
        self.head = nn.Linear(self.d_token, 1)

        return self

    def _tokenize(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, input_dim)
        return: tokens T (B, k, d_token)  -- k = n_num + n_cat
        """
        B, D = x.shape

        tokens_list = []

        # ----- numerical -----
        if self.n_num > 0:
            # (B, n_num)
            x_num = x.index_select(1, self.num_idx_tensor).unsqueeze(-1)  # (B, n_num, 1)
            # (1, n_num, d) broadcast + (B, n_num, 1) -> (B, n_num, d)
            T_num = self.num_bias + x_num * self.num_weight
            tokens_list.append(T_num)

        # ----- categorical -----
        if self.n_cat > 0:
            x_cat = x.index_select(1, self.cat_idx_tensor).long()  # (B, n_cat)
            cat_tokens = []
            for j, emb in enumerate(self.cat_embeddings):
                tj = emb(x_cat[:, j])  # (B, d_token)
                cat_tokens.append(tj.unsqueeze(1))  # (B, 1, d_token)
            T_cat = torch.cat(cat_tokens, dim=1)  # (B, n_cat, d_token)
            T_cat = T_cat + self.cat_bias  # broadcast: (1, n_cat, d)
            tokens_list.append(T_cat)

        if len(tokens_list) == 0:
            raise ValueError("No features (numerical or categorical) were provided.")

        # feature dimension 기준으로 concat
        T = torch.cat(tokens_list, dim=1)  # (B, k, d_token)
        return T

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, input_dim), float32 (범주형도 float->long 캐스팅 예정)
        return: (B, 1) logits (sigmoid 이전)
        """
        B, D = x.shape

        # 1) Feature Tokenizer
        T = self._tokenize(x)  # (B, k, d)
        T = self.token_dropout(T)

        # 2) [CLS] 토큰 붙이기
        cls = self.cls_token.expand(B, -1, -1)  # (B, 1, d)
        T = torch.cat([cls, T], dim=1)  # (B, 1 + k, d)

        # 3) Transformer 인코더
        Z = self.transformer(T)  # (B, 1 + k, d)

        # 4) [CLS] 벡터에서 예측
        cls_rep = Z[:, 0, :]  # (B, d)
        h = self.final_norm(cls_rep)
        h = self.final_relu(h)
        h = self.final_dropout(h)
        logits = self.head(h)  # (B, 1)
        return logits


In [4]:
# -*- coding: utf-8 -*-
# ===========================================================
# MLP + TabTransformer + FT-Transformer 전체 데모 코드
# ===========================================================

# Standard Library
from abc import ABC, abstractmethod

# Third Party
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import average_precision_score, log_loss, roc_auc_score
from torch.utils.data import DataLoader, TensorDataset


# -----------------------------------------------------------
# Base Neural Network Model
# -----------------------------------------------------------
class BaseNNModel(nn.Module, ABC):
    @abstractmethod
    def __init__(self, **kwargs):
        super(BaseNNModel, self).__init__()

    @abstractmethod
    def build_network(self) -> nn.Module:
        pass

    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pass


# -----------------------------------------------------------
# MLP Model
# -----------------------------------------------------------
class MLPModel(BaseNNModel):
    def __init__(
        self,
        input_dim: int,
        hidden_dims: list[int],
        cat_features: list[int] | None = None,
        cat_dims: list[int] | None = None,
        emb_dim: int = 8,
    ):
        super(MLPModel, self).__init__()

        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.cat_features = cat_features or []
        self.cat_dims = cat_dims or []
        self.emb_dim = emb_dim

        self.embeddings = None
        self.network = self.build_network()

    def build_network(self) -> nn.Sequential:
        # categorical embedding layer
        if len(self.cat_dims) > 0:
            self.embeddings = nn.ModuleList(
                [nn.Embedding(cat_dim, self.emb_dim) for cat_dim in self.cat_dims]
            )

        combined_input_dim = (
            self.input_dim - len(self.cat_features)
            + len(self.cat_features) * self.emb_dim
        )

        layers = []
        dims = [combined_input_dim] + self.hidden_dims

        for i in range(len(dims) - 1):
            hidden_layer = nn.Linear(dims[i], dims[i + 1])
            nn.init.kaiming_normal_(
                hidden_layer.weight, mode="fan_in", nonlinearity="relu"
            )

            layers.append(hidden_layer)
            layers.append(nn.BatchNorm1d(dims[i + 1]))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.2))

        # output layer (logit)
        layers.append(nn.Linear(dims[-1], 1))

        return nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        continuous_features = []
        embedded_features = []
        cat_idx = 0

        for i in range(self.input_dim):
            if i in self.cat_features:
                cat_values = x[:, i].long()
                embedded = self.embeddings[cat_idx](cat_values)
                embedded_features.append(embedded)
                cat_idx += 1
            else:
                continuous_features.append(x[:, i : i + 1])

        if continuous_features:
            continuous_features = torch.cat(continuous_features, dim=1)
        else:
            continuous_features = None

        if embedded_features:
            embedded_features = torch.cat(embedded_features, dim=1)
        else:
            embedded_features = None

        if continuous_features is not None and embedded_features is not None:
            combined_features = torch.cat(
                [continuous_features, embedded_features], dim=1
            )
        elif embedded_features is not None:
            combined_features = embedded_features
        elif continuous_features is not None:
            combined_features = continuous_features
        else:
            raise ValueError("No features found for forward pass.")

        logits = self.network(combined_features)
        return logits  # (B, 1)


# -----------------------------------------------------------
# TabTransformer Model (개선 버전)
# -----------------------------------------------------------
class TabTransformerModel(BaseNNModel):
    """
    - input_dim: 전체 feature 개수
    - cat_features: 범주형 feature 인덱스 리스트 (X의 column index)
    - cat_dims: 각 범주형 feature의 cardinality (고유값 개수)
      * 각 범주형 값은 0 ~ (cardinality-1)의 int 인덱스로 들어온다고 가정.
    - 나머지 feature는 모두 continuous feature로 처리.
    """

    def __init__(
        self,
        input_dim: int,
        cat_features: list[int] | None = None,
        cat_dims: list[int] | None = None,
        d_token: int = 32,
        n_heads: int = 4,
        n_layers: int = 2,
        dim_feedforward: int | None = None,
        attn_dropout: float = 0.1,
        embedding_dropout: float = 0.1,
        add_cls: bool = False,
        pooling: str = "concat",  # "concat" or "cls"
        cont_proj: str = "linear",  # "none" or "linear"
        mlp_hidden_dims: tuple[int, ...] = (128, 64),
        mlp_dropout: float = 0.2,
        use_missing_category: bool = False,
    ):
        super().__init__()

        self.input_dim = input_dim
        self.cat_features = cat_features or []
        self.cat_dims = cat_dims or []
        self.d_token = d_token
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.dim_feedforward = dim_feedforward or (4 * d_token)
        self.attn_dropout = attn_dropout
        self.embedding_dropout_p = embedding_dropout
        self.add_cls = add_cls
        self.pooling = pooling
        self.cont_proj = cont_proj
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_dropout = mlp_dropout
        self.use_missing_category = use_missing_category

        self.build_network()

    def build_network(self) -> nn.Module:
        assert len(self.cat_features) == len(
            self.cat_dims
        ), "cat_features 개수와 cat_dims 개수가 다릅니다."

        # 연속형 / 범주형 개수
        self.n_cat = len(self.cat_features)
        self.n_cont = self.input_dim - self.n_cat

        # ----- 인덱스 buffer 등록 -----
        if self.n_cat > 0:
            cat_idx = torch.tensor(self.cat_features, dtype=torch.long)
            self.register_buffer("cat_idx_tensor", cat_idx, persistent=False)
        else:
            self.register_buffer(
                "cat_idx_tensor",
                torch.zeros(0, dtype=torch.long),
                persistent=False,
            )

        cont_idx = [i for i in range(self.input_dim) if i not in self.cat_features]
        if len(cont_idx) != self.n_cont:
            raise ValueError("continuous feature 개수 계산이 잘못되었습니다.")
        cont_idx = torch.tensor(cont_idx, dtype=torch.long)
        self.register_buffer("cont_idx_tensor", cont_idx, persistent=False)

        # ====== Categorical path ======
        if self.n_cat == 0:
            self.cat_embeddings = nn.ModuleList()
            self.col_embedding = None
        else:
            self.cat_embeddings = nn.ModuleList()
            for c in self.cat_dims:
                n_embeddings = c + (1 if self.use_missing_category else 0)
                emb = nn.Embedding(
                    num_embeddings=n_embeddings,
                    embedding_dim=self.d_token,
                )
                nn.init.normal_(emb.weight, std=0.02)
                self.cat_embeddings.append(emb)

            # column embedding (선택적 token-type 역할)
            self.col_embedding = nn.Embedding(self.n_cat, self.d_token)
            nn.init.normal_(self.col_embedding.weight, std=0.02)

        if self.add_cls:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.d_token))
            nn.init.normal_(self.cls_token, std=0.02)

        self.embedding_dropout = nn.Dropout(self.embedding_dropout_p)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=self.d_token,
            nhead=self.n_heads,
            dim_feedforward=self.dim_feedforward,
            dropout=self.attn_dropout,
            batch_first=True,
            norm_first=False,  # 논문 스타일(Post-LN)에 가깝게
        )
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers=self.n_layers)

        # ====== 연속형 path ======
        if self.n_cont > 0:
            self.cont_bn = nn.BatchNorm1d(self.n_cont)
            if self.cont_proj == "linear":
                self.cont_linear = nn.Linear(self.n_cont, self.d_token)
                nn.init.kaiming_uniform_(
                    self.cont_linear.weight, mode="fan_in", nonlinearity="relu"
                )
                cont_out_dim = self.d_token
            else:
                self.cont_linear = nn.Identity()
                cont_out_dim = self.n_cont
        else:
            self.cont_bn = None
            self.cont_linear = None
            cont_out_dim = 0

        # ====== Head (logit 출력) ======
        backbone_out = (
            self.d_token if self.pooling == "cls" else self.n_cat * self.d_token
        )
        in_dim = backbone_out + cont_out_dim

        layers = []
        prev = in_dim
        for h in self.mlp_hidden_dims:
            lin = nn.Linear(prev, h)
            nn.init.kaiming_uniform_(lin.weight, mode="fan_in", nonlinearity="relu")
            layers.extend(
                [lin, nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(self.mlp_dropout)]
            )
            prev = h
        layers.append(nn.Linear(prev, 1))  # logits
        self.head = nn.Sequential(*layers)

        return self

    def _encode_categoricals(self, x_cat: torch.LongTensor) -> torch.Tensor:
        """
        x_cat: (B, n_cat) -> contextualized embedding
        - use_missing_category=True 인 경우, 전처리에서 missing을 마지막 인덱스로 매핑했다고 가정.
        """
        B = x_cat.size(0)
        if self.n_cat == 0:
            return torch.zeros(
                B,
                self.d_token if self.pooling == "cls" else 0,
                device=x_cat.device,
                dtype=torch.float32,
            )

        tok_list = []
        for j, emb in enumerate(self.cat_embeddings):
            tok = emb(x_cat[:, j])  # (B, d)
            if self.col_embedding is not None:
                tok = tok + self.col_embedding.weight[j]  # (d,)
            tok_list.append(tok.unsqueeze(1))  # (B, 1, d)

        x_tok = torch.cat(tok_list, dim=1)  # (B, n_cat, d)

        if self.add_cls:
            cls = self.cls_token.expand(B, -1, -1)  # (B, 1, d)
            x_tok = torch.cat([cls, x_tok], dim=1)  # (B, 1 + n_cat, d)

        x_tok = self.embedding_dropout(x_tok)
        z = self.transformer(x_tok)  # (B, T, d)

        if self.pooling == "cls" and self.add_cls:
            out = z[:, 0, :]  # (B, d)
        elif self.pooling == "cls":
            out = z.mean(dim=1)  # (B, d)
        else:
            if self.add_cls:
                z = z[:, 1:, :]  # CLS 제거
            out = z.reshape(B, -1)  # (B, n_cat*d)
        return out

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, input_dim)
        - self.cat_features 인덱스에 해당하는 column은 categorical, 나머지는 continuous로 사용.
        - 출력: (B, 1) logits (sigmoid 전)
        """
        B, D = x.shape
        device = x.device

        # ----- categorical split -----
        if self.n_cat > 0:
            x_cat = x.index_select(1, self.cat_idx_tensor).long()
        else:
            x_cat = torch.zeros(B, 0, dtype=torch.long, device=device)

        # ----- continuous split -----
        if self.n_cont > 0:
            x_cont = x.index_select(1, self.cont_idx_tensor).float()
        else:
            x_cont = None

        # categorical path
        z_cat = self._encode_categoricals(x_cat)

        # continuous path
        if (x_cont is not None) and (self.n_cont > 0):
            if x_cont.ndim == 1:
                x_cont = x_cont.unsqueeze(1)
            x_cont = self.cont_bn(x_cont)
            x_cont = self.cont_linear(x_cont)
            z = torch.cat([z_cat, x_cont], dim=1)
        else:
            z = z_cat

        logits = self.head(z)  # (B, 1)
        return logits


# -----------------------------------------------------------
# FT-Transformer Model (단일 클래스)
# -----------------------------------------------------------
class FTTransformerModel(BaseNNModel):
    """
    FT-Transformer (Feature Tokenizer + Transformer) 구현.

    - input_dim: 전체 feature 개수
    - cat_features: 범주형 feature 인덱스 리스트 (X의 column index)
    - cat_dims: 각 범주형 feature의 cardinality (고유값 개수)
        * 각 범주형 값은 0 ~ (cardinality-1)의 int 인덱스로 들어온다고 가정.
    - 나머지 feature는 모두 numerical feature로 처리.
    """

    def __init__(
        self,
        input_dim: int,
        cat_features: list[int] | None = None,
        cat_dims: list[int] | None = None,
        d_token: int = 32,
        n_heads: int = 4,
        n_layers: int = 3,
        dim_feedforward: int | None = None,
        attn_dropout: float = 0.1,
        token_dropout: float = 0.0,
        mlp_dropout: float = 0.2,
    ):
        super().__init__()

        self.input_dim = input_dim
        self.cat_features = cat_features or []
        self.cat_dims = cat_dims or []
        self.d_token = d_token
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.dim_feedforward = dim_feedforward or (4 * d_token)
        self.attn_dropout = attn_dropout
        self.token_dropout_p = token_dropout
        self.mlp_dropout = mlp_dropout

        self.build_network()

    def build_network(self) -> nn.Module:
        # --- feature 개수 세기 ---
        assert len(self.cat_features) == len(
            self.cat_dims
        ), "cat_features와 cat_dims 길이가 다릅니다."

        self.n_cat = len(self.cat_features)
        self.n_num = self.input_dim - self.n_cat

        # ===== feature 인덱스를 buffer로 저장 =====
        if self.n_cat > 0:
            cat_idx = torch.tensor(self.cat_features, dtype=torch.long)
            self.register_buffer("cat_idx_tensor", cat_idx, persistent=False)
        else:
            self.register_buffer(
                "cat_idx_tensor",
                torch.zeros(0, dtype=torch.long),
                persistent=False,
            )

        num_idx = [i for i in range(self.input_dim) if i not in self.cat_features]
        if len(num_idx) != self.n_num:
            raise ValueError("numerical feature 개수 계산이 잘못되었습니다.")
        num_idx = torch.tensor(num_idx, dtype=torch.long)
        self.register_buffer("num_idx_tensor", num_idx, persistent=False)

        # ===== Feature Tokenizer =====
        # 1) Numerical features: T_j = b_j + x_j * W_j
        if self.n_num > 0:
            self.num_weight = nn.Parameter(
                torch.empty(1, self.n_num, self.d_token)
            )
            self.num_bias = nn.Parameter(
                torch.empty(1, self.n_num, self.d_token)
            )
            nn.init.normal_(self.num_weight, std=0.02)
            nn.init.normal_(self.num_bias, std=0.02)
        else:
            self.num_weight = None
            self.num_bias = None

        # 2) Categorical features: T_j = b_j + Embedding(x_j)
        if self.n_cat > 0:
            self.cat_embeddings = nn.ModuleList()
            for c in self.cat_dims:
                emb = nn.Embedding(num_embeddings=c, embedding_dim=self.d_token)
                nn.init.normal_(emb.weight, std=0.02)
                self.cat_embeddings.append(emb)

            self.cat_bias = nn.Parameter(
                torch.empty(1, self.n_cat, self.d_token)
            )
            nn.init.normal_(self.cat_bias, std=0.02)
        else:
            self.cat_embeddings = nn.ModuleList()
            self.cat_bias = None

        # CLS 토큰
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.d_token))
        nn.init.normal_(self.cls_token, std=0.02)

        # 토큰 dropout
        self.token_dropout = nn.Dropout(self.token_dropout_p)

        # ===== Transformer Encoder (PreNorm) =====
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_token,
            nhead=self.n_heads,
            dim_feedforward=self.dim_feedforward,
            dropout=self.attn_dropout,
            batch_first=True,
            norm_first=True,  # FT-Transformer: PreNorm
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=self.n_layers,
        )

        # ===== Head (Prediction) =====
        self.final_norm = nn.LayerNorm(self.d_token)
        self.final_relu = nn.ReLU()
        self.final_dropout = nn.Dropout(self.mlp_dropout)
        self.head = nn.Linear(self.d_token, 1)

        return self

    def _tokenize(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, input_dim)
        return: tokens T (B, k, d_token)  -- k = n_num + n_cat
        """
        B, D = x.shape

        tokens_list = []

        # ----- numerical -----
        if self.n_num > 0:
            x_num = x.index_select(1, self.num_idx_tensor).unsqueeze(-1)  # (B, n_num, 1)
            T_num = self.num_bias + x_num * self.num_weight  # (B, n_num, d)
            tokens_list.append(T_num)

        # ----- categorical -----
        if self.n_cat > 0:
            x_cat = x.index_select(1, self.cat_idx_tensor).long()  # (B, n_cat)
            cat_tokens = []
            for j, emb in enumerate(self.cat_embeddings):
                tj = emb(x_cat[:, j])  # (B, d_token)
                cat_tokens.append(tj.unsqueeze(1))  # (B, 1, d_token)
            T_cat = torch.cat(cat_tokens, dim=1)  # (B, n_cat, d_token)
            T_cat = T_cat + self.cat_bias  # (B, n_cat, d_token)
            tokens_list.append(T_cat)

        if len(tokens_list) == 0:
            raise ValueError("No features (numerical or categorical) were provided.")

        T = torch.cat(tokens_list, dim=1)  # (B, k, d_token)
        return T

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, input_dim), float32 (범주형도 float->long 캐스팅 예정)
        return: (B, 1) logits
        """
        B, D = x.shape

        # 1) Feature Tokenizer
        T = self._tokenize(x)  # (B, k, d)
        T = self.token_dropout(T)

        # 2) [CLS] 토큰 붙이기
        cls = self.cls_token.expand(B, -1, -1)  # (B, 1, d)
        T = torch.cat([cls, T], dim=1)  # (B, 1 + k, d)

        # 3) Transformer
        Z = self.transformer(T)  # (B, 1 + k, d)

        # 4) [CLS]로 예측
        cls_rep = Z[:, 0, :]  # (B, d)
        h = self.final_norm(cls_rep)
        h = self.final_relu(h)
        h = self.final_dropout(h)
        logits = self.head(h)  # (B, 1)
        return logits


# -----------------------------------------------------------
# Deep Learning Binary Classifier
# -----------------------------------------------------------
class DeepLearningBinaryClassifier(BaseEstimator, ClassifierMixin):
    def __init__(
        self,
        model_type: str = "mlp",
        model_params: dict | None = None,
    ):
        self.model_type = model_type
        self.model_params = model_params or {}
        self.model = None

    @property
    def device(self) -> torch.device:
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _build_model(self) -> BaseNNModel:
        model_registry = {
            "mlp": MLPModel,
            "tabtransformer": TabTransformerModel,
            "fttransformer": FTTransformerModel,
        }

        if self.model_type not in model_registry:
            raise ValueError(f"Unknown model type: {self.model_type}")

        valid_params = {
            k: v for k, v in self.model_params.items() if k not in ["loss_fn", "lr"]
        }

        model_class = model_registry[self.model_type](**valid_params)
        return model_class

    def _get_loss_fn(self) -> nn.Module:
        loss_name = self.model_params.get("loss_fn", "logloss")
        if loss_name == "logloss":
            return nn.BCEWithLogitsLoss(reduction="none")
        else:
            raise ValueError(f"Unknown loss function: {loss_name}")

    def fit(
        self,
        X: np.ndarray,
        y: np.ndarray,
        sample_weight: np.ndarray | None = None,
        eval_set: list[tuple[np.ndarray, np.ndarray]] | None = None,
        eval_metric: list[str] | None = None,
        max_epochs: int = 10,
        patience: int | None = None,
        batch_size: int = 128,
        verbose: bool = True,
    ) -> "DeepLearningBinaryClassifier":

        lr = self.model_params.get("lr", 0.001)
        eval_metric = eval_metric or ["logloss"]

        X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
        y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1).to(self.device)

        if eval_set is not None:
            x_eval_tensor = torch.tensor(
                eval_set[0][0], dtype=torch.float32
            ).to(self.device)
            y_eval_true = eval_set[0][1]
        else:
            x_eval_tensor = None
            y_eval_true = None

        if sample_weight is not None:
            sample_weight_tensor = torch.tensor(
                sample_weight, dtype=torch.float32
            ).to(self.device)
        else:
            sample_weight_tensor = torch.ones_like(y_tensor, dtype=torch.float32)

        train_dataset = TensorDataset(X_tensor, y_tensor, sample_weight_tensor)
        train_dataloader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )

        if self.model is None:
            self.model = self._build_model().to(self.device)

        loss_fn = self._get_loss_fn()
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)

        patience_counter = 0
        best_metric = float("inf")
        best_model_weights = None

        for epoch in range(max_epochs):
            self.model.train()
            epoch_loss = 0.0
            n_batches = 0

            for x_batch, y_batch, weight_batch in train_dataloader:
                optimizer.zero_grad()

                y_pred_logits = self.model(x_batch)
                loss = loss_fn(y_pred_logits, y_batch)
                weighted_loss = (loss * weight_batch).sum() / weight_batch.sum()

                weighted_loss.backward()
                optimizer.step()

                epoch_loss += weighted_loss.item()
                n_batches += 1

            if verbose:
                print(
                    f"Epoch {epoch + 1}/{max_epochs} "
                    f"- [train] loss: {epoch_loss / max(1, n_batches):.6f}"
                )

            # evaluation
            if eval_set is not None:
                self.model.eval()
                with torch.no_grad():
                    y_eval_logits = self.model(x_eval_tensor)
                    y_eval_pred = torch.sigmoid(y_eval_logits).cpu().numpy().ravel()

                eval_metrics = {}
                for metric in eval_metric:
                    if metric == "logloss":
                        eval_metrics["logloss"] = log_loss(y_eval_true, y_eval_pred)
                    elif metric == "average_precision":
                        eval_metrics["average_precision"] = -average_precision_score(
                            y_eval_true, y_eval_pred
                        )
                    elif metric == "auc":
                        eval_metrics["auc"] = -roc_auc_score(y_eval_true, y_eval_pred)
                    else:
                        raise ValueError(f"Unknown metric: {metric}")

                if verbose:
                    metrics_str = ", ".join(
                        [f"{k}: {v:.4f}" for k, v in eval_metrics.items()]
                    )
                    print(f"  - [eval] {metrics_str}")

                # early stopping (기준 metric은 리스트의 첫 번째)
                main_metric_name = eval_metric[0]
                current_metric = eval_metrics.get(
                    main_metric_name, eval_metrics["logloss"]
                )

                if verbose:
                    print(
                        f"    -- (early_stopping) current_metric: {current_metric:.6f}, "
                        f"best_metric: {best_metric:.6f}"
                    )

                if current_metric < best_metric:
                    best_metric = current_metric
                    patience_counter = 0
                    best_model_weights = self.model.state_dict()
                else:
                    patience_counter += 1
                    if patience is not None and patience_counter >= patience:
                        if verbose:
                            print(f"Early stopping at epoch {epoch + 1}")
                        break

        if best_model_weights is not None:
            self.model.load_state_dict(best_model_weights)

        return self

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        self.model = self.model.to(self.device)
        X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)

        with torch.no_grad():
            self.model.eval()
            logits = self.model(X_tensor)
            probs1 = torch.sigmoid(logits).cpu().numpy()

        if probs1.shape[1] == 1:
            probs1 = probs1.reshape(-1, 1)

        probs0 = 1.0 - probs1
        probs = np.hstack((probs0, probs1))
        return probs.astype("float")

    def predict(self, X: np.ndarray) -> np.ndarray:
        probs = self.predict_proba(X)
        return probs.argmax(axis=1)


# -----------------------------------------------------------
# 한국형 피처 스키마(연령/성별/OS/점수형) 데모 데이터 생성
# -----------------------------------------------------------
def make_kor_feature_demo_data(
    n_samples: int = 5000,
    seed: int = 42,
):
    rng = np.random.RandomState(seed)
    n = n_samples

    # --- 1) 각 피처 생성 ---
    # 연령: 18~80 사이 normal 분포
    age = np.clip(rng.normal(loc=40, scale=12, size=n), 18, 80).astype("float32")

    # 성별: 3개 카테고리 (0: 미상, 1: 남, 2: 여)
    gender = rng.randint(0, 3, size=n).astype("int64")

    # 앱가입건수: 포아송
    app_join_cnt = rng.poisson(lam=2.0, size=n).astype("float32")

    # 앱마지막접속경과일: 지수분포 (최근일수록 작게)
    last_login_days = rng.exponential(scale=10.0, size=n).astype("float32")

    # OS: 4개 카테고리 (0: Android, 1: iOS, 2: Web, 3: 기타)
    os_cat = rng.randint(0, 4, size=n).astype("int64")

    # 앱가입경과일: 평균 200일 정도로
    app_join_days = rng.exponential(scale=200.0, size=n).astype("float32")

    # 점수형(0~1 근처의 score 형태로)
    biz_owner_score = (rng.binomial(1, 0.3, size=n) * rng.uniform(0.5, 1.0, size=n)).astype("float32")
    home_owner_score = (rng.binomial(1, 0.4, size=n) * rng.uniform(0.5, 1.0, size=n)).astype("float32")
    car_owner_score = (rng.binomial(1, 0.5, size=n) * rng.uniform(0.5, 1.0, size=n)).astype("float32")
    married_score   = (rng.binomial(1, 0.5, size=n) * rng.uniform(0.5, 1.0, size=n)).astype("float32")
    child_score     = (rng.binomial(1, 0.4, size=n) * rng.uniform(0.5, 1.0, size=n)).astype("float32")

    # --- 2) 라벨 생성용 latent score ---
    w_gender = rng.uniform(-0.3, 0.3, size=3)   # gender 0/1/2
    w_os     = rng.uniform(-0.2, 0.2, size=4)   # os 0~3

    score = np.zeros(n, dtype="float32")

    score += 0.05 * (age - 40) / 10.0
    score += 0.25 * app_join_cnt
    score -= 0.03 * last_login_days
    score -= 0.002 * app_join_days
    score += 0.8 * biz_owner_score
    score += 0.6 * home_owner_score
    score += 0.7 * car_owner_score
    score += 0.9 * married_score
    score += 1.0 * child_score

    score += w_gender[gender]
    score += w_os[os_cat]

    noise = rng.normal(scale=0.5, size=n).astype("float32")
    bias = -0.2
    logit = score + bias + noise
    prob = 1.0 / (1.0 + np.exp(-logit))
    y = (prob > 0.5).astype("int64")

    # --- 3) 최종 X 매트릭스 ---
    X = np.column_stack(
        [
            age,                        # 0
            gender.astype("float32"),   # 1
            app_join_cnt,               # 2
            last_login_days,            # 3
            os_cat.astype("float32"),   # 4
            app_join_days,              # 5
            biz_owner_score,            # 6
            home_owner_score,           # 7
            car_owner_score,            # 8
            married_score,              # 9
            child_score,                # 10
        ]
    ).astype("float32")

    cat_feature_indices = [1, 4]   # 성별, OS
    cat_dims = [3, 4]              # gender:3, os:4

    return X, y, cat_feature_indices, cat_dims


# -----------------------------------------------------------
# TabTransformer 데모
# -----------------------------------------------------------
def demo_train_tabtransformer():
    print("\n===== TabTransformer Demo (KOR Feature Schema) =====")

    X, y, cat_features, cat_dims = make_kor_feature_demo_data(
        n_samples=2000,
        seed=123,
    )

    print("X shape:", X.shape, " | y shape:", y.shape)
    print("Categorical feature indices:", cat_features)
    print("Categorical dims (cardinality):", cat_dims)

    # train / val / test split
    N = X.shape[0]
    idx = np.arange(N)
    np.random.shuffle(idx)

    tr_end = int(N * 0.7)
    va_end = int(N * 0.85)
    tr_idx, va_idx, te_idx = idx[:tr_end], idx[tr_end:va_end], idx[va_end:]

    X_tr, y_tr = X[tr_idx], y[tr_idx]
    X_va, y_va = X[va_idx], y[va_idx]
    X_te, y_te = X[te_idx], y[te_idx]

    sample_weight = np.ones_like(y_tr, dtype="float32")

    input_dim = X.shape[1]

    model_params = {
        "input_dim": input_dim,
        "cat_features": cat_features,
        "cat_dims": cat_dims,
        "d_token": 32,
        "n_heads": 4,
        "n_layers": 2,
        "dim_feedforward": None,   # None이면 4*d_token
        "attn_dropout": 0.1,
        "embedding_dropout": 0.05,
        "add_cls": False,
        "pooling": "concat",
        "cont_proj": "linear",
        "mlp_hidden_dims": (128, 64),
        "mlp_dropout": 0.2,
        "use_missing_category": False,
        "lr": 1e-3,
        "loss_fn": "logloss",
    }

    clf = DeepLearningBinaryClassifier(
        model_type="tabtransformer",
        model_params=model_params,
    )

    clf.fit(
        X_tr,
        y_tr,
        sample_weight=sample_weight,
        eval_set=[(X_va, y_va)],
        eval_metric=["logloss"],
        max_epochs=5,      # 데모용
        patience=2,
        batch_size=256,
        verbose=True,
    )

    probs_te = clf.predict_proba(X_te)[:, 1]
    preds_te = (probs_te >= 0.5).astype("int64")

    acc = (preds_te == y_te).mean()
    auc = roc_auc_score(y_te, probs_te)
    ll = log_loss(y_te, probs_te)

    print("\n===== TabTransformer Test Metrics =====")
    print(f"Accuracy : {acc:.4f}")
    print(f"ROC-AUC  : {auc:.4f}")
    print(f"Logloss  : {ll:.4f}")
    print("Sample probs (first 10):", np.round(probs_te[:10], 4))


# -----------------------------------------------------------
# MLP 데모
# -----------------------------------------------------------
def demo_train_mlp():
    print("\n===== MLP Demo (KOR Feature Schema) =====")

    X, y, cat_features, cat_dims = make_kor_feature_demo_data(
        n_samples=2000,
        seed=321,
    )

    print("X shape:", X.shape, " | y shape:", y.shape)
    print("Categorical feature indices:", cat_features)
    print("Categorical dims (cardinality):", cat_dims)

    N = X.shape[0]
    idx = np.arange(N)
    np.random.shuffle(idx)

    tr_end = int(N * 0.7)
    va_end = int(N * 0.85)
    tr_idx, va_idx, te_idx = idx[:tr_end], idx[tr_end:va_end], idx[va_end:]

    X_tr, y_tr = X[tr_idx], y[tr_idx]
    X_va, y_va = X[va_idx], y[va_idx]
    X_te, y_te = X[te_idx], y[te_idx]

    sample_weight = np.ones_like(y_tr, dtype="float32")

    input_dim = X.shape[1]

    model_params = {
        "input_dim": input_dim,
        "hidden_dims": [128, 64],
        "cat_features": cat_features,
        "cat_dims": cat_dims,
        "emb_dim": 8,
        "lr": 1e-3,
        "loss_fn": "logloss",
    }

    clf = DeepLearningBinaryClassifier(
        model_type="mlp",
        model_params=model_params,
    )

    clf.fit(
        X_tr,
        y_tr,
        sample_weight=sample_weight,
        eval_set=[(X_va, y_va)],
        eval_metric=["logloss"],
        max_epochs=5,      # 데모용
        patience=2,
        batch_size=256,
        verbose=True,
    )

    probs_te = clf.predict_proba(X_te)[:, 1]
    preds_te = (probs_te >= 0.5).astype("int64")

    acc = (preds_te == y_te).mean()
    auc = roc_auc_score(y_te, probs_te)
    ll = log_loss(y_te, probs_te)

    print("\n===== MLP Test Metrics =====")
    print(f"Accuracy : {acc:.4f}")
    print(f"ROC-AUC  : {auc:.4f}")
    print(f"Logloss  : {ll:.4f}")
    print("Sample probs (first 10):", np.round(probs_te[:10], 4))


# -----------------------------------------------------------
# FT-Transformer 데모
# -----------------------------------------------------------
def demo_train_fttransformer():
    print("\n===== FT-Transformer Demo (KOR Feature Schema) =====")

    X, y, cat_features, cat_dims = make_kor_feature_demo_data(
        n_samples=2000,
        seed=456,
    )

    print("X shape:", X.shape, " | y shape:", y.shape)
    print("Categorical feature indices:", cat_features)
    print("Categorical dims (cardinality):", cat_dims)

    # train / val / test split (70 / 15 / 15)
    N = X.shape[0]
    idx = np.arange(N)
    np.random.shuffle(idx)

    tr_end = int(N * 0.7)
    va_end = int(N * 0.85)
    tr_idx, va_idx, te_idx = idx[:tr_end], idx[tr_end:va_end], idx[va_end:]

    X_tr, y_tr = X[tr_idx], y[tr_idx]
    X_va, y_va = X[va_idx], y[va_idx]
    X_te, y_te = X[te_idx], y[te_idx]

    sample_weight = np.ones_like(y_tr, dtype="float32")

    input_dim = X.shape[1]

    # FT-Transformer용 하이퍼파라미터
    model_params = {
        "input_dim": input_dim,
        "cat_features": cat_features,
        "cat_dims": cat_dims,
        "d_token": 32,
        "n_heads": 4,
        "n_layers": 3,
        "dim_feedforward": None,  # None이면 4 * d_token
        "attn_dropout": 0.1,
        "token_dropout": 0.05,
        "mlp_dropout": 0.2,
        "lr": 1e-3,
        "loss_fn": "logloss",
    }

    clf = DeepLearningBinaryClassifier(
        model_type="fttransformer",
        model_params=model_params,
    )

    clf.fit(
        X_tr,
        y_tr,
        sample_weight=sample_weight,
        eval_set=[(X_va, y_va)],
        eval_metric=["logloss"],
        max_epochs=5,      # 데모용
        patience=2,
        batch_size=256,
        verbose=True,
    )

    # 평가
    probs_te = clf.predict_proba(X_te)[:, 1]
    preds_te = (probs_te >= 0.5).astype("int64")

    acc = (preds_te == y_te).mean()
    auc = roc_auc_score(y_te, probs_te)
    ll = log_loss(y_te, probs_te)

    print("\n===== FT-Transformer Test Metrics =====")
    print(f"Accuracy : {acc:.4f}")
    print(f"ROC-AUC  : {auc:.4f}")
    print(f"Logloss  : {ll:.4f}")
    print("Sample probs (first 10):", np.round(probs_te[:10], 4))


# -----------------------------------------------------------
# 메인
# -----------------------------------------------------------
if __name__ == "__main__":
    # demo_train_tabtransformer()
    # demo_train_mlp()
    demo_train_fttransformer()


===== FT-Transformer Demo (KOR Feature Schema) =====
X shape: (2000, 11)  | y shape: (2000,)
Categorical feature indices: [1, 4]
Categorical dims (cardinality): [3, 4]




Epoch 1/5 - [train] loss: 140.687468
  - [eval] logloss: 0.5333
    -- (early_stopping) current_metric: 0.533306, best_metric: inf
Epoch 2/5 - [train] loss: 128.496629
  - [eval] logloss: 0.5203
    -- (early_stopping) current_metric: 0.520328, best_metric: 0.533306
Epoch 3/5 - [train] loss: 126.481900
  - [eval] logloss: 0.4843
    -- (early_stopping) current_metric: 0.484252, best_metric: 0.520328
Epoch 4/5 - [train] loss: 117.148365
  - [eval] logloss: 0.4436
    -- (early_stopping) current_metric: 0.443629, best_metric: 0.484252
Epoch 5/5 - [train] loss: 107.854777
  - [eval] logloss: 0.4695
    -- (early_stopping) current_metric: 0.469536, best_metric: 0.443629

===== FT-Transformer Test Metrics =====
Accuracy : 0.7200
ROC-AUC  : 0.7987
Logloss  : 0.4887
Sample probs (first 10): [0.8738 0.6252 0.3958 0.8739 0.3938 0.4466 0.4032 0.8745 0.3937 0.873 ]
