# TabFormer: Tabular BERT for Credit Card Transaction Modeling

This notebook trains a hierarchical BERT model on credit card transaction data using the TabFormer architecture.
Designed to run in Kaggle environment with GPU acceleration.

## Overview
- **Model**: Hierarchical Tabular BERT with field-wise cross-entropy
- **Task**: Masked Language Modeling on transaction sequences
- **Data**: Synthetic credit card transactions (24M records)
- **Hardware**: GPU-accelerated training

## 1. Install Required Packages

In [None]:
# Install required packages for Kaggle environment
# Uninstall existing torch packages to avoid CUDA version conflicts
!pip uninstall -y torch torchvision torchaudio

# Install PyTorch with CUDA 11.8 support (all from same index)
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Install transformers with torch extras and other dependencies
!pip install -q "transformers[torch]>=4.30.0" scikit-learn>=1.0.0 pandas>=1.3.0 numpy>=1.21.0

## 2. Import Libraries and Setup

In [None]:
import os
import sys
import random
import math
import pickle
import logging
from collections import OrderedDict
from typing import Dict, List, Tuple, Union

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from sklearn.preprocessing import LabelEncoder, MinMaxScaler

from transformers import (
    BertTokenizer,
    BertConfig,
    BertModel,
    BertForMaskedLM,
    PreTrainedModel,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)

from IPython.display import display, HTML, Markdown

# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO
)
logger = logging.getLogger(__name__)

# Display GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n{'='*60}")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
print(f"{'='*60}\n")

## 3. Configuration

In [None]:
# Training configuration
CONFIG = {
    'seed': 42,
    'data_root': './data/credit_card/',
    'data_fname': 'card_transaction.v1',
    'output_dir': './output',
    'vocab_dir': './output',
    
    # Model hyperparameters
    'field_hidden_size': 64,  # Reduced for faster training
    'mlm_prob': 0.15,
    
    # Data parameters
    'seq_len': 10,  # Sequence length (number of transactions)
    'stride': 5,  # Stride for sliding window
    'num_bins': 10,  # Number of bins for quantization
    'nrows': 100000,  # Number of rows to use (None for all data)
    'user_ids': None,  # Filter by user IDs (None for all users)
    
    # Training parameters
    'num_train_epochs': 3,
    'batch_size': 32,
    'save_steps': 500,
    'logging_steps': 100,
}

# Create output directory
os.makedirs(CONFIG['output_dir'], exist_ok=True)
os.makedirs(CONFIG['vocab_dir'], exist_ok=True)

# Set random seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(CONFIG['seed'])

display(HTML(f"<h3>Configuration Set</h3><p>Using {CONFIG['nrows']} rows for training</p>"))

## 4. Vocabulary Class

Manages the vocabulary for tabular data with field-aware tokenization.

In [None]:
class AttrDict(dict):
    """Dictionary that allows attribute-style access"""
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


