In [None]:
import math
from typing import Optional, List

import numpy as np
import torch
import torch.nn as nn


class TabularTextTransformerModel(BaseNNModel):
    """
    Tabular-Text Transformer (TTT) 단일 클래스 버전.
    - 논문/공식 코드 구조를 최대한 따르면서,
    - BaseNNModel을 상속해 DeepLearningBinaryClassifier에서 바로 쓸 수 있게 래핑.

    가정:
        X shape: (B, input_dim_total)
        앞쪽 input_dim_tab 열: tabular (수치 + 범주)
        뒤쪽 text_seq_len 열: text token id (정수, float32여도 long으로 캐스팅)

        num_features, cat_features 인덱스는 [0 .. input_dim_tab-1] 범위 안에서 사용.

    forward:
        logits = model(X)   # (B, 1)  (binary classification logit)
    """

    def __init__(
        self,
        # ---- 전체 입력 구조 ----
        input_dim_total: int,
        input_dim_tab: int,           # 앞쪽 tabular feature 개수
        num_features: List[int],      # tab 안에서 수치형 인덱스
        cat_features: List[int],      # tab 안에서 범주형 인덱스
        cat_dims: List[int],          # 각 범주형 cardinality
        # ---- 텍스트 쪽 설정 ----
        text_vocab_size: int,
        text_pad_idx: int,
        text_seq_len: int,
        # ---- 공통 Transformer 설정 ----
        d_model: int = 64,
        n_heads: int = 4,
        n_layers_overall: int = 2,
        n_layers_self: int = 2,
        d_ff: Optional[int] = None,
        dropout: float = 0.1,
        d_fc: int = 128,
        # ---- 수치형 encoding 설정 ----
        numeric_embedding: str = "dq",  # "dq" (distance-to-quantile) or "linear"
        num_quantiles: int = 6,
        # quantiles: (n_num_features, num_quantiles) float32 or np.ndarray
        quantiles: Optional[torch.Tensor] = None,
    ):
        super().__init__()

        # ===== 입력 관련 설정 =====
        self.input_dim_total = input_dim_total
        self.input_dim_tab = input_dim_tab
        self.text_seq_len = text_seq_len

        assert input_dim_tab + text_seq_len == input_dim_total, \
            "input_dim_tab + text_seq_len != input_dim_total"

        self.num_features = num_features
        self.cat_features = cat_features
        self.cat_dims = cat_dims

        self.n_num = len(num_features)
        self.n_cat = len(cat_features)

        # tab 인덱스를 buffer로 등록 (0~input_dim_tab-1 내부 기준)
        num_idx = torch.tensor(num_features, dtype=torch.long)
        cat_idx = torch.tensor(cat_features, dtype=torch.long)
        self.register_buffer("num_idx_tensor", num_idx, persistent=False)
        self.register_buffer("cat_idx_tensor", cat_idx, persistent=False)

        # ===== 하이퍼파라미터 저장 =====
        self.text_vocab_size = text_vocab_size
        self.text_pad_idx = text_pad_idx
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_layers_overall = n_layers_overall
        self.n_layers_self = n_layers_self
        self.d_ff = d_ff or (4 * d_model)
        self.dropout_p = dropout
        self.d_fc = d_fc

        self.numeric_embedding = numeric_embedding
        self.num_quantiles = num_quantiles
        self.n_classes = 1  # binary logit 하나

        # ===== 네트워크 구성 =====
        self.build_network(quantiles)

    # ------------------------------------------------------------------
    # Sinusoidal Positional Encoding (텍스트용, 별도 클래스 없이 내부 구현)
    # ------------------------------------------------------------------
    def _build_text_positional_encoding(self, max_len: int):
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)  # (L, 1)
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2, dtype=torch.float32)
            * (-math.log(10000.0) / self.d_model)
        )  # (d/2,)

        pe = torch.zeros(max_len, self.d_model, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, L, d)
        self.register_buffer("pe_text", pe, persistent=False)

    def _add_text_positional_encoding(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, L, d_model)
        """
        L = x.size(1)
        return x + self.pe_text[:, :L, :]

    # ------------------------------------------------------------------
    # 네트워크 구성 (OverallAttentionBlock / Core 구조 포함)
    # ------------------------------------------------------------------
    def build_network(self, quantiles: Optional[torch.Tensor]) -> nn.Module:
        dropout = self.dropout_p
        d_model = self.d_model

        # ---------- 텍스트 임베딩 ----------
        self.text_embedding = nn.Embedding(
            num_embeddings=self.text_vocab_size,
            embedding_dim=d_model,
            padding_idx=self.text_pad_idx,
        )
        nn.init.normal_(self.text_embedding.weight, std=0.02)
        self.text_dropout = nn.Dropout(dropout)

        # [CLS]_text / [CLS]_tab
        self.cls_text = nn.Parameter(torch.zeros(1, 1, d_model))
        self.cls_tab = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls_text, std=0.02)
        nn.init.normal_(self.cls_tab, std=0.02)

        # 텍스트 positional encoding
        self._build_text_positional_encoding(self.text_seq_len + 1)

        # ---------- 범주형 임베딩 ----------
        if self.n_cat > 0:
            self.cat_embeddings = nn.ModuleList()
            for c in self.cat_dims:
                emb = nn.Embedding(num_embeddings=c, embedding_dim=d_model)
                nn.init.normal_(emb.weight, std=0.02)
                self.cat_embeddings.append(emb)
            self.cat_dropout = nn.Dropout(dropout)
        else:
            self.cat_embeddings = nn.ModuleList()
            self.cat_dropout = nn.Identity()

        # ---------- 수치형 encoding ----------
        self.use_dq = (self.numeric_embedding == "dq") and (self.n_num > 0)

        if self.n_num > 0:
            if self.use_dq:
                # distance-to-quantile 용 quantiles 등록
                if quantiles is None:
                    raise ValueError(
                        "numeric_embedding='dq' 인 경우 quantiles 텐서를 넣어주세요."
                    )
                if isinstance(quantiles, np.ndarray):
                    quantiles = torch.from_numpy(quantiles.astype("float32"))

                assert quantiles.shape == (
                    self.n_num,
                    self.num_quantiles,
                ), "quantiles shape must be (n_num_features, num_quantiles)"

                self.register_buffer(
                    "num_quantiles_tensor",
                    quantiles.clone().detach().float(),
                    persistent=False,
                )
                # S_{j,k} embedding
                self.num_quantile_embeddings = nn.Parameter(
                    torch.randn(self.n_num, self.num_quantiles, d_model) * 0.02
                )
            else:
                # feature별 Linear(1 -> d_model)
                self.num_linears = nn.ModuleList(
                    [nn.Linear(1, d_model) for _ in range(self.n_num)]
                )
                for lin in self.num_linears:
                    nn.init.zeros_(lin.bias)
                    nn.init.kaiming_uniform_(lin.weight)
                # dummy buffers (사용 안 함)
                self.register_buffer(
                    "num_quantiles_tensor",
                    torch.zeros(0, self.num_quantiles),
                    persistent=False,
                )
                self.num_quantile_embeddings = None
        else:
            self.register_buffer(
                "num_quantiles_tensor",
                torch.zeros(0, self.num_quantiles),
                persistent=False,
            )
            self.num_quantile_embeddings = None
            self.num_linears = nn.ModuleList()

        # ---------- Overall Attention blocks (각 레이어별 구성) ----------
        # 따로 OverallAttentionBlock 클래스를 만들지 않고, 모듈리스트에 구성 요소 저장
        self.over_text_attn = nn.ModuleList()
        self.over_text_ln_q = nn.ModuleList()
        self.over_text_ln_kv = nn.ModuleList()
        self.over_text_ln_ff = nn.ModuleList()
        self.over_text_ffn = nn.ModuleList()

        self.over_tab_attn = nn.ModuleList()
        self.over_tab_ln_q = nn.ModuleList()
        self.over_tab_ln_kv = nn.ModuleList()
        self.over_tab_ln_ff = nn.ModuleList()
        self.over_tab_ffn = nn.ModuleList()

        for _ in range(self.n_layers_overall):
            # text stream
            self.over_text_attn.append(
                nn.MultiheadAttention(
                    embed_dim=d_model,
                    num_heads=self.n_heads,
                    dropout=dropout,
                    batch_first=True,
                )
            )
            self.over_text_ln_q.append(nn.LayerNorm(d_model))
            self.over_text_ln_kv.append(nn.LayerNorm(d_model))
            self.over_text_ln_ff.append(nn.LayerNorm(d_model))
            self.over_text_ffn.append(
                nn.Sequential(
                    nn.Linear(d_model, self.d_ff),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(self.d_ff, d_model),
                )
            )

            # tab stream
            self.over_tab_attn.append(
                nn.MultiheadAttention(
                    embed_dim=d_model,
                    num_heads=self.n_heads,
                    dropout=dropout,
                    batch_first=True,
                )
            )
            self.over_tab_ln_q.append(nn.LayerNorm(d_model))
            self.over_tab_ln_kv.append(nn.LayerNorm(d_model))
            self.over_tab_ln_ff.append(nn.LayerNorm(d_model))
            self.over_tab_ffn.append(
                nn.Sequential(
                    nn.Linear(d_model, self.d_ff),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(self.d_ff, d_model),
                )
            )

        self.over_dropout_attn = nn.Dropout(dropout)
        self.over_dropout_ffn = nn.Dropout(dropout)

        # ---------- Self-Attention encoders ----------
        encoder_layer_text = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=self.n_heads,
            dim_feedforward=self.d_ff,
            dropout=dropout,
            batch_first=True,
        )
        self.text_transformer_encoder = nn.TransformerEncoder(
            encoder_layer_text,
            num_layers=self.n_layers_self,
        )

        encoder_layer_tab = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=self.n_heads,
            dim_feedforward=self.d_ff,
            dropout=dropout,
            batch_first=True,
        )
        self.tab_transformer_encoder = nn.TransformerEncoder(
            encoder_layer_tab,
            num_layers=self.n_layers_self,
        )

        # ---------- 최종 FC Head ----------
        self.fc = nn.Sequential(
            nn.Linear(2 * d_model, self.d_fc),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(self.d_fc, self.n_classes),
        )
        nn.init.zeros_(self.fc[0].bias)
        nn.init.kaiming_uniform_(self.fc[0].weight)
        nn.init.zeros_(self.fc[3].bias)
        nn.init.kaiming_uniform_(self.fc[3].weight)

        return self

    # ------------------------------------------------------------------
    # 수치형 distance-to-quantile embedding
    # ------------------------------------------------------------------
    def _embed_numerical_dq(self, x_num: torch.Tensor) -> torch.Tensor:
        """
        x_num: (B, n_num)
        return: (B, n_num, d_model)
        """
        B = x_num.size(0)
        if self.n_num == 0:
            return torch.zeros(
                B, 0, self.d_model, device=x_num.device, dtype=torch.float32
            )

        v = x_num.unsqueeze(-1)  # (B, n_num, 1)
        q = self.num_quantiles_tensor.unsqueeze(0)  # (1, n_num, s)

        dist = torch.abs(v - q)  # (B, n_num, s)
        eps = 1e-8

        eq_mask = dist < 1e-6              # (B, n_num, s)
        has_eq = eq_mask.any(dim=-1, keepdim=True)  # (B, n_num, 1)

        inv_dist = 1.0 / (dist + eps)
        weights = torch.where(has_eq, eq_mask.float(), inv_dist)
        weights = weights / (weights.sum(dim=-1, keepdim=True) + eps)

        S = self.num_quantile_embeddings.unsqueeze(0)  # (1, n_num, s, d)
        emb = (weights.unsqueeze(-1) * S).sum(dim=2)   # (B, n_num, d)
        return emb

    # ------------------------------------------------------------------
    # OverallAttentionBlock 한 층 수행 (text/tab 공용 내부 함수)
    # ------------------------------------------------------------------
    def _overall_block(
        self,
        x_q: torch.Tensor,
        x_kv: torch.Tensor,
        attn: nn.MultiheadAttention,
        ln_q: nn.LayerNorm,
        ln_kv: nn.LayerNorm,
        ln_ff: nn.LayerNorm,
        ffn: nn.Sequential,
        key_padding_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        x_q: query 시퀀스 (B, T_q, d)
        x_kv: key/value 시퀀스 (B, T_kv, d)
        """
        q = ln_q(x_q)
        kv = ln_kv(x_kv)

        attn_out, _ = attn(
            q, kv, kv, key_padding_mask=key_padding_mask
        )  # (B, T_q, d)

        x = x_q + self.over_dropout_attn(attn_out)  # residual 1

        y = ln_ff(x)
        y = ffn(y)
        out = x + self.over_dropout_ffn(y)          # residual 2
        return out

    # ------------------------------------------------------------------
    # forward
    # ------------------------------------------------------------------
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, input_dim_total) float32
        1) 앞쪽 input_dim_tab -> tab (수치/범주)
        2) 나머지 text_seq_len -> text 토큰
        반환: logits (B, 1)
        """
        B, D = x.shape
        device = x.device

        # ===== 1. tab / text split =====
        x_tab = x[:, : self.input_dim_tab]  # (B, input_dim_tab)
        x_text = x[:, self.input_dim_tab : self.input_dim_tab + self.text_seq_len]  # (B, L)

        # --- tab: numerical / categorical ---
        if self.n_num > 0:
            x_num = x_tab.index_select(1, self.num_idx_tensor).float()  # (B, n_num)
        else:
            x_num = None

        if self.n_cat > 0:
            x_cat = x_tab.index_select(1, self.cat_idx_tensor).long()   # (B, n_cat)
        else:
            x_cat = None

        # --- text: long + padding mask ---
        x_text_long = x_text.long()                     # (B, L)
        text_padding_mask = x_text_long.eq(self.text_pad_idx)  # (B, L), True=PAD

        # ===== 2. 텍스트 임베딩 =====
        z_text = self.text_embedding(x_text_long)    # (B, L, d)
        z_text = self.text_dropout(z_text)

        cls_text = self.cls_text.expand(B, 1, -1)    # (B,1,d)
        z_text = torch.cat([cls_text, z_text], dim=1) * math.sqrt(self.d_model)  # (B, L+1, d)
        z_text = self._add_text_positional_encoding(z_text)

        # 텍스트 key_padding_mask (CLS 앞에 False)
        text_kpm = torch.cat(
            [
                torch.zeros(B, 1, dtype=torch.bool, device=device),
                text_padding_mask,
            ],
            dim=1,
        )  # (B, L+1)

        # ===== 3. 탭 임베딩 =====
        # --- 범주형 ---
        if self.n_cat > 0 and x_cat is not None:
            cat_emb_list = [emb(x_cat[:, j]) for j, emb in enumerate(self.cat_embeddings)]  # (B,d)
            z_cat = torch.stack(cat_emb_list, dim=1)  # (B, n_cat, d)
            z_cat = self.cat_dropout(z_cat)
        else:
            z_cat = torch.zeros(B, 0, self.d_model, device=device)

        # --- 수치형 ---
        if self.n_num > 0 and x_num is not None:
            if self.use_dq:
                z_num = self._embed_numerical_dq(x_num)  # (B, n_num, d)
            else:
                num_emb_list = []
                for j, lin in enumerate(self.num_linears):
                    num_emb_list.append(
                        lin(x_num[:, j].view(B, 1, 1))
                    )  # (B,1,d)
                z_num = torch.cat(num_emb_list, dim=1)  # (B, n_num, d)
        else:
            z_num = torch.zeros(B, 0, self.d_model, device=device)

        # --- 탭 concat + [CLS]_tab ---
        z_tab_feats = torch.cat([z_cat, z_num], dim=1)  # (B, t_tab, d)
        cls_tab = self.cls_tab.expand(B, 1, -1)         # (B,1,d)
        z_tab = torch.cat([cls_tab, z_tab_feats], dim=1) * math.sqrt(self.d_model)

        # padding 개념 없음
        tab_kpm = None

        # 초기 상태 저장 (cross-modal key용)
        z_text_init = z_text
        z_tab_init = z_tab

        # ===== 4. Dual Overall Attention (L_overall 층) =====
        for i in range(self.n_layers_overall):
            # 텍스트 스트림
            kv_text = torch.cat([z_text, z_tab_init], dim=1)
            z_text = self._overall_block(
                x_q=z_text,
                x_kv=kv_text,
                attn=self.over_text_attn[i],
                ln_q=self.over_text_ln_q[i],
                ln_kv=self.over_text_ln_kv[i],
                ln_ff=self.over_text_ln_ff[i],
                ffn=self.over_text_ffn[i],
                key_padding_mask=None,  # 간단 버전: 전체 사용
            )

            # 탭 스트림
            kv_tab = torch.cat([z_tab, z_text_init], dim=1)
            z_tab = self._overall_block(
                x_q=z_tab,
                x_kv=kv_tab,
                attn=self.over_tab_attn[i],
                ln_q=self.over_tab_ln_q[i],
                ln_kv=self.over_tab_ln_kv[i],
                ln_ff=self.over_tab_ln_ff[i],
                ffn=self.over_tab_ffn[i],
                key_padding_mask=None,
            )

        # ===== 5. Self-Attention Encoders =====
        z_text = self.text_transformer_encoder(
            z_text, src_key_padding_mask=text_kpm
        )  # (B, L+1, d)
        z_tab = self.tab_transformer_encoder(z_tab)   # (B, t_tab+1, d)

        # [CLS] 추출 후 concat
        cls_text_out = z_text[:, 0, :]  # (B, d)
        cls_tab_out = z_tab[:, 0, :]    # (B, d)
        mm_feat = torch.cat([cls_text_out, cls_tab_out], dim=-1)  # (B, 2d)

        logits = self.fc(mm_feat)  # (B,1)
        if logits.ndim == 1:
            logits = logits.view(-1, 1)
        elif logits.size(1) != 1:
            logits = logits[:, :1]

        return logits

In [6]:
# -*- coding: utf-8 -*-
# ===========================================================
# Tabular-Text Transformer (TTT, text optional) + Demo
# ===========================================================

from __future__ import annotations

import math
from abc import ABC, abstractmethod
from typing import Optional, List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import log_loss, roc_auc_score
from torch.utils.data import DataLoader, TensorDataset


# -----------------------------------------------------------
# Base Neural Network Model
# -----------------------------------------------------------
class BaseNNModel(nn.Module, ABC):
    def __init__(self, **kwargs):
        super().__init__()

    @abstractmethod
    def build_network(self) -> nn.Module:
        ...

    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ...


# -----------------------------------------------------------
# Tabular-Text Transformer (TTT) - text optional
# -----------------------------------------------------------
class TabularTextTransformerModel(BaseNNModel):
    """
    Tabular-Text Transformer (TTT) 단일 클래스 버전.
    - 탭 + (옵션) 텍스트를 함께 처리.
    - 텍스트가 없어도(text_seq_len=0) 돌아가도록 구현.

    입력 X 형태:
        X: (B, input_dim_total) float32

        * 앞쪽 input_dim_tab 열: tabular (수치 + 범주)
        * 뒤쪽 text_seq_len 열: text token id (정수, float32여도 long 캐스팅)

        num_features, cat_features 인덱스는 [0 .. input_dim_tab-1] 범위 안에서 사용.

    사용 예 (텍스트 O):
        input_dim_tab = 11
        text_seq_len = 16
        input_dim_total = 27 = 11 + 16

    사용 예 (텍스트 X):
        input_dim_tab = input_dim_total = 11
        text_seq_len = 0

    forward:
        logits = model(X)   # (B, 1)  (binary classification logit)
    """

    def __init__(
        self,
        # ---- 전체 입력 구조 ----
        input_dim_total: int,
        input_dim_tab: int,           # 앞쪽 tabular feature 개수
        num_features: List[int],      # tab 안에서 수치형 인덱스
        cat_features: List[int],      # tab 안에서 범주형 인덱스
        cat_dims: List[int],          # 각 범주형 cardinality
        # ---- 텍스트 쪽 설정 ----
        text_vocab_size: int,
        text_pad_idx: int,
        text_seq_len: int,            # 0이면 텍스트 없음
        # ---- 공통 Transformer 설정 ----
        d_model: int = 64,
        n_heads: int = 4,
        n_layers_overall: int = 2,
        n_layers_self: int = 2,
        d_ff: Optional[int] = None,
        dropout: float = 0.1,
        d_fc: int = 128,
        # ---- 수치형 encoding 설정 ----
        numeric_embedding: str = "dq",  # "dq" (distance-to-quantile) or "linear"
        num_quantiles: int = 6,
        # quantiles: (n_num_features, num_quantiles) float32 or np.ndarray
        quantiles: Optional[torch.Tensor] = None,
    ):
        super().__init__()

        # ===== 입력 관련 설정 =====
        self.input_dim_total = input_dim_total
        self.input_dim_tab = input_dim_tab
        self.text_seq_len = text_seq_len

        # 텍스트가 없으면: input_dim_total == input_dim_tab, text_seq_len=0
        assert input_dim_tab + text_seq_len == input_dim_total, \
            "input_dim_tab + text_seq_len != input_dim_total"

        self.num_features = num_features
        self.cat_features = cat_features
        self.cat_dims = cat_dims

        self.n_num = len(num_features)
        self.n_cat = len(cat_features)

        # tab 인덱스를 buffer로 등록 (0~input_dim_tab-1 내부 기준)
        num_idx = torch.tensor(num_features, dtype=torch.long)
        cat_idx = torch.tensor(cat_features, dtype=torch.long)
        self.register_buffer("num_idx_tensor", num_idx, persistent=False)
        self.register_buffer("cat_idx_tensor", cat_idx, persistent=False)

        # ===== 하이퍼파라미터 저장 =====
        self.text_vocab_size = text_vocab_size
        self.text_pad_idx = text_pad_idx
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_layers_overall = n_layers_overall
        self.n_layers_self = n_layers_self
        self.d_ff = d_ff or (4 * d_model)
        self.dropout_p = dropout
        self.d_fc = d_fc

        self.numeric_embedding = numeric_embedding
        self.num_quantiles = num_quantiles
        self.n_classes = 1  # binary logit 하나
        self.scale = d_model ** 0.5

        # quantiles를 build_network에서 쓰기 위해 일단 저장
        self._init_quantiles = quantiles

        # ===== 네트워크 구성 =====
        self.build_network()

    # ------------------------------------------------------------------
    # Sinusoidal Positional Encoding (텍스트용)
    # ------------------------------------------------------------------
    def _build_text_positional_encoding(self, max_len: int):
        pe = torch.zeros(max_len, self.d_model, dtype=torch.float32)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2, dtype=torch.float32)
            * (-math.log(10000.0) / self.d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe_text", pe, persistent=False)

    def _add_text_positional_encoding(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, L, d_model)
        """
        L = x.size(1)
        return x + self.pe_text[:, :L, :]

    # ------------------------------------------------------------------
    # 네트워크 구성 (OverallAttentionBlock / Core 구조 포함)
    # ------------------------------------------------------------------
    def build_network(self) -> nn.Module:
        dropout = self.dropout_p
        d_model = self.d_model

        # ---------- 텍스트 임베딩 ----------
        self.text_embedding = nn.Embedding(
            num_embeddings=self.text_vocab_size,
            embedding_dim=d_model,
            padding_idx=self.text_pad_idx,
        )
        nn.init.normal_(self.text_embedding.weight, std=0.02)
        self.text_dropout = nn.Dropout(dropout)

        # [CLS]_text / [CLS]_tab
        self.cls_text = nn.Parameter(torch.zeros(1, 1, d_model))
        self.cls_tab = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls_text, std=0.02)
        nn.init.normal_(self.cls_tab, std=0.02)

        # 텍스트 positional encoding
        max_len_for_pe = (self.text_seq_len + 1) if self.text_seq_len > 0 else 1
        self._build_text_positional_encoding(max_len_for_pe)

        # ---------- 범주형 임베딩 ----------
        if self.n_cat > 0:
            self.cat_embeddings = nn.ModuleList()
            for c in self.cat_dims:
                emb = nn.Embedding(num_embeddings=c, embedding_dim=d_model)
                nn.init.normal_(emb.weight, std=0.02)
                self.cat_embeddings.append(emb)
            self.cat_dropout = nn.Dropout(dropout)
        else:
            self.cat_embeddings = nn.ModuleList()
            self.cat_dropout = nn.Identity()

        # ---------- 수치형 encoding ----------
        self.use_dq = (self.numeric_embedding == "dq") and (self.n_num > 0)

        quantiles = self._init_quantiles

        if self.n_num > 0:
            if self.use_dq:
                # distance-to-quantile 용 quantiles 등록
                if quantiles is None:
                    raise ValueError(
                        "numeric_embedding='dq' 인 경우 quantiles 텐서를 넣어주세요."
                    )
                if isinstance(quantiles, np.ndarray):
                    quantiles = torch.from_numpy(quantiles.astype("float32"))

                assert quantiles.shape == (
                    self.n_num,
                    self.num_quantiles,
                ), "quantiles shape must be (n_num_features, num_quantiles)"

                self.register_buffer(
                    "num_quantiles_tensor",
                    quantiles.clone().detach().float(),
                    persistent=False,
                )
                # S_{j,k} embedding
                self.num_quantile_embeddings = nn.Parameter(
                    torch.randn(self.n_num, self.num_quantiles, d_model) * 0.02
                )
                # linear path는 사용 안 함
                self.num_linears = nn.ModuleList()
            else:
                # feature별 Linear(1 -> d_model)
                self.num_linears = nn.ModuleList(
                    [nn.Linear(1, d_model) for _ in range(self.n_num)]
                )
                for lin in self.num_linears:
                    nn.init.zeros_(lin.bias)
                    nn.init.kaiming_uniform_(lin.weight)
                # dummy buffers (DQ용은 사용 안 함)
                self.register_buffer(
                    "num_quantiles_tensor",
                    torch.zeros(0, self.num_quantiles),
                    persistent=False,
                )
                self.num_quantile_embeddings = None
        else:
            self.register_buffer(
                "num_quantiles_tensor",
                torch.zeros(0, self.num_quantiles),
                persistent=False,
            )
            self.num_quantile_embeddings = None
            self.num_linears = nn.ModuleList()

        # ---------- Overall Attention blocks ----------
        self.over_text_attn = nn.ModuleList()
        self.over_text_ln_q = nn.ModuleList()
        self.over_text_ln_kv = nn.ModuleList()
        self.over_text_ln_ff = nn.ModuleList()
        self.over_text_ffn = nn.ModuleList()

        self.over_tab_attn = nn.ModuleList()
        self.over_tab_ln_q = nn.ModuleList()
        self.over_tab_ln_kv = nn.ModuleList()
        self.over_tab_ln_ff = nn.ModuleList()
        self.over_tab_ffn = nn.ModuleList()

        for _ in range(self.n_layers_overall):
            # text stream
            self.over_text_attn.append(
                nn.MultiheadAttention(
                    embed_dim=d_model,
                    num_heads=self.n_heads,
                    dropout=dropout,
                    batch_first=True,
                )
            )
            self.over_text_ln_q.append(nn.LayerNorm(d_model))
            self.over_text_ln_kv.append(nn.LayerNorm(d_model))
            self.over_text_ln_ff.append(nn.LayerNorm(d_model))
            self.over_text_ffn.append(
                nn.Sequential(
                    nn.Linear(d_model, self.d_ff),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(self.d_ff, d_model),
                )
            )

            # tab stream
            self.over_tab_attn.append(
                nn.MultiheadAttention(
                    embed_dim=d_model,
                    num_heads=self.n_heads,
                    dropout=dropout,
                    batch_first=True,
                )
            )
            self.over_tab_ln_q.append(nn.LayerNorm(d_model))
            self.over_tab_ln_kv.append(nn.LayerNorm(d_model))
            self.over_tab_ln_ff.append(nn.LayerNorm(d_model))
            self.over_tab_ffn.append(
                nn.Sequential(
                    nn.Linear(d_model, self.d_ff),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(self.d_ff, d_model),
                )
            )

        self.over_dropout_attn = nn.Dropout(dropout)
        self.over_dropout_ffn = nn.Dropout(dropout)

        # ---------- Self-Attention encoders ----------
        encoder_layer_text = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=self.n_heads,
            dim_feedforward=self.d_ff,
            dropout=dropout,
            batch_first=True,
        )
        self.text_transformer_encoder = nn.TransformerEncoder(
            encoder_layer_text,
            num_layers=self.n_layers_self,
        )

        encoder_layer_tab = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=self.n_heads,
            dim_feedforward=self.d_ff,
            dropout=dropout,
            batch_first=True,
        )
        self.tab_transformer_encoder = nn.TransformerEncoder(
            encoder_layer_tab,
            num_layers=self.n_layers_self,
        )

        # ---------- 최종 FC Head ----------
        self.fc = nn.Sequential(
            nn.Linear(2 * d_model, self.d_fc),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(self.d_fc, self.n_classes),
        )
        nn.init.zeros_(self.fc[0].bias)
        nn.init.kaiming_uniform_(self.fc[0].weight)
        nn.init.zeros_(self.fc[3].bias)
        nn.init.kaiming_uniform_(self.fc[3].weight)

        return self

    # ------------------------------------------------------------------
    # 수치형 distance-to-quantile embedding
    # ------------------------------------------------------------------
    def _embed_numerical_dq(self, x_num: torch.Tensor) -> torch.Tensor:
        """
        x_num: (B, n_num)
        return: (B, n_num, d_model)
        """
        B = x_num.size(0)
        if self.n_num == 0:
            return torch.zeros(B, 0, self.d_model, device=x_num.device)
        v = x_num.unsqueeze(-1)
        q = self.num_quantiles_tensor.unsqueeze(0)
        dist = torch.abs(v - q)
        eps = 1e-8
        eq_mask = dist < 1e-6
        has_eq = eq_mask.any(dim=-1, keepdim=True)
        inv_dist = 1.0 / (dist + eps)
        weights = torch.where(has_eq, eq_mask.float(), inv_dist)
        weights = weights / (weights.sum(dim=-1, keepdim=True) + eps)
        S = self.num_quantile_embeddings.unsqueeze(0)
        emb = (weights.unsqueeze(-1) * S).sum(dim=2)
        return emb

    # ------------------------------------------------------------------
    # OverallAttentionBlock 한 층 수행 (text/tab 공용 내부 함수)
    # ------------------------------------------------------------------
    def _overall_block(
        self,
        x_q: torch.Tensor,
        x_kv: torch.Tensor,
        attn: nn.MultiheadAttention,
        ln_q: nn.LayerNorm,
        ln_kv: nn.LayerNorm,
        ln_ff: nn.LayerNorm,
        ffn: nn.Sequential,
        key_padding_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        x_q: query 시퀀스 (B, T_q, d)
        x_kv: key/value 시퀀스 (B, T_kv, d)
        """
        q = ln_q(x_q)
        kv = ln_kv(x_kv)
        attn_out, _ = attn(q, kv, kv, key_padding_mask=key_padding_mask)
        x = x_q + self.over_dropout_attn(attn_out)
        y = ln_ff(x)
        y = ffn(y)
        out = x + self.over_dropout_ffn(y)
        return out

    # ------------------------------------------------------------------
    # forward
    # ------------------------------------------------------------------
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, input_dim_total) float32

        1) 앞쪽 input_dim_tab -> tab (수치/범주)
        2) 나머지 text_seq_len -> text 토큰 (text_seq_len>0일 때만)
        반환: logits (B, 1)
        """
        B, D = x.shape
        device = x.device

        # ===== 1. tab / text split =====
        x_tab = x[:, : self.input_dim_tab]  # (B, input_dim_tab)

        # --- 텍스트 처리 (옵션) ---
        if self.text_seq_len > 0:
            x_text = x[:, self.input_dim_tab : self.input_dim_tab + self.text_seq_len]  # (B, L)
            x_text_long = x_text.long()                     # (B, L)
            text_padding_mask = x_text_long.eq(self.text_pad_idx)  # (B, L), True=PAD

            # 텍스트 임베딩
            z_text = self.text_embedding(x_text_long)    # (B, L, d)
            z_text = self.text_dropout(z_text)

            cls_text = self.cls_text.expand(B, 1, -1)    # (B,1,d)
            z_text = torch.cat([cls_text, z_text], dim=1) * self.scale  # (B, L+1, d)
            z_text = self._add_text_positional_encoding(z_text)

            # 텍스트 key_padding_mask (CLS 앞에 False)
            text_kpm = torch.cat(
                [
                    torch.zeros(B, 1, dtype=torch.bool, device=device),
                    text_padding_mask,
                ],
                dim=1,
            )  # (B, L+1)
        else:
            # 텍스트 컬럼이 전혀 없는 경우:
            # z_text는 CLS 하나짜리 더미 시퀀스로 만든다.
            x_text_long = None
            text_kpm = torch.zeros(B, 1, dtype=torch.bool, device=device)  # (B,1)

            cls_text = self.cls_text.expand(B, 1, -1) * self.scale  # (B,1,d)
            # positional encoding도 길이 1짜리만 적용
            z_text = self._add_text_positional_encoding(cls_text)  # (B,1,d)

        # --- tab: numerical / categorical ---
        if self.n_num > 0:
            x_num = x_tab.index_select(1, self.num_idx_tensor).float()  # (B, n_num)
        else:
            x_num = None

        if self.n_cat > 0:
            x_cat = x_tab.index_select(1, self.cat_idx_tensor).long()   # (B, n_cat)
        else:
            x_cat = None

        # ===== 2. 탭 임베딩 =====
        # --- 범주형 ---
        if self.n_cat > 0 and x_cat is not None:
            cat_emb_list = [emb(x_cat[:, j]) for j, emb in enumerate(self.cat_embeddings)]
            z_cat = torch.stack(cat_emb_list, dim=1)  # (B, n_cat, d)
            z_cat = self.cat_dropout(z_cat)
        else:
            z_cat = torch.zeros(B, 0, self.d_model, device=device)

        # --- 수치형 ---
        if self.n_num > 0 and x_num is not None:
            if self.use_dq:
                z_num = self._embed_numerical_dq(x_num)  # (B, n_num, d)
            else:
                num_emb_list = []
                for j, lin in enumerate(self.num_linears):
                    num_emb_list.append(
                        lin(x_num[:, j].view(B, 1, 1))
                    )  # (B,1,d)
                z_num = torch.cat(num_emb_list, dim=1)  # (B, n_num, d)
        else:
            z_num = torch.zeros(B, 0, self.d_model, device=device)

        # --- 탭 concat + [CLS]_tab ---
        z_tab_feats = torch.cat([z_cat, z_num], dim=1)  # (B, t_tab, d)
        cls_tab = self.cls_tab.expand(B, 1, -1)         # (B,1,d)
        z_tab = torch.cat([cls_tab, z_tab_feats], dim=1) * self.scale  # (B, t_tab+1, d)

        # 초기 상태 저장 (cross-modal key용)
        z_text_init = z_text
        z_tab_init = z_tab

        # ===== 3. Dual Overall Attention (L_overall 층) =====
        for i in range(self.n_layers_overall):
            # 텍스트 스트림
            kv_text = torch.cat([z_text, z_tab_init], dim=1)
            z_text = self._overall_block(
                x_q=z_text,
                x_kv=kv_text,
                attn=self.over_text_attn[i],
                ln_q=self.over_text_ln_q[i],
                ln_kv=self.over_text_ln_kv[i],
                ln_ff=self.over_text_ln_ff[i],
                ffn=self.over_text_ffn[i],
                key_padding_mask=None,  # 간단 버전: 전체 사용
            )

            # 탭 스트림
            kv_tab = torch.cat([z_tab, z_text_init], dim=1)
            z_tab = self._overall_block(
                x_q=z_tab,
                x_kv=kv_tab,
                attn=self.over_tab_attn[i],
                ln_q=self.over_tab_ln_q[i],
                ln_kv=self.over_tab_ln_kv[i],
                ln_ff=self.over_tab_ln_ff[i],
                ffn=self.over_tab_ffn[i],
                key_padding_mask=None,
            )

        # ===== 4. Self-Attention Encoders =====
        z_text = self.text_transformer_encoder(
            z_text, src_key_padding_mask=text_kpm
        )  # (B, L_text, d)  (L_text = L+1 or 1)
        z_tab = self.tab_transformer_encoder(z_tab)   # (B, t_tab+1, d)

        # [CLS] 추출 후 concat
        cls_text_out = z_text[:, 0, :]  # (B, d)
        cls_tab_out = z_tab[:, 0, :]    # (B, d)
        mm_feat = torch.cat([cls_text_out, cls_tab_out], dim=-1)  # (B, 2d)

        logits = self.fc(mm_feat)  # (B,1)
        if logits.ndim == 1:
            logits = logits.view(-1, 1)
        elif logits.size(1) != 1:
            logits = logits[:, :1]

        return logits


# -----------------------------------------------------------
# Deep Learning Binary Classifier
# -----------------------------------------------------------
class DeepLearningBinaryClassifier(BaseEstimator, ClassifierMixin):
    def __init__(
        self,
        model_type: str = "tabtexttransformer",
        model_params: dict | None = None,
    ):
        self.model_type = model_type
        self.model_params = model_params or {}
        self.model = None

    @property
    def device(self) -> torch.device:
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _build_model(self) -> BaseNNModel:
        model_registry = {
            "tabtexttransformer": TabularTextTransformerModel,
            # 필요 시 다른 모델들도 추가 가능
        }

        if self.model_type not in model_registry:
            raise ValueError(f"Unknown model type: {self.model_type}")

        # lr, loss_fn은 여기서 제거
        valid_params = {
            k: v for k, v in self.model_params.items() if k not in ["loss_fn", "lr"]
        }

        model_class = model_registry[self.model_type](**valid_params)
        return model_class

    def _get_loss_fn(self) -> nn.Module:
        loss_name = self.model_params.get("loss_fn", "logloss")
        if loss_name == "logloss":
            return nn.BCEWithLogitsLoss(reduction="none")
        else:
            raise ValueError(f"Unknown loss function: {loss_name}")

    def fit(
        self,
        X: np.ndarray,
        y: np.ndarray,
        sample_weight: np.ndarray | None = None,
        eval_set: list[tuple[np.ndarray, np.ndarray]] | None = None,
        eval_metric: list[str] | None = None,
        max_epochs: int = 10,
        patience: int | None = None,
        batch_size: int = 128,
        verbose: bool = True,
    ) -> "DeepLearningBinaryClassifier":

        lr = self.model_params.get("lr", 0.001)
        eval_metric = eval_metric or ["logloss"]

        X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
        y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1).to(self.device)

        if eval_set is not None:
            x_eval_tensor = torch.tensor(
                eval_set[0][0], dtype=torch.float32
            ).to(self.device)
            y_eval_true = eval_set[0][1]
        else:
            x_eval_tensor = None
            y_eval_true = None

        if sample_weight is not None:
            sample_weight_tensor = torch.tensor(
                sample_weight, dtype=torch.float32
            ).to(self.device)
        else:
            sample_weight_tensor = torch.ones_like(y_tensor, dtype=torch.float32)

        train_dataset = TensorDataset(X_tensor, y_tensor, sample_weight_tensor)
        train_dataloader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )

        if self.model is None:
            self.model = self._build_model().to(self.device)

        loss_fn = self._get_loss_fn()
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)

        patience_counter = 0
        best_metric = float("inf")
        best_model_weights = None

        for epoch in range(max_epochs):
            self.model.train()
            epoch_loss = 0.0
            n_batches = 0

            for x_batch, y_batch, weight_batch in train_dataloader:
                optimizer.zero_grad()

                y_pred_logits = self.model(x_batch)
                loss = loss_fn(y_pred_logits, y_batch)
                weighted_loss = (loss * weight_batch).sum() / weight_batch.sum()

                weighted_loss.backward()
                optimizer.step()

                epoch_loss += weighted_loss.item()
                n_batches += 1

            if verbose:
                print(
                    f"Epoch {epoch + 1}/{max_epochs} "
                    f"- [train] loss: {epoch_loss / max(1, n_batches):.6f}"
                )

            # evaluation
            if eval_set is not None:
                self.model.eval()
                with torch.no_grad():
                    y_eval_logits = self.model(x_eval_tensor)
                    y_eval_pred = torch.sigmoid(y_eval_logits).cpu().numpy().ravel()

                eval_metrics = {}
                for metric in eval_metric:
                    if metric == "logloss":
                        eval_metrics["logloss"] = log_loss(y_eval_true, y_eval_pred)
                    elif metric == "auc":
                        eval_metrics["auc"] = -roc_auc_score(y_eval_true, y_eval_pred)
                    else:
                        raise ValueError(f"Unknown metric: {metric}")

                if verbose:
                    metrics_str = ", ".join(
                        [f"{k}: {v:.4f}" for k, v in eval_metrics.items()]
                    )
                    print(f"  - [eval] {metrics_str}")

                # early stopping (기준 metric은 리스트의 첫 번째)
                main_metric_name = eval_metric[0]
                current_metric = eval_metrics.get(
                    main_metric_name, eval_metrics["logloss"]
                )

                if verbose:
                    print(
                        f"    -- (early_stopping) current_metric: {current_metric:.6f}, "
                        f"best_metric: {best_metric:.6f}"
                    )

                if current_metric < best_metric:
                    best_metric = current_metric
                    patience_counter = 0
                    best_model_weights = self.model.state_dict()
                else:
                    patience_counter += 1
                    if patience is not None and patience_counter >= patience:
                        if verbose:
                            print(f"Early stopping at epoch {epoch + 1}")
                        break

        if best_model_weights is not None:
            self.model.load_state_dict(best_model_weights)

        return self

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        self.model = self.model.to(self.device)
        X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)

        with torch.no_grad():
            self.model.eval()
            logits = self.model(X_tensor)
            probs1 = torch.sigmoid(logits).cpu().numpy()

        if probs1.shape[1] == 1:
            probs1 = probs1.reshape(-1, 1)

        probs0 = 1.0 - probs1
        probs = np.hstack((probs0, probs1))
        return probs.astype("float")

    def predict(self, X: np.ndarray) -> np.ndarray:
        probs = self.predict_proba(X)
        return probs.argmax(axis=1)


