In [1]:
# Third Party
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import accuracy_score, roc_auc_score, log_loss
from torch.utils.data import DataLoader, TensorDataset

# -----------------------------
# TabTransformer Backbone
# -----------------------------
class TabTransformerBackbone(nn.Module):
    """
    Backbone that contextualizes only categorical columns using a Transformer.
    - Sum of per-column embedding and column embedding (tabular positional embedding)
    - (Optional) prepend a [CLS] token for pooling
    - Contextualize with TransformerEncoder
    """
    def __init__(
        self,
        cat_cardinalities,              # List[int], number of unique values per categorical column
        d_token=32,                     # token embedding dimension (Transformer d_model)
        n_heads=4,
        n_layers=2,
        dim_feedforward=128,
        attn_dropout=0.1,               # PyTorch uses one dropout (applies in both self-attn and FFN)
        ff_dropout=0.1,                 # kept for API compatibility (treated same as attn_dropout)
        embedding_dropout=0.1,
        add_cls=False,                  # if True, prepend a [CLS] token
        pooling="concat",               # "concat" or "cls"
        padding_idx=None,               # specify padding index if needed for embeddings
        norm_first=True                 # Pre-LN architecture
    ):
        super().__init__()
        assert pooling in ("concat", "cls")
        self.n_cat = len(cat_cardinalities)
        self.d_token = d_token
        self.add_cls = add_cls
        self.pooling = pooling

        if self.n_cat == 0:
            # No categorical columns -> act as a dummy pass-through
            self.cat_embeddings = nn.ModuleList()
            self.col_embedding = None
        else:
            # Per-column embeddings
            self.cat_embeddings = nn.ModuleList([
                nn.Embedding(
                    num_embeddings=c + (1 if (padding_idx is None) else 0),
                    embedding_dim=d_token,
                    padding_idx=None if (padding_idx is None) else padding_idx
                )
                for c in cat_cardinalities
            ])
            # Column embeddings (tabular counterpart of positional encoding)
            self.col_embedding = nn.Embedding(self.n_cat, d_token)

        # (Optional) CLS token
        if self.add_cls:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, d_token))
            nn.init.normal_(self.cls_token, std=0.02)

        self.embedding_dropout = nn.Dropout(embedding_dropout)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_token,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=attn_dropout,   # single dropout parameter in PyTorch
            batch_first=True,
            norm_first=norm_first
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        # Weight initialization
        for emb in self.cat_embeddings:
            nn.init.normal_(emb.weight, std=0.02)
        if self.col_embedding is not None:
            nn.init.normal_(self.col_embedding.weight, std=0.02)

    def forward(self, x_cat: torch.LongTensor):
        """
        Args:
            x_cat: LongTensor of shape (B, n_cat)
        Returns:
            If pooling='concat' -> FloatTensor of shape (B, n_cat * d_token)
            If pooling='cls'    -> FloatTensor of shape (B, d_token)
        """
        B = x_cat.size(0)
        if self.n_cat == 0:
            # No categorical columns
            if self.pooling == "cls":
                out = torch.zeros(B, self.d_token, device=x_cat.device, dtype=torch.float32)
            else:
                out = torch.zeros(B, 0, device=x_cat.device, dtype=torch.float32)
            return out

        tok_list = []
        for j, emb in enumerate(self.cat_embeddings):
            tok = emb(x_cat[:, j])                         # (B, d)
            if self.col_embedding is not None:
                tok = tok + self.col_embedding.weight[j]   # add column embedding (d,)
            tok_list.append(tok.unsqueeze(1))              # (B, 1, d)

        x_tok = torch.cat(tok_list, dim=1)                 # (B, n_cat, d)

        if self.add_cls:
            cls = self.cls_token.expand(B, -1, -1)         # (B, 1, d)
            x_tok = torch.cat([cls, x_tok], dim=1)         # (B, 1 + n_cat, d)

        x_tok = self.embedding_dropout(x_tok)              # (B, T, d)
        z = self.transformer(x_tok)                        # (B, T, d)

        if self.pooling == "cls" and self.add_cls:
            out = z[:, 0, :]                               # (B, d)
        elif self.pooling == "cls" and not self.add_cls:
            out = z.mean(dim=1)                            # (B, d)
        else:
            if self.add_cls:
                z = z[:, 1:, :]                            # (B, n_cat, d) — drop CLS
            out = z.reshape(B, -1)                         # (B, n_cat * d)

        return out