class Vocabulary:
    """Vocabulary for tabular data with field-aware tokenization"""
    
    def __init__(self, adap_thres=10000, target_column_name="Is Fraud?"):
        # Special tokens
        self.unk_token = "[UNK]"
        self.sep_token = "[SEP]"
        self.pad_token = "[PAD]"
        self.cls_token = "[CLS]"
        self.mask_token = "[MASK]"
        self.bos_token = "[BOS]"
        self.eos_token = "[EOS]"
        
        self.adap_thres = adap_thres
        self.adap_sm_cols = set()
        
        self.target_column_name = target_column_name
        self.special_field_tag = "SPECIAL"
        
        self.special_tokens = [self.unk_token, self.sep_token, self.pad_token,
                               self.cls_token, self.mask_token, self.bos_token, self.eos_token]
        
        self.token2id = OrderedDict()  # {field: {token: [global_id, local_id]}, ...}
        self.id2token = OrderedDict()  # {global_id: [token, field, local_id]}
        self.field_keys = OrderedDict()
        self.token2id[self.special_field_tag] = OrderedDict()
        
        self.filename = ''
        
        # Initialize special tokens
        for token in self.special_tokens:
            global_id = len(self.id2token)
            local_id = len(self.token2id[self.special_field_tag])
            self.token2id[self.special_field_tag][token] = [global_id, local_id]
            self.id2token[global_id] = [token, self.special_field_tag, local_id]
    
    def set_field_keys(self, keys):
        """Initialize field keys"""
        for key in keys:
            self.token2id[key] = OrderedDict()
            global_id = len(self.id2token)
            local_id = len(self.token2id[key])
            self.token2id[key][self.sep_token] = [global_id, local_id]
            self.id2token[global_id] = [self.sep_token, key, local_id]
            self.field_keys[key] = global_id
    
    def set_id(self, token, field_name, return_local=False):
        """Set token ID for a field"""
        if token not in self.token2id[field_name]:
            global_id = len(self.id2token)
            local_id = len(self.token2id[field_name])
            self.token2id[field_name][token] = [global_id, local_id]
            self.id2token[global_id] = [token, field_name, local_id]
        else:
            global_id, local_id = self.token2id[field_name][token]
        
        return local_id if return_local else global_id
    
    def get_id(self, token, field_name="", special_token=False, return_local=False):
        """Get token ID"""
        if special_token:
            field_name = self.special_field_tag
        
        if token in self.token2id[field_name]:
            global_id, local_id = self.token2id[field_name][token]
        else:
            raise Exception(f"Token {token} not found in field: {field_name}")
        
        return local_id if return_local else global_id
    
    def get_special_tokens(self):
        """Return special tokens as AttrDict"""
        return AttrDict(
            unk_token=self.unk_token,
            sep_token=self.sep_token,
            pad_token=self.pad_token,
            cls_token=self.cls_token,
            mask_token=self.mask_token,
            bos_token=self.bos_token,
            eos_token=self.eos_token
        )
    
    def get_from_local_ids(self, field, local_ids):
        """Convert local IDs to global IDs"""
        return [self.token2id[field][token][0] for token, ids in self.token2id[field].items() if ids[1] in local_ids]
    
    def save_vocab(self, file_name):
        """Save vocabulary to file"""
        self.filename = file_name
        with open(file_name, 'wb') as f:
            pickle.dump(self, f)
        logger.info(f"Vocabulary saved to {file_name}")
    
    def __len__(self):
        return len(self.id2token)

print("✓ Vocabulary class defined")

## 5. Transaction Dataset Class

Processes credit card transaction data into sequences for training.

