# TabTransformer++ for Residual Learning

> **Based on the original TabTransformer paper:**
> 
> Huang, X., Khetan, A., Cvitkovic, M., & Karnin, Z. (2020). *TabTransformer: Tabular Data Modeling Using Contextual Embeddings*. arXiv:2012.06678
> 
> üìÑ [arXiv Paper](https://arxiv.org/abs/2012.06678) | üîó [GitHub](https://github.com/lucidrains/tab-transformer-pytorch)

This notebook extends the TabTransformer architecture with additional innovations for **residual learning**. The key idea is:

1. Train a simple "base" model (Ridge Regression) to make initial predictions
2. Train a TabTransformer to predict the **residuals** (errors) of the base model
3. Combine: `Final Prediction = Base Prediction + Predicted Residual`

This stacking technique often yields better results than either model alone.

## Key Components
- **Quantile Binning**: Converts continuous features into discrete tokens
- **Gated Fusion**: Learns to balance binned tokens with raw scalar values *(novel extension)*
- **EMA (Exponential Moving Average)**: Polyak averaging for more stable predictions
- **Isotonic Calibration**: Post-processing to improve residual predictions

## 1. Setup & Configuration

Import required libraries and define hyperparameters:

- **Feature Engineering**: Number of bins for quantile discretization
- **Model Architecture**: Embedding dimensions, attention heads, transformer layers
- **Training**: Learning rate, batch size, EMA decay for Polyak averaging

In [None]:
import os
import gc
import time
import warnings
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import KFold, train_test_split
from sklearn.linear_model import Ridge
from sklearn.ensemble import RandomForestRegressor
from sklearn.isotonic import IsotonicRegression
from sklearn.metrics import root_mean_squared_error

warnings.filterwarnings("ignore")
pd.set_option("mode.copy_on_write", True)

# =============================================================================
# Configuration - TabTransformer++ Hyperparameters
# =============================================================================
class Config:
    SEED            = 2025
    
    # --- Quantile Binning (Tokenization) ---
    NBINS           = 32        # Bins for raw numeric features
    NBINS_BASE      = 128       # Finer bins for base model predictions
    NBINS_DT        = 64        # Bins for tree model predictions
    
    # --- TabTransformer++ Architecture ---
    EMB_DIM         = 64        # Embedding dimension (d_model)
    N_HEADS         = 4         # Multi-head attention heads
    N_LAYERS        = 3         # Transformer encoder layers
    MLP_HID         = 192       # Prediction head hidden dimension
    DROPOUT         = 0.1       # Attention & feedforward dropout
    EMB_DROPOUT     = 0.05      # Post-embedding dropout
    TOKENDROP_P     = 0.12      # TokenDrop regularization probability
    
    # --- Training ---
    EPOCHS          = 10        # Training epochs (shortened for demo)
    BATCH_SIZE      = 1024
    LR              = 2e-3      # AdamW learning rate
    WEIGHT_DECAY    = 1e-5      # L2 regularization
    EMA_DECAY       = 0.995     # Exponential Moving Average (Polyak averaging)
    DEVICE          = "cuda" if torch.cuda.is_available() else "cpu"

def seed_everything(seed):
    """Set random seeds for reproducibility."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

seed_everything(Config.SEED)
print(f"Running on {Config.DEVICE}")

Running on cpu


## 2. Data Simulation: Building the "Stack"

This section simulates a real-world stacking scenario:

1. **Load Data**: California Housing dataset (predicting median house values)
2. **Train Base Models** using K-Fold cross-validation:
   - `Ridge` regression ‚Üí generates `base_pred` (our primary predictions)
   - `RandomForest` ‚Üí generates `dt_pred` (provides additional signal)
3. **Calculate Residuals**: `residual = target - base_pred`
   - This is what the TabTransformer will learn to predict

The out-of-fold (OOF) predictions prevent data leakage.

In [None]:
def get_simulated_data():
    """
    Simulate a model stacking scenario for residual learning.
    
    Steps:
        1. Load California Housing dataset
        2. Train base models (Ridge, RandomForest) with K-Fold CV
        3. Generate out-of-fold (OOF) predictions to avoid leakage
        4. Calculate residuals: target - base_prediction
    
    Returns:
        train_df: Training data with base_pred, dt_pred, residual columns
        test_df: Test data with base_pred, dt_pred columns
        features: List of original feature column names
    """
    print("\n--- 1. Simulating Base & DT Models (The 'Stack') ---")
    data = fetch_california_housing(as_frame=True)
    df = data.frame
    target_col = "MedHouseVal"
    
    # Hold out 20% as test set (simulates private leaderboard)
    train_df, test_df = train_test_split(df, test_size=0.2, random_state=Config.SEED)
    train_df = train_df.reset_index(drop=True)
    test_df = test_df.reset_index(drop=True)
    
    # Initialize prediction columns
    train_df["base_pred"] = 0.0
    train_df["dt_pred"] = 0.0
    train_df["fold"] = -1
    
    # K-Fold for leak-free OOF predictions
    kf = KFold(n_splits=5, shuffle=True, random_state=Config.SEED)
    
    # Base models for the stack
    model_base = Ridge(alpha=1.0)
    model_dt = RandomForestRegressor(n_estimators=20, max_depth=8, n_jobs=-1, random_state=Config.SEED)
    
    print("Generating OOF predictions...")
    for fold, (tr_idx, val_idx) in enumerate(kf.split(train_df)):
        X_tr = train_df.loc[tr_idx].drop(columns=[target_col, "base_pred", "dt_pred", "fold"])
        y_tr = train_df.loc[tr_idx, target_col]
        X_val = train_df.loc[val_idx].drop(columns=[target_col, "base_pred", "dt_pred", "fold"])
        
        # Ridge regression (base model)
        model_base.fit(X_tr, y_tr)
        train_df.loc[val_idx, "base_pred"] = model_base.predict(X_val)
        
        # Random Forest (tree-based model)
        model_dt.fit(X_tr, y_tr)
        train_df.loc[val_idx, "dt_pred"] = model_dt.predict(X_val)
        
        train_df.loc[val_idx, "fold"] = fold

    # Generate test predictions (trained on full training set)
    print("Generating Test predictions...")
    X_full = train_df.drop(columns=[target_col, "base_pred", "dt_pred", "fold"])
    y_full = train_df[target_col]
    X_test = test_df.drop(columns=[target_col])
    
    model_base.fit(X_full, y_full)
    test_df["base_pred"] = model_base.predict(X_test)
    
    model_dt.fit(X_full, y_full)
    test_df["dt_pred"] = model_dt.predict(X_test)
    
    # Calculate residuals - this is what TabTransformer++ will predict
    train_df["residual"] = train_df[target_col] - train_df["base_pred"]
    
    # Extract original feature names
    features = [c for c in train_df.columns if c not in [target_col, "base_pred", "dt_pred", "fold", "residual"]]
    
    base_rmse = root_mean_squared_error(train_df[target_col], train_df['base_pred'])
    print(f"Base Model RMSE (Train OOF): {base_rmse:.4f}")
    
    return train_df, test_df, features

# Run simulation
train_df, test_df, features = get_simulated_data()


--- 1. Simulating Base & DT Models (The 'Stack') ---
Generating OOF predictions...
Generating Test predictions...
Base Model RMSE (Train OOF): 0.8094


## 3. Tabular Tokenizer

The `TabularTokenizer` prepares data for the transformer:

### Quantile Binning (Discretization)
- Converts continuous features into discrete "tokens" (like words in NLP)
- Uses quantile-based bins so each bin has roughly equal samples
- Different bin counts for features (32), base predictions (128), and tree predictions (64)

### Z-Score Normalization
- Standardizes raw values: `(x - mean) / std`
- Preserves the original numeric information alongside tokens

This dual representation (tokens + scalars) gives the model both discrete patterns and continuous precision.

In [None]:
class TabularTokenizer:
    """
    Dual-representation tokenizer for TabTransformer++.
    
    Creates two representations per feature:
        1. Token IDs: Quantile bin indices (discrete)
        2. Scalar values: Z-score normalized (continuous)
    
    This enables the Gated Fusion mechanism to blend discrete patterns
    with continuous precision.
    """
    def __init__(self, cols):
        self.cols = cols
        self.edges = {}   # Quantile bin edges per feature
        self.stats = {}   # (mean, std) for z-scoring
        
    def _make_edges(self, x, nbins):
        """Compute quantile-based bin edges."""
        x = x[np.isfinite(x)]
        if len(x) == 0: 
            return np.array([0.0, 1.0])
        qs = np.linspace(0.0, 1.0, nbins + 1)
        edges = np.unique(np.quantile(x, qs))
        if len(edges) < 2: 
            edges = np.array([x.min(), x.max() + 1e-6])
        return edges

    def fit(self, df):
        """Fit tokenizer on training data only (leak-free)."""
        # Original features
        for c in self.cols:
            self.edges[c] = self._make_edges(df[c].values, Config.NBINS)
            self.stats[c] = (df[c].mean(), df[c].std() + 1e-8)
            
        # Base model predictions (finer bins for precision)
        self.edges["_base_"] = self._make_edges(df["base_pred"].values, Config.NBINS_BASE)
        self.stats["_base_"] = (df["base_pred"].mean(), df["base_pred"].std() + 1e-8)
        
        # Tree model predictions
        self.edges["_dt_"] = self._make_edges(df["dt_pred"].values, Config.NBINS_DT)
        self.stats["_dt_"] = (df["dt_pred"].mean(), df["dt_pred"].std() + 1e-8)
        
        # Target (residual) statistics for z-scoring
        self.stats["_target_"] = (df["residual"].mean(), df["residual"].std() + 1e-8)

    def transform(self, df):
        """
        Transform data to dual representation.
        
        Returns:
            toks: Token IDs [N, T] - discrete bin indices
            vals: Scalar values [N, T] - z-score normalized
        """
        N = len(df)
        T = len(self.cols) + 2  # features + base_pred + dt_pred
        
        toks = np.zeros((N, T), dtype=np.int64)
        vals = np.zeros((N, T), dtype=np.float32)
        
        def _process_column(col_name, edge_key, stat_key, out_idx):
            v = df[col_name].values
            # Discretize: assign to quantile bins
            idx = np.searchsorted(self.edges[edge_key], v, side="right") - 1
            toks[:, out_idx] = np.clip(idx, 0, len(self.edges[edge_key]) - 2)
            # Normalize: z-score standardization
            mu, sd = self.stats[stat_key]
            vals[:, out_idx] = (v - mu) / sd

        # Process original features
        for i, c in enumerate(self.cols):
            _process_column(c, c, c, i)
            
        # Process stacked predictions (base & tree models)
        _process_column("base_pred", "_base_", "_base_", T - 2)
        _process_column("dt_pred", "_dt_", "_dt_", T - 1)
        
        return toks, vals
    
    def get_vocab_sizes(self):
        """Get vocabulary size for each feature's embedding layer."""
        sizes = [len(self.edges[c]) - 1 for c in self.cols]
        sizes.append(len(self.edges["_base_"]) - 1)
        sizes.append(len(self.edges["_dt_"]) - 1)
        return sizes

## 4. Model Architecture

### TabTransformerGated ‚Äî Architectural Innovations

This implementation extends the original TabTransformer with several key innovations for tabular data:

---

### üî∑ Innovation 1: Dual Representation (Tokens + Scalars)

Unlike standard transformers that use only discrete tokens, we maintain **both representations**:

| Representation | How it's Created | What it Captures |
|----------------|------------------|------------------|
| **Token Embedding** | Quantile bin ‚Üí learned embedding | Discrete patterns, ordinal relationships |
| **Value Embedding** | Raw scalar ‚Üí MLP projection | Precise numeric magnitude |

**Why both?** Binning loses precision (e.g., 1.01 and 1.99 might share a bin), but raw scalars lack the pattern-matching power of embeddings.

---

### üî∑ Innovation 2: Learnable Gated Fusion

Each feature has a **learnable gate** (initialized to 0) that controls the blend:

```
final_embedding[i] = token_emb[i] + œÉ(gate[i]) √ó value_emb[i]
```

- `œÉ(gate)` is a sigmoid, so fusion weight is in [0, 1]
- **Gate ‚âà 0**: Model relies mostly on discrete token patterns
- **Gate ‚âà 1**: Model uses both discrete + continuous equally
- Gates are learned per-feature, so the model adapts to each column's characteristics

---

### üî∑ Innovation 3: Per-Token Value MLPs

Instead of a single shared MLP for all features, each feature gets its **own projection network**:

```python
PerTokenValMLP: Linear(1 ‚Üí 64) ‚Üí GELU ‚Üí Linear(64 ‚Üí 64) ‚Üí LayerNorm
```

This allows different features to learn different transformations (e.g., log-like for skewed features, linear for normal ones).

---

### üî∑ Innovation 4: TokenDrop Regularization

During training, we randomly **zero out** feature embeddings with probability `p=0.12`:

```python
mask = (random > p)  # per-sample, per-feature
mask[:, 0] = 1.0     # Never drop CLS token
x = x * mask
```

Benefits:
- Forces model to not over-rely on any single feature
- Similar to dropout but at the feature level
- Improves generalization on tabular data

---

### üî∑ Innovation 5: CLS Token Aggregation

Following BERT's approach, we prepend a special `[CLS]` token:

```
Input:  [CLS, feat_1, feat_2, ..., feat_n, base_pred, dt_pred]
Output: Use CLS embedding for final prediction
```

The transformer's self-attention allows CLS to attend to all features and learn a global representation.

---

### üî∑ Innovation 6: Pre-LayerNorm Transformer

We use `norm_first=True` (Pre-LN) instead of Post-LN:

```
Pre-LN:  x = x + Attention(LayerNorm(x))
Post-LN: x = LayerNorm(x + Attention(x))
```

Pre-LN is more stable for training and doesn't require careful learning rate warmup.

---

### Full Architecture Summary

```
Input: (tokens, values) for each of T features
         ‚Üì
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ  Per-Feature Processing (for i in T):  ‚îÇ
‚îÇ    token_emb = Embedding(token[i])      ‚îÇ
‚îÇ    value_emb = MLP(value[i])            ‚îÇ
‚îÇ    gate = sigmoid(learnable_param[i])   ‚îÇ
‚îÇ    feat[i] = token_emb + gate*value_emb ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
         ‚Üì
    Embedding Dropout (p=0.05)
         ‚Üì
    Prepend [CLS] token
         ‚Üì
    TokenDrop (p=0.12, training only)
         ‚Üì
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ  Transformer Encoder (3 layers):        ‚îÇ
‚îÇ    - 4 attention heads                  ‚îÇ
‚îÇ    - dim=64, feedforward=256            ‚îÇ
‚îÇ    - Pre-LayerNorm, GELU activation     ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
         ‚Üì
    Extract [CLS] embedding
         ‚Üì
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ  Prediction Head:                       ‚îÇ
‚îÇ    LayerNorm ‚Üí Linear(64‚Üí192) ‚Üí GELU    ‚îÇ
‚îÇ    ‚Üí Dropout ‚Üí Linear(192‚Üí1)            ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
         ‚Üì
    Output: Predicted Residual (z-scored)
```

In [None]:
class TokenDrop(nn.Module):
    """
    Feature-level dropout regularization (Innovation #4).
    
    Randomly zeros out feature embeddings during training to prevent
    over-reliance on any single feature. CLS token is never dropped.
    """
    def __init__(self, p=0.1):
        super().__init__()
        self.p = p
        
    def forward(self, x):
        # x: [B, 1+T, D] where first token is CLS
        if not self.training or self.p <= 0: 
            return x
        mask = (torch.rand(x.shape[0], x.shape[1], 1, device=x.device) > self.p).float()
        mask[:, 0, :] = 1.0  # Preserve CLS token
        return x * mask


class PerTokenValMLP(nn.Module):
    """
    Per-feature value projection network (Innovation #3).
    
    Each feature gets its own MLP to project scalar values to embedding space.
    This allows different features to learn different transformations.
    """
    def __init__(self, emb_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
            nn.LayerNorm(emb_dim)
        )
        
    def forward(self, x): 
        return self.net(x)


class TabTransformerGated(nn.Module):
    """
    TabTransformer++ with Gated Fusion.
    
    Key architectural innovations:
        1. Dual representation (tokens + scalars)
        2. Learnable per-feature gates for fusion
        3. Per-token value MLPs
        4. TokenDrop regularization
        5. CLS token for aggregation
        6. Pre-LayerNorm transformer (norm_first=True)
    """
    def __init__(self, vocab_sizes):
        super().__init__()
        self.num_tokens = len(vocab_sizes)
        
        # Innovation #1: Token embeddings (discrete representation)
        self.embs = nn.ModuleList([
            nn.Embedding(v + 1, Config.EMB_DIM) for v in vocab_sizes
        ])
        
        # Innovation #3: Per-feature value MLPs (continuous representation)
        self.val_mlps = nn.ModuleList([
            PerTokenValMLP(Config.EMB_DIM) for _ in vocab_sizes
        ])
        
        # Innovation #2: Learnable gates for fusion (initialized to 0)
        self.gates = nn.ParameterList([
            nn.Parameter(torch.zeros(1)) for _ in vocab_sizes
        ])
        self.sigmoid = nn.Sigmoid()
        
        # Innovation #5: CLS token for global aggregation
        self.cls_token = nn.Parameter(torch.zeros(1, 1, Config.EMB_DIM))
        self.emb_dropout = nn.Dropout(Config.EMB_DROPOUT)
        
        # Innovation #4: TokenDrop regularization
        self.tokendrop = TokenDrop(Config.TOKENDROP_P)
        
        # Innovation #6: Pre-LayerNorm Transformer (stable training)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=Config.EMB_DIM, 
            nhead=Config.N_HEADS, 
            dim_feedforward=Config.EMB_DIM * 4,
            dropout=Config.DROPOUT, 
            batch_first=True, 
            norm_first=True,  # Pre-LN for stability
            activation="gelu"
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=Config.N_LAYERS)
        
        # Prediction head
        self.head = nn.Sequential(
            nn.LayerNorm(Config.EMB_DIM),
            nn.Linear(Config.EMB_DIM, Config.MLP_HID),
            nn.GELU(),
            nn.Dropout(Config.DROPOUT),
            nn.Linear(Config.MLP_HID, 1)
        )
        
    def forward(self, x_tok, x_val):
        B = x_tok.shape[0]
        
        # Gated Fusion: embedding[i] = token_emb + sigmoid(gate) * value_emb
        emb_list = []
        for i in range(self.num_tokens):
            tok_e = self.embs[i](x_tok[:, i])           # Discrete embedding
            val_e = self.val_mlps[i](x_val[:, i:i+1])   # Continuous embedding
            g = self.sigmoid(self.gates[i])             # Learnable blend weight
            emb_list.append(tok_e + g * val_e)
            
        x = torch.stack(emb_list, dim=1)  # [B, T, D]
        x = self.emb_dropout(x)
        
        # Prepend CLS token for global aggregation
        cls = self.cls_token.expand(B, 1, -1)
        x = torch.cat([cls, x], dim=1)  # [B, 1+T, D]
        
        # Apply TokenDrop and Transformer encoder
        x = self.tokendrop(x)
        x = self.encoder(x)
        
        # Extract CLS embedding for prediction
        return self.head(x[:, 0, :]).squeeze(-1)