# -----------------------------------------------------------
# 한국형 피처 스키마(연령/성별/OS/점수형) 데모 데이터 생성 (탭만)
# -----------------------------------------------------------
def make_kor_feature_demo_data(
    n_samples: int = 5000,
    seed: int = 42,
):
    rng = np.random.RandomState(seed)
    n = n_samples

    # --- 1) 각 피처 생성 ---
    age = np.clip(rng.normal(loc=40, scale=12, size=n), 18, 80).astype("float32")
    gender = rng.randint(0, 3, size=n).astype("int64")
    app_join_cnt = rng.poisson(lam=2.0, size=n).astype("float32")
    last_login_days = rng.exponential(scale=10.0, size=n).astype("float32")
    os_cat = rng.randint(0, 4, size=n).astype("int64")
    app_join_days = rng.exponential(scale=200.0, size=n).astype("float32")

    biz_owner_score = (rng.binomial(1, 0.3, size=n) * rng.uniform(0.5, 1.0, size=n)).astype("float32")
    home_owner_score = (rng.binomial(1, 0.4, size=n) * rng.uniform(0.5, 1.0, size=n)).astype("float32")
    car_owner_score = (rng.binomial(1, 0.5, size=n) * rng.uniform(0.5, 1.0, size=n)).astype("float32")
    married_score   = (rng.binomial(1, 0.5, size=n) * rng.uniform(0.5, 1.0, size=n)).astype("float32")
    child_score     = (rng.binomial(1, 0.4, size=n) * rng.uniform(0.5, 1.0, size=n)).astype("float32")

    # --- 2) 라벨 생성용 latent score ---
    w_gender = rng.uniform(-0.3, 0.3, size=3)   # gender 0/1/2
    w_os     = rng.uniform(-0.2, 0.2, size=4)   # os 0~3

    score = np.zeros(n, dtype="float32")

    score += 0.05 * (age - 40) / 10.0
    score += 0.25 * app_join_cnt
    score -= 0.03 * last_login_days
    score -= 0.002 * app_join_days
    score += 0.8 * biz_owner_score
    score += 0.6 * home_owner_score
    score += 0.7 * car_owner_score
    score += 0.9 * married_score
    score += 1.0 * child_score

    score += w_gender[gender]
    score += w_os[os_cat]

    noise = rng.normal(scale=0.5, size=n).astype("float32")
    bias = -0.2
    logit = score + bias + noise
    prob = 1.0 / (1.0 + np.exp(-logit))
    y = (prob > 0.5).astype("int64")

    # --- 3) 최종 Tab X 매트릭스 ---
    X_tab = np.column_stack(
        [
            age,                        # 0
            gender.astype("float32"),   # 1
            app_join_cnt,               # 2
            last_login_days,            # 3
            os_cat.astype("float32"),   # 4
            app_join_days,              # 5
            biz_owner_score,            # 6
            home_owner_score,           # 7
            car_owner_score,            # 8
            married_score,              # 9
            child_score,                # 10
        ]
    ).astype("float32")

    cat_feature_indices = [1, 4]   # 성별, OS
    cat_dims = [3, 4]              # gender:3, os:4

    return X_tab, y, cat_feature_indices, cat_dims