# -----------------------------
# Full Model (Backbone + Head)
# -----------------------------
class TabTransformerModel(nn.Module):
    """
    Full TabTransformer:
    - Transformer backbone contextualizes categorical features
    - Continuous features are normalized + optionally projected
    - Concatenate -> MLP head -> single logit -> sigmoid
    """
    def __init__(
        self,
        cat_cardinalities,          # List[int]
        n_continuous=0,
        d_token=32,
        n_heads=4,
        n_layers=2,
        dim_feedforward=128,
        attn_dropout=0.1,
        ff_dropout=0.1,             # kept for API compatibility
        embedding_dropout=0.1,
        add_cls=False,
        pooling="concat",
        cont_proj="linear",         # "none" or "linear"
        mlp_hidden_dims=(128, 64),  # head MLP sizes
        mlp_dropout=0.2
    ):
        super().__init__()
        self.n_cont = n_continuous
        self.cont_proj = cont_proj

        # Backbone
        self.backbone = TabTransformerBackbone(
            cat_cardinalities=cat_cardinalities,
            d_token=d_token,
            n_heads=n_heads,
            n_layers=n_layers,
            dim_feedforward=dim_feedforward,
            attn_dropout=attn_dropout,
            ff_dropout=ff_dropout,
            embedding_dropout=embedding_dropout,
            add_cls=add_cls,
            pooling=pooling,
            norm_first=True
        )

        # Continuous feature processing
        if n_continuous > 0:
            self.cont_bn = nn.BatchNorm1d(n_continuous)
            if cont_proj == "linear":
                self.cont_linear = nn.Linear(n_continuous, d_token)
                nn.init.kaiming_uniform_(self.cont_linear.weight, mode="fan_in", nonlinearity="relu")
                cont_out_dim = d_token
            else:
                self.cont_linear = nn.Identity()
                cont_out_dim = n_continuous
        else:
            self.cont_bn = None
            self.cont_linear = None
            cont_out_dim = 0

        # Backbone output dimension
        if pooling == "cls":
            backbone_out = d_token
        else:
            backbone_out = len(cat_cardinalities) * d_token

        in_dim = backbone_out + cont_out_dim

        # Head MLP
        mlp_layers = []
        prev = in_dim
        for h in mlp_hidden_dims:
            lin = nn.Linear(prev, h)
            nn.init.kaiming_uniform_(lin.weight, mode="fan_in", nonlinearity="relu")
            mlp_layers.extend([lin, nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(mlp_dropout)])
            prev = h
        mlp_layers.append(nn.Linear(prev, 1))
        mlp_layers.append(nn.Sigmoid())
        self.head = nn.Sequential(*mlp_layers)

    def forward(self, x_cat: torch.LongTensor, x_cont: torch.FloatTensor = None):
        """
        Args:
            x_cat: (B, n_cat) long
            x_cont: (B, n_cont) float or None
        Returns:
            (B, 1) sigmoid probability
        """
        z_cat = self.backbone(x_cat)  # (B, d_backbone)

        if (x_cont is not None) and (self.n_cont > 0):
            if x_cont.ndim == 1:
                x_cont = x_cont.unsqueeze(1)
            x_cont = self.cont_bn(x_cont)
            x_cont = self.cont_linear(x_cont)
            z = torch.cat([z_cat, x_cont], dim=1)
        else:
            z = z_cat

        out = self.head(z)            # (B, 1)
        return out