In [None]:
class TransactionDataset(Dataset):
    """Credit card transaction dataset for sequence modeling"""
    
    def __init__(self,
                 mlm=True,
                 user_ids=None,
                 seq_len=10,
                 num_bins=10,
                 cached=False,
                 root="./data/card/",
                 fname="card_trans",
                 vocab_dir="checkpoints",
                 fextension="",
                 nrows=None,
                 flatten=False,
                 stride=5,
                 adap_thres=10**8,
                 return_labels=False,
                 skip_user=False):
        
        self.root = root
        self.fname = fname
        self.nrows = nrows
        self.fextension = f'_{fextension}' if fextension else ''
        self.cached = cached
        self.user_ids = user_ids
        self.return_labels = return_labels
        self.skip_user = skip_user
        
        self.mlm = mlm
        self.trans_stride = stride
        self.flatten = flatten
        
        self.vocab = Vocabulary(adap_thres)
        self.seq_len = seq_len
        self.encoder_fit = {}
        
        self.trans_table = None
        self.data = []
        self.labels = []
        self.window_label = []
        
        self.ncols = None
        self.num_bins = num_bins
        
        # Process data
        self.encode_data()
        self.init_vocab()
        self.prepare_samples()
        self.save_vocab(vocab_dir)
    
    def __getitem__(self, index):
        if self.flatten:
            return_data = torch.tensor(self.data[index], dtype=torch.long)
        else:
            return_data = torch.tensor(self.data[index], dtype=torch.long).reshape(self.seq_len, -1)
        
        if self.return_labels:
            return_data = (return_data, torch.tensor(self.labels[index], dtype=torch.long))
        
        return return_data
    
    def __len__(self):
        return len(self.data)
    
    def save_vocab(self, vocab_dir):
        """Save vocabulary"""
        file_name = os.path.join(vocab_dir, f'vocab{self.fextension}.nb')
        self.vocab.save_vocab(file_name)
    
    @staticmethod
    def label_fit_transform(column, enc_type="label"):
        """Fit and transform column using encoder"""
        if enc_type == "label":
            mfit = LabelEncoder()
        else:
            mfit = MinMaxScaler()
        mfit.fit(column)
        return mfit, mfit.transform(column)
    
    @staticmethod
    def nanNone(x):
        """Replace NaN with 'None'"""
        return x.fillna('None')
    
    @staticmethod
    def nanZero(x):
        """Replace NaN with 0"""
        return x.fillna(0)
    
    @staticmethod
    def fraudEncoder(x):
        """Encode fraud labels"""
        return x.map(lambda val: 1 if val == 'Yes' else 0)
    
    def _quantize(self, inputs, bin_edges):
        """Quantize continuous values"""
        out_tokens = []
        for val in inputs:
            out_tokens.append(np.abs(bin_edges - val).argmin())
        return out_tokens
    
    def encode_data(self):
        """Encode transaction data"""
        data_file = os.path.join(self.root, f"{self.fname}.csv")
        
        logger.info(f"Reading data from {data_file}")
        data = pd.read_csv(data_file, nrows=self.nrows)
        logger.info(f"Loaded {len(data)} transactions")
        
        # Handle NaN values
        data['Errors?'] = self.nanNone(data['Errors?'])
        data['Is Fraud?'] = self.fraudEncoder(data['Is Fraud?'])
        data['Zip'] = self.nanZero(data['Zip'])
        data['Merchant State'] = self.nanNone(data['Merchant State'])
        
        # Remove dollar sign from Amount
        data['Amount'] = data['Amount'].apply(lambda x: float(x.replace('$', '')))
        
        # Encode categorical columns
        cat_cols = ['Merchant Name', 'Merchant City', 'Merchant State', 'Zip', 
                   'MCC', 'Errors?', 'Use Chip', 'Year', 'Month', 'Day']
        
        for col in cat_cols:
            self.encoder_fit[col], data[col] = self.label_fit_transform(data[col].astype(str))
        
        # Quantize Amount
        min_val, max_val = data['Amount'].min(), data['Amount'].max()
        bin_edges = np.linspace(min_val, max_val, self.num_bins)
        data['Amount'] = self._quantize(data['Amount'].values, bin_edges)
        
        # Convert Time to numeric
        data['Time'] = data['Time'].apply(lambda x: int(x.split(':')[0]) * 60 + int(x.split(':')[1]))
        self.encoder_fit['Time'], data['Time'] = self.label_fit_transform(data['Time'])
        
        if not self.skip_user:
            self.encoder_fit['User'], data['User'] = self.label_fit_transform(data['User'])
        
        self.encoder_fit['Card'], data['Card'] = self.label_fit_transform(data['Card'])
        
        self.trans_table = data
        logger.info("Data encoding completed")
    
    def init_vocab(self):
        """Initialize vocabulary from encoded data"""
        columns = list(self.trans_table.columns)
        
        # Remove label column
        columns.remove('Is Fraud?')
        
        if self.skip_user:
            columns.remove('User')
        
        self.vocab.set_field_keys(columns)
        
        # Build vocabulary
        for column in tqdm(columns, desc="Building vocabulary"):
            for val in self.trans_table[column].unique():
                token = str(val)
                self.vocab.set_id(token, column)
        
        logger.info(f"Vocabulary size: {len(self.vocab)}")
    
    def user_level_data(self):
        """Organize data by user"""
        if self.user_ids:
            user_ids = [int(u) for u in self.user_ids]
        else:
            user_ids = self.trans_table['User'].unique()
        
        user_data = {}
        for user_id in user_ids:
            user_data[user_id] = self.trans_table[self.trans_table['User'] == user_id]
        
        return user_data
    
    def prepare_samples(self):
        """Prepare training samples from transactions"""
        logger.info("Preparing samples...")
        
        columns = list(self.trans_table.columns)
        columns.remove('Is Fraud?')
        if self.skip_user:
            columns.remove('User')
        
        user_data = self.user_level_data()
        
        cls_token = self.vocab.get_id(self.vocab.cls_token, special_token=True)
        sep_token = self.vocab.get_id(self.vocab.sep_token, special_token=True)
        
        for user_id, user_trans in tqdm(user_data.items(), desc="Processing users"):
            user_vocab_ids = []
            user_labels = []
            
            for _, trans in user_trans.iterrows():
                trans_ids = []
                
                if not self.flatten:
                    trans_ids.append(cls_token)
                
                for column in columns:
                    token = str(trans[column])
                    token_id = self.vocab.get_id(token, column)
                    trans_ids.append(token_id)
                    
                    if not self.flatten:
                        trans_ids.append(sep_token)
                
                user_vocab_ids.append(trans_ids)
                user_labels.append(trans['Is Fraud?'])
            
            # Create sliding windows
            for idx in range(0, len(user_vocab_ids) - self.seq_len + 1, self.trans_stride):
                ids = user_vocab_ids[idx:(idx + self.seq_len)]
                ids = [item for sublist in ids for item in sublist]  # Flatten
                self.data.append(ids)
            
            for jdx in range(0, len(user_labels) - self.seq_len + 1, self.trans_stride):
                ids = user_labels[jdx:(jdx + self.seq_len)]
                self.labels.append(ids)
                
                fraud = 1 if sum(ids) > 0 else 0
                self.window_label.append(fraud)
        
        assert len(self.data) == len(self.labels)
        
        # Calculate ncols (number of columns)
        self.ncols = len(self.vocab.field_keys) - 2 + (1 if self.mlm else 0)
        
        logger.info(f"Prepared {len(self.data)} samples")
        logger.info(f"Number of columns: {self.ncols}")