# -----------------------------------------------------------
# 수치형 quantile 계산 유틸 (DQ embedding용)
# -----------------------------------------------------------
def compute_numeric_quantiles(
    X: np.ndarray,
    numeric_indices: List[int],
    num_quantiles: int = 6,
    quantile_grid: Optional[np.ndarray] = None,
) -> torch.Tensor:
    """
    X: feature matrix (N, D) - 여기서는 tab 부분만 넘기는 걸 추천
    numeric_indices: 수치형 feature column index 리스트
    num_quantiles: quantile 개수
    quantile_grid: np.ndarray shape (num_quantiles,) 0~1 사이 값 (None이면 균등)

    return: torch.Tensor (n_num_features, num_quantiles) float32
    """
    X_num = X[:, numeric_indices].astype("float32")
    n_num = X_num.shape[1]

    if quantile_grid is None:
        quantile_grid = np.linspace(0.0, 1.0, num_quantiles, dtype=np.float32)

    q_list = []
    for j in range(n_num):
        col = X_num[:, j]
        q_col = np.quantile(col, quantile_grid, method="linear").astype("float32")
        q_list.append(q_col)

    quantiles = np.stack(q_list, axis=0)  # (n_num, num_quantiles)
    return torch.from_numpy(quantiles)


# -----------------------------------------------------------
# Tabular-Text Transformer 데모 (텍스트 있는 버전)
# -----------------------------------------------------------
def demo_train_tabtexttransformer():
    print("\n===== Tabular-Text Transformer Demo (with text) =====")

    # 1) Tab 데이터 생성
    X_tab, y, cat_features, cat_dims = make_kor_feature_demo_data(
        n_samples=2000,
        seed=123,
    )

    print("Tab X shape:", X_tab.shape, "| y shape:", y.shape)
    print("Categorical feature indices:", cat_features)
    print("Categorical dims:", cat_dims)

    # 2) 텍스트 토큰 생성 (간단 랜덤 토큰)
    n_samples = X_tab.shape[0]
    text_seq_len = 16
    text_vocab_size = 50
    text_pad_idx = 0

    rng = np.random.RandomState(999)
    X_text = rng.randint(1, text_vocab_size, size=(n_samples, text_seq_len)).astype("int64")

    # 일부 위치를 pad(0)로 설정
    pad_mask = rng.rand(n_samples, text_seq_len) < 0.1  # 10% 정도 pad
    X_text[pad_mask] = text_pad_idx

    # 3) Tab + Text 합치기
    X_total = np.concatenate(
        [X_tab, X_text.astype("float32")],
        axis=1,
    ).astype("float32")

    input_dim_tab = X_tab.shape[1]
    input_dim_total = X_total.shape[1]

    # 4) train / val / test split
    N = X_total.shape[0]
    idx = np.arange(N)
    np.random.shuffle(idx)

    tr_end = int(N * 0.7)
    va_end = int(N * 0.85)
    tr_idx, va_idx, te_idx = idx[:tr_end], idx[tr_end:va_end], idx[va_end:]

    X_tr, y_tr = X_total[tr_idx], y[tr_idx]
    X_va, y_va = X_total[va_idx], y[va_idx]
    X_te, y_te = X_total[te_idx], y[te_idx]

    sample_weight = np.ones_like(y_tr, dtype="float32")

    # 5) 수치형/범주형 인덱스 (tab 부분 기준)
    num_features = [i for i in range(input_dim_tab) if i not in cat_features]

    # 6) DQ embedding용 quantiles (train tab 부분에서 계산)
    quantile_grid = np.array([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], dtype=np.float32)
    X_tr_tab = X_tr[:, :input_dim_tab]  # tab 부분만
    quantiles = compute_numeric_quantiles(
        X_tr_tab,
        numeric_indices=num_features,
        num_quantiles=len(quantile_grid),
        quantile_grid=quantile_grid,
    )

    print("Numeric feature indices:", num_features)
    print("Quantiles shape:", quantiles.shape)

    # 7) 모델 파라미터 구성
    model_params = {
        "input_dim_total": input_dim_total,
        "input_dim_tab": input_dim_tab,
        "num_features": num_features,
        "cat_features": cat_features,
        "cat_dims": cat_dims,
        "text_vocab_size": text_vocab_size,
        "text_pad_idx": text_pad_idx,
        "text_seq_len": text_seq_len,   # ← 텍스트가 있으므로 >0
        "d_model": 64,
        "n_heads": 4,
        "n_layers_overall": 2,
        "n_layers_self": 2,
        "d_ff": None,          # None이면 4*d_model
        "dropout": 0.1,
        "d_fc": 128,
        "numeric_embedding": "dq",  # distance-to-quantile
        "num_quantiles": quantiles.shape[1],
        "quantiles": quantiles,
        "lr": 1e-3,
        "loss_fn": "logloss",
    }

    clf = DeepLearningBinaryClassifier(
        model_type="tabtexttransformer",
        model_params=model_params,
    )

    clf.fit(
        X_tr,
        y_tr,
        sample_weight=sample_weight,
        eval_set=[(X_va, y_va)],
        eval_metric=["logloss"],
        max_epochs=5,      # 데모용
        patience=2,
        batch_size=256,
        verbose=True,
    )

    # 8) 평가
    probs_te = clf.predict_proba(X_te)[:, 1]
    preds_te = (probs_te >= 0.5).astype("int64")

    acc = (preds_te == y_te).mean()
    auc = roc_auc_score(y_te, probs_te)
    ll = log_loss(y_te, probs_te)

    print("\n===== Tabular-Text Transformer Test Metrics =====")
    print(f"Accuracy : {acc:.4f}")
    print(f"ROC-AUC  : {auc:.4f}")
    print(f"Logloss  : {ll:.4f}")
    print("Sample probs (first 10):", np.round(probs_te[:10], 4))