# -----------------------------
# Sklearn-Compatible Classifier
# -----------------------------
class TabTransformerBinaryClassifier(BaseEstimator, ClassifierMixin):
    """
    Example:
        clf = TabTransformerBinaryClassifier(
            cat_idx=[0, 1, 5],                # categorical column indices
            cat_cardinalities=[10, 5, 20],    # cardinality per categorical column
            cont_idx=[2, 3, 4],               # continuous column indices (or [])
            d_token=32, n_heads=4, n_layers=2,
            hidden_dims=(128, 64), lr=1e-3
        )
    """
    def __init__(
        self,
        cat_idx,
        cat_cardinalities,
        cont_idx=None,
        d_token=32,
        n_heads=4,
        n_layers=2,
        dim_feedforward=128,
        attn_dropout=0.1,
        ff_dropout=0.1,
        embedding_dropout=0.1,
        add_cls=False,
        pooling="concat",
        cont_proj="linear",
        hidden_dims=(128, 64),
        mlp_dropout=0.2,
        lr=1e-3,
        weight_decay=1e-4,
        loss_fn="logloss",
        device=None
    ):
        self.cat_idx = list(cat_idx)
        self.cat_cardinalities = list(cat_cardinalities)
        self.cont_idx = list(cont_idx) if cont_idx is not None else []

        # Model / training hyperparameters
        self.d_token = d_token
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.dim_feedforward = dim_feedforward
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
        self.embedding_dropout = embedding_dropout
        self.add_cls = add_cls
        self.pooling = pooling
        self.cont_proj = cont_proj
        self.hidden_dims = hidden_dims
        self.mlp_dropout = mlp_dropout
        self.lr = lr
        self.weight_decay = weight_decay
        self.loss_fn_name = loss_fn

        # Internal state
        self.model = None
        self.best_model_weights = None
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    # -------------------------
    # Internals
    # -------------------------
    def _build_model(self):
        n_cont = len(self.cont_idx)
        model = TabTransformerModel(
            cat_cardinalities=self.cat_cardinalities,
            n_continuous=n_cont,
            d_token=self.d_token,
            n_heads=self.n_heads,
            n_layers=self.n_layers,
            dim_feedforward=self.dim_feedforward,
            attn_dropout=self.attn_dropout,
            ff_dropout=self.ff_dropout,
            embedding_dropout=self.embedding_dropout,
            add_cls=self.add_cls,
            pooling=self.pooling,
            cont_proj=self.cont_proj,
            mlp_hidden_dims=self.hidden_dims,
            mlp_dropout=self.mlp_dropout
        )
        return model.to(self.device)

    def _define_loss_fn(self):
        if self.loss_fn_name == "logloss":
            return nn.BCELoss(reduction="none")
        else:
            raise Exception(f"{self.loss_fn_name} is not defined")

    def _split_X(self, X):
        """
        Split input X (np.ndarray or torch.Tensor) into:
        - x_cat: LongTensor (B, n_cat)
        - x_cont: FloatTensor (B, n_cont) or None
        """
        if isinstance(X, torch.Tensor):
            X_np = X.detach().cpu().numpy()
        else:
            X_np = X

        if len(self.cat_idx) > 0:
            x_cat_np = X_np[:, self.cat_idx].astype("int64")
        else:
            x_cat_np = np.zeros((X_np.shape[0], 0), dtype="int64")

        if len(self.cont_idx) > 0:
            x_cont_np = X_np[:, self.cont_idx].astype("float32")
        else:
            x_cont_np = None

        x_cat = torch.tensor(x_cat_np, dtype=torch.long, device=self.device)
        x_cont = torch.tensor(x_cont_np, dtype=torch.float32, device=self.device) if x_cont_np is not None else None
        return x_cat, x_cont

    # -------------------------
    # Public API
    # -------------------------
    def fit(
        self,
        X,
        y,
        sample_weight=None,
        eval_set=None,            # list of tuples: [(X_val, y_val)]
        eval_metric=None,         # supports ["logloss"] only
        max_epochs=10,
        patience=None,
        batch_size=32,
        num_workers=0,
        verbose=True,
        pin_memory=None           # set True when using CUDA for faster host->device transfer
    ):
        if pin_memory is None:
            pin_memory = (self.device == "cuda")

        # Prepare tensors
        x_cat, x_cont = self._split_X(X)
        y_tensor = torch.tensor(y, dtype=torch.float32, device=self.device).view(-1, 1)

        if sample_weight is not None:
            w_tensor = torch.tensor(sample_weight, dtype=torch.float32, device=self.device).view(-1, 1)
        else:
            w_tensor = torch.ones_like(y_tensor, dtype=torch.float32, device=self.device)

        if eval_set is not None:
            X_val, y_val = eval_set[0]
            x_cat_val, x_cont_val = self._split_X(X_val)
            y_val_tensor = torch.tensor(y_val, dtype=torch.float32, device=self.device).view(-1, 1)
        else:
            x_cat_val = x_cont_val = y_val_tensor = None

        # Build model
        if self.model is None:
            self.model = self._build_model()

        loss_fn = self._define_loss_fn()
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        # DataLoader
        if x_cont is None:
            train_dataset = TensorDataset(x_cat, y_tensor, w_tensor)
        else:
            train_dataset = TensorDataset(x_cat, x_cont, y_tensor, w_tensor)

        def _make_train_loader():
            return DataLoader(
                train_dataset, batch_size=batch_size, shuffle=True,
                num_workers=num_workers, pin_memory=pin_memory
            )

        train_loader = _make_train_loader()

        best_loss = float("inf")
        patience_counter = 0

        for epoch in range(max_epochs):
            self.model.train()
            epoch_loss = 0.0
            n_steps = 0

            if x_cont is None:
                for Xc_b, y_b, w_b in train_loader:
                    optimizer.zero_grad()
                    y_pred = self.model(Xc_b, None)
                    loss = loss_fn(y_pred, y_b)
                    weighted_loss = (loss * w_b).sum() / w_b.sum()
                    weighted_loss.backward()
                    optimizer.step()
                    epoch_loss += weighted_loss.item()
                    n_steps += 1
            else:
                for Xc_b, Xn_b, y_b, w_b in train_loader:
                    optimizer.zero_grad()
                    y_pred = self.model(Xc_b, Xn_b)
                    loss = loss_fn(y_pred, y_b)
                    weighted_loss = (loss * w_b).sum() / w_b.sum()
                    weighted_loss.backward()
                    optimizer.step()
                    epoch_loss += weighted_loss.item()
                    n_steps += 1

            train_avg = epoch_loss / max(1, n_steps)
            if verbose:
                print(f"Epoch {epoch + 1}/{max_epochs} - train_loss: {train_avg:.6f}")

            # -------------------------
            # Validation
            # -------------------------
            if eval_set is not None:
                self.model.eval()
                if eval_metric is not None:
                    for m in eval_metric:
                        if m != "logloss":
                            raise Exception(f"{eval_metric} is not defined")

                with torch.no_grad():
                    if x_cont_val is None:
                        val_dataset = TensorDataset(x_cat_val, y_val_tensor)
                        val_loader = DataLoader(
                            val_dataset, batch_size=2048, shuffle=False,
                            num_workers=num_workers, pin_memory=pin_memory
                        )
                        eval_loss = 0.0
                        n_eval = 0
                        for Xc_v, y_v in val_loader:
                            y_pred_v = self.model(Xc_v, None)
                            loss_v = loss_fn(y_pred_v, y_v)
                            eval_loss += (loss_v.sum() / len(loss_v)).item()
                            n_eval += 1
                    else:
                        val_dataset = TensorDataset(x_cat_val, x_cont_val, y_val_tensor)
                        val_loader = DataLoader(
                            val_dataset, batch_size=2048, shuffle=False,
                            num_workers=num_workers, pin_memory=pin_memory
                        )
                        eval_loss = 0.0
                        n_eval = 0
                        for Xc_v, Xn_v, y_v in val_loader:
                            y_pred_v = self.model(Xc_v, Xn_v)
                            loss_v = loss_fn(y_pred_v, y_v)
                            eval_loss += (loss_v.sum() / len(loss_v)).item()
                            n_eval += 1

                    eval_loss = eval_loss / max(1, n_eval)
                    if verbose:
                        print(f"          val_loss: {eval_loss:.6f}")

                    if patience is not None:
                        if eval_loss < best_loss:
                            best_loss = eval_loss
                            patience_counter = 0
                            self.best_model_weights = {k: v.detach().cpu().clone() for k, v in self.model.state_dict().items()}
                        else:
                            patience_counter += 1
                            if patience_counter >= patience:
                                if verbose:
                                    print(f"Early stopping at epoch {epoch + 1}")
                                break

        if self.best_model_weights is not None:
            self.model.load_state_dict(self.best_model_weights)

        return self

    def predict_proba(self, X):
        x_cat, x_cont = self._split_X(X)
        self.model.eval()
        with torch.no_grad():
            probs1 = self.model(x_cat, x_cont).detach().cpu().numpy()  # (B, 1) sigmoid
        probs1 = probs1.astype("float")
        probs0 = 1.0 - probs1
        return np.hstack([probs0, probs1])

    def predict(self, X):
        probs = self.predict_proba(X)
        return probs.argmax(axis=1)