print("✓ TransactionDataset class defined")

## 6. Load and Prepare Data

In [None]:
# Create dataset
display(HTML("<h3>Loading Transaction Data</h3>"))

dataset = TransactionDataset(
    root=CONFIG['data_root'],
    fname=CONFIG['data_fname'],
    vocab_dir=CONFIG['vocab_dir'],
    nrows=CONFIG['nrows'],
    user_ids=CONFIG['user_ids'],
    mlm=True,
    cached=False,
    stride=CONFIG['stride'],
    flatten=False,  # Use hierarchical model
    return_labels=False,
    skip_user=False,
    seq_len=CONFIG['seq_len'],
    num_bins=CONFIG['num_bins']
)

vocab = dataset.vocab
custom_special_tokens = vocab.get_special_tokens()

# Display dataset info
print(f"\n{'='*60}")
print(f"Dataset Statistics:")
print(f"  Total samples: {len(dataset):,}")
print(f"  Vocabulary size: {len(vocab):,}")
print(f"  Number of columns: {dataset.ncols}")
print(f"  Sequence length: {CONFIG['seq_len']}")
print(f"{'='*60}\n")

## 7. Split Dataset

In [None]:
def random_split_dataset(dataset, lengths, random_seed=20200706):
    """Split dataset with reproducible randomness"""
    # Save state
    state = {
        'python_state': random.getstate(),
        'numpy_state': np.random.get_state(),
        'torch_state': torch.get_rng_state(),
        'cuda_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None
    }
    
    # Set seed
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(random_seed)
    
    # Split
    train_dataset, eval_dataset, test_dataset = torch.utils.data.random_split(dataset, lengths)
    
    # Restore state
    random.setstate(state['python_state'])
    np.random.set_state(state['numpy_state'])
    torch.set_rng_state(state['torch_state'])
    if torch.cuda.is_available() and state['cuda_state'] is not None:
        torch.cuda.set_rng_state(state['cuda_state'])
    
    return train_dataset, eval_dataset, test_dataset

# Split dataset: 60% train, 20% val, 20% test
totalN = len(dataset)
trainN = int(0.6 * totalN)
valtestN = totalN - trainN
valN = int(valtestN * 0.5)
testN = valtestN - valN

lengths = [trainN, valN, testN]
train_dataset, eval_dataset, test_dataset = random_split_dataset(dataset, lengths)

print(f"\n{'='*60}")
print(f"Dataset Split:")
print(f"  Train: {trainN:,} ({trainN/totalN:.1%})")
print(f"  Val:   {valN:,} ({valN/totalN:.1%})")
print(f"  Test:  {testN:,} ({testN/totalN:.1%})")
print(f"{'='*60}\n")

## 8. TabFormer Model Architecture

Hierarchical BERT model with field embeddings for tabular data.