class TTDataset(Dataset):
    """PyTorch Dataset for TabTransformer++ dual-representation data."""
    def __init__(self, toks, vals, y=None):
        self.toks = torch.as_tensor(toks, dtype=torch.long)
        self.vals = torch.as_tensor(vals, dtype=torch.float32)
        self.y = torch.as_tensor(y, dtype=torch.float32) if y is not None else None
        
    def __len__(self): 
        return len(self.toks)
    
    def __getitem__(self, i): 
        return (self.toks[i], self.vals[i]), (self.y[i] if self.y is not None else 0.0)

## 5. Training Loop

### Cross-Validation Strategy
For each of the 5 folds:

1. **Leak-Free Tokenization**: Fit tokenizer only on training data
2. **Z-Score Targets**: Normalize residuals for stable training
3. **Train with EMA**: 
   - Main model learns via gradient descent
   - EMA model maintains exponential moving average of weights (Polyak averaging)
   - EMA often generalizes better than the final trained weights

### Isotonic Calibration
After training, we calibrate predictions using **Isotonic Regression**:
- Maps the model's z-scored outputs back to actual residual values
- Monotonic transformation that can correct systematic biases
- Fitted on validation data, then applied to test predictions

### Final Prediction
```
final_prediction = base_pred + calibrated_residual
```