# ============================================================
# Synthetic Data Generator + Training Demo
# ============================================================
def make_synthetic_tabular(
    n_samples=20000,
    cat_cardinalities=(12, 7, 25),   # 3 categorical columns
    n_cont=4,                        # 4 continuous columns
    seed=42
):
    rng = np.random.RandomState(seed)
    n_cat = len(cat_cardinalities)

    # Generate categorical columns (each column 0..K-1)
    X_cat = np.column_stack([
        rng.randint(0, c, size=n_samples).astype("int64") for c in cat_cardinalities
    ])  # (N, n_cat)

    # Generate continuous columns
    X_cont = rng.randn(n_samples, n_cont).astype("float32")  # (N, n_cont)

    # Latent score: categorical contribution + continuous linear combo + noise
    cat_weights = [rng.randn(c) * rng.uniform(0.3, 1.0) for c in cat_cardinalities]
    score_cat = np.zeros(n_samples, dtype="float32")
    for j in range(n_cat):
        score_cat += cat_weights[j][X_cat[:, j]]

    w_cont = rng.randn(n_cont).astype("float32")
    score_cont = (X_cont * w_cont).sum(axis=1)

    bias = 0.2
    noise = rng.normal(scale=0.5, size=n_samples).astype("float32")

    logit = 0.6 * score_cat + 0.8 * score_cont + bias + noise
    prob = 1 / (1 + np.exp(-logit))
    y = (prob > 0.5).astype("int64")

    # Concatenate as [categorical | continuous]; later split by indices
    X = np.concatenate([X_cat.astype("float32"), X_cont], axis=1)
    cat_idx = list(range(n_cat))
    cont_idx = list(range(n_cat, n_cat + n_cont))

    return X, y, cat_idx, cont_idx, list(cat_cardinalities)