In [None]:
# Field Embeddings for hierarchical model
class TabFormerEmbeddings(nn.Module):
    """Hierarchical embeddings for tabular data"""
    
    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.field_hidden_size, 
                                           padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, 
                                               config.field_hidden_size)
        self.LayerNorm = nn.LayerNorm(config.field_hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
    
    def forward(self, input_ids):
        input_shape = input_ids.size()
        seq_length = input_shape[1]
        
        position_ids = self.position_ids[:, :seq_length]
        
        inputs_embeds = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        
        embeddings = inputs_embeds + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        
        return embeddings


# Custom BERT Configuration
class TabFormerBertConfig(BertConfig):
    """Configuration for TabFormer BERT"""
    
    def __init__(self,
                 ncols=12,
                 field_hidden_size=64,
                 flatten=False,
                 **kwargs):
        super().__init__(**kwargs)
        self.ncols = ncols
        self.field_hidden_size = field_hidden_size
        self.flatten = flatten


# TabFormer BERT for Masked LM
class TabFormerBertForMaskedLM(BertForMaskedLM):
    """BERT model with field-wise cross-entropy"""
    
    def __init__(self, config, vocab):
        super().__init__(config)
        self.vocab = vocab
    
    def forward(self, input_ids=None, inputs_embeds=None, masked_lm_labels=None, **kwargs):
        if inputs_embeds is None and input_ids is not None:
            outputs = super().forward(input_ids=input_ids, labels=masked_lm_labels, **kwargs)
        else:
            outputs = super().forward(inputs_embeds=inputs_embeds, labels=masked_lm_labels, **kwargs)
        return outputs


# Hierarchical TabFormer Model
class TabFormerHierarchicalLM(PreTrainedModel):
    """Hierarchical TabFormer with field embeddings"""
    base_model_prefix = "bert"
    
    def __init__(self, config, vocab):
        super().__init__(config)
        self.config = config
        self.tab_embeddings = TabFormerEmbeddings(config)
        self.tb_model = TabFormerBertForMaskedLM(config, vocab)
    
    def forward(self, input_ids, **input_args):
        inputs_embeds = self.tab_embeddings(input_ids)
        return self.tb_model(inputs_embeds=inputs_embeds, **input_args)


# TabFormer BERT Language Model Wrapper
class TabFormerBertLM:
    """TabFormer BERT Language Model"""
    
    def __init__(self, special_tokens, vocab, field_ce=False, flatten=False, 
                 ncols=None, field_hidden_size=768):
        self.ncols = ncols
        self.vocab = vocab
        vocab_file = self.vocab.filename
        
        hidden_size = field_hidden_size if flatten else (field_hidden_size * self.ncols)
        
        self.config = TabFormerBertConfig(
            vocab_size=len(self.vocab),
            ncols=self.ncols,
            hidden_size=hidden_size,
            field_hidden_size=field_hidden_size,
            flatten=flatten,
            num_attention_heads=self.ncols
        )
        
        self.tokenizer = BertTokenizer(
            vocab_file,
            do_lower_case=False,
            **special_tokens
        )
        
        self.model = self.get_model(field_ce, flatten)
    
    def get_model(self, field_ce, flatten):
        """Get appropriate model based on configuration"""
        if flatten and not field_ce:
            # Flattened vanilla BERT
            model = BertForMaskedLM(self.config)
        elif flatten and field_ce:
            # Flattened field CE BERT
            model = TabFormerBertForMaskedLM(self.config, self.vocab)
        else:
            # Hierarchical field CE BERT
            model = TabFormerHierarchicalLM(self.config, self.vocab)
        
        return model

print("✓ TabFormer model architecture defined")

## 9. Initialize Model

In [None]:
display(HTML("<h3>Initializing TabFormer BERT Model</h3>"))

tab_net = TabFormerBertLM(
    custom_special_tokens,
    vocab=vocab,
    field_ce=True,  # Use field-wise cross-entropy
    flatten=False,  # Use hierarchical model
    ncols=dataset.ncols,
    field_hidden_size=CONFIG['field_hidden_size']
)

print(f"\n{'='*60}")
print(f"Model Configuration:")
print(f"  Type: {tab_net.model.__class__.__name__}")
print(f"  Vocabulary size: {len(vocab):,}")
print(f"  Hidden size: {tab_net.config.hidden_size}")
print(f"  Field hidden size: {tab_net.config.field_hidden_size}")
print(f"  Number of attention heads: {tab_net.config.num_attention_heads}")
print(f"  Number of layers: {tab_net.config.num_hidden_layers}")

# Count parameters
total_params = sum(p.numel() for p in tab_net.model.parameters())
trainable_params = sum(p.numel() for p in tab_net.model.parameters() if p.requires_grad)
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"{'='*60}\n")

## 10. Data Collator for MLM

Custom data collator for hierarchical masked language modeling.

