# PyTorch FT-Transformer Fraud Classification

This notebook implements a grid search for the FT-Transformer model

In [None]:
# Variables
train_dataset_path = '../data/train.csv'
test_dataset_path = '../data/test.csv'
metadata_path = '../data/preprocessing_metadata.json'
class_label = 'Class'
random_seed = 42

In [None]:
# Import required libraries
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report, roc_auc_score,
    precision_recall_curve, roc_curve, average_precision_score
)
import pickle
import warnings
import multiprocessing as mp
warnings.filterwarnings('ignore')

# Set multiprocessing start method for Jupyter compatibility
try:
    mp.set_start_method('spawn', force=True)
    print("Multiprocessing start method set to 'spawn' for Jupyter compatibility")
except RuntimeError:
    print("Multiprocessing start method already set")

# Disable multiprocessing in DataLoaders for Jupyter safety
import os
os.environ['PYTORCH_DATALOADER_NUM_WORKERS'] = '0'

# Set random seeds for reproducibility
np.random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    # Ensure deterministic behavior on CUDA
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Device configuration - prioritize CUDA if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("All libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
print(f"Using device: {device}")

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

!pip install rtdl_revisiting_models

## 1. Data Loading and Exploration

Loading the balanced training data and original test data for fraud detection.


In [None]:
# Load the datasets
print("Loading datasets...")

# Load balanced training data (SMOTE applied)
train_df = pd.read_csv(train_dataset_path)
print(f"Balanced training data shape: {train_df.shape}")

# Load original test data (imbalanced)
test_df = pd.read_csv(test_dataset_path)
print(f"Test data shape: {test_df.shape}")

# Display basic information
print("\n=== Training Data Info ===")
print(train_df.info())
print(f"\nClass distribution in training data:")
print(train_df[class_label].value_counts())
print(f"Training fraud percentage: {train_df[class_label].mean()*100:.2f}%")

print("\n=== Test Data Info ===")
print(f"\nClass distribution in test data:")
print(test_df[class_label].value_counts())
print(f"Test fraud percentage: {test_df[class_label].mean()*100:.2f}%")


## 2. Data Preprocessing for FT-Transformer

Preparing the data for TabNet training including feature separation and encoding.


In [None]:
# Data preprocessing for TabNet
def preprocess_data(train_df, test_df):
    """
    Preprocess data for TabNet training

    Returns:
        X_train, y_train, X_test, y_test, cat_idxs, cat_dims
    """
    import json

    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    print(f"Loaded metadata: {metadata}")

    # Separate features and target
    feature_cols = [col for col in train_df.columns if col != class_label]

    X_train = train_df[feature_cols].copy()
    y_train = train_df[class_label].values.astype(int)

    X_test = test_df[feature_cols].copy()
    y_test = test_df[class_label].values.astype(int)

    print(f"Feature columns: {len(feature_cols)}")
    print(f"Training samples: {len(X_train)}")
    print(f"Test samples: {len(X_test)}")

    categorical_cols = []
    numerical_cols = [col for col in feature_cols if col not in categorical_cols]


    # Get categorical info from metadata but verify against current feature order
    metadata_features = metadata.get('feature_columns', [])

    print(f"Metadata feature order: {metadata_features}")
    print(f"Current feature order: {feature_cols}")

    # Convert DataFrames to numpy arrays for TabNet (required)
    X_train = X_train.values.astype(np.float32)
    X_test = X_test.values.astype(np.float32)

    print(f"\nConverted to numpy arrays for TabNet compatibility")
    print(f"X_train type: {type(X_train)}, shape: {X_train.shape}")
    print(f"X_test type: {type(X_test)}, shape: {X_test.shape}")

    return X_train, y_train, X_test, y_test, feature_cols, numerical_cols

# Preprocess the data
X_train, y_train, X_test, y_test, feature_names, numerical_cols = preprocess_data(train_df, test_df)

print(f"\nFinal shapes:")
print(f"X_train: {X_train.shape}")
print(f"y_train: {y_train.shape}")
print(f"X_test: {X_test.shape}")
print(f"y_test: {y_test.shape}")
print(f"Class distribution in y_train: {np.bincount(y_train)}")
print(f"Class distribution in y_test: {np.bincount(y_test)}")


In [None]:
# Create validation split from training data
X_train_df = pd.DataFrame(X_train, columns=feature_names)
y_train_df = pd.DataFrame(y_train, columns=[class_label])

X_train_split, X_val, y_train_split, y_val = train_test_split(
    X_train_df, y_train_df,
    test_size=0.1,
    random_state=random_seed,
    stratify=y_train_df
)

print(f"Training split: {X_train_split.shape}")
print(f"Validation split: {X_val.shape}")

## 3. FT-Transformer Model Definition and Training

Building and training the TabNet classifier with optimized hyperparameters for fraud detection.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from rtdl_revisiting_models import FTTransformer
import copy
import random

class FraudFTTransformer:
    def __init__(self, n_num_features, n_cat_features, n_cont_features, cat_cardinalities,
                 device=None, seed: int = 42, **model_params):
        # Use global device if not specified
        if device is None:
            device = globals().get('device', torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

        self.device = device
        self.seed = seed

        # Set seeds for reproducibility across libs
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        print(f"Initializing FraudFTTransformer on device: {self.device}")

        default_params = {
            'n_blocks':3,
            'd_block': 264,
            'attention_n_heads': 12,
            'attention_dropout': 0.1,
            'ffn_d_hidden': None,
            'ffn_d_hidden_multiplier': 5 / 4,
            'ffn_dropout': 0.1,
            'residual_dropout': 0.0,
        }

        # Update with provided parameters
        default_params.update(model_params)
        self.model_params = default_params

        print(f"Model parameters: {self.model_params}")

        # Model configuration optimized for fraud detection
        self.model = FTTransformer(
            n_cont_features=n_cont_features,
            cat_cardinalities=cat_cardinalities,
            d_out=2,
            **default_params
        ).to(self.device)

        fraud_count = (y_train_split == 1).sum()
        non_fraud_count = (y_train_split == 0).sum()

        self.criterion = nn.CrossEntropyLoss().to(self.device)

        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=0.0003,
            weight_decay=1e-3
        )

        # Ensure DataLoader worker determinism - use CPU generator for reproducibility
        self.generator = torch.Generator().manual_seed(self.seed)

        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', patience=5, factor=0.7, min_lr=1e-6
        )

        # Print model info
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f"Model parameters: {total_params:,} total, {trainable_params:,} trainable")

        if torch.cuda.is_available() and self.device.type == 'cuda':
            print(f"GPU memory allocated: {torch.cuda.memory_allocated(self.device) / 1024**2:.1f} MB")


    def fit(self, X_num, y, X_val_num=None, y_val=None,
            epochs=100, batch_size=512, patience=20):

        X_cat = None

        pin_memory = torch.cuda.is_available() and self.device.type == 'cuda'
        # Use num_workers=0 in Jupyter to avoid multiprocessing issues
        num_workers = 0

        print(f"DataLoader config: pin_memory={pin_memory}, num_workers={num_workers}")

        # Create data loaders with CUDA optimizations
        train_dataset = TensorDataset(
            torch.FloatTensor(X_num),
            torch.LongTensor(y)
        )

        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            generator=self.generator,
            pin_memory=pin_memory,
            num_workers=num_workers
        )

        if X_val_num is not None:
            val_dataset = TensorDataset(
                torch.FloatTensor(X_val_num),
                torch.LongTensor(y_val)
            )

            val_loader = DataLoader(
                val_dataset,
                batch_size=batch_size,
                pin_memory=pin_memory,
                num_workers=num_workers
            )

        # Training loop
        best_val_loss = float('inf')
        patience_counter = 0
        history = {'train_loss': [], 'train_accuracy': [], 'val_loss': [], 'val_auc': [], 'val_accuracy': []}

        print(f"Starting training for {epochs} epochs on {self.device}")

        for epoch in range(epochs):
            # Training phase
            self.model.train()
            train_loss = 0

            train_correct = 0
            train_total = 0

            for batch_idx, (batch_num, batch_y) in enumerate(train_loader):
                # Move tensors to device with non_blocking for CUDA efficiency
                batch_num = batch_num.to(self.device, non_blocking=True)
                batch_y = batch_y.to(self.device, non_blocking=True)

                self.optimizer.zero_grad()

                # Forward pass
                outputs = self.model(batch_num, None)
                loss = self.criterion(outputs, batch_y.squeeze())

                # Backward pass
                loss.backward()

                # Gradient clipping for stability
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

                self.optimizer.step()
                train_loss += loss.item()

                # Calculate training accuracy
                _, predicted = torch.max(outputs.data, 1)
                train_total += batch_y.size(0)
                train_correct += (predicted == batch_y.squeeze()).sum().item()

                # Clear cache periodically on CUDA to prevent memory issues
                if torch.cuda.is_available() and batch_idx % 100 == 0:
                    torch.cuda.empty_cache()

            # Calculate and store training accuracy
            train_accuracy = train_correct / train_total
            history['train_accuracy'].append(train_accuracy)

            # Validation phase
            if X_val_num is not None:
                val_loss, val_auc, val_accuracy = self._validate(val_loader)
                history['val_loss'].append(val_loss)
                history['val_auc'].append(val_auc)
                history['val_accuracy'].append(val_accuracy)

                # Learning rate scheduling
                self.scheduler.step(val_loss)

                # Early stopping tracking
                improved = val_loss < best_val_loss
                if improved:
                    best_val_loss = val_loss
                    patience_counter = 0
                    # Save a deep copy of the best state dict
                    best_state_dict = copy.deepcopy(self.model.state_dict())
                else:
                    patience_counter += 1

                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch}")
                    break

                # Print statement with accuracy
                print(f"Epoch {epoch}: Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_accuracy:.4f} | "
                      f"Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}, Val Acc: {val_accuracy:.4f}")


            history['train_loss'].append(train_loss/len(train_loader))

        # Restore best model weights if available
        if 'best_state_dict' in locals() and best_state_dict is not None:
            self.model.load_state_dict(best_state_dict)
            self.best_val_loss = best_val_loss
            self.best_state_dict = best_state_dict
        else:
            print("No validation data provided - using final model weights")
            self.best_val_loss = None
            self.best_state_dict = None

        return history

    def _validate(self, val_loader):
        self.model.eval()
        val_loss = 0
        all_probs = []
        all_labels = []

        # Track validation accuracy
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for batch_num, batch_y in val_loader:
                # Move tensors to device with non_blocking for CUDA efficiency
                batch_num = batch_num.to(self.device, non_blocking=True)
                batch_y = batch_y.to(self.device, non_blocking=True)

                outputs = self.model(batch_num, None)
                loss = self.criterion(outputs, batch_y.squeeze())
                val_loss += loss.item()

                probs = torch.softmax(outputs, dim=1)[:, 1]
                all_probs.extend(probs.cpu().numpy())
                all_labels.extend(batch_y.cpu().numpy())

                # Calculate validation accuracy
                _, predicted = torch.max(outputs.data, 1)
                val_total += batch_y.size(0)
                val_correct += (predicted == batch_y.squeeze()).sum().item()

        val_accuracy = val_correct / val_total
        val_auc = roc_auc_score(all_labels, all_probs)
        return val_loss/len(val_loader), val_auc, val_accuracy

    def predict_proba(self, X_num, batch_size=512):
        self.model.eval()
        all_probs = []

        # CUDA optimizations for inference
        pin_memory = torch.cuda.is_available() and self.device.type == 'cuda'

        dataset = TensorDataset(
            torch.FloatTensor(X_num)
        )
        loader = DataLoader(
            dataset,
            batch_size=batch_size,
            pin_memory=pin_memory,
            num_workers=0
        )

        with torch.no_grad():
            for batch in loader:
                batch_num = batch[0].to(self.device, non_blocking=True)

                outputs = self.model(batch_num, None)
                probs = torch.softmax(outputs, dim=1)
                all_probs.extend(probs.cpu().numpy())

        return np.array(all_probs)

    def get_memory_usage(self):
        """Get current memory usage information"""
        if torch.cuda.is_available() and self.device.type == 'cuda':
            allocated = torch.cuda.memory_allocated(self.device) / 1024**2
            cached = torch.cuda.memory_reserved(self.device) / 1024**2
            return f"GPU Memory - Allocated: {allocated:.1f} MB, Cached: {cached:.1f} MB"
        else:
            return "CPU mode - no GPU memory tracking"

    def cleanup_memory(self):
        """Clean up CUDA memory and force garbage collection"""
        import gc

        # Force garbage collection to clean up any lingering DataLoader references
        gc.collect()

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print("CUDA cache cleared")

        print("Memory cleanup completed")

    def save_model(self, filepath):
        """Save the best model state and configuration"""
        save_dict = {
            'model_state_dict': self.best_state_dict if hasattr(self, 'best_state_dict') and self.best_state_dict is not None else self.model.state_dict(),
            'model_params': self.model_params,
            'best_val_loss': getattr(self, 'best_val_loss', None),
            'device': str(self.device),
            'seed': self.seed
        }
        torch.save(save_dict, filepath)
        print(f"Model saved to {filepath}")

    def load_model(self, filepath):
        """Load model state and configuration"""
        checkpoint = torch.load(filepath, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.best_val_loss = checkpoint.get('best_val_loss', None)
        self.best_state_dict = checkpoint['model_state_dict']
        print(f"Model loaded from {filepath}")
        if self.best_val_loss is not None:
            print(f"Best validation loss: {self.best_val_loss:.4f}")

    def ensure_best_model(self):
        """Ensure the model is using the best weights"""
        if hasattr(self, 'best_state_dict') and self.best_state_dict is not None:
            self.model.load_state_dict(self.best_state_dict)
            print("Using best model weights for evaluation")
        else:
            print("No best model weights available, using current weights")

X_train_num = X_train_split[numerical_cols].reset_index(drop=True).to_numpy()
y_train = y_train_split.reset_index(drop=True).to_numpy()

X_val_num = X_val[numerical_cols].reset_index(drop=True).to_numpy()
y_val_val = y_val.reset_index(drop=True).to_numpy()

categorical_cols = []
cat_dims = []

## 4. Hyperparameter Grid Search

Performing grid search to find the optimal hyperparameters for the FT-Transformer model.

### Grid Search Implementation

We'll search over key hyperparameters that significantly impact model performance:
- Model architecture parameters (d_block, n_blocks, attention_n_heads)
- Regularization parameters (attention_dropout, ffn_dropout)
- Training parameters (learning rate, batch size)

In [None]:
import itertools
from sklearn.metrics import roc_auc_score, average_precision_score
import os
import json
from datetime import datetime

class GridSearchFTTransformer:
    """
    Grid search implementation for FT-Transformer hyperparameter optimization
    """

    def __init__(self, X_train, y_train, X_val, y_val, X_test, y_test, feature_names, numerical_cols,
                 categorical_cols=None, cat_dims=None, device=None, seed=42):
        self.X_train = X_train
        self.y_train = y_train
        self.X_val = X_val
        self.y_val = y_val
        self.X_test = X_test
        self.y_test = y_test
        self.feature_names = feature_names
        self.numerical_cols = numerical_cols
        self.categorical_cols = categorical_cols or []
        self.cat_dims = cat_dims or []
        self.device = device or globals().get('device', torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        self.seed = seed

        self.results = []
        self.best_params = None
        self.best_score = -np.inf
        self.best_model = None

    def define_search_space(self, search_type='quick'):
        """
        Define hyperparameter search space

        Args:
            search_type: 'quick' for fast search, 'comprehensive' for thorough search
        """
        if search_type == 'quick':
            # Quick search with fewer combinations
            param_grid = {
                # Model architecture
                'd_block': [128, 192, 256],
                'n_blocks': [2, 3],
                'attention_n_heads': [4, 8],

                # Regularization
                'attention_dropout': [0.1, 0.2],
                'ffn_dropout': [0.1, 0.2],

                # Training parameters
                'learning_rate': [0.0003, 0.0005, 0.001],
                'batch_size_multiplier': [1, 2]
            }
        else:  # comprehensive
            param_grid = {
                # Model architecture
                'd_block': [96, 128, 192, 256, 320],
                'n_blocks': [1, 2, 3, 4],
                'attention_n_heads': [2, 4, 6, 8, 10, 12],

                # Regularization
                'attention_dropout': [0.0, 0.1, 0.2, 0.3],
                'ffn_dropout': [0.0, 0.1, 0.2, 0.3],
                'residual_dropout': [0.0, 0.1],

                # Training parameters
                'learning_rate': [0.0001, 0.0003, 0.0005, 0.001, 0.002],
                'batch_size_multiplier': [0.5, 1, 2, 4]
            }

        return param_grid

    def create_param_combinations(self, param_grid, max_combinations=None):
        """Create all parameter combinations"""
        keys = list(param_grid.keys())
        values = list(param_grid.values())

        combinations = list(itertools.product(*values))

        if max_combinations and len(combinations) > max_combinations:
            # Randomly sample combinations if too many
            np.random.seed(self.seed)
            indices = np.random.choice(len(combinations), max_combinations, replace=False)
            combinations = [combinations[i] for i in indices]
            print(f"Randomly sampled {max_combinations} combinations from {len(list(itertools.product(*values)))}")

        param_combinations = []
        for combo in combinations:
            param_dict = dict(zip(keys, combo))
            param_combinations.append(param_dict)

        return param_combinations

    def train_single_model(self, params, trial_num, total_trials):
        """Train a single model with given parameters"""
        print(f"\nTrial {trial_num}/{total_trials}")
        print(f"Parameters: {params}")

        try:
            # Extract training parameters
            learning_rate = params.pop('learning_rate', 0.0005)
            batch_size_multiplier = params.pop('batch_size_multiplier', 1)

                            # Calculate batch size
            base_batch_size = 4096 if torch.cuda.is_available() and self.device.type == 'cuda' else 1024
            batch_size = int(base_batch_size * batch_size_multiplier)

            print(f"🔧 Using batch_size: {batch_size} (base: {base_batch_size}, multiplier: {batch_size_multiplier})")

            # Create model with current parameters
            model = FraudFTTransformer(
                n_num_features=len(self.feature_names),
                n_cat_features=len(self.categorical_cols),
                n_cont_features=len(self.numerical_cols),
                cat_cardinalities=self.cat_dims,
                device=self.device,
                seed=self.seed,
                **params  # Pass model architecture parameters
            )

            # Update learning rate
            for param_group in model.optimizer.param_groups:
                param_group['lr'] = learning_rate

            # Train model with early stopping
            history = model.fit(
                self.X_train, self.y_train,
                self.X_val, self.y_val,
                epochs=17,  # Reduced epochs for grid search
                batch_size=batch_size,
                patience=5
            )

            # Ensure we're using the best model for evaluation
            model.ensure_best_model()

            # Evaluate on TEST set using both ROC-AUC and PR-AUC
            test_probs = model.predict_proba(self.X_test)[:, 1]
            test_roc_auc = roc_auc_score(self.y_test, test_probs)
            test_pr_auc = average_precision_score(self.y_test, test_probs)

            # Get best validation loss from training
            best_val_loss = getattr(model, 'best_val_loss', None)
            if best_val_loss is None and history['val_loss']:
                best_val_loss = min(history['val_loss'])

            result = {
                'trial': trial_num,
                'params': {**params, 'learning_rate': learning_rate, 'batch_size': batch_size},
                'test_roc_auc': float(test_roc_auc),
                'test_pr_auc': float(test_pr_auc),  # Primary metric for model selection
                'val_loss': float(best_val_loss),
                'training_epochs': len(history['val_auc']) if history['val_auc'] else 0,
                'timestamp': datetime.now().isoformat()
            }

            print(f"Trial {trial_num} completed - Test PR-AUC: {test_pr_auc:.4f}, Test ROC-AUC: {test_roc_auc:.4f}, Val Loss: {best_val_loss:.4f}")

            # Update best model if this is better (using PR-AUC as primary metric)
            if test_pr_auc > self.best_score:
                self.best_score = test_pr_auc
                self.best_params = result['params'].copy()
                # Save best model
                # model.save_model(f'../models/ft_transformer_best_trial_{trial_num}.pth')
                self.best_model = model
                print(f"New best model! Test PR-AUC: {test_pr_auc:.4f}, Test ROC-AUC: {test_roc_auc:.4f}")

                            # Cleanup
                model.cleanup_memory()
                del model

                # Additional cleanup for Jupyter
                import gc
                gc.collect()

            return result

        except Exception as e:
            print(f"Trial {trial_num} failed: {str(e)}")
            return {
                'trial': trial_num,
                'params': params,
                'test_roc_auc': -1,
                'test_pr_auc': -1,
                'val_loss': float('inf'),
                'error': str(e),
                'timestamp': datetime.now().isoformat()
            }

    def run_grid_search(self, search_type='quick', max_combinations=20, save_results=True):
        """
        Run the complete grid search

        Args:
            search_type: 'quick' or 'comprehensive'
            max_combinations: Maximum number of combinations to try
            save_results: Whether to save results to file
        """
        print(f"Starting {search_type} grid search for FT-Transformer")
        print(f"Maximum combinations: {max_combinations}")
        print(f"Device: {self.device}")

        # Define search space
        param_grid = self.define_search_space(search_type)
        print(f"Search space: {param_grid}")

        # Create parameter combinations
        param_combinations = self.create_param_combinations(param_grid, max_combinations)
        total_trials = len(param_combinations)

        print(f"Total trials to run: {total_trials}")

        # Run trials
        start_time = datetime.now()

        for i, params in enumerate(param_combinations, 1):
            result = self.train_single_model(params.copy(), i, total_trials)
            self.results.append(result)

            # Save intermediate results
            if save_results and i % 5 == 0:
                self.save_results(f'../results/grid_search_intermediate_{i}.json')

        end_time = datetime.now()
        duration = end_time - start_time

        print(f"\nGrid search completed!")
        print(f"Total time: {duration}")
        print(f"Best PR-AUC: {self.best_score:.4f}")
        print(f"Best parameters: {self.best_params}")

        if save_results:
            self.save_results('../results/grid_search_final_results.json')

        print(self.results)
        return self.results

    def save_results(self, filepath):
        """Save grid search results"""
        os.makedirs(os.path.dirname(filepath), exist_ok=True)

        results_summary = {
            'search_completed': datetime.now().isoformat(),
            'total_trials': len(self.results),
            'best_score': self.best_score,
            'best_params': self.best_params,
            'device_used': str(self.device),
            'all_results': self.results
        }

        with open(filepath, 'w') as f:
            json.dump(results_summary, f, indent=2, default=str)

        print(f"Results saved to {filepath}")

    def get_top_results(self, n=5):
        """Get top N results based on PR-AUC (TEST set)"""
        if not self.results:
            return []

        # Filter out failed trials and sort by PR-AUC (TEST set)
        valid_results = [r for r in self.results if r['test_pr_auc'] > 0]
        sorted_results = sorted(valid_results, key=lambda x: x['test_pr_auc'], reverse=True)

        return sorted_results[:n]

    def plot_results(self):
        """Plot grid search results"""
        if not self.results:
            print("No results to plot")
            return

        valid_results = [r for r in self.results if r['test_pr_auc'] > 0]

        if not valid_results:
            print("No valid results to plot")
            return

        # Extract data for plotting (TEST set metrics)
        pr_aucs = [r['test_pr_auc'] for r in valid_results]
        roc_aucs = [r['test_roc_auc'] for r in valid_results]
        trials = [r['trial'] for r in valid_results]

        plt.figure(figsize=(15, 10))

        # Plot 1: PR-AUC by trial
        plt.subplot(2, 3, 1)
        plt.plot(trials, pr_aucs, 'bo-', alpha=0.7)
        plt.axhline(y=self.best_score, color='r', linestyle='--', label=f'Best PR-AUC: {self.best_score:.4f}')
        plt.xlabel('Trial')
        plt.ylabel('Validation PR-AUC')
        plt.title('Validation PR-AUC by Trial')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # Plot 2: ROC-AUC by trial
        plt.subplot(2, 3, 2)
        plt.plot(trials, roc_aucs, 'go-', alpha=0.7)
        plt.xlabel('Trial')
        plt.ylabel('Validation ROC-AUC')
        plt.title('Validation ROC-AUC by Trial')
        plt.grid(True, alpha=0.3)

        # Plot 3: PR-AUC distribution
        plt.subplot(2, 3, 3)
        plt.hist(pr_aucs, bins=min(20, len(pr_aucs)//2), alpha=0.7, edgecolor='black')
        plt.axvline(x=self.best_score, color='r', linestyle='--', label=f'Best PR-AUC: {self.best_score:.4f}')
        plt.xlabel('Validation PR-AUC')
        plt.ylabel('Frequency')
        plt.title('Distribution of Validation PR-AUC')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # Plot 4: Parameter analysis (if we have enough data)
        if len(valid_results) > 5:
            plt.subplot(2, 3, 4)
            # Analyze d_block parameter
            d_blocks = [r['params'].get('d_block', 192) for r in valid_results]
            pr_auc_by_d_block = {}
            for db, pr_auc in zip(d_blocks, pr_aucs):
                if db not in pr_auc_by_d_block:
                    pr_auc_by_d_block[db] = []
                pr_auc_by_d_block[db].append(pr_auc)

            db_means = {k: np.mean(v) for k, v in pr_auc_by_d_block.items()}
            plt.bar(db_means.keys(), db_means.values(), alpha=0.7)
            plt.xlabel('d_block')
            plt.ylabel('Mean Validation PR-AUC')
            plt.title('Performance by d_block')
            plt.grid(True, alpha=0.3)

        # Plot 5: PR-AUC vs ROC-AUC scatter
        plt.subplot(2, 3, 5)
        plt.scatter(roc_aucs, pr_aucs, alpha=0.7)
        plt.xlabel('ROC-AUC')
        plt.ylabel('PR-AUC')
        plt.title('PR-AUC vs ROC-AUC')
        plt.grid(True, alpha=0.3)

        # Plot 6: Top results
        plt.subplot(2, 3, 6)
        top_results = self.get_top_results(min(10, len(valid_results)))
        top_pr_aucs = [r['test_pr_auc'] for r in top_results]
        top_trials = [r['trial'] for r in top_results]

        plt.barh(range(len(top_pr_aucs)), top_pr_aucs, alpha=0.7)
        plt.yticks(range(len(top_pr_aucs)), [f"Trial {t}" for t in top_trials])
        plt.xlabel('Test PR-AUC')
        plt.title(f'Top {len(top_results)} Results (Test PR-AUC)')
        plt.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

def create_safe_dataloader(dataset, batch_size, shuffle=False, pin_memory=False):
    """
    Create a Jupyter-safe DataLoader with proper configuration
    """
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        pin_memory=pin_memory,
        num_workers=0,  # Always 0 for Jupyter compatibility
        drop_last=False
    )

print("Grid search implementation ready!")
print("DataLoader configuration optimized for Jupyter notebook environment")

# Create directories for saving results and models
# os.makedirs('../results', exist_ok=True)
# os.makedirs('../models', exist_ok=True)

# Initialize grid search
# Prepare test data for grid search
X_test_df = pd.DataFrame(X_test, columns=feature_names)
X_test_num_gs = X_test_df[numerical_cols].reset_index(drop=True).to_numpy()

grid_search = GridSearchFTTransformer(
    X_train=X_train_num,
    y_train=y_train,
    X_val=X_val_num,
    y_val=y_val_val,
    X_test=X_test_num_gs,
    y_test=y_test,
    feature_names=feature_names,
    numerical_cols=numerical_cols,
    categorical_cols=categorical_cols,
    cat_dims=cat_dims,
    device=device,
    seed=random_seed
)

# Run grid search (quick version for demonstration)
# You can change to 'comprehensive' for more thorough search
search_results = grid_search.run_grid_search(
    search_type='comp',  # or 'comprehensive'
    max_combinations=100,  # Adjust based on your time/compute budget
    save_results=False
)


In [None]:
# Analyze grid search results
print("\n" + "="*60)
print("GRID SEARCH RESULTS ANALYSIS")
print("="*60)

# Display top results
top_results = grid_search.get_top_results(5)
print(f"\nTop 5 Results (ranked by Test PR-AUC):")
for i, result in enumerate(top_results, 1):
    print(f"\n{i}. Trial {result['trial']} - Test PR-AUC: {result['test_pr_auc']:.4f}, Test ROC-AUC: {result['test_roc_auc']:.4f}")
    print(f"   Parameters: {result['params']}")

# Plot results
print(f"\nPlotting grid search results...")
grid_search.plot_results()

# Get the best model for final evaluation
best_model = grid_search.best_model
if best_model is not None:
    print(f"\nBest model loaded with Test PR-AUC: {grid_search.best_score:.4f}")
    print(f"Best parameters: {grid_search.best_params}")

    # Ensure we're using the best weights
    best_model.ensure_best_model()

    # Save the best model as the final model
    # best_model.save_model('../models/ft_transformer_best_final.pth')

    # Update the global ft_model to use the best one
    ft_model = best_model
    print("Updated global ft_model to use the best hyperparameters")
else:
    print("No best model found, using original model")

print(f"\nGrid search completed successfully!")
print(f"Final model memory usage: {ft_model.get_memory_usage()}")