# -----------------------------------------------------------
# 메인
# -----------------------------------------------------------
if __name__ == "__main__":
    demo_train_tabtexttransformer()
    # 텍스트 없이 쓰고 싶다면:
    #  - X_total 대신 X_tab만 사용
    #  - input_dim_total = input_dim_tab = X_tab.shape[1]
    #  - text_seq_len = 0
    #  - text_vocab_size는 1, text_pad_idx=0 정도로 두면 됩니다.


===== Tabular-Text Transformer Demo (with text) =====
Tab X shape: (2000, 11) | y shape: (2000,)
Categorical feature indices: [1, 4]
Categorical dims: [3, 4]
Numeric feature indices: [0, 2, 3, 5, 6, 7, 8, 9, 10]
Quantiles shape: torch.Size([9, 6])
Epoch 1/5 - [train] loss: 226.742945
  - [eval] logloss: 0.6107
    -- (early_stopping) current_metric: 0.610688, best_metric: inf
Epoch 2/5 - [train] loss: 130.932189
  - [eval] logloss: 0.5732
    -- (early_stopping) current_metric: 0.573194, best_metric: 0.610688
Epoch 3/5 - [train] loss: 127.385031
  - [eval] logloss: 0.6110
    -- (early_stopping) current_metric: 0.611009, best_metric: 0.573194
Epoch 4/5 - [train] loss: 125.481831
  - [eval] logloss: 0.5632
    -- (early_stopping) current_metric: 0.563169, best_metric: 0.573194
Epoch 5/5 - [train] loss: 121.487501
  - [eval] logloss: 0.5738
    -- (early_stopping) current_metric: 0.573800, best_metric: 0.563169

===== Tabular-Text Transformer Test Metrics =====
Accuracy : 0.7800
ROC-AUC