In [None]:
class TransDataCollatorForLanguageModeling(DataCollatorForLanguageModeling):
    """Data collator for hierarchical tabular data"""
    
    def __call__(self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        batch = self._tensorize_batch(examples)
        sz = batch.shape
        
        if self.mlm:
            # Flatten for masking
            batch = batch.view(sz[0], -1)
            inputs, labels = self.mask_tokens(batch)
            # Reshape back to hierarchical format
            return {"input_ids": inputs.view(sz), "masked_lm_labels": labels.view(sz)}
        else:
            labels = batch.clone().detach()
            if self.tokenizer.pad_token_id is not None:
                labels[labels == self.tokenizer.pad_token_id] = -100
            return {"input_ids": batch, "labels": labels}
    
    def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Prepare masked tokens: 80% MASK, 10% random, 10% original"""
        if self.tokenizer.mask_token is None:
            raise ValueError("This tokenizer does not have a mask token")
        
        labels = inputs.clone()
        
        # Sample tokens for MLM (15% by default)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        special_tokens_mask = [
            self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) 
            for val in labels.tolist()
        ]
        probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
        
        if self.tokenizer._pad_token is not None:
            padding_mask = labels.eq(self.tokenizer.pad_token_id)
            probability_matrix.masked_fill_(padding_mask, value=0.0)
        
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # Only compute loss on masked tokens
        
        # 80% of the time, replace with [MASK]
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
        
        # 10% of the time, replace with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]
        
        # 10% of the time, keep the token unchanged
        return inputs, labels

# Create data collator
data_collator = TransDataCollatorForLanguageModeling(
    tokenizer=tab_net.tokenizer,
    mlm=True,
    mlm_probability=CONFIG['mlm_prob']
)

print("✓ Data collator initialized")

## 11. Training Configuration

In [None]:
training_args = TrainingArguments(
    output_dir=CONFIG['output_dir'],
    num_train_epochs=CONFIG['num_train_epochs'],
    per_device_train_batch_size=CONFIG['batch_size'],
    per_device_eval_batch_size=CONFIG['batch_size'],
    logging_dir=os.path.join(CONFIG['output_dir'], 'logs'),
    logging_steps=CONFIG['logging_steps'],
    save_steps=CONFIG['save_steps'],
    save_total_limit=2,
    prediction_loss_only=True,
    overwrite_output_dir=True,
    do_train=True,
    do_eval=False,
    # GPU optimization
    fp16=torch.cuda.is_available(),  # Use mixed precision on GPU
    dataloader_pin_memory=True,
    dataloader_num_workers=2,
)

print(f"\n{'='*60}")
print(f"Training Configuration:")
print(f"  Epochs: {CONFIG['num_train_epochs']}")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  Mixed precision (FP16): {training_args.fp16}")
print(f"  Device: {device}")
print(f"{'='*60}\n")

## 12. Initialize Trainer

In [None]:
trainer = Trainer(
    model=tab_net.model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

print("✓ Trainer initialized")
display(HTML("<h3>Ready to Train!</h3>"))

## 13. Train the Model

Start training the TabFormer model on credit card transactions.

In [None]:
display(HTML("<h2>🚀 Training Started</h2>"))

# Train the model
trainer.train()

display(HTML("<h2>✅ Training Complete!</h2>"))

## 14. Save Model

In [None]:
# Save the final model
final_model_path = os.path.join(CONFIG['output_dir'], 'final_model')
trainer.save_model(final_model_path)

print(f"\n{'='*60}")
print(f"Model saved to: {final_model_path}")
print(f"{'='*60}\n")

display(HTML(f"<h3>✓ Model saved successfully</h3><p>Location: {final_model_path}</p>"))

## 15. Training Summary

In [None]:
# Display training summary
print(f"\n{'='*60}")
print(f"TRAINING SUMMARY")
print(f"{'='*60}")
print(f"Dataset: Credit Card Transactions")
print(f"Model: Hierarchical TabFormer BERT")
print(f"Training samples: {trainN:,}")
print(f"Validation samples: {valN:,}")
print(f"Epochs: {CONFIG['num_train_epochs']}")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Field hidden size: {CONFIG['field_hidden_size']}")
print(f"Total parameters: {total_params:,}")
print(f"Device: {device}")
print(f"{'='*60}\n")

display(Markdown("""
## Next Steps

1. **Evaluate the model**: Use the trained model for inference on test data
2. **Fine-tune**: Adjust hyperparameters for better performance
3. **Deploy**: Export the model for production use
4. **Analyze**: Examine attention patterns and learned representations
"""))