In [None]:
# =============================================================================
# Training Loop with K-Fold Cross-Validation
# =============================================================================

# Storage for predictions
oof_preds = np.zeros(len(train_df))        # Out-of-fold residual predictions
test_preds_accum = np.zeros(len(test_df))  # Averaged test predictions

folds = sorted(train_df["fold"].unique())
print(f"\n--- 2. Training TabTransformer++ for Residual Learning ({len(folds)} folds) ---")

for k in folds:
    # =========================================================================
    # A. Leak-Free Data Preparation
    # =========================================================================
    tr_mask = train_df["fold"] != k
    va_mask = train_df["fold"] == k
    
    # Fit tokenizer on training fold only (prevents data leakage)
    tokenizer = TabularTokenizer(features)
    tokenizer.fit(train_df[tr_mask])
    
    # Transform to dual representation (tokens + scalars)
    X_tr_tok, X_tr_val = tokenizer.transform(train_df[tr_mask])
    X_va_tok, X_va_val = tokenizer.transform(train_df[va_mask])
    X_te_tok, X_te_val = tokenizer.transform(test_df)
    
    # Z-score normalize targets for stable training
    y_mu, y_std = tokenizer.stats["_target_"]
    y_tr = (train_df.loc[tr_mask, "residual"].values - y_mu) / y_std
    y_va_raw = train_df.loc[va_mask, "residual"].values
    
    # =========================================================================
    # B. Create DataLoaders
    # =========================================================================
    dl_tr = DataLoader(TTDataset(X_tr_tok, X_tr_val, y_tr), 
                       batch_size=Config.BATCH_SIZE, shuffle=True)
    dl_va = DataLoader(TTDataset(X_va_tok, X_va_val), 
                       batch_size=Config.BATCH_SIZE, shuffle=False)
    dl_te = DataLoader(TTDataset(X_te_tok, X_te_val), 
                       batch_size=Config.BATCH_SIZE, shuffle=False)
    
    # =========================================================================
    # C. Initialize Models (Main + EMA for Polyak Averaging)
    # =========================================================================
    model = TabTransformerGated(tokenizer.get_vocab_sizes()).to(Config.DEVICE)
    ema_model = TabTransformerGated(tokenizer.get_vocab_sizes()).to(Config.DEVICE)
    ema_model.load_state_dict(model.state_dict())
    
    opt = torch.optim.AdamW(model.parameters(), lr=Config.LR, weight_decay=Config.WEIGHT_DECAY)
    loss_fn = nn.SmoothL1Loss(beta=1.0)  # Huber loss for robustness
    
    # =========================================================================
    # D. Training Loop with EMA Updates
    # =========================================================================
    for epoch in range(Config.EPOCHS):
        model.train()
        for (xt, xv), y in dl_tr:
            xt, xv, y = xt.to(Config.DEVICE), xv.to(Config.DEVICE), y.to(Config.DEVICE)
            
            opt.zero_grad()
            pred = model(xt, xv)
            loss = loss_fn(pred, y)
            loss.backward()
            opt.step()
            
            # Update EMA model (Polyak averaging for better generalization)
            with torch.no_grad():
                for p, ema_p in zip(model.parameters(), ema_model.parameters()):
                    ema_p.data.mul_(Config.EMA_DECAY).add_(p.data, alpha=1 - Config.EMA_DECAY)
    
    # =========================================================================
    # E. Evaluation with Isotonic Calibration
    # =========================================================================
    ema_model.eval()
    
    # Predict validation set (in z-score space)
    preds_z = []
    with torch.no_grad():
        for (xt, xv), _ in dl_va:
            preds_z.append(ema_model(xt.to(Config.DEVICE), xv.to(Config.DEVICE)).cpu().numpy())
    preds_z = np.concatenate(preds_z)
    
    # Isotonic calibration: map z-scored predictions to actual residuals
    iso = IsotonicRegression(out_of_bounds="clip")
    iso.fit(preds_z, y_va_raw)
    calib_preds = iso.predict(preds_z)
    
    oof_preds[va_mask] = calib_preds
    rmse = root_mean_squared_error(y_va_raw, calib_preds)
    print(f"Fold {k} | Residual RMSE: {rmse:.4f}")
    
    # Apply calibration to test predictions
    preds_te_z = []
    with torch.no_grad():
        for (xt, xv), _ in dl_te:
            preds_te_z.append(ema_model(xt.to(Config.DEVICE), xv.to(Config.DEVICE)).cpu().numpy())
    preds_te_z = np.concatenate(preds_te_z)
    test_preds_accum += iso.predict(preds_te_z) / len(folds)  # Average across folds
    
    # Cleanup
    del model, ema_model, opt, dl_tr
    if Config.DEVICE == "cuda": 
        torch.cuda.empty_cache()

