In [1]:
# -*- coding: utf-8 -*-
"""
Monolithic Tabular-Text Transformer (TTT) — single-class model
- 모든 모델 컴포넌트를 TabularTextTransformerMono 내부 메서드로 통합
- DTQ, 사인/코사인 PE, Overall Attention, FFN을 별도 클래스 없이 구현
- DataLoader엔 CPU 텐서만 넣고, 배치 루프에서만 .to(device) (pin_memory 에러 방지)
"""

# Third Party
import math
import re
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import accuracy_score, roc_auc_score, log_loss
from torch.utils.data import DataLoader, TensorDataset


# =========================
# 0) Utilities
# =========================
def _to_array(X):
    if isinstance(X, pd.DataFrame):
        return X.values
    return np.asarray(X, dtype=object)


class SimpleTokenizer:
    """아주 단순한 공백 기반 토크나이저 + vocab"""
    def __init__(self, lower=True, min_freq=2, max_vocab=30000, pad_token="<pad>", unk_token="<unk>"):
        self.lower = lower
        self.min_freq = min_freq
        self.max_vocab = max_vocab
        self.pad_token = pad_token
        self.unk_token = unk_token
        self.stoi = {pad_token: 0, unk_token: 1}
        self.itos = [pad_token, unk_token]
        self.fitted_ = False

    def _basic_tokenize(self, s: str):
        if s is None:
            return []
        if not isinstance(s, str):
            s = str(s)
        if self.lower:
            s = s.lower()
        return [t for t in re.split(r"\W+", s) if t]

    def fit(self, texts: List[str]):
        from collections import Counter
        c = Counter()
        for t in texts:
            c.update(self._basic_tokenize(t))
        items = sorted([kv for kv in c.items() if kv[1] >= self.min_freq], key=lambda x: (-x[1], x[0]))
        for tok, _ in items[: self.max_vocab - len(self.itos)]:
            if tok not in self.stoi:
                self.stoi[tok] = len(self.itos)
                self.itos.append(tok)
        self.fitted_ = True
        return self

    def encode(self, s: str):
        toks = self._basic_tokenize(s)
        return [self.stoi.get(t, 1) for t in toks]  # 1 = <unk>

    @property
    def vocab_size(self):
        return len(self.itos)