def train_and_evaluate_demo():
    # Reproducibility
    np.random.seed(0)
    torch.manual_seed(0)

    # ---------------- Data ----------------
    X, y, cat_idx, cont_idx, cat_cardinalities = make_synthetic_tabular(
        n_samples=20000, cat_cardinalities=(12, 7, 25), n_cont=4, seed=13
    )

    # Train/Val/Test split
    N = X.shape[0]
    idx = np.arange(N)
    np.random.shuffle(idx)

    tr_end = int(N * 0.7)
    va_end = int(N * 0.85)

    tr_idx = idx[:tr_end]
    va_idx = idx[tr_end:va_end]
    te_idx = idx[va_end:]

    X_tr, y_tr = X[tr_idx], y[tr_idx]
    X_va, y_va = X[va_idx], y[va_idx]
    X_te, y_te = X[te_idx], y[te_idx]

    # ---------------- Model ----------------
    clf = TabTransformerBinaryClassifier(
        cat_idx=cat_idx,
        cat_cardinalities=cat_cardinalities,
        cont_idx=cont_idx,
        d_token=32,
        n_heads=4,
        n_layers=2,
        dim_feedforward=128,
        attn_dropout=0.1,
        embedding_dropout=0.05,
        pooling="concat",          # or "cls" (with add_cls=True)
        add_cls=False,
        cont_proj="linear",
        hidden_dims=(128, 64),
        mlp_dropout=0.2,
        lr=1e-3,
        weight_decay=1e-4,
        loss_fn="logloss",
        device="cuda" if torch.cuda.is_available() else "cpu"
    )

    clf.fit(
        X_tr, y_tr,
        eval_set=[(X_va, y_va)],
        eval_metric=["logloss"],
        max_epochs=15,
        patience=3,
        batch_size=256,
        verbose=True
    )

    # ---------------- Eval ----------------
    proba_va = clf.predict_proba(X_va)[:, 1]
    proba_te = clf.predict_proba(X_te)[:, 1]
    pred_te = (proba_te >= 0.5).astype(int)

    acc = accuracy_score(y_te, pred_te)
    auc = roc_auc_score(y_te, proba_te)
    ll  = log_loss(y_te, np.vstack([1 - proba_te, proba_te]).T)

    print("\n===== Test Metrics =====")
    print(f"Accuracy : {acc:.4f}")
    print(f"ROC-AUC  : {auc:.4f}")
    print(f"Logloss  : {ll:.4f}")

    # Show sample predictions
    print("\nSample predictions (first 10):")
    print(np.round(proba_te[:10], 4))

if __name__ == "__main__":
    train_and_evaluate_demo()




Epoch 1/15 - train_loss: 0.345680
          val_loss: 0.193235
Epoch 2/15 - train_loss: 0.179509
          val_loss: 0.147929
Epoch 3/15 - train_loss: 0.149519
          val_loss: 0.130593
Epoch 4/15 - train_loss: 0.141352
          val_loss: 0.126637
Epoch 5/15 - train_loss: 0.138444
          val_loss: 0.122914
Epoch 6/15 - train_loss: 0.138383
          val_loss: 0.123396
Epoch 7/15 - train_loss: 0.133943
          val_loss: 0.124393
Epoch 8/15 - train_loss: 0.134532
          val_loss: 0.124173
Early stopping at epoch 8

===== Test Metrics =====
Accuracy : 0.9550
ROC-AUC  : 0.9940
Logloss  : 0.1048

Sample predictions (first 10):
[0.0577 0.0019 0.9995 0.     0.9976 0.0043 0.0666 0.9945 0.9992 0.9986]