# =============================================================================
# Final Results: Base + Predicted Residual
# =============================================================================
final_oof = train_df["base_pred"] + oof_preds
final_test = test_df["base_pred"] + test_preds_accum

base_cv = root_mean_squared_error(train_df["MedHouseVal"], train_df["base_pred"])
tt_cv = root_mean_squared_error(train_df["MedHouseVal"], final_oof)

base_test = root_mean_squared_error(test_df["MedHouseVal"], test_df["base_pred"])
tt_test = root_mean_squared_error(test_df["MedHouseVal"], final_test)

print("\n" + "=" * 50)
print("FINAL RESULTS: TabTransformer++ Residual Learning")
print("=" * 50)
print(f"TRAIN (Cross-Validation) RMSE:")
print(f"  Base Model Only:           {base_cv:.5f}")
print(f"  Base + TabTransformer++:   {tt_cv:.5f}")
print("-" * 30)
print(f"TEST (Holdout) RMSE:")
print(f"  Base Model Only:           {base_test:.5f}")
print(f"  Base + TabTransformer++:   {tt_test:.5f}")
print("=" * 50)


--- 2. Training Residual TabTransformer (5 folds) ---
Fold 0 | Residual RMSE: 0.6073
Fold 1 | Residual RMSE: 0.9915
Fold 2 | Residual RMSE: 0.6077
Fold 3 | Residual RMSE: 0.6098
Fold 4 | Residual RMSE: 0.6089

FINAL RESULTS SUMMARY
TRAIN (CV) RMSE:
  Base Model Only:      0.80939
  Base + TT Residual:   0.70200
--------------------
TEST (Holdout) RMSE:
  Base Model Only:      0.73611
  Base + TT Residual:   0.59240