# =========================
# 1) Preprocessor (cat/num/text)
# =========================
class TTTPreprocessor:
    """
    - Categorical: 문자열→정수 ID (0 = OOV)
    - Numeric    : per-column quantiles 저장 (DTQ)
    - Text       : vocab 구성 후 고정 길이 인코딩
    """
    def __init__(
        self,
        categorical_indices: Optional[List[int]] = None,
        text_indices: Optional[List[int]] = None,
        use_oov: bool = True,
        add_na_token: bool = True,
        max_text_len: int = 128,
        min_freq: int = 2,
        max_vocab: int = 30000,
    ):
        self.categorical_indices = None if categorical_indices is None else list(categorical_indices)
        self.text_indices = [] if text_indices is None else list(text_indices)
        self.use_oov = use_oov
        self.add_na_token = add_na_token

        self.cat_idx, self.cont_idx = [], []
        self.cat_maps = {}
        self.cardinalities = []
        self.num_quantiles = None
        self.s = 6

        self.max_text_len = max_text_len
        self.tok = SimpleTokenizer(min_freq=min_freq, max_vocab=max_vocab)
        self.fitted_ = False

    def fit(self, X, categorical_indices=None, text_indices=None, s: int = 6):
        X = _to_array(X)
        n_cols = X.shape[1]
        if categorical_indices is not None:
            self.categorical_indices = list(categorical_indices)
        if text_indices is not None:
            self.text_indices = list(text_indices)

        # 범주형 자동 추론(텍스트 컬럼 제외)
        if self.categorical_indices is None:
            self.categorical_indices = []
            for j in range(n_cols):
                if j in (self.text_indices or []):
                    continue
                if X[:, j].dtype == object:
                    self.categorical_indices.append(j)

        cat_set = set(self.categorical_indices or [])
        text_set = set(self.text_indices or [])
        all_idx = set(range(n_cols))
        self.cont_idx = sorted(list(all_idx - cat_set - text_set))
        self.cat_idx = sorted(list(cat_set))

        # 범주형 매핑
        self.cat_maps, self.cardinalities = {}, []
        for j in self.cat_idx:
            col = X[:, j]
            if self.add_na_token:
                col = np.where(pd.isna(col), "<NA>", col)
            uniq = pd.unique(col)
            self.cat_maps[j] = {v: i + 1 for i, v in enumerate(uniq)}  # 1..K
            self.cardinalities.append(len(uniq))

        # 수치형 분위수
        self.s = int(s)
        if len(self.cont_idx) > 0:
            cont = X[:, self.cont_idx].astype("float32")
            qs = np.linspace(0, 1, self.s, dtype=np.float32)
            self.num_quantiles = np.quantile(cont, qs, axis=0).T  # (n_num, s)
        else:
            self.num_quantiles = None

        # 텍스트 vocab
        if len(self.text_indices) > 0:
            texts = (X[:, self.text_indices].astype(str)
                     if len(self.text_indices) > 1 else X[:, self.text_indices[0]].astype(str))
            if texts.ndim == 2:
                joined = [" ".join(row.tolist()) for row in texts]
            else:
                joined = texts.tolist()
            self.tok.fit(joined)

        self.fitted_ = True
        return self

    def transform(self, X):
        assert self.fitted_, "Call fit() first."
        X = _to_array(X)
        B = X.shape[0]

        # categorical
        if len(self.cat_idx) > 0:
            x_cat = np.zeros((B, len(self.cat_idx)), dtype="int64")
            for ti, j in enumerate(self.cat_idx):
                col = X[:, j]
                if self.add_na_token:
                    col = np.where(pd.isna(col), "<NA>", col)
                m = self.cat_maps[j]
                x_cat[:, ti] = np.array([m.get(v, 0) for v in col], dtype="int64")
        else:
            x_cat = np.zeros((B, 0), dtype="int64")

        # numeric
        x_num = X[:, self.cont_idx].astype("float32") if len(self.cont_idx) > 0 else None

        # text
        if len(self.text_indices) > 0:
            texts = (X[:, self.text_indices].astype(str)
                     if len(self.text_indices) > 1 else X[:, self.text_indices[0]].astype(str))
            if texts.ndim == 2:
                joined = [" ".join(row.tolist()) for row in texts]
            else:
                joined = texts.tolist()
            enc = []
            for s in joined:
                ids = self.tok.encode(s)[: self.max_text_len]
                if len(ids) < self.max_text_len:
                    ids = ids + [0] * (self.max_text_len - len(ids))
                enc.append(ids)
            x_text = np.asarray(enc, dtype="int64")
        else:
            x_text = None

        return x_cat, x_num, x_text

    def fit_transform(self, X, **kw):
        self.fit(X, **kw)
        return self.transform(X)


# =========================
# 2) Monolithic TTT Model (single class)
# =========================
class TabularTextTransformerMono(nn.Module):
    """
    단일 클래스 내부에:
      - 텍스트 임베딩(+사인/코사인 위치인코딩), [CLS]_text
      - 범주형 임베딩(+OOV=0)
      - 수치형 DTQ(분위수 역거리 가중합 임베딩)
      - Dual-stream Overall Attention (text/tab)
      - 듀얼 헤드(BCE), 최종 출력은 평균
    """
    def __init__(
        self,
        cat_cardinalities: List[int],
        n_num: int,
        vocab_size: int,
        d_model=128, n_heads=4, n_layers=2, dim_feedforward=256, dropout=0.1,
        num_quantiles: Optional[np.ndarray] = None,  # (n_num, s)
        max_text_len: int = 2048,
        text_pad_id: int = 0,
        use_text_positional_encoding: bool = True,
    ):
        super().__init__()
        self.n_cat = len(cat_cardinalities)
        self.n_num = n_num
        self.vocab_size = vocab_size
        self.d = d_model
        self.text_pad_id = text_pad_id
        self.use_text_pos = use_text_positional_encoding
        self.max_text_len = max_text_len

        # ---- 텍스트 임베딩 & [CLS] ----
        self.text_tok = nn.Embedding(vocab_size, d_model, padding_idx=text_pad_id)
        nn.init.normal_(self.text_tok.weight, std=0.02)
        self.cls_text = nn.Parameter(torch.zeros(1, 1, d_model)); nn.init.normal_(self.cls_text, std=0.02)

        # ---- 사인/코사인 위치인코딩 버퍼 ----
        if self.use_text_pos:
            pe = self._build_sinusoidal_pe(d_model, max_len=max_text_len + 1)  # +1 for [CLS]
            self.register_buffer("pe_text", pe)  # (max_len+1, d)
        else:
            self.register_buffer("pe_text", None)

        # ---- 범주형 임베딩 ----
        if self.n_cat > 0:
            self.cat_embs = nn.ModuleList([nn.Embedding(c + 1, d_model, padding_idx=0) for c in cat_cardinalities])
            for emb in self.cat_embs:
                nn.init.normal_(emb.weight, std=0.02)
        else:
            self.cat_embs = nn.ModuleList()

        # ---- DTQ 파라미터/버퍼 ----
        if n_num > 0:
            assert num_quantiles is not None and num_quantiles.shape[0] == n_num
            self.register_buffer("quantiles", torch.tensor(num_quantiles.astype("float32")))  # (n_num, s)
            self.S_dtq = nn.Parameter(torch.randn(n_num, self.quantiles.size(1), d_model) * 0.02)
        else:
            self.register_buffer("quantiles", None)
            self.S_dtq = None

        # ---- [CLS]_tab ----
        self.cls_tab = nn.Parameter(torch.zeros(1, 1, d_model)); nn.init.normal_(self.cls_tab, std=0.02)
        self.emb_drop = nn.Dropout(dropout)

        # ---- Dual-stream Overall Attention 레이어 스택 ----
        self.layers_text = nn.ModuleList([
            nn.ModuleDict(dict(
                ln_q = nn.LayerNorm(d_model),
                ln_kv= nn.LayerNorm(d_model),
                attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True),
                ln_ff= nn.LayerNorm(d_model),
                ffn  = nn.Sequential(
                    nn.Linear(d_model, dim_feedforward),
                    nn.ReLU(inplace=True),
                    nn.Dropout(dropout),
                    nn.Linear(dim_feedforward, d_model),
                    nn.Dropout(dropout),
                ),
            )) for _ in range(n_layers)
        ])
        self.layers_tab = nn.ModuleList([
            nn.ModuleDict(dict(
                ln_q = nn.LayerNorm(d_model),
                ln_kv= nn.LayerNorm(d_model),
                attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True),
                ln_ff= nn.LayerNorm(d_model),
                ffn  = nn.Sequential(
                    nn.Linear(d_model, dim_feedforward),
                    nn.ReLU(inplace=True),
                    nn.Dropout(dropout),
                    nn.Linear(dim_feedforward, d_model),
                    nn.Dropout(dropout),
                ),
            )) for _ in range(n_layers)
        ])

        # ---- Heads ----
        self.head_text = nn.Sequential(nn.Linear(d_model, 1), nn.Sigmoid())
        self.head_tab  = nn.Sequential(nn.Linear(d_model, 1), nn.Sigmoid())

    # ===== 내부 유틸 =====
    @staticmethod
    def _build_sinusoidal_pe(d_model: int, max_len: int):
        pe = torch.zeros(max_len, d_model, dtype=torch.float32)
        pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        return pe  # (max_len, d)

    def _add_pos_encoding_text(self, z: torch.Tensor) -> torch.Tensor:
        # z: (B, 1+T, d)  (앞에 [CLS])
        if self.pe_text is None:
            return z
        T = z.size(1)
        return z + self.pe_text[:T].unsqueeze(0)

    def _embed_text(self, x_text: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        x_text: (B, T) long (0=pad)
        return: z_text0 (B, 1+T, d), pad_mask_text (B, 1+T)
        """
        if x_text is None:
            # 텍스트가 아예 없는 경우: [CLS] 하나만
            B = 1
            z = self.cls_text.expand(B, 1, -1)
            pad_mask = torch.zeros(B, 1, dtype=torch.bool, device=z.device)
            return z, pad_mask

        B, T = x_text.size()
        z = self.text_tok(x_text)             # (B, T, d)
        cls = self.cls_text.expand(B, 1, -1)  # (B, 1, d)
        z = torch.cat([cls, z], dim=1)        # (B, 1+T, d)
        z = self._add_pos_encoding_text(z)
        z = self.emb_drop(z)

        pad_mask = (x_text == self.text_pad_id)  # (B, T)
        pad_mask = torch.cat([torch.zeros(B, 1, dtype=torch.bool, device=x_text.device), pad_mask], dim=1)
        return z, pad_mask

    def _embed_tab(self, x_cat: Optional[torch.Tensor], x_num: Optional[torch.Tensor]) -> torch.Tensor:
        """
        반환: z_tab0 (B, 1+T_tab, d)  with [CLS]_tab
        """
        B = (x_cat.size(0) if x_cat is not None else (x_num.size(0) if x_num is not None else 1))
        parts = []
        if self.n_cat > 0 and x_cat is not None:
            cat_tok = [emb(x_cat[:, j]).unsqueeze(1) for j, emb in enumerate(self.cat_embs)]  # (B,1,d) list
            parts.append(torch.cat(cat_tok, dim=1))  # (B, n_cat, d)

        if self.n_num > 0 and x_num is not None and self.S_dtq is not None:
            # DTQ: 역거리 가중합으로 per-feature 임베딩
            # x_num: (B, n_num), Q: (n_num, s), S: (n_num, s, d)
            x = x_num.unsqueeze(-1)                         # (B, n_num, 1)
            Q = self.quantiles.unsqueeze(0)                 # (1, n_num, s)
            eq = (x == Q)
            any_eq = eq.any(dim=-1, keepdim=True)
            inv = torch.where(eq, torch.ones_like(Q), 1.0 / (torch.abs(x - Q) + 1e-12))
            inv = torch.where(any_eq, torch.where(eq, torch.ones_like(inv), torch.zeros_like(inv)), inv)
            w = inv / (inv.sum(dim=-1, keepdim=True) + 1e-12)   # (B, n_num, s)

            # (B, n_num, s) @ (n_num, s, d) → (B, n_num, d)  (feature별로 계산)
            z_num_list = []
            for j in range(self.n_num):
                wj = w[:, j, :].unsqueeze(1)                 # (B, 1, s)
                Sj = self.S_dtq[j:j+1, :, :]                 # (1, s, d)
                z_num_list.append(wj @ Sj)                   # (B, 1, d)
            z_num = torch.cat(z_num_list, dim=1)             # (B, n_num, d)
            parts.append(z_num)

        z = torch.cat(parts, dim=1) if parts else torch.zeros(B, 0, self.d, device=self.cls_tab.device)
        cls = self.cls_tab.expand(B, 1, -1)
        z = torch.cat([cls, z], dim=1)
        return self.emb_drop(z)

    def forward(
        self,
        x_cat: Optional[torch.LongTensor],
        x_num: Optional[torch.FloatTensor],
        x_text: Optional[torch.LongTensor],
    ):
        """
        반환:
          p_avg (B,1) = 0.5*(p_tab + p_text), p_tab (B,1), p_text (B,1)
        """
        # 초기 임베딩
        z_tab0 = self._embed_tab(x_cat, x_num)                 # (B, Ttab0, d)
        pad_mask_tab = torch.zeros(z_tab0.size()[:2], dtype=torch.bool, device=z_tab0.device)

        if x_text is not None:
            z_text0, pad_mask_text = self._embed_text(x_text)  # (B, Ttxt0, d)
        else:
            B = z_tab0.size(0)
            z_text0 = self.cls_text.expand(B, 1, -1)
            pad_mask_text = torch.zeros(B, 1, dtype=torch.bool, device=z_tab0.device)

        # 진행 상태
        z_tab = z_tab0
        z_text = z_text0

        # L 레이어 반복 (Overall Attention: self + initial-cross)
        for i in range(len(self.layers_tab)):
            # --- Tab stream ---
            lt = self.layers_tab[i]
            q = lt.ln_q(z_tab)
            kv = lt.ln_kv(torch.cat([z_tab, z_text0], dim=1))
            kpm = torch.cat([pad_mask_tab, pad_mask_text], dim=1)
            attn_out, _ = lt.attn(query=q, key=kv, value=kv, key_padding_mask=kpm)
            z_tab = z_tab + attn_out
            z_tab = z_tab + lt.ffn(lt.ln_ff(z_tab))

            # --- Text stream ---
            lx = self.layers_text[i]
            qx = lx.ln_q(z_text)
            kvx = lx.ln_kv(torch.cat([z_text, z_tab0], dim=1))
            kpmx = torch.cat([pad_mask_text, pad_mask_tab], dim=1)
            attn_text, _ = lx.attn(query=qx, key=kvx, value=kvx, key_padding_mask=kpmx)
            z_text = z_text + attn_text
            z_text = z_text + lx.ffn(lx.ln_ff(z_text))

        # [CLS] 추출 후 듀얼 헤드
        cls_tab  = z_tab[:, 0, :]
        cls_text = z_text[:, 0, :]
        p_tab  = self.head_tab(cls_tab)     # (B,1)
        p_text = self.head_text(cls_text)   # (B,1)
        p = 0.5 * (p_tab + p_text)
        return p, p_tab, p_text


# =========================
# 3) Sklearn-compatible wrapper (CPU DataLoader + move-to-device in loop)
# =========================
class TTTBinaryClassifier(BaseEstimator, ClassifierMixin):
    def __init__(
        self,
        auto_preprocess=True,
        categorical_indices: Optional[List[int]] = None,
        text_indices: Optional[List[int]] = None,
        use_oov=True,
        max_text_len=128,
        min_freq=2,
        max_vocab=30000,
        quantiles_s=6,
        d_model=128, n_heads=4, n_layers=2, dim_feedforward=256, dropout=0.1,
        lr=1e-3, weight_decay=1e-4,
        loss_fn="logloss",
        device=None,
    ):
        self.auto_preprocess = auto_preprocess
        self.categorical_indices = categorical_indices
        self.text_indices = text_indices
        self.use_oov = use_oov
        self.max_text_len = max_text_len
        self.min_freq = min_freq
        self.max_vocab = max_vocab
        self.quantiles_s = quantiles_s

        self.d_model = d_model
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout

        self.lr = lr
        self.weight_decay = weight_decay
        self.loss_fn_name = loss_fn

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        # filled later
        self.preproc: Optional[TTTPreprocessor] = None
        self.model: Optional[TabularTextTransformerMono] = None
        self.best_model_state = None

        self.cat_idx: List[int] = []
        self.cont_idx: List[int] = []
        self.text_idx: List[int] = []
        self.cat_cardinalities: List[int] = []
        self.num_quantiles: Optional[np.ndarray] = None
        self.vocab_size: int = 2

    # internals
    def _define_loss(self):
        if self.loss_fn_name == "logloss":
            return nn.BCELoss(reduction="none")
        raise ValueError(self.loss_fn_name)

    def _split_to_cpu_tensors(self, X):
        """DataLoader용 CPU 텐서 반환(루프에서만 .to(device))"""
        if self.auto_preprocess:
            x_cat_np, x_num_np, x_text_np = self.preproc.transform(X)
        else:
            arr = X if not isinstance(X, torch.Tensor) else X.detach().cpu().numpy()
            x_cat_np = arr[:, self.cat_idx].astype("int64") if len(self.cat_idx) > 0 else np.zeros((arr.shape[0], 0), dtype="int64")
            x_num_np = arr[:, self.cont_idx].astype("float32") if len(self.cont_idx) > 0 else None
            x_text_np = arr[:, self.text_idx].astype("int64") if len(self.text_idx) > 0 else None

        x_cat  = torch.tensor(x_cat_np, dtype=torch.long) if x_cat_np is not None else None
        x_num  = torch.tensor(x_num_np, dtype=torch.float32) if x_num_np is not None else None
        x_text = torch.tensor(x_text_np, dtype=torch.long) if x_text_np is not None else None
        return x_cat, x_num, x_text

    def _build_model(self):
        n_num = len(self.cont_idx)
        m = TabularTextTransformerMono(
            cat_cardinalities=self.cat_cardinalities,
            n_num=n_num,
            vocab_size=self.vocab_size,
            d_model=self.d_model,
            n_heads=self.n_heads,
            n_layers=self.n_layers,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            num_quantiles=self.num_quantiles,
            max_text_len=self.max_text_len + 1,  # [CLS] 포함 여유
        )
        return m.to(self.device)

    # public API
    def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
            max_epochs=10, patience=None, batch_size=256, num_workers=0, verbose=True):

        # auto preprocess
        if self.auto_preprocess:
            self.preproc = TTTPreprocessor(
                categorical_indices=self.categorical_indices,
                text_indices=self.text_indices,
                use_oov=self.use_oov,
                max_text_len=self.max_text_len,
                min_freq=self.min_freq,
                max_vocab=self.max_vocab,
            )
            self.preproc.fit(X, s=self.quantiles_s)
            self.cat_idx = self.preproc.cat_idx
            self.cont_idx = self.preproc.cont_idx
            self.text_idx = self.preproc.text_indices
            self.cat_cardinalities = self.preproc.cardinalities
            self.num_quantiles = self.preproc.num_quantiles
            self.vocab_size = self.preproc.tok.vocab_size

        # CPU tensors
        x_cat, x_num, x_text = self._split_to_cpu_tensors(X)
        y_t = torch.tensor(y, dtype=torch.float32).view(-1, 1)
        w_t = torch.tensor(sample_weight, dtype=torch.float32).view(-1, 1) if sample_weight is not None else torch.ones_like(y_t)

        # validation
        if eval_set is not None:
            Xv, yv = eval_set[0]
            xc_v, xn_v, xt_v = self._split_to_cpu_tensors(Xv)
            yv_t = torch.tensor(yv, dtype=torch.float32).view(-1, 1)
        else:
            xc_v = xn_v = xt_v = yv_t = None

        if self.model is None:
            self.model = self._build_model()

        loss_fn = self._define_loss()
        opt = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        # datasets/dataloaders (CPU)
        if x_num is None and x_text is None:
            train_ds = TensorDataset(x_cat, y_t, w_t)
            def itr(dl):
                for a, yb, wb in dl: yield a, None, None, yb, wb
        elif x_num is None:
            train_ds = TensorDataset(x_cat, x_text, y_t, w_t)
            def itr(dl):
                for a, t, yb, wb in dl: yield a, None, t, yb, wb
        elif x_text is None:
            train_ds = TensorDataset(x_cat, x_num, y_t, w_t)
            def itr(dl):
                for a, n, yb, wb in dl: yield a, n, None, yb, wb
        else:
            train_ds = TensorDataset(x_cat, x_num, x_text, y_t, w_t)
            def itr(dl):
                for a, n, t, yb, wb in dl: yield a, n, t, yb, wb

        pin = (self.device == "cuda")
        train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=pin)

        best, pc = float("inf"), 0
        for ep in range(max_epochs):
            self.model.train()
            acc_loss, steps = 0.0, 0
            for a, n, t, yb, wb in itr(train_dl):
                # move to device
                if a is not None: a = a.to(self.device, non_blocking=True)
                if n is not None: n = n.to(self.device, non_blocking=True)
                if t is not None: t = t.to(self.device, non_blocking=True)
                yb = yb.to(self.device, non_blocking=True)
                wb = wb.to(self.device, non_blocking=True)

                opt.zero_grad()
                p, pt, px = self.model(a, n, t)
                l = 0.5 * (loss_fn(pt, yb) + loss_fn(px, yb))  # dual loss
                wloss = (l * wb).sum() / wb.sum()
                wloss.backward()
                opt.step()
                acc_loss += wloss.item(); steps += 1

            if verbose:
                print(f"Epoch {ep+1}/{max_epochs} - train_loss: {acc_loss/max(1,steps):.6f}")

            # validation
            if eval_set is not None:
                self.model.eval()
                with torch.no_grad():
                    if xn_v is None and xt_v is None:
                        vds = TensorDataset(xc_v, yv_t)
                        vdl = DataLoader(vds, batch_size=2048, shuffle=False, pin_memory=pin)
                        ev, c = 0.0, 0
                        for a, yb in vdl:
                            a = a.to(self.device, non_blocking=True)
                            yb = yb.to(self.device, non_blocking=True)
                            _, pt, px = self.model(a, None, None)
                            l = 0.5*(loss_fn(pt, yb) + loss_fn(px, yb))
                            ev += (l.sum()/len(l)).item(); c += 1
                    elif xn_v is None:
                        vds = TensorDataset(xc_v, xt_v, yv_t)
                        vdl = DataLoader(vds, batch_size=2048, shuffle=False, pin_memory=pin)
                        ev, c = 0.0, 0
                        for a, t, yb in vdl:
                            a = a.to(self.device, non_blocking=True)
                            t = t.to(self.device, non_blocking=True)
                            yb = yb.to(self.device, non_blocking=True)
                            _, pt, px = self.model(a, None, t)
                            l = 0.5*(loss_fn(pt, yb) + loss_fn(px, yb))
                            ev += (l.sum()/len(l)).item(); c += 1
                    elif xt_v is None:
                        vds = TensorDataset(xc_v, xn_v, yv_t)
                        vdl = DataLoader(vds, batch_size=2048, shuffle=False, pin_memory=pin)
                        ev, c = 0.0, 0
                        for a, n, yb in vdl:
                            a = a.to(self.device, non_blocking=True)
                            n = n.to(self.device, non_blocking=True)
                            yb = yb.to(self.device, non_blocking=True)
                            _, pt, px = self.model(a, n, None)
                            l = 0.5*(loss_fn(pt, yb) + loss_fn(px, yb))
                            ev += (l.sum()/len(l)).item(); c += 1
                    else:
                        vds = TensorDataset(xc_v, xn_v, xt_v, yv_t)
                        vdl = DataLoader(vds, batch_size=2048, shuffle=False, pin_memory=pin)
                        ev, c = 0.0, 0
                        for a, n, t, yb in vdl:
                            a = a.to(self.device, non_blocking=True)
                            n = n.to(self.device, non_blocking=True)
                            t = t.to(self.device, non_blocking=True)
                            yb = yb.to(self.device, non_blocking=True)
                            _, pt, px = self.model(a, n, t)
                            l = 0.5*(loss_fn(pt, yb) + loss_fn(px, yb))
                            ev += (l.sum()/len(l)).item(); c += 1
                    ev /= max(1, c)
                    if verbose:
                        print(f"          val_loss: {ev:.6f}")
                    if patience is not None:
                        if ev < best:
                            best = ev; pc = 0
                            self.best_model_state = {k: v.detach().cpu().clone() for k, v in self.model.state_dict().items()}
                        else:
                            pc += 1
                            if pc >= patience:
                                if verbose:
                                    print(f"Early stopping at epoch {ep+1}")
                                break

        if self.best_model_state is not None:
            self.model.load_state_dict(self.best_model_state)
        return self

    def predict_proba(self, X, batch_size=2048):
        xc, xn, xt = self._split_to_cpu_tensors(X)
        pin = (self.device == "cuda")

        if xn is None and xt is None:
            ds = TensorDataset(xc)
            def itr(dl):
                for (a,) in dl: yield a, None, None
        elif xn is None:
            ds = TensorDataset(xc, xt)
            def itr(dl):
                for a, t in dl: yield a, None, t
        elif xt is None:
            ds = TensorDataset(xc, xn)
            def itr(dl):
                for a, n in dl: yield a, n, None
        else:
            ds = TensorDataset(xc, xn, xt)
            def itr(dl):
                for a, n, t in dl: yield a, n, t

        dl = DataLoader(ds, batch_size=batch_size, shuffle=False, pin_memory=pin)
        self.model.eval()
        outs = []
        with torch.no_grad():
            for a, n, t in itr(dl):
                if a is not None: a = a.to(self.device, non_blocking=True)
                if n is not None: n = n.to(self.device, non_blocking=True)
                if t is not None: t = t.to(self.device, non_blocking=True)
                p, _, _ = self.model(a, n, t)
                outs.append(p.detach().cpu())
        p = torch.cat(outs, dim=0).numpy().astype("float")
        return np.hstack([1.0 - p, p])

    def predict(self, X):
        return (self.predict_proba(X)[:, 1] >= 0.5).astype(int)

    def predict_with_uncertainty_flag(self, X, batch_size=2048):
        xc, xn, xt = self._split_to_cpu_tensors(X)
        pin = (self.device == "cuda")

        if xn is None and xt is None:
            ds = TensorDataset(xc)
            def itr(dl):
                for (a,) in dl: yield a, None, None
        elif xn is None:
            ds = TensorDataset(xc, xt)
            def itr(dl):
                for a, t in dl: yield a, None, t
        elif xt is None:
            ds = TensorDataset(xc, xn)
            def itr(dl):
                for a, n in dl: yield a, n, None
        else:
            ds = TensorDataset(xc, xn, xt)
            def itr(dl):
                for a, n, t in dl: yield a, n, t

        dl = DataLoader(ds, batch_size=batch_size, shuffle=False, pin_memory=pin)
        self.model.eval()
        ps, tabs, texts = [], [], []
        loss_fn = self._define_loss()
        with torch.no_grad():
            for a, n, t in itr(dl):
                if a is not None: a = a.to(self.device, non_blocking=True)
                if n is not None: n = n.to(self.device, non_blocking=True)
                if t is not None: t = t.to(self.device, non_blocking=True)
                p, pt, px = self.model(a, n, t)
                ps.append(p.detach().cpu()); tabs.append(pt.detach().cpu()); texts.append(px.detach().cpu())
        p = torch.cat(ps).numpy().reshape(-1)
        pt = torch.cat(tabs).numpy().reshape(-1)
        px = torch.cat(texts).numpy().reshape(-1)
        disagree = ((pt >= 0.5).astype(int) != (px >= 0.5).astype(int))
        return p, disagree


# =========================
# 4) 샘플 데이터 생성 (텍스트 포함)
# =========================
def make_mixed_text_sample(n_samples=40000, seed=7):
    rng = np.random.RandomState(seed)

    genders = np.array(["남성", "여성", "기타"], dtype=object)
    cities  = np.array(["서울","부산","대구","인천","수원","고양"], dtype=object)
    devices = np.array(["ios","android","web"], dtype=object)

    gender_col = rng.choice(genders, size=n_samples, p=[0.48,0.48,0.04])
    city_col   = rng.choice(cities , size=n_samples)
    device_col = rng.choice(devices, size=n_samples, p=[0.35,0.55,0.10])

    n_cont = 10
    X_cont = rng.randn(n_samples, n_cont).astype("float32")

    adjs = ["빠른","튼튼한","가벼운","비싼","저렴한","세련된","불편한","편안한","예쁜","심플한","강력한"]
    nouns= ["자켓","코트","셔츠","신발","가방","지갑","바지","스웨터","시계","안경"]
    sentiments = ["최고","별로","만족","불만","추천","애매"]
    def make_text():
        k = rng.randint(5, 16)
        toks = []
        for _ in range(k):
            bucket = rng.choice([0,1,2], p=[0.5,0.35,0.15])
            if bucket == 0: toks.append(rng.choice(adjs))
            elif bucket == 1: toks.append(rng.choice(nouns))
            else: toks.append(rng.choice(sentiments))
        return " ".join(toks)

    reviews = np.array([make_text() for _ in range(n_samples)], dtype=object)
    titles  = np.array([rng.choice(adjs) + " " + rng.choice(nouns) for _ in range(n_samples)], dtype=object)

    w_gender = {g: w for g, w in zip(genders, rng.uniform(-0.8, 0.8, len(genders)))}
    w_city   = {c: w for c, w in zip(cities , rng.uniform(-0.6, 1.0, len(cities )))}
    w_device = {d: w for d, w in zip(devices, rng.uniform(-0.5, 0.9, len(devices)))}
    w_cont   = rng.randn(n_cont).astype("float32")

    score_cat = (np.vectorize(w_gender.get)(gender_col)
                 + np.vectorize(w_city.get)(city_col)
                 + np.vectorize(w_device.get)(device_col)).astype("float32")
    score_cont = (X_cont * w_cont).sum(axis=1).astype("float32")
    score_text = np.array([1.0 if "최고" in r or "추천" in r else (-0.7 if "불만" in r else 0.0) for r in reviews], dtype="float32")

    bias, noise = 0.1, rng.normal(scale=0.5, size=n_samples).astype("float32")
    logit = 0.6*score_cat + 0.8*score_cont + 0.5*score_text + bias + noise
    prob = 1/(1+np.exp(-logit))
    y = (prob > 0.5).astype("int64")

    X = np.empty((n_samples, 5 + n_cont), dtype=object)
    X[:, 0] = gender_col
    X[:, 1] = city_col
    X[:, 2] = device_col
    X[:, 3] = titles
    X[:, 4] = reviews
    X[:, 5:] = X_cont

    categorical_feature_indices = [0,1,2]
    text_feature_indices = [3,4]
    return X, y, categorical_feature_indices, text_feature_indices


# =========================
# 5) 데모 (학습/평가)
# =========================
def train_and_evaluate_demo():
    np.random.seed(0); torch.manual_seed(0)

    X, y, cat_idx, text_idx = make_mixed_text_sample(n_samples=30000, seed=123)

    N = X.shape[0]; idx = np.arange(N); np.random.shuffle(idx)
    tr_end = int(N*0.7); va_end = int(N*0.85)
    tr, va, te = idx[:tr_end], idx[tr_end:va_end], idx[va_end:]
    X_tr, y_tr = X[tr], y[tr]
    X_va, y_va = X[va], y[va]
    X_te, y_te = X[te], y[te]

    clf = TTTBinaryClassifier(
        auto_preprocess=True,
        categorical_indices=cat_idx,
        text_indices=text_idx,
        use_oov=True,
        max_text_len=96,
        min_freq=2,
        max_vocab=40000,
        quantiles_s=6,
        d_model=128, n_heads=4, n_layers=2, dim_feedforward=256, dropout=0.1,
        lr=1e-3, weight_decay=1e-4,
        device="cuda" if torch.cuda.is_available() else "cpu",
    )

    clf.fit(
        X_tr, y_tr,
        eval_set=[(X_va, y_va)],
        eval_metric=["logloss"],
        max_epochs=10,
        patience=2,
        batch_size=512,
        verbose=True,
    )

    proba = clf.predict_proba(X_te)[:, 1]
    pred  = (proba >= 0.5).astype(int)
    acc = accuracy_score(y_te, pred)
    auc = roc_auc_score(y_te, proba)
    ll  = log_loss(y_te, np.vstack([1-proba, proba]).T)

    print("\n===== Test Metrics =====")
    print(f"Accuracy : {acc:.4f}")
    print(f"ROC-AUC  : {auc:.4f}")
    print(f"Logloss  : {ll:.4f}")

    p_avg, disagree = clf.predict_with_uncertainty_flag(X_te)
    print(f"Uncertainty rate (stream disagreement): {disagree.mean():.3f}")


if __name__ == "__main__":
    train_and_evaluate_demo()

Epoch 1/10 - train_loss: 0.460041
          val_loss: 0.285489
Epoch 2/10 - train_loss: 0.248642
          val_loss: 0.210745
Epoch 3/10 - train_loss: 0.226196
          val_loss: 0.267167
Epoch 4/10 - train_loss: 0.208846
          val_loss: 0.191431
Epoch 5/10 - train_loss: 0.191726
          val_loss: 0.422729
Epoch 6/10 - train_loss: 0.228785
          val_loss: 0.198735
Early stopping at epoch 6

===== Test Metrics =====
Accuracy : 0.9222
ROC-AUC  : 0.9821
Logloss  : 0.1785
Uncertainty rate (stream disagreement): 0.020
