In [None]:
"""DICE algorithm
Input : parquet file with classification value 1 as first column, and feature values 
Output : list of valid counterfactuals, with classification value 0 at first column, 
    and with a new last column with the original factual row index"""

import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import time
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

class MLP(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, 2000)
        self.bn1 = nn.BatchNorm1d(2000)
        self.dropout1 = nn.Dropout(0.02)
        self.fc2 = nn.Linear(2000, 200)
        self.bn2 = nn.BatchNorm1d(200)
        self.dropout2 = nn.Dropout(0.02)
        self.fc3 = nn.Linear(200, 20)
        self.bn3 = nn.BatchNorm1d(20)
        self.dropout3 = nn.Dropout(0.02)
        self.fc4 = nn.Linear(20, 1)
        self.sigmoid = nn.Sigmoid()
        self.leaky_relu = nn.LeakyReLU(0.01)
        
    def forward(self, x):
        x = self.leaky_relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = self.leaky_relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        x = self.leaky_relu(self.bn3(self.fc3(x)))
        x = self.dropout3(x)
        x = self.sigmoid(self.fc4(x))
        return x

def load_model(model_path, input_size, device):
    """Load the pretrained MLP model"""
    model = MLP(input_size).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

def hinge_loss(predictions, desired_class=0):
    """
    Compute hinge loss for counterfactual generation
    Ensures zero penalty when predictions are already in the desired class
    """
    # For desired class 0, we want predictions to be below 0.5
    if desired_class == 0:
        # Compute logits (inverse of sigmoid)
        logits = -torch.log(1/predictions - 1)
        # Returns zero when prediction is already below 0.5 (logit < 0)
        return torch.max(torch.zeros_like(logits), 1 + logits)
    else:
        # Compute logits
        logits = -torch.log(1/predictions - 1)
        # Returns zero when prediction is already above 0.5 (logit > 0)
        return torch.max(torch.zeros_like(logits), 1 - logits)

def compute_l1_distance(x1, x2):
    """Compute L1 distance between two tensors"""
    return torch.abs(x1 - x2).sum(dim=1, keepdim=True)

def compute_dpp_diversity(counterfactuals, epsilon=1e-6):
    """
    Compute determinantal point process diversity
    K_ij = 1 / (1 + dist(ci, cj)) where dist is L1 distance
    """
    batch_size, num_cfs, num_features = counterfactuals.shape
    
    # Compute pairwise L1 distances for all counterfactuals in batch
    # Reshape for broadcasting
    cfs_expanded1 = counterfactuals.view(batch_size, num_cfs, 1, num_features)
    cfs_expanded2 = counterfactuals.view(batch_size, 1, num_cfs, num_features)
    
    # Calculate L1 distances
    pairwise_distances = torch.abs(cfs_expanded1 - cfs_expanded2).sum(dim=3)
    
    # Compute kernel matrix K
    K = 1.0 / (1.0 + pairwise_distances)
    
    # Add small noise to diagonal to avoid ill-conditioned matrices
    diag_noise = torch.randn(batch_size, num_cfs, device=K.device) * epsilon
    K = K + torch.diag_embed(diag_noise)
    
    # Compute determinant for each matrix in the batch
    dpp_diversity = torch.linalg.det(K)
    
    return dpp_diversity

def generate_counterfactuals(
    data_path,
    model_path,
    proximity_weight=0.5,
    diversity_weight=1.0,
    gpu_id=0,
    num_cfs=5,
    batch_size=32,
    learning_rate=0.01,
    max_iterations=1000,
    early_stop_threshold=0.001,
    output_path='counterfactuals.parquet'
):
    """
    Generate counterfactual explanations for samples in the dataset
    
    Parameters:
    - data_path: Path to input DataFrame (parquet format)
    - model_path: Path to pretrained MLP model
    - proximity_weight: Weight for proximity loss (Œª1)
    - diversity_weight: Weight for diversity loss (Œª2)
    - gpu_id: GPU device ID to use
    - num_cfs: Number of counterfactuals to generate per sample
    - batch_size: Batch size for processing
    - learning_rate: Learning rate for optimization
    - max_iterations: Maximum number of optimization iterations
    - early_stop_threshold: Threshold for early stopping
    - output_path: Path to save generated counterfactuals
    """
    start_time = time.time()
    
    # Set up GPU
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load data
    print("Loading data...")
    df = pd.read_parquet(data_path)
    
    # Extract features and labels
    y = df.iloc[:, 0].values  # First column is the label
    X = df.iloc[:, 1:].values  # Remaining columns are features
    
    # Save the original DataFrame indices for later mapping
    original_indices = df.index.tolist()
    
    # Convert to tensors
    X_tensor = torch.tensor(X, dtype=torch.float32)
    y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)
    
    # Create dataset and dataloader
    dataset = TensorDataset(X_tensor, y_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    # Load model
    input_size = X.shape[1]
    model = load_model(model_path, input_size, device)
    
    # Initialize list to store all results
    all_results = []
    
    print(f"Generating {num_cfs} counterfactuals per sample in batches of {batch_size}...")
    
    # Process batches
    for batch_idx, (X_batch, _) in enumerate(tqdm(dataloader)):
        batch_start_idx = batch_idx * batch_size
        batch_size_actual = X_batch.size(0)  # Actual batch size (may be smaller for last batch)
        
        # Move batch to device
        X_batch = X_batch.to(device)
        
        # Generate counterfactuals for the batch
        counterfactuals = optimize_counterfactuals(
            model=model,
            original_samples=X_batch,
            num_cfs=num_cfs,
            proximity_weight=proximity_weight,
            diversity_weight=diversity_weight,
            learning_rate=learning_rate,
            max_iterations=max_iterations,
            early_stop_threshold=early_stop_threshold,
            device=device
        )
        
        # Move results back to CPU for storage
        counterfactuals = counterfactuals.detach().cpu()
        
        # Evaluate model predictions on counterfactuals
        with torch.no_grad():
            cf_flat = counterfactuals.reshape(-1, input_size)
            model.eval()
            predictions = model(cf_flat.to(device)).reshape(batch_size_actual, num_cfs).cpu()
        
        # Store counterfactuals and their metadata
        for i in range(batch_size_actual):
            # Get the actual DataFrame index for this sample
            position_idx = batch_start_idx + i
            if position_idx < len(original_indices):
                df_index = original_indices[position_idx]
                
                for j in range(num_cfs):
                    # Get prediction for this counterfactual
                    pred = predictions[i, j].item()
                    
                    # Store only if it's a valid counterfactual (predicted as class 0)
                    if pred < 0.5:
                        # Store the counterfactual with its target class (0) and DataFrame index
                        cf_with_metadata = torch.cat([
                            torch.tensor([0.0]),  # Target class is 0
                            counterfactuals[i, j],
                            torch.tensor([float(df_index)])  # DataFrame index (not position)
                        ])
                        all_results.append(cf_with_metadata.numpy())
    
    # Convert results to DataFrame and save
    if all_results:
        result_df = pd.DataFrame(all_results)
        result_df.columns = ['class'] + list(df.columns[1:]) + ['original_index']
        result_df.to_parquet(output_path)
        print(f"Generated {len(result_df)} counterfactuals saved to {output_path}")
    else:
        print("No counterfactuals were successfully generated.")
    
    elapsed_time = time.time() - start_time
    print(f"Total execution time: {elapsed_time:.2f} seconds")
    
    return result_df if all_results else None

def optimize_counterfactuals(
    model,
    original_samples,
    num_cfs=5,
    proximity_weight=0.5,
    diversity_weight=1.0,
    learning_rate=0.01,
    max_iterations=1000,
    early_stop_threshold=0.001,
    device=None
):
    """
    Optimize counterfactuals for a batch of samples
    
    Parameters:
    - model: The MLP model
    - original_samples: Batch of original samples
    - num_cfs: Number of counterfactuals to generate per sample
    - proximity_weight: Weight for proximity loss (Œª1)
    - diversity_weight: Weight for diversity loss (Œª2)
    - learning_rate: Learning rate for optimization
    - max_iterations: Maximum number of optimization iterations
    - early_stop_threshold: Threshold for early stopping
    - device: Device to use for computation
    
    Returns:
    - counterfactuals: Tensor of shape (batch_size, num_cfs, num_features)
    """
    if device is None:
        device = original_samples.device
    
    batch_size, num_features = original_samples.shape
    
    # Initialize counterfactuals with random noise around original samples
    counterfactuals = original_samples.unsqueeze(1).repeat(1, num_cfs, 1)
    counterfactuals = counterfactuals + torch.randn_like(counterfactuals) * 0.1
    counterfactuals = counterfactuals.clamp(0, 1)  # Ensure values are normalized
    
    # Make counterfactuals trainable
    counterfactuals = counterfactuals.clone().detach().requires_grad_(True)
    
    # Setup optimizer
    optimizer = optim.Adam([counterfactuals], lr=learning_rate)
    
    # For early stopping
    prev_loss = float('inf')
    
    # Optimization loop
    for iteration in range(max_iterations):
        optimizer.zero_grad()
        
        # Reshape counterfactuals for model input
        cf_flat = counterfactuals.reshape(-1, num_features)
        
        # Get model predictions
        predictions = model(cf_flat).reshape(batch_size, num_cfs)
        
        # Compute y-loss (hinge loss)
        y_loss = hinge_loss(predictions, desired_class=0).mean()
        
        # Compute proximity loss
        original_expanded = original_samples.unsqueeze(1).expand_as(counterfactuals)
        proximity_loss = compute_l1_distance(counterfactuals, original_expanded).mean()
        
        # Compute diversity loss
        diversity_loss = -compute_dpp_diversity(counterfactuals).mean()
        
        # Combine losses
        total_loss = y_loss + proximity_weight * proximity_loss + diversity_weight * diversity_loss
        
        # Backward pass and optimization
        total_loss.backward()
        optimizer.step()
        
        # Project back to [0, 1] bounds
        with torch.no_grad():
            counterfactuals.clamp_(0, 1)
        
        # Early stopping check
        if iteration % 50 == 0:
            current_loss = total_loss.item()
            loss_difference = abs(prev_loss - current_loss)
            if loss_difference < early_stop_threshold:
                break
            prev_loss = current_loss
    
    return counterfactuals


# For n=3 counterfactuals
generate_counterfactuals(
    data_path='data_eval_norm.parquet',
    model_path='mlp_model.pth',
    proximity_weight=0.01,
    diversity_weight=1,
    gpu_id=0,
    num_cfs=3,
    batch_size=1024,
    output_path='3cf_p001.parquet'
)


In [None]:
"""DICE post-hoc sparsity algorithm, adapted with percentile and gradient method, and ranking reversion by ascending or descending values
Input : Output of previous script
Output : Same format but sparser"""

import torch
import pandas as pd
import numpy as np
import time
from tqdm import tqdm

class MLP(torch.nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_size, 2000)
        self.bn1 = torch.nn.BatchNorm1d(2000)
        self.dropout1 = torch.nn.Dropout(0.02)
        self.fc2 = torch.nn.Linear(2000, 200)
        self.bn2 = torch.nn.BatchNorm1d(200)
        self.dropout2 = torch.nn.Dropout(0.02)
        self.fc3 = torch.nn.Linear(200, 20)
        self.bn3 = torch.nn.BatchNorm1d(20)
        self.dropout3 = torch.nn.Dropout(0.02)
        self.fc4 = torch.nn.Linear(20, 1)
        self.sigmoid = torch.nn.Sigmoid()
        self.leaky_relu = torch.nn.LeakyReLU(0.01)
        
    def forward(self, x):
        # Handle batch norm for single samples if needed
        if x.dim() == 2 and x.size(0) == 1:
            # For a single sample, clone it to make a batch of 2
            x_batch = torch.cat([x, x], dim=0)
            x_batch = self.leaky_relu(self.bn1(self.fc1(x_batch)))
            x_batch = self.dropout1(x_batch)
            x_batch = self.leaky_relu(self.bn2(self.fc2(x_batch)))
            x_batch = self.dropout2(x_batch)
            x_batch = self.leaky_relu(self.bn3(self.fc3(x_batch)))
            x_batch = self.dropout3(x_batch)
            x_batch = self.sigmoid(self.fc4(x_batch))
            return x_batch[0:1]  # Return only the first sample
        else:
            # Normal batch processing
            x = self.leaky_relu(self.bn1(self.fc1(x)))
            x = self.dropout1(x)
            x = self.leaky_relu(self.bn2(self.fc2(x)))
            x = self.dropout2(x)
            x = self.leaky_relu(self.bn3(self.fc3(x)))
            x = self.dropout3(x)
            x = self.sigmoid(self.fc4(x))
            return x

def load_model(model_path, input_size, device):
    """Load the pretrained MLP model"""
    model = MLP(input_size).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

def compute_feature_statistics(data_tensor):
    """
    Compute MAD statistics for all features directly on GPU when possible
    
    Args:
        data_tensor: Original data as a PyTorch tensor (without class column)
        
    Returns:
        Tuple of (medians, MADs) for each feature
    """
    num_features = data_tensor.shape[1]
    
    try:
        # Try to compute statistics directly on GPU
        medians = torch.median(data_tensor, dim=0)[0]
        
        # Compute deviations
        deviations = torch.abs(data_tensor - medians.unsqueeze(0))
        
        # Compute MAD for each feature
        mads = torch.median(deviations, dim=0)[0]
        
        # Replace zeros with small value to avoid division by zero
        mads = torch.clamp(mads, min=1e-6)
        
        return medians, mads
        
    except RuntimeError:
        # Fallback to CPU if GPU memory is exceeded
        print("WARNING: Falling back to CPU for statistics computation")
        data_np = data_tensor.cpu().numpy()
        medians = np.zeros(num_features)
        mads = np.zeros(num_features)
        
        for feature_idx in range(num_features):
            values = data_np[:, feature_idx]
            
            # Compute median absolute deviation (MAD)
            median = np.median(values)
            deviations = np.abs(values - median)
            mad = np.median(deviations)
            
            medians[feature_idx] = median
            mads[feature_idx] = mad if mad > 0 else 1e-6  # Avoid division by zero
        
        return (
            torch.tensor(medians, dtype=data_tensor.dtype, device=data_tensor.device),
            torch.tensor(mads, dtype=data_tensor.dtype, device=data_tensor.device)
        )

def compute_percentile_thresholds(data_tensor, percentile_param):
    """
    Compute percentile statistics for all features directly on GPU when possible
    
    Args:
        data_tensor: Original data as a PyTorch tensor (without class column)
        percentile_param: Percentile threshold (e.g., 0.1 for 10th percentile)
        
    Returns:
        Tensor of percentile thresholds for each feature
    """
    num_features = data_tensor.shape[1]
    
    try:
        # Try using torch operations on GPU
        # Compute medians on GPU
        medians = torch.median(data_tensor, dim=0)[0]
        
        # Compute deviations
        deviations = torch.abs(data_tensor - medians.unsqueeze(0))
        
        # Initialize percentiles tensor
        percentiles = torch.full((num_features,), 1e-6, 
                               dtype=data_tensor.dtype, 
                               device=data_tensor.device)
        
        # Compute percentiles for each feature
        for feature_idx in range(num_features):
            feature_devs = deviations[:, feature_idx]
            non_zero_mask = feature_devs > 0
            
            if non_zero_mask.sum() > 0:
                # Get non-zero deviations
                non_zero_devs = feature_devs[non_zero_mask]
                
                # Use torch.quantile for percentile calculation
                percentiles[feature_idx] = torch.quantile(
                    non_zero_devs, 
                    percentile_param, 
                    interpolation='linear'
                )
        
        return percentiles
        
    except (RuntimeError, AttributeError):
        # Fallback to CPU if GPU memory is exceeded or torch version doesn't support quantile
        print("WARNING: Falling back to CPU for percentile computation")
        data_np = data_tensor.cpu().numpy()
        percentiles = np.zeros(num_features)
        
        for feature_idx in range(num_features):
            values = data_np[:, feature_idx]
            
            # Compute median absolute deviation (MAD)
            median = np.median(values)
            deviations = np.abs(values - median)
            
            # Compute percentile of non-zero deviations
            non_zero_deviations = deviations[deviations > 0]
            if len(non_zero_deviations) > 0:
                percentiles[feature_idx] = np.percentile(non_zero_deviations, percentile_param * 100)
            else:
                percentiles[feature_idx] = 1e-6  # Small non-zero value
        
        return torch.tensor(percentiles, dtype=data_tensor.dtype, device=data_tensor.device)

def compute_gradient_importance(model, cf_features, original_features, mads=None, use_mad_norm=False):
    """
    Compute gradient-based importance scores for all features
    
    Args:
        model: PyTorch model
        cf_features: Counterfactual features tensor
        original_features: Original features tensor
        mads: Median absolute deviations (optional, for normalization)
        use_mad_norm: Whether to normalize by MAD
        
    Returns:
        Tensor of importance scores (lower = more mutable)
    """
    # Create a copy for gradient computation to avoid modifying the input
    with torch.enable_grad():
        cf_features_grad = cf_features.clone().requires_grad_(True)
        
        # Forward pass
        output = model(cf_features_grad)
        
        # Compute gradients
        output.sum().backward()
        
        # Calculate importance scores
        # Importance = gradient * |cf - original|
        gradients = cf_features_grad.grad.abs().detach()
    
    differences = torch.abs(cf_features - original_features)
    importance = gradients * differences
    
    # Normalize by MAD if requested
    if use_mad_norm and mads is not None:
        importance = importance / mads.unsqueeze(0)
    
    return importance

@torch.no_grad()
def enhance_sparsity(
    original_data_path,
    counterfactuals_path,
    model_path,
    sparsity_param=0.1,
    method="percentile",
    sort_descending=False,  # New parameter to control sorting direction
    gpu_id=0,
    output_path='sparse_counterfactuals.parquet',
    max_gpu_memory_fraction=0.8
):
    """
    Enhance sparsity of counterfactuals by modifying features according to specified method
    
    Args:
        original_data_path: Path to original normalized data
        counterfactuals_path: Path to generated counterfactuals
        model_path: Path to trained MLP model
        sparsity_param: Parameter controlling sparsity enhancement:
            - If method='percentile': percentile threshold (0.1 = 10th percentile)
            - If method='gradient': sparsity_param > 0.5 uses MAD normalization
        method: Feature ranking method ('percentile' or 'gradient')
        sort_descending: Whether to sort features in descending order (True) or ascending (False)
        gpu_id: GPU device ID to use
        output_path: Path to save sparsity-enhanced counterfactuals
        max_gpu_memory_fraction: Maximum fraction of GPU memory to use for tensor storage
        
    Returns:
        DataFrame with sparsity-enhanced counterfactuals
    """
    start_time = time.time()
    
    # Set device
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load data
    print("Loading data...")
    original_data = pd.read_parquet(original_data_path)
    counterfactuals = pd.read_parquet(counterfactuals_path)
    
    # Check GPU memory to determine if we need to split the data
    if torch.cuda.is_available():
        total_gpu_memory = torch.cuda.get_device_properties(gpu_id).total_memory
        available_gpu_memory = total_gpu_memory * max_gpu_memory_fraction
        print(f"Using up to {available_gpu_memory / 1024**3:.2f} GB of GPU memory")
    else:
        available_gpu_memory = float('inf')  # No GPU limit
    
    # Extract feature names (excluding class and original_index columns)
    cf_columns = list(counterfactuals.columns)
    feature_names = cf_columns[1:-1]  # Exclude first (class) and last (original_index) columns
    num_features = len(feature_names)
    
    # Load model
    model = load_model(model_path, num_features, device)
    
    # Calculate memory requirements - more conservative estimation due to vectorization
    total_counterfactuals = len(counterfactuals)
    memory_per_sample = num_features * 4 * 7  # 7 tensors with 4 bytes per float32
    total_memory_needed = total_counterfactuals * memory_per_sample
    
    # Determine batch size based on GPU memory
    if total_memory_needed > available_gpu_memory:
        batch_size = int(available_gpu_memory / memory_per_sample)
        print(f"Processing in batches of {batch_size} samples due to GPU memory constraints")
        use_batching = True
    else:
        batch_size = total_counterfactuals
        use_batching = False
        print("Processing all counterfactuals in a single batch")
    
    # Parse method parameters
    use_mad_norm = False
    if method == "gradient" or method == "grad":  # Added "grad" as an acceptable value
        if sparsity_param > 0.5:
            use_mad_norm = True
            print(f"Using gradient method with MAD normalization")
        else:
            print(f"Using gradient method without MAD normalization")
    else:  # Default to percentile method
        method = "percentile"
        print(f"Using percentile method with parameter {sparsity_param}")
    
    # Process data in batches if needed
    sparse_counterfactuals = counterfactuals.copy()
    
    # Pre-compute statistics for all original data if it fits in memory
    if not use_batching and method == "percentile":
        print("Pre-computing statistics for all data...")
        all_original_features = torch.tensor(
            original_data[feature_names].values, 
            dtype=torch.float32, 
            device=device
        )
        medians, mads = compute_feature_statistics(all_original_features)
        percentiles = compute_percentile_thresholds(all_original_features, sparsity_param)
        all_thresholds = torch.minimum(mads, percentiles)
        # Use the sort_descending parameter here for global sorting
        _, global_sorted_indices = torch.sort(all_thresholds, descending=sort_descending)
        
        # Free memory
        del all_original_features
        torch.cuda.empty_cache()
    
    # Main batch processing loop
    for batch_start in tqdm(range(0, total_counterfactuals, batch_size)):
        batch_end = min(batch_start + batch_size, total_counterfactuals)
        batch_cf = counterfactuals.iloc[batch_start:batch_end]
        
        # Get original indices and find corresponding original samples
        original_indices = batch_cf['original_index'].astype(int).values
        original_samples = original_data.iloc[original_indices]
        
        # Extract features as tensors
        cf_features = torch.tensor(batch_cf[feature_names].values, dtype=torch.float32, device=device)
        original_features = torch.tensor(original_samples[feature_names].values, dtype=torch.float32, device=device)
        
        # Check initial predictions (we want to keep class = 0)
        initial_preds = model(cf_features).view(-1)
        valid_cf_mask = initial_preds < 0.5
        
        # Skip if no valid counterfactuals
        if not valid_cf_mask.any():
            continue
        
        # Make a working copy of the counterfactuals to update
        working_cf_features = cf_features.clone()
        
        # Get active samples mask (initially all valid counterfactuals)
        active_samples_mask = valid_cf_mask.clone()
        
        # Compute feature ranking if not pre-computed
        if use_batching or method != "percentile":
            # Compute feature statistics for this batch
            medians, mads = compute_feature_statistics(original_features)
            
            if method == "percentile":
                # Compute percentile thresholds
                percentiles = compute_percentile_thresholds(original_features, sparsity_param)
                # Use min(MAD, percentile) as thresholds
                thresholds = torch.minimum(mads, percentiles)
                # Sort features by thresholds using the sort_descending parameter
                _, sorted_indices = torch.sort(thresholds, descending=sort_descending)
            else:  # method == "gradient" or method == "grad"
                # Compute gradient-based importance scores
                with torch.enable_grad():
                    cf_features_grad = cf_features.clone().requires_grad_(True)
                    output = model(cf_features_grad)
                    output.sum().backward()
                    gradients = cf_features_grad.grad.abs()
                
                # Calculate importance
                differences = torch.abs(cf_features - original_features)
                importance = gradients * differences
                
                # Normalize by MAD if requested
                if use_mad_norm:
                    importance = importance / mads.unsqueeze(0)
                
                # Average importance across samples
                avg_importance = importance.mean(dim=0)
                # Sort features by importance using the sort_descending parameter
                _, sorted_indices = torch.sort(avg_importance, descending=sort_descending)
        else:
            # Use pre-computed global indices for percentile method
            sorted_indices = global_sorted_indices
        
        # Process features in sorted order - VECTORIZED APPROACH
        for feature_idx in sorted_indices:
            # Skip if no active samples
            if not active_samples_mask.any():
                break
                
            # Store current values for the specific feature across active samples
            current_feature_values = working_cf_features[active_samples_mask, feature_idx].clone()
            
            # Tentatively revert this feature to original values for all active samples
            working_cf_features[active_samples_mask, feature_idx] = original_features[active_samples_mask, feature_idx]
            
            # Check which samples remain valid counterfactuals in a single batch
            batch_preds = model(working_cf_features[active_samples_mask]).view(-1)
            
            # Find which samples flipped to class 1 (invalid)
            flipped_mask = batch_preds >= 0.5
            
            # If any samples flipped, revert just those changes
            if flipped_mask.any():
                # Find the original indices of flipped samples within the active set
                flipped_indices = torch.where(active_samples_mask)[0][flipped_mask]
                
                # Revert the change for samples that flipped
                working_cf_features[flipped_indices, feature_idx] = current_feature_values[flipped_mask]
        
        # Update the dataframe with the sparsified counterfactuals
        sparse_counterfactuals.iloc[batch_start:batch_end, 1:-1] = working_cf_features.cpu().numpy()
    
    # Save the sparse counterfactuals
    sparse_counterfactuals.to_parquet(output_path)
    
    elapsed_time = time.time() - start_time
    print(f"Total execution time: {elapsed_time:.2f} seconds")
    print(f"Enhanced counterfactuals saved to {output_path}")
    
    return sparse_counterfactuals


enhance_sparsity(
    original_data_path='data_norm.parquet',
    counterfactuals_path='3cf_p100.parquet',
    model_path='mlp_model.pth',
    sparsity_param=0.1,
    method="percentile",
    sort_descending=True,
    gpu_id=0,
    output_path='sparse_3cf_p100_desc.parquet',
    max_gpu_memory_fraction=0.9)


In [None]:
"""Evaluation script for validity, sparsity, proximity, diversity, 10-NN distance to average, plausibility
Intput : Output dataframe of DICE or post-hoc sparsity algorithm
Output : JSON dictionnary with min, average, max values and std for the dataframe"""

import torch
import numpy as np
import pandas as pd
import time
import os
from tqdm import tqdm
import math

class MLP(torch.nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_size, 2000)
        self.bn1 = torch.nn.BatchNorm1d(2000)
        self.dropout1 = torch.nn.Dropout(0.02)
        self.fc2 = torch.nn.Linear(2000, 200)
        self.bn2 = torch.nn.BatchNorm1d(200)
        self.dropout2 = torch.nn.Dropout(0.02)
        self.fc3 = torch.nn.Linear(200, 20)
        self.bn3 = torch.nn.BatchNorm1d(20)
        self.dropout3 = torch.nn.Dropout(0.02)
        self.fc4 = torch.nn.Linear(20, 1)
        self.sigmoid = torch.nn.Sigmoid()
        self.leaky_relu = torch.nn.LeakyReLU(0.01)
        
    def forward(self, x):
        # Handle batch norm for single samples if needed
        if x.dim() == 2 and x.size(0) == 1:
            # For a single sample, clone it to make a batch of 2
            x_batch = torch.cat([x, x], dim=0)
            x_batch = self.leaky_relu(self.bn1(self.fc1(x_batch)))
            x_batch = self.dropout1(x_batch)
            x_batch = self.leaky_relu(self.bn2(self.fc2(x_batch)))
            x_batch = self.dropout2(x_batch)
            x_batch = self.leaky_relu(self.bn3(self.fc3(x_batch)))
            x_batch = self.dropout3(x_batch)
            x_batch = self.sigmoid(self.fc4(x_batch))
            return x_batch[0:1]  # Return only the first sample
        else:
            # Normal batch processing
            x = self.leaky_relu(self.bn1(self.fc1(x)))
            x = self.dropout1(x)
            x = self.leaky_relu(self.bn2(self.fc2(x)))
            x = self.dropout2(x)
            x = self.leaky_relu(self.bn3(self.fc3(x)))
            x = self.dropout3(x)
            x = self.sigmoid(self.fc4(x))
            return x
        
    def find_boundary_intersection(self, x_factual, x_counterfactual, eps=1e-5, max_iterations=50):
        """
        Find the point on the decision boundary that intersects the line between factual and counterfactual
        
        Args:
            x_factual: Factual sample tensor
            x_counterfactual: Counterfactual sample tensor
            eps: Precision threshold
            max_iterations: Maximum number of binary search iterations
            
        Returns:
            Intersection point and distance from factual to intersection
        """
        with torch.no_grad():
            # Get predictions for both points
            fact_pred = self.forward(x_factual).item()
            cf_pred = self.forward(x_counterfactual).item()
            
            # Check if predictions are different enough (i.e., on opposite sides of the boundary)
            # Use a slightly relaxed condition to handle border cases
            if (fact_pred > 0.55 and cf_pred > 0.55) or (fact_pred < 0.45 and cf_pred < 0.45):
                # If both are clearly on the same side, use the distance to counterfactual
                # This is a reasonable fallback rather than returning infinity
                return x_counterfactual, torch.sum(torch.abs(x_counterfactual - x_factual)).item()
            
            # Direction vector from factual to counterfactual
            direction = x_counterfactual - x_factual
            
            # Get direction magnitude (for normalizing)
            direction_mag = torch.norm(direction, p=1).item()
            if direction_mag < eps:  # If points are too close
                return x_factual, 0.0
                
            # Initialize binary search
            low = 0.0  # factual point
            high = 1.0  # counterfactual point
            
            # Binary search for the decision boundary
            for _ in range(max_iterations):
                mid = (low + high) / 2.0
                x_mid = x_factual + mid * direction
                pred_mid = self.forward(x_mid).item()
                
                # Check if we're close enough to the boundary
                if abs(pred_mid - 0.5) < eps:
                    # Calculate L1 distance from factual to intersection
                    distance = torch.sum(torch.abs(x_mid - x_factual)).item()
                    return x_mid, distance
                
                # Update search range
                if (pred_mid > 0.5 and fact_pred > 0.5) or (pred_mid < 0.5 and fact_pred < 0.5):
                    # Mid point is on same side as factual
                    low = mid
                else:
                    # Mid point is on same side as counterfactual
                    high = mid
                    
                # Check if our search range is small enough
                if high - low < eps:
                    x_intersection = x_factual + mid * direction
                    distance = torch.sum(torch.abs(x_intersection - x_factual)).item()
                    return x_intersection, distance
            
            # If we reach max iterations, return the midpoint
            x_intersection = x_factual + ((low + high) / 2.0) * direction
            distance = torch.sum(torch.abs(x_intersection - x_factual)).item()
            return x_intersection, distance

def load_model(model_path, input_size, device):
    """Load the pretrained MLP model"""
    model = MLP(input_size).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

def compute_median_absolute_deviation(data_tensor):
    """Compute median absolute deviation for each feature"""
    median = torch.median(data_tensor, dim=0).values
    deviation = torch.abs(data_tensor - median)
    mad = torch.median(deviation, dim=0).values
    # Replace zero MAD values with 1 to avoid division by zero
    mad[mad == 0] = 1.0
    return mad

def update_evaluation_file(name, metric, value, results, output_file="evaluation.txt"):
    """Update the evaluation file with the current results"""
    # Store the result - don't write to file yet (will be done at the end)
    results[name][metric] = value
    
    # Print the update
    print(f"Updated results with {name} - {metric}: {value:.4f}")
    
    return results

def compute_gaussian_probability(distance, mean, variance):
    """
    Compute the probability P(abs(x-mean) > abs(distance-mean))
    This gives the probability of being further from the mean than this distance
    """
    # Calculate how far we are from the mean
    distance_from_mean = abs(distance - mean)
    
    # Compute the probability of being further from the mean (one-sided)
    # This is equivalent to 2 * (1 - CDF(|x-mean|))
    z_score = distance_from_mean / math.sqrt(variance)
    probability = 2 * (1 - 0.5 * (1 + math.erf(z_score / math.sqrt(2))))
    
    return probability


import json

@torch.no_grad()
def evaluate_counterfactuals(
    original_data_path,
    counterfactual_dfs,
    model_path,
    metrics_to_compute={
        'validity': True,
        'proximity': True,
        'sparsity': True,
        'sparsity_count': True,
        'diversity': True,
        'sparse_diversity': True,
        'avg_10nn_distance': True,
        'avg_10nn_dataset': True,
        'avg_10nn_class0': True
    },
    batch_size=512,
    gpu_id=0,
    output_path="evaluation_results.json"
):
    """
    Evaluate counterfactuals using multiple metrics
    
    Args:
        original_data_path: Path to original data
        counterfactual_dfs: Dictionary mapping names to counterfactual DataFrames or paths
        model_path: Path to trained model
        metrics_to_compute: Dictionary specifying which metrics to compute
        batch_size: Batch size for GPU processing
        gpu_id: GPU device ID to use
        output_path: Path to output evaluation results as JSON
        
    Returns:
        Dictionary of evaluation results per counterfactual set
    """
    start_time = time.time()
    
    # Set device
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load original data
    print("Loading original data...")
    original_df = pd.read_parquet(original_data_path)
    
    # Process counterfactual dataframes
    cf_data = {}
    for name, cf_df_or_path in counterfactual_dfs.items():
        if isinstance(cf_df_or_path, str):
            cf_data[name] = pd.read_parquet(cf_df_or_path)
        else:
            cf_data[name] = cf_df_or_path
    
    # Extract feature names (excluding class column and original_index)
    feature_names = list(original_df.columns)[1:]  # Exclude first column (class)
    num_features = len(feature_names)
    
    # Load model
    model = load_model(model_path, num_features, device)
    
    # Prepare data for GPU
    # Convert original data to tensor (exclude class column)
    original_features = torch.tensor(original_df.iloc[:, 1:].values, dtype=torch.float32, device=device)
    
    # Compute median absolute deviation (MAD) for normalization
    mad = compute_median_absolute_deviation(original_features)
    
    # Extract class information
    original_classes = original_df.iloc[:, 0].values
    
    # Create dictionary to store results
    results = {name: {} for name in cf_data.keys()}
    
    # Gaussian parameters for new metrics
    dataset_gaussian_mean = 8304.722022
    dataset_gaussian_variance = 4707926.898804
    class0_gaussian_mean = 5950.676427
    class0_gaussian_variance = 3400381.504375
    
    # Pre-compute class 0 mask for the original data
    class0_mask = torch.tensor(original_classes == 0, device=device)
    class0_samples = original_features[class0_mask] if torch.any(class0_mask) else None
    
    # Process each counterfactual dataset
    for name, cf_df in cf_data.items():
        print(f"\nEvaluating: {name}")
        
        # FIXED: Direct mapping from counterfactual to original row positions
        cf_to_orig_map = {}
        if 'original_index' in cf_df.columns:
            for i, row in enumerate(cf_df.itertuples()):
                # original_index directly gives the row position in data_norm
                orig_position = int(row.original_index)
                if 0 <= orig_position < len(original_df):
                    cf_to_orig_map[i] = orig_position
                else:
                    cf_to_orig_map[i] = -1
        else:
            # Assume sample_id column refers to row positions as well
            for i, row in enumerate(cf_df.itertuples()):
                if hasattr(row, 'sample_id'):
                    orig_position = int(row.sample_id)
                    if 0 <= orig_position < len(original_df):
                        cf_to_orig_map[i] = orig_position
                    else:
                        cf_to_orig_map[i] = -1
                else:
                    cf_to_orig_map[i] = -1
        
        # Convert counterfactual features to tensor (exclude class column and original_index/sample_id)
        if 'original_index' in cf_df.columns:
            cf_features = torch.tensor(cf_df.iloc[:, 1:-1].values, dtype=torch.float32, device=device)
        else:
            # Assume first column is class and rest are features (no original_index)
            cf_features = torch.tensor(cf_df.iloc[:, 1:].values, dtype=torch.float32, device=device)
        
        # Organize counterfactuals by original index
        cf_by_orig = {}
        for cf_idx, orig_idx in cf_to_orig_map.items():
            if orig_idx != -1:
                if orig_idx not in cf_by_orig:
                    cf_by_orig[orig_idx] = []
                cf_by_orig[orig_idx].append(cf_idx)
        
        # Store metric values for this counterfactual set
        metric_values = {}
        
        # ------ Compute metrics ------
        
        # 1. Validity
        if metrics_to_compute.get('validity', False):
            print("Computing validity...")
            validity_scores = []
            
            # Process in batches
            for start_idx in range(0, len(cf_features), batch_size):
                end_idx = min(start_idx + batch_size, len(cf_features))
                batch_cf = cf_features[start_idx:end_idx]
                
                # Get original indices for this batch
                batch_orig_indices = [cf_to_orig_map.get(i, -1) for i in range(start_idx, end_idx)]
                
                # Filter out invalid mappings
                valid_indices = [(i, orig_idx) for i, orig_idx in enumerate(batch_orig_indices) if orig_idx != -1]
                if not valid_indices:
                    continue
                
                # Split into cf indices and orig indices
                batch_cf_indices, batch_orig_indices = zip(*valid_indices)
                
                # Get predictions for valid counterfactuals
                valid_cf = batch_cf[list(batch_cf_indices)]
                cf_preds = model(valid_cf).cpu().numpy()
                
                # Get predictions for corresponding original samples
                orig_samples = original_features[list(batch_orig_indices)]
                orig_preds = model(orig_samples).cpu().numpy()
                
                # Compute validity for each counterfactual
                for cf_pred, orig_pred in zip(cf_preds, orig_preds):
                    # Binary label change (0 or 1)
                    validity_score = abs(float(cf_pred > 0.5) - float(orig_pred >= 0.5))
                    validity_scores.append(validity_score)
            
            if validity_scores:
                min_val = float(np.min(validity_scores))
                avg_val = float(np.mean(validity_scores))
                max_val = float(np.max(validity_scores))
                std_val = float(np.std(validity_scores))
                metric_values['validity'] = [min_val, avg_val, max_val, std_val]
                print(f"validity: min={min_val:.4f}, avg={avg_val:.4f}, max={max_val:.4f}, std={std_val:.4f}")
        
        # 2-3. Proximity and Sparsity
        if metrics_to_compute.get('proximity', False) or metrics_to_compute.get('sparsity', False) or metrics_to_compute.get('sparsity_count', False):
            print("Computing proximity and sparsity...")
            proximity_scores = []
            sparsity_scores = []
            sparsity_count_scores = []
            
            # Process in batches
            for start_idx in range(0, len(cf_features), batch_size):
                end_idx = min(start_idx + batch_size, len(cf_features))
                batch_cf = cf_features[start_idx:end_idx]
                
                # Get original indices for this batch
                batch_orig_indices = [cf_to_orig_map.get(i, -1) for i in range(start_idx, end_idx)]
                
                # Filter out invalid mappings
                valid_indices = [(i, orig_idx) for i, orig_idx in enumerate(batch_orig_indices) if orig_idx != -1]
                if not valid_indices:
                    continue
                
                # Split into cf indices and orig indices
                batch_cf_indices, batch_orig_indices = zip(*valid_indices)
                
                # Get valid counterfactuals and corresponding original samples
                valid_cf = batch_cf[list(batch_cf_indices)]
                orig_samples = original_features[list(batch_orig_indices)]
                
                # Compute absolute differences
                abs_diff = torch.abs(valid_cf - orig_samples)
                
                # Compute proximity (L1 distance normalized by MAD)
                if metrics_to_compute.get('proximity', False):
                    normalized_diff = abs_diff / mad
                    batch_proximity = torch.mean(normalized_diff, dim=1)
                    proximity_scores.extend(batch_proximity.cpu().numpy())
                
                # Compute sparsity (1 - proportion of changed features)
                if metrics_to_compute.get('sparsity', False):
                    # Consider a feature changed if abs_diff > 1e-2
                    changed_features = (abs_diff > 1e-2).float()
                    batch_sparsity = 1 - torch.mean(changed_features, dim=1)
                    sparsity_scores.extend(batch_sparsity.cpu().numpy())
                    
                # Compute sparsity count (raw count of changed features)
                if metrics_to_compute.get('sparsity_count', False):
                    changed_count = torch.sum((abs_diff > 1e-2).float(), dim=1)
                    sparsity_count_scores.extend(changed_count.cpu().numpy())
            
            if metrics_to_compute.get('proximity', False) and proximity_scores:
                min_val = float(np.min(proximity_scores))
                avg_val = float(np.mean(proximity_scores))
                max_val = float(np.max(proximity_scores))
                std_val = float(np.std(proximity_scores))
                metric_values['proximity'] = [min_val, avg_val, max_val, std_val]
                print(f"proximity: min={min_val:.4f}, avg={avg_val:.4f}, max={max_val:.4f}, std={std_val:.4f}")
                
            if metrics_to_compute.get('sparsity', False) and sparsity_scores:
                min_val = float(np.min(sparsity_scores))
                avg_val = float(np.mean(sparsity_scores))
                max_val = float(np.max(sparsity_scores))
                std_val = float(np.std(sparsity_scores))
                metric_values['sparsity'] = [min_val, avg_val, max_val, std_val]
                print(f"sparsity: min={min_val:.4f}, avg={avg_val:.4f}, max={max_val:.4f}, std={std_val:.4f}")
                    
            if metrics_to_compute.get('sparsity_count', False) and sparsity_count_scores:
                min_val = float(np.min(sparsity_count_scores))
                avg_val = float(np.mean(sparsity_count_scores))
                max_val = float(np.max(sparsity_count_scores))
                std_val = float(np.std(sparsity_count_scores))
                metric_values['sparsity_count'] = [min_val, avg_val, max_val, std_val]
                print(f"sparsity_count: min={min_val:.4f}, avg={avg_val:.4f}, max={max_val:.4f}, std={std_val:.4f}")
        
        # 4. Diversity
        if metrics_to_compute.get('diversity', False):
            print("Computing diversity...")
            diversity_scores = []
            
            # For each original sample, compute diversity of its counterfactuals
            for orig_idx, cf_indices in cf_by_orig.items():
                if len(cf_indices) > 1:  # Need at least 2 CFs to compute diversity
                    # Get counterfactuals for this original sample
                    orig_cfs = cf_features[cf_indices]
                    
                    # Compute pairwise distances using torch.cdist (fully vectorized)
                    # Using raw L1 distances without MAD normalization
                    pairwise_distances = torch.cdist(orig_cfs, orig_cfs, p=1)
                    
                    # Extract the upper triangular part (excluding diagonal)
                    mask = torch.triu(torch.ones_like(pairwise_distances), diagonal=1).bool()
                    distances = pairwise_distances[mask]
                    
                    # Compute average distance
                    if len(distances) > 0:
                        avg_distance = torch.mean(distances).item()
                        diversity_scores.append(avg_distance)
            
            if diversity_scores:
                min_val = float(np.min(diversity_scores))
                avg_val = float(np.mean(diversity_scores))
                max_val = float(np.max(diversity_scores))
                std_val = float(np.std(diversity_scores))
                metric_values['diversity'] = [min_val, avg_val, max_val, std_val]
                print(f"diversity: min={min_val:.4f}, avg={avg_val:.4f}, max={max_val:.4f}, std={std_val:.4f}")
        
        # 5. Sparse Diversity
        if metrics_to_compute.get('sparse_diversity', False):
            print("Computing sparse diversity...")
            sparse_diversity_scores = []
            
            # For each original sample, compute sparse diversity of its counterfactuals
            for orig_idx, cf_indices in cf_by_orig.items():
                if len(cf_indices) > 1:  # Need at least 2 CFs to compute diversity
                    # Get counterfactuals for this original sample
                    orig_cfs = cf_features[cf_indices]
                    orig_sample = original_features[orig_idx]
                    
                    # Identify changed features for each counterfactual (binary mask)
                    changed_features = (torch.abs(orig_cfs - orig_sample.unsqueeze(0)) > 1e-5).float()
                    
                    # Number of counterfactuals
                    n_cfs = len(cf_indices)
                    
                    # Vectorized intersection computation
                    # Changed features is [n_cfs, num_features]
                    # We'll compute intersection for all pairs at once
                    
                    # Create expanded tensors for broadcasting
                    # [n_cfs, 1, num_features] and [1, n_cfs, num_features]
                    features_i = changed_features.unsqueeze(1)
                    features_j = changed_features.unsqueeze(0)
                    
                    # Compute intersection and union for all pairs
                    intersection = torch.sum(features_i * features_j, dim=2)  # [n_cfs, n_cfs]
                    union = torch.sum(torch.clamp(features_i + features_j, 0, 1), dim=2)  # [n_cfs, n_cfs]
                    
                    # Create a mask for valid pairs (upper triangle, excluding diagonal)
                    mask = torch.triu(torch.ones(n_cfs, n_cfs, device=device), diagonal=1).bool()
                    
                    # Get valid intersection and union values
                    valid_intersection = intersection[mask]
                    valid_union = union[mask]
                    
                    # Compute IoU (avoid division by zero)
                    valid_mask = valid_union > 0
                    if torch.any(valid_mask):
                        iou = torch.zeros_like(valid_intersection)
                        iou[valid_mask] = valid_intersection[valid_mask] / valid_union[valid_mask]
                        
                        # Compute average IoU
                        avg_overlap = torch.mean(iou).item()
                        sparse_diversity_scores.append(1 - avg_overlap)
            
            if sparse_diversity_scores:
                min_val = float(np.min(sparse_diversity_scores))
                avg_val = float(np.mean(sparse_diversity_scores))
                max_val = float(np.max(sparse_diversity_scores))
                std_val = float(np.std(sparse_diversity_scores))
                metric_values['sparse_diversity'] = [min_val, avg_val, max_val, std_val]
                print(f"sparse_diversity: min={min_val:.4f}, avg={avg_val:.4f}, max={max_val:.4f}, std={std_val:.4f}")
        
        # 6-8. New Metrics: Average 10-NN Distance and Probabilities
        # Initialize storage for distances
        avg_10nn_distances_list = []
        avg_10nn_dataset_probs = []
        avg_10nn_class0_probs = []
        
        # First compute k-NN distances for each counterfactual
        if metrics_to_compute.get('avg_10nn_distance', False) or metrics_to_compute.get('avg_10nn_dataset', False) or metrics_to_compute.get('avg_10nn_class0', False):
            print("Computing 10-NN distances in dataset and class 0...")
            
            # Process counterfactuals in batches
            for start_idx in range(0, len(cf_features), batch_size):
                end_idx = min(start_idx + batch_size, len(cf_features))
                batch_cf = cf_features[start_idx:end_idx]
                
                # Get mappings for this batch
                batch_orig_indices = [cf_to_orig_map.get(i, -1) for i in range(start_idx, end_idx)]
                valid_mask = torch.tensor([idx != -1 for idx in batch_orig_indices], device=device)
                
                if not torch.any(valid_mask):
                    continue
                
                valid_cf = batch_cf[valid_mask]
                valid_orig_indices = [idx for idx in batch_orig_indices if idx != -1]
                
                # Find the 10 nearest neighbors in the entire dataset
                if metrics_to_compute.get('avg_10nn_distance', False) or metrics_to_compute.get('avg_10nn_dataset', False):
                    # Initialize storage for 10-NN distances
                    k = 11  # k+1 to avoid counting self (though unlikely to be exact match)
                    knn_distances = torch.full((len(valid_cf), k), float('inf'), device=device)
                    
                    # Process original data in chunks
                    chunk_size = 5000  # Adjust based on available GPU memory
                    for chunk_start in range(0, len(original_features), chunk_size):
                        chunk_end = min(chunk_start + chunk_size, len(original_features))
                        data_chunk = original_features[chunk_start:chunk_end]
                        
                        # Compute pairwise L1 distances between valid_cf and this chunk
                        distances = torch.cdist(valid_cf, data_chunk, p=1)
                        
                        # Update top-k distances
                        combined = torch.cat([knn_distances, distances], dim=1)
                        topk_values, _ = torch.topk(combined, k=k, dim=1, largest=False)
                        knn_distances = topk_values
                    
                    # Calculate average 10-NN distance for each counterfactual
                    # Skip the first one (potentially self) and take the next 10
                    avg_10nn_distances = torch.mean(knn_distances[:, 1:11], dim=1)
                    
                    # Save raw distances for the avg_10nn_distance metric
                    if metrics_to_compute.get('avg_10nn_distance', False):
                        avg_10nn_distances_list.extend(avg_10nn_distances.cpu().numpy())
                    
                    # Compute the Gaussian probability for each distance
                    if metrics_to_compute.get('avg_10nn_dataset', False):
                        for dist in avg_10nn_distances.cpu().numpy():
                            prob = compute_gaussian_probability(dist, dataset_gaussian_mean, dataset_gaussian_variance)
                            avg_10nn_dataset_probs.append(prob)
                
                # Find the 10 nearest neighbors in class 0
                if metrics_to_compute.get('avg_10nn_class0', False) and class0_samples is not None and len(class0_samples) > 0:
                    # Initialize storage for 10-NN distances
                    k = min(11, len(class0_samples))  # k+1 to avoid self, but cap at available samples
                    knn_distances = torch.full((len(valid_cf), k), float('inf'), device=device)
                    
                    # Process class 0 data in chunks
                    chunk_size = 5000  # Adjust based on available GPU memory
                    for chunk_start in range(0, len(class0_samples), chunk_size):
                        chunk_end = min(chunk_start + chunk_size, len(class0_samples))
                        data_chunk = class0_samples[chunk_start:chunk_end]
                        
                        # Compute pairwise L1 distances between valid_cf and this chunk
                        distances = torch.cdist(valid_cf, data_chunk, p=1)
                        
                        # Update top-k distances
                        combined = torch.cat([knn_distances, distances], dim=1)
                        topk_values, _ = torch.topk(combined, k=k, dim=1, largest=False)
                        knn_distances = topk_values
                    
                    # Calculate average 10-NN distance for each counterfactual
                    # Skip the first one (potentially self) and take up to the next 10
                    if k > 1:
                        # Use as many neighbors as available after skipping the first one
                        neighbors_to_use = min(10, k-1)
                        avg_10nn_distances = torch.mean(knn_distances[:, 1:1+neighbors_to_use], dim=1)
                        
                        # Compute the Gaussian probability for each distance
                        for dist in avg_10nn_distances.cpu().numpy():
                            prob = compute_gaussian_probability(dist, class0_gaussian_mean, class0_gaussian_variance)
                            avg_10nn_class0_probs.append(prob)
            
            # Compute and save metrics with min, avg, max, std
            if metrics_to_compute.get('avg_10nn_distance', False) and avg_10nn_distances_list:
                min_val = float(np.min(avg_10nn_distances_list))
                avg_val = float(np.mean(avg_10nn_distances_list))
                max_val = float(np.max(avg_10nn_distances_list))
                std_val = float(np.std(avg_10nn_distances_list))
                metric_values['avg_10nn_distance'] = [min_val, avg_val, max_val, std_val]
                print(f"avg_10nn_distance: min={min_val:.4f}, avg={avg_val:.4f}, max={max_val:.4f}, std={std_val:.4f}")
            
            if metrics_to_compute.get('avg_10nn_dataset', False) and avg_10nn_dataset_probs:
                min_val = float(np.min(avg_10nn_dataset_probs))
                avg_val = float(np.mean(avg_10nn_dataset_probs))
                max_val = float(np.max(avg_10nn_dataset_probs))
                std_val = float(np.std(avg_10nn_dataset_probs))
                metric_values['avg_10nn_dataset_probability'] = [min_val, avg_val, max_val, std_val]
                print(f"avg_10nn_dataset_probability: min={min_val:.4f}, avg={avg_val:.4f}, max={max_val:.4f}, std={std_val:.4f}")
            
            if metrics_to_compute.get('avg_10nn_class0', False) and avg_10nn_class0_probs:
                min_val = float(np.min(avg_10nn_class0_probs))
                avg_val = float(np.mean(avg_10nn_class0_probs))
                max_val = float(np.max(avg_10nn_class0_probs))
                std_val = float(np.std(avg_10nn_class0_probs))
                metric_values['avg_10nn_class0_probability'] = [min_val, avg_val, max_val, std_val]
                print(f"avg_10nn_class0_probability: min={min_val:.4f}, avg={avg_val:.4f}, max={max_val:.4f}, std={std_val:.4f}")
        
        # Store metric values for this counterfactual set
        results[name] = metric_values
    
        # Print summary
        print("\nEvaluation Summary:")
        for name, metrics in results.items():
            print(f"\n{name}:")
            for metric, values in metrics.items():
                print(f"  {metric}: min={values[0]:.4f}, avg={values[1]:.4f}, max={values[2]:.4f}, std={values[3]:.4f}")

        elapsed_time = time.time() - start_time
        print(f"\nTotal evaluation time: {elapsed_time:.2f} seconds")

    # Write results to JSON file - modified to append to existing file if present
    print(f"\nSaving results to {output_path}")
    try:
        # Try to load existing JSON file
        if os.path.exists(output_path):
            with open(output_path, 'r') as f:
                try:
                    existing_results = json.load(f)
                    print(f"Successfully loaded existing results from {output_path}")

                    # Update with new results (append or overwrite)
                    for name, metrics in results.items():
                        if name in existing_results:
                            print(f"Updating existing entry: {name}")
                        else:
                            print(f"Adding new entry: {name}")
                        existing_results[name] = metrics

                    # Save updated results
                    results = existing_results
                except json.JSONDecodeError:
                    print(f"Could not parse existing file as JSON. Creating new file.")

        # Write the combined results
        with open(output_path, 'w') as f:
            json.dump(results, f, indent=2)

        print(f"Results successfully written to {output_path}")
    except Exception as e:
        print(f"Error saving results to {output_path}: {str(e)}")
        # Create a backup file in case of error
        backup_path = f"{output_path}.backup"
        with open(backup_path, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"Backup saved to {backup_path}")

        return results



    
    




results = evaluate_counterfactuals(
    original_data_path='data_norm.parquet',
    counterfactual_dfs = {
        'best_counterfactuals_2feat_osvm':'best_counterfactuals_2feat_osvm_round.parquet'
    },
    model_path='mlp_model.pth',
    metrics_to_compute={
        'validity': True,
        'proximity': True,
        'sparsity': True,
        'sparsity_count': True,
        'diversity': True,
        'sparse_diversity': True,
        'avg_10nn_distance': True, 
        'avg_10nn_dataset': True,
        'avg_10nn_class0': True
    },
    batch_size=512,
    gpu_id=0,
    output_path="evaluation_with_knn_metrics.json"
)

In [2]:
"""BF 1 features CF generation
Input : Same as DICE method
Output : Same format as input, no original factual index as the end"""


import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from tqdm import tqdm
import time
from datetime import datetime, timedelta
from torch.utils.data import DataLoader, TensorDataset

class MLP(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, 2000)
        self.bn1 = nn.BatchNorm1d(2000)
        self.dropout1 = nn.Dropout(0.02)
        self.fc2 = nn.Linear(2000, 200)
        self.bn2 = nn.BatchNorm1d(200)
        self.dropout2 = nn.Dropout(0.02)
        self.fc3 = nn.Linear(200, 20)
        self.bn3 = nn.BatchNorm1d(20)
        self.dropout3 = nn.Dropout(0.02)
        self.fc4 = nn.Linear(20, 1)
        self.sigmoid = nn.Sigmoid()
        self.leaky_relu = nn.LeakyReLU(0.01)
        
    def forward(self, x):
        x = self.leaky_relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = self.leaky_relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        x = self.leaky_relu(self.bn3(self.fc3(x)))
        x = self.dropout3(x)
        x = self.sigmoid(self.fc4(x))
        return x

def load_model(model_path, input_size, device):
    """Load the pretrained MLP model"""
    model = MLP(input_size).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

def efficient_knn_distances(candidates, reference_data, device, chunk_size=32, k=10):
    """Memory-efficient k-NN computation with small chunks"""
    n_candidates = candidates.shape[0]
    n_reference = reference_data.shape[0]
    k = min(k, n_reference)
    
    knn_distances = torch.zeros(n_candidates, device=device)
    
    for i in range(0, n_candidates, chunk_size):
        end_i = min(i + chunk_size, n_candidates)
        candidate_chunk = candidates[i:end_i]
        
        chunk_expanded = candidate_chunk.unsqueeze(1)
        reference_expanded = reference_data.unsqueeze(0)
        
        distances = torch.sum(torch.abs(chunk_expanded - reference_expanded), dim=2)
        topk_distances, _ = torch.topk(distances, k, dim=1, largest=False)
        knn_distances[i:end_i] = topk_distances.mean(dim=1)
        
        del chunk_expanded, distances, topk_distances
    
    return knn_distances

class AdvancedProgressTracker:
    """Advanced progress tracking with detailed time estimates and statistics"""
    
    def __init__(self, total_items, item_name, update_interval=5.0):
        self.total_items = total_items
        self.item_name = item_name
        self.update_interval = update_interval
        self.start_time = time.time()
        self.last_update_time = self.start_time
        self.completed_items = 0
        self.speed_history = []
        self.max_history = 20  # Keep last 20 speed measurements for smoothing
        
        # Display startup info
        print("=" * 80)
        print(f"STARTING {item_name.upper()} COMPUTATION")
        print("=" * 80)
        print(f"Total {item_name}: {self.total_items:,}")
        print(f"Started at: {datetime.now().strftime('%H:%M:%S')}")
        print("-" * 80)
    
    def update(self, completed_items):
        """Update progress with detailed time estimates"""
        self.completed_items = completed_items
        current_time = time.time()
        
        # Update at specified intervals
        if current_time - self.last_update_time >= self.update_interval:
            elapsed_time = current_time - self.start_time
            
            if self.completed_items > 0:
                # Calculate current speed
                current_speed = self.completed_items / elapsed_time
                self.speed_history.append(current_speed)
                
                # Keep only recent history for smoothing
                if len(self.speed_history) > self.max_history:
                    self.speed_history = self.speed_history[-self.max_history:]
                
                # Use smoothed speed for better estimates
                avg_speed = sum(self.speed_history) / len(self.speed_history)
                
                # Calculate progress and time estimates
                progress_pct = (self.completed_items / self.total_items) * 100
                remaining_items = self.total_items - self.completed_items
                
                if avg_speed > 0:
                    eta_seconds = remaining_items / avg_speed
                    total_time_estimate = self.total_items / avg_speed
                    
                    # Format time strings
                    eta_str = self._format_time(eta_seconds)
                    total_str = self._format_time(total_time_estimate)
                    elapsed_str = self._format_time(elapsed_time)
                    
                    # Create progress bar
                    bar_length = 40
                    filled_length = int(bar_length * progress_pct / 100)
                    bar = '‚ñà' * filled_length + '‚ñë' * (bar_length - filled_length)
                    
                    # Calculate completion time
                    completion_time = datetime.now() + timedelta(seconds=eta_seconds)
                    
                    # Dynamic update line
                    progress_line = (
                        f"\r{progress_pct:6.2f}% [{bar}] "
                        f"{self.completed_items:,}/{self.total_items:,} {self.item_name} | "
                        f"Elapsed: {elapsed_str} | ETA: {eta_str} | "
                        f"Speed: {avg_speed:.0f}/sec | "
                        f"Finish: {completion_time.strftime('%H:%M:%S')}"
                    )
                    
                    print(progress_line, end='', flush=True)
                else:
                    print(f"\rInitializing... {self.completed_items:,} {self.item_name} processed", end='', flush=True)
            
            self.last_update_time = current_time
    
    def _format_time(self, seconds):
        """Format seconds into human-readable time string"""
        if seconds < 60:
            return f"{seconds:.0f}s"
        elif seconds < 3600:
            minutes = seconds / 60
            return f"{minutes:.1f}m"
        else:
            hours = seconds / 3600
            return f"{hours:.1f}h"
    
    def finish(self):
        """Print completion summary"""
        print()  # New line after progress bar
        total_time = time.time() - self.start_time
        final_speed = self.total_items / total_time if total_time > 0 else 0
        
        print("=" * 80)
        print(f"{self.item_name.upper()} COMPLETED!")
        print("-" * 80)
        print(f"Total time: {self._format_time(total_time)}")
        print(f"Final speed: {final_speed:.0f} {self.item_name}/second")
        print(f"Completed at: {datetime.now().strftime('%H:%M:%S')}")
        print("=" * 80)

def precompute_all_predictions(X_eval, all_feature_samples, model, device, 
                               num_points, batch_size=512, save_path=None):
    """
    GAME CHANGER: Precompute ALL predictions for ALL feature variations
    
    This transforms the problem from:
    - 863M individual model calls (super slow)
    - To: ~86M predictions in large efficient batches (much faster)
    """
    n_samples, n_features = X_eval.shape
    total_predictions = n_samples * n_features * num_points
    
    print(f"PRECOMPUTING ALL PREDICTIONS...")
    print(f"Samples: {n_samples:,} | Features: {n_features:,} | Points: {num_points}")
    print(f"Total predictions needed: {total_predictions:,}")
    print(f"Memory required: ~{total_predictions * 4 / 1e9:.1f}GB")
    print()
    
    # Storage for all predictions: (n_samples, n_features, num_points)
    all_predictions = torch.zeros(n_samples, n_features, num_points, device='cpu')
    
    # Initialize progress tracker for samples
    sample_tracker = AdvancedProgressTracker(n_samples, "samples", update_interval=3.0)
    
    processed_predictions = 0
    
    # Process each sample
    for sample_idx in range(n_samples):
        sample = X_eval[sample_idx:sample_idx+1]  # (1, n_features)
        
        # Create ALL candidates for this sample at once
        sample_candidates = []
        
        for feature_idx in range(n_features):
            # Create candidates for this feature
            feature_samples = all_feature_samples[feature_idx]
            candidates_for_feature = sample.repeat(num_points, 1)  # (num_points, n_features)
            candidates_for_feature[:, feature_idx] = feature_samples
            sample_candidates.append(candidates_for_feature)
        
        # Stack all candidates for this sample: (n_features * num_points, n_features)
        all_sample_candidates = torch.cat(sample_candidates, dim=0)
        
        # Batch predict ALL candidates for this sample
        sample_predictions = torch.zeros(all_sample_candidates.shape[0], device=device)
        
        for i in range(0, all_sample_candidates.shape[0], batch_size):
            end_i = min(i + batch_size, all_sample_candidates.shape[0])
            batch_candidates = all_sample_candidates[i:end_i].to(device)
            
            with torch.no_grad():
                batch_preds = model(batch_candidates).squeeze()
                if batch_preds.dim() == 0:
                    batch_preds = batch_preds.unsqueeze(0)
                sample_predictions[i:end_i] = batch_preds
            
            del batch_candidates
        
        # Reshape predictions back to (n_features, num_points)
        sample_predictions = sample_predictions.view(n_features, num_points)
        all_predictions[sample_idx] = sample_predictions.cpu()
        
        processed_predictions += n_features * num_points
        
        # Update progress tracker
        sample_tracker.update(sample_idx + 1)
        
        # Clean up
        del sample_candidates, all_sample_candidates, sample_predictions
        
        # Periodic GPU memory cleanup
        if (sample_idx + 1) % 50 == 0:
            torch.cuda.empty_cache()
    
    # Finish progress tracking
    sample_tracker.finish()
    
    # Calculate final statistics
    total_time = time.time() - sample_tracker.start_time
    prediction_rate = total_predictions / total_time if total_time > 0 else 0
    
    print(f"PRECOMPUTATION STATISTICS:")
    print(f"- Total predictions: {total_predictions:,}")
    print(f"- Prediction rate: {prediction_rate:.0f} predictions/second")
    print(f"- Memory efficiency: {total_predictions * 4 / 1e6:.0f} MB stored")
    
    # Optionally save precomputed predictions
    if save_path:
        print(f"\nSaving precomputed predictions to {save_path}...")
        save_start = time.time()
        torch.save(all_predictions, save_path)
        save_time = time.time() - save_start
        file_size = all_predictions.numel() * 4 / 1e9
        print(f"Saved {file_size:.1f}GB in {save_time:.1f}s ({file_size/save_time:.1f}GB/s)")
    
    return all_predictions

def compute_scores_with_precomputed_predictions(
    tcga_class0_path,
    data_eval_path,
    model_path,
    num_points=50,
    range_percent=10,
    diff_weight=30,
    nn_weight=10,
    l1_weight=10,
    gpu_id=0,
    batch_size=16,
    prediction_batch_size=512,     # Batch size for precomputation
    knn_chunk_size=16,
    precompute_save_path='precomputed_predictions.pt',
    output_cf_path='best_counterfactuals.parquet',
    output_importance_path='feature_importance.parquet'
):
    """
    Ultra-fast version using precomputed predictions
    """
    start_time = time.time()
    
    # GPU setup
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    print(f"Using device: {device}")
    
    # Constants
    class0_10nn_mean = 5950.676427
    class0_10nn_std = np.sqrt(3400381.504375)
    l1_mean = 6471.3016
    l1_std = 1596.4753
    
    # Load data
    print("Loading data...")
    df_class0 = pd.read_parquet(tcga_class0_path)
    df_eval = pd.read_parquet(data_eval_path)
    
    X_class0 = df_class0.iloc[:, 1:].values
    X_eval = df_eval.iloc[:, 1:].values
    
    num_features = X_eval.shape[1]
    n_samples = X_eval.shape[0]
    print(f"Data: {n_samples} samples, {num_features} features")
    
    # Handle dimension mismatch
    if X_class0.shape[1] != num_features:
        if X_class0.shape[1] < num_features:
            padding = np.zeros((X_class0.shape[0], num_features - X_class0.shape[1]))
            X_class0 = np.hstack((X_class0, padding))
        else:
            X_class0 = X_class0[:, :num_features]
    
    # Move class0 to GPU
    X_class0_tensor = torch.tensor(X_class0, dtype=torch.float32).to(device)
    X_eval_tensor = torch.tensor(X_eval, dtype=torch.float32)
    
    # Precompute feature ranges
    print("Computing feature ranges...")
    feature_min = X_eval.min(axis=0)
    feature_max = X_eval.max(axis=0)
    feature_range = (feature_max - feature_min) * (range_percent / 100)
    
    # Precompute sampling points
    all_feature_samples = []
    for i in range(num_features):
        min_val = feature_min[i] - feature_range[i]
        max_val = feature_max[i] + feature_range[i]
        samples = torch.linspace(min_val, max_val, num_points, dtype=torch.float32)
        all_feature_samples.append(samples)
    
    # Load model
    model = load_model(model_path, num_features, device)
    
    # Get initial predictions
    print("Computing initial predictions...")
    initial_predictions = torch.zeros(n_samples, device='cpu')
    for i in range(0, n_samples, prediction_batch_size):
        end_i = min(i + prediction_batch_size, n_samples)
        with torch.no_grad():
            batch_pred = model(X_eval_tensor[i:end_i].to(device)).squeeze().cpu()
            if batch_pred.dim() == 0:
                batch_pred = batch_pred.unsqueeze(0)
            initial_predictions[i:end_i] = batch_pred
    
    # GAME CHANGER: Precompute ALL predictions
    try:
        print("Loading precomputed predictions...")
        all_predictions = torch.load(precompute_save_path)
        print(f"Loaded precomputed predictions: {all_predictions.shape}")
    except FileNotFoundError:
        print("Precomputed predictions not found. Computing them now...")
        all_predictions = precompute_all_predictions(
            X_eval_tensor, all_feature_samples, model, device, 
            num_points, prediction_batch_size, precompute_save_path
        )
    
    # Now the scoring phase is SUPER FAST - just lookups!
    print("\nüöÄ Starting ULTRA-FAST scoring with precomputed predictions...")
    
    best_counterfactuals = []
    feature_importance_data = []
    
    # Create dataloader
    eval_dataset = TensorDataset(X_eval_tensor)
    eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)
    
    # Initialize scoring progress tracker
    total_operations = n_samples * num_features
    scoring_tracker = AdvancedProgressTracker(total_operations, "feature evaluations", update_interval=2.0)
    
    completed_operations = 0
    
    # Process samples
    for batch_idx, (X_batch,) in enumerate(eval_dataloader):
        batch_size_actual = X_batch.size(0)
        batch_start_idx = batch_idx * batch_size
        
        for i in range(batch_size_actual):
            sample_idx = batch_start_idx + i
            if sample_idx >= n_samples:
                break
                
            eval_sample = X_batch[i:i+1]
            initial_pred = initial_predictions[sample_idx].item()
            
            best_score = float('-inf')
            best_cf = eval_sample.clone()
            best_cf_pred = initial_pred
            sample_importance = []
            
            # Process each feature (now SUPER FAST - no model calls!)
            for feature_idx in range(num_features):
                # Get precomputed predictions for this feature (INSTANT!)
                feature_predictions = all_predictions[sample_idx, feature_idx]  # (num_points,)
                
                # Create candidates for k-NN and L1 computation
                feature_samples = all_feature_samples[feature_idx]
                candidates = eval_sample.repeat(num_points, 1)
                candidates[:, feature_idx] = feature_samples
                
                # Move to GPU for k-NN computation
                candidates_gpu = candidates.to(device)
                
                # Efficient k-NN computation
                nn_distances = efficient_knn_distances(
                    candidates_gpu, X_class0_tensor, device, 
                    chunk_size=knn_chunk_size, k=10
                )
                
                # L1 distances to original sample
                l1_distances = torch.sum(torch.abs(candidates - eval_sample), dim=1)
                
                # Compute scores using precomputed predictions
                pred_diffs = initial_pred - feature_predictions
                nn_scores = torch.abs(nn_distances.cpu() - class0_10nn_mean) / class0_10nn_std
                l1_scores = torch.abs(l1_distances - l1_mean) / l1_std
                
                total_scores = (pred_diffs * diff_weight) - (nn_scores * nn_weight) - (l1_scores * l1_weight)
                
                # Find best candidate for this feature
                best_idx = torch.argmax(total_scores)
                feature_best_score = total_scores[best_idx]
                feature_best_pred = feature_predictions[best_idx]
                feature_best_candidate = candidates[best_idx]
                
                # Store feature importance
                sample_importance.append((initial_pred - feature_best_pred).item())
                
                # Update global best
                if feature_best_score > best_score:
                    best_score = feature_best_score
                    best_cf = feature_best_candidate.unsqueeze(0)
                    best_cf_pred = feature_best_pred.item()
                
                # Minimal cleanup
                del candidates_gpu, nn_distances
                
                # Update progress tracker
                completed_operations += 1
                scoring_tracker.update(completed_operations)
            
            # Store results
            cf_row = [best_cf_pred] + best_cf[0].numpy().tolist()
            best_counterfactuals.append(cf_row)
            feature_importance_data.append(sample_importance)
    
    # Finish scoring progress
    scoring_tracker.finish()
    
    # Create and save results
    print("Creating output DataFrames...")
    cf_columns = ['classification'] + [f'feature_{i}' for i in range(num_features)]
    importance_columns = [f'feature_{i}_diff' for i in range(num_features)]
    
    best_cf_df = pd.DataFrame(best_counterfactuals, columns=cf_columns)
    importance_df = pd.DataFrame(feature_importance_data, columns=importance_columns)
    
    # Save results
    best_cf_df.to_parquet(output_cf_path, engine='pyarrow', compression='snappy', index=False)
    importance_df.to_parquet(output_importance_path, engine='pyarrow', compression='snappy', index=False)
    
    print(f"Saved: {output_cf_path} and {output_importance_path}")
    
    total_time = time.time() - start_time
    precompute_time = scoring_tracker.start_time - start_time
    scoring_time = time.time() - scoring_tracker.start_time
    
    print(f"\nüèÅ FINAL TIMING SUMMARY:")
    print("=" * 60)
    print(f"TOTAL EXECUTION TIME: {total_time/60:.1f} minutes ({total_time/3600:.1f} hours)")
    print(f"‚îú‚îÄ Precomputation: {precompute_time/60:.1f} minutes ({precompute_time/total_time*100:.1f}%)")
    print(f"‚îú‚îÄ Scoring: {scoring_time/60:.1f} minutes ({scoring_time/total_time*100:.1f}%)")
    print(f"‚îî‚îÄ Other: {(total_time-precompute_time-scoring_time)/60:.1f} minutes")
    print("=" * 60)
    
    # Final cleanup
    torch.cuda.empty_cache()
    
    return best_cf_df, importance_df




bf1_cfs, feature_importances = compute_scores_with_precomputed_predictions(
    tcga_class0_path='tcga_class0.parquet',
    data_eval_path='data_eval_norm.parquet',
    model_path='mlp_model.pth',
    num_points=25,
    range_percent=0,
    diff_weight=30,
    nn_weight=10,
    l1_weight=10,
    batch_size=16,
    prediction_batch_size=2048,
    knn_chunk_size=64,
    gpu_id=0
)


Using device: cuda:0
Loading data...


KeyboardInterrupt: 

In [None]:
"""BF 2 features CF generation
Input : Same as DICE method
Output : Same format as input, no original factual index as the end"""

import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import time
from datetime import datetime, timedelta
import gc
import json
from torch.cuda.amp import autocast, GradScaler

class MLP(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, 2000)
        self.bn1 = nn.BatchNorm1d(2000)
        self.dropout1 = nn.Dropout(0.02)
        self.fc2 = nn.Linear(2000, 200)
        self.bn2 = nn.BatchNorm1d(200)
        self.dropout2 = nn.Dropout(0.02)
        self.fc3 = nn.Linear(200, 20)
        self.bn3 = nn.BatchNorm1d(20)
        self.dropout3 = nn.Dropout(0.02)
        self.fc4 = nn.Linear(20, 1)
        self.sigmoid = nn.Sigmoid()
        self.leaky_relu = nn.LeakyReLU(0.01)
        
    def forward(self, x):
        x = self.leaky_relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = self.leaky_relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        x = self.leaky_relu(self.bn3(self.fc3(x)))
        x = self.dropout3(x)
        x = self.sigmoid(self.fc4(x))
        return x

class PyTorchOCSVM(nn.Module):
    def __init__(self, input_dim, nu=0.1, device='cuda:0'):
        super(PyTorchOCSVM, self).__init__()
        self.nu = nu
        self.device = device
        self.input_dim = input_dim
        
        self.feature_net = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        ).to(device)
        
    def forward(self, x):
        return self.feature_net(x)

class OptimizedGPUOCSVM:
    """GPU-native OSVM with adaptive batch sizing and conservative fp16 usage"""
    def __init__(self):
        self.model = None
        self.scaler_mean = None
        self.scaler_scale = None
        self.center = None
        self.rho = 0.0
        self.device = None
        self.batch_size = 1024
        
    def load_model(self, filepath):
        checkpoint = torch.load(filepath, map_location='cuda:0')
        
        self.nu = checkpoint['nu']
        self.rho = checkpoint['rho']
        self.center = checkpoint['center'].to('cuda:0') if checkpoint['center'] is not None else None
        self.device = 'cuda:0'
        
        # Keep scaler parameters in float32 for numerical stability
        scaler = checkpoint['scaler']
        self.scaler_mean = torch.tensor(scaler.mean_, dtype=torch.float32, device='cuda:0')
        self.scaler_scale = torch.tensor(scaler.scale_, dtype=torch.float32, device='cuda:0')
        
        input_dim = checkpoint['input_dim']
        self.model = PyTorchOCSVM(input_dim, self.nu, self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(self.device)
        self.model.eval()
        
        # JIT compile for speed
        dummy_input = torch.randn(1, input_dim, device=self.device)
        self.model = torch.jit.trace(self.model, dummy_input)
        
        # Find optimal batch size for this hardware
        self.batch_size = self._find_optimal_batch_size(input_dim)
        
        print(f"OSVM model loaded, JIT compiled, optimal batch size: {self.batch_size}")
    
    def _find_optimal_batch_size(self, n_features):
        """Find optimal batch size based on available VRAM"""
        for batch_size in [2048, 1536, 1024, 768, 512, 256]:
            try:
                dummy = torch.randn(batch_size, n_features, device=self.device)
                with torch.no_grad():
                    batch_scaled = (dummy - self.scaler_mean) / self.scaler_scale
                    with autocast():
                        _ = self.model(batch_scaled)
                del dummy, batch_scaled
                torch.cuda.empty_cache()
                return batch_size
            except torch.cuda.OutOfMemoryError:
                continue
        
        return 128  # Conservative fallback
    
    def decision_function(self, X_tensor):
        """GPU-native decision function with mixed precision where safe"""
        decisions = torch.empty(len(X_tensor), device=self.device, dtype=torch.float32)
        
        with torch.no_grad():
            for i in range(0, len(X_tensor), self.batch_size):
                end_idx = min(i + self.batch_size, len(X_tensor))
                batch = X_tensor[i:end_idx]
                
                # Keep scaling in float32 for numerical stability
                batch_scaled = (batch - self.scaler_mean) / self.scaler_scale
                
                # Use mixed precision for forward pass only
                with autocast():
                    outputs = self.model(batch_scaled)
                
                # Distance calculation in float32 for precision
                distances = torch.sum((outputs.float() - self.center) ** 2, dim=1)
                decisions[i:end_idx] = -distances
        
        return decisions

def load_model(model_path, input_size, device):
    """Load and JIT compile MLP model"""
    model = MLP(input_size).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    
    # JIT compile for speed
    dummy_input = torch.randn(1, input_size, device=device)
    model = torch.jit.trace(model, dummy_input)
    print("MLP model loaded and JIT compiled")
    
    return model

def compute_feature_ranges(X_eval, range_percent):
    """Compute feature ranges using GPU quantiles"""
    if range_percent < 0:
        lower_percentile = abs(range_percent)
        upper_percentile = 100 - abs(range_percent)
        min_vals = torch.quantile(X_eval, lower_percentile / 100.0, dim=0)
        max_vals = torch.quantile(X_eval, upper_percentile / 100.0, dim=0)
        print(f"Using percentile range: {lower_percentile}% to {upper_percentile}%")
    else:
        feature_min = X_eval.min(dim=0)[0]
        feature_max = X_eval.max(dim=0)[0]
        feature_range = (feature_max - feature_min) * (range_percent / 100)
        min_vals = feature_min - feature_range
        max_vals = feature_max + feature_range
        print(f"Using expanded range: ¬±{range_percent}% beyond observed min/max")
    
    return min_vals, max_vals

def warmup_models(mlp_model, osvm_model, n_features, device):
    """Pre-warm models to trigger CUDA kernel compilation"""
    print("Warming up models...")
    
    with torch.no_grad():
        warmup_data = torch.randn(64, n_features, device=device)
        
        with autocast():
            _ = mlp_model(warmup_data)
        
        _ = osvm_model.decision_function(warmup_data)
        
        del warmup_data
        torch.cuda.empty_cache()
    
    print("Model warmup completed")

def process_candidates_chunked(sample, first_feat_idx, second_feat_global_idx, 
                              first_samples, second_samples, mlp_model, osvm_model,
                              prediction_weight, distance_weight, osvm_weight,
                              L1_NORM_MEAN, L1_NORM_STD, use_mixed_precision, 
                              chunk_size=1024):
    """Process candidates in memory-efficient chunks with CORRECTED SCORING"""
    
    grid_first, grid_second = torch.meshgrid(first_samples, second_samples, indexing='ij')
    total_candidates = grid_first.numel()
    
    best_score = float('-inf')  # Initialize to negative infinity for ARGMAX
    best_pred = 0.0
    best_candidate = sample.clone()
    
    # Process in chunks to limit memory usage
    for chunk_start in range(0, total_candidates, chunk_size):
        chunk_end = min(chunk_start + chunk_size, total_candidates)
        chunk_size_actual = chunk_end - chunk_start
        
        # Create chunk candidates
        candidates_chunk = sample.unsqueeze(0).expand(chunk_size_actual, -1).contiguous()
        candidates_chunk[:, first_feat_idx] = grid_first.flatten()[chunk_start:chunk_end]
        candidates_chunk[:, second_feat_global_idx] = grid_second.flatten()[chunk_start:chunk_end]
        
        with torch.no_grad():
            # Process chunk
            if use_mixed_precision:
                with autocast():
                    chunk_preds = mlp_model(candidates_chunk).squeeze()
            else:
                chunk_preds = mlp_model(candidates_chunk).squeeze()
            
            if chunk_preds.dim() == 0:
                chunk_preds = chunk_preds.unsqueeze(0)
            
            chunk_osvm = osvm_model.decision_function(candidates_chunk)
            chunk_l1 = torch.sum(torch.abs(candidates_chunk - sample), dim=1)
            
            # CORRECTED SCORING FORMULA:
            # score = -mlp_prediction √ó prediction_weight - abs(distance-mean)/std + (1-osvm_score) √ó osvm_weight
            prediction_component = -chunk_preds * prediction_weight
            distance_component = -(torch.abs(chunk_l1 - L1_NORM_MEAN) / L1_NORM_STD) * distance_weight
            osvm_component = (1 - chunk_osvm) * osvm_weight
            
            chunk_scores = prediction_component + distance_component + osvm_component
            
            # Update best from this chunk (ARGMAX - find maximum score)
            chunk_best_idx = torch.argmax(chunk_scores)
            chunk_best_score = chunk_scores[chunk_best_idx].item()
            
            if chunk_best_score > best_score:  # GREATER THAN for maximum
                best_score = chunk_best_score
                best_pred = chunk_preds[chunk_best_idx].item()
                best_candidate = candidates_chunk[chunk_best_idx].clone()
        
        # Cleanup chunk
        del candidates_chunk, chunk_preds, chunk_osvm, chunk_l1, chunk_scores
    
    return best_score, best_pred, best_candidate

class ProgressTracker:
    """Comprehensive progress tracking with detailed statistics"""
    
    def __init__(self, total_samples, total_features, n_second_features, num_points_first, num_points_second):
        self.total_samples = total_samples
        self.total_features = total_features
        self.n_second_features = n_second_features
        self.num_points_first = num_points_first
        self.num_points_second = num_points_second
        
        # Calculate total work
        self.grid_size = num_points_first * num_points_second
        self.total_evaluations = total_samples * total_features * n_second_features * self.grid_size
        
        # Timing
        self.start_time = time.time()
        self.last_update_time = self.start_time
        self.update_interval = 20  # Update every 20 seconds
        
        # Progress counters
        self.completed_samples = 0
        self.completed_evaluations = 0
        
        # Speed tracking
        self.speed_history = []
        self.max_speed_history = 10
        
        self.print_startup_summary()
    
    def print_startup_summary(self):
        """Print comprehensive startup information"""
        print(f"\n{'='*85}")
        print(f"COUNTERFACTUAL GENERATION STARTED")
        print(f"{'='*85}")
        print(f"Dataset Configuration:")
        print(f"   ‚Ä¢ Total samples: {self.total_samples:,}")
        print(f"   ‚Ä¢ Primary features: {self.total_features:,}")
        print(f"   ‚Ä¢ Secondary features: {self.n_second_features:,}")
        print(f"Grid Configuration:")
        print(f"   ‚Ä¢ Primary grid points: {self.num_points_first}")
        print(f"   ‚Ä¢ Secondary grid points: {self.num_points_second}")
        print(f"   ‚Ä¢ Grid size per pair: {self.grid_size:,}")
        print(f"Computation Scale:")
        print(f"   ‚Ä¢ Total evaluations: {self.total_evaluations:,}")
        print(f"   ‚Ä¢ Evaluations per sample: {self.total_evaluations // self.total_samples:,}")
        print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"{'='*85}\n")
    
    def update(self, sample_idx, feature_idx=None, force_update=False):
        """Update progress with detailed tracking"""
        current_time = time.time()
        
        # Update counters
        self.completed_samples = sample_idx + 1
        if feature_idx is not None:
            # Calculate completed evaluations based on current progress
            base_evaluations = sample_idx * self.total_features * self.n_second_features * self.grid_size
            current_feature_evaluations = feature_idx * self.n_second_features * self.grid_size
            self.completed_evaluations = base_evaluations + current_feature_evaluations
        else:
            # If no feature index, assume sample is complete
            self.completed_evaluations = self.completed_samples * self.total_features * self.n_second_features * self.grid_size
        
        # Check if we should update display
        should_update = (
            force_update or 
            (current_time - self.last_update_time >= self.update_interval) or
            sample_idx == 0 or
            sample_idx == self.total_samples - 1
        )
        
        if should_update:
            self.last_update_time = current_time
            self._display_progress(current_time)
    
    def _display_progress(self, current_time):
        """Display comprehensive progress information"""
        elapsed_time = current_time - self.start_time
        
        # Calculate progress percentage
        progress_pct = (self.completed_evaluations / self.total_evaluations) * 100 if self.total_evaluations > 0 else 0
        sample_progress_pct = (self.completed_samples / self.total_samples) * 100
        
        # Calculate speed metrics
        if elapsed_time > 0:
            current_speed = self.completed_evaluations / elapsed_time
            self.speed_history.append(current_speed)
            if len(self.speed_history) > self.max_speed_history:
                self.speed_history.pop(0)
            avg_speed = sum(self.speed_history) / len(self.speed_history)
        else:
            avg_speed = 0
        
        # Calculate ETA
        if avg_speed > 0:
            remaining_evaluations = self.total_evaluations - self.completed_evaluations
            eta_seconds = remaining_evaluations / avg_speed
            finish_time = datetime.now() + timedelta(seconds=eta_seconds)
        else:
            eta_seconds = 0
            finish_time = datetime.now()
        
        # Format time strings
        elapsed_str = self._format_duration(elapsed_time)
        eta_str = self._format_duration(eta_seconds)
        
        # Create progress bar
        bar_width = 40
        filled_width = int(bar_width * progress_pct / 100)
        bar = '‚ñà' * filled_width + '‚ñë' * (bar_width - filled_width)
        
        # Display progress
        print(f"\r{progress_pct:6.2f}% [{bar}] "
              f"Sample: {self.completed_samples:4d}/{self.total_samples} ({sample_progress_pct:5.1f}%) | "
              f"Elapsed: {elapsed_str} | ETA: {eta_str} | "
              f"Speed: {avg_speed:8,.0f} eval/s | "
              f"Finish: {finish_time.strftime('%m-%d %H:%M')}", end='', flush=True)
    
    def _format_duration(self, seconds):
        """Format duration in human readable format"""
        if seconds < 60:
            return f"{seconds:.0f}s"
        elif seconds < 3600:
            return f"{seconds/60:.1f}m"
        elif seconds < 86400:
            return f"{seconds/3600:.1f}h"
        else:
            return f"{seconds/86400:.1f}d"
    
    def finish(self):
        """Display final completion summary"""
        total_time = time.time() - self.start_time
        avg_speed = self.completed_evaluations / total_time if total_time > 0 else 0
        
        print(f"\n\n{'='*85}")
        print(f"COUNTERFACTUAL GENERATION COMPLETED!")
        print(f"{'='*85}")
        print(f"Final Statistics:")
        print(f"   ‚Ä¢ Total runtime: {self._format_duration(total_time)}")
        print(f"   ‚Ä¢ Evaluations completed: {self.completed_evaluations:,}")
        print(f"   ‚Ä¢ Samples processed: {self.completed_samples:,}")
        print(f"   ‚Ä¢ Average speed: {avg_speed:,.0f} evaluations/second")
        print(f"Finished at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"{'='*85}\n")

def generate_2feature_counterfactuals_osvm_optimized(
    data_eval_path,
    model_path,
    osvm_model_path,
    second_feature_json_path,
    num_points_first=25,
    num_points_second=15,
    range_percent=-10,
    classification_weight=10.0,
    distance_weight=10.0,
    osvm_weight=10.0,
    gpu_id=0,
    use_mixed_precision=True,
    prediction_batch_size=2048,
    use_chunked_processing=True,
    chunk_size=1024,
    output_cf_path='best_counterfactuals_2feat_osvm.parquet',
    output_importance_path='feature_importance_2feat_osvm.parquet'
):
    """
    CORRECTED VERSION with comprehensive progress tracking
    Optimized counterfactual generation with conservative fp16 usage:
    - fp16 for large storage tensors (candidates, samples)
    - Mixed precision for model inference
    - fp32 for all distance calculations and final scoring
    """
    start_time = time.time()
    
    # GPU setup
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        torch.backends.cudnn.benchmark = True
        torch.backends.cuda.matmul.allow_tf32 = True
        print(f"Using device: {device} with RTX 2080 Ti optimizations")
    
    # Load data
    print("Loading evaluation data...")
    df_eval = pd.read_parquet(data_eval_path)
    X_eval = df_eval.iloc[:, 1:].values
    X_eval_tensor = torch.tensor(X_eval, dtype=torch.float32, device=device)
    
    n_samples, n_features = X_eval_tensor.shape
    print(f"Data: {n_samples} samples, {n_features} features")
    
    # Load models
    print("Loading models...")
    mlp_model = load_model(model_path, n_features, device)
    osvm_model = OptimizedGPUOCSVM()
    osvm_model.load_model(osvm_model_path)
    
    warmup_models(mlp_model, osvm_model, n_features, device)
    
    # Load second feature indices
    print(f"Loading second feature indices...")
    try:
        with open(second_feature_json_path, 'r') as f:
            second_feature_indices = json.load(f)
        
        valid_indices = [idx for idx in second_feature_indices 
                        if isinstance(idx, int) and 0 <= idx < n_features]
        second_feature_indices = torch.tensor(valid_indices, device=device, dtype=torch.long)
        
        if len(second_feature_indices) == 0:
            print("Error: No valid second feature indices found")
            return None, None
            
    except (FileNotFoundError, json.JSONDecodeError) as e:
        print(f"Error loading JSON: {e}")
        return None, None
    
    n_second_features = len(second_feature_indices)
    print(f"Loaded {n_second_features} second feature indices")
    
    # Compute feature ranges
    print("Computing feature ranges...")
    min_vals, max_vals = compute_feature_ranges(X_eval_tensor, range_percent)
    
    # Pre-allocate sampling tensors
    print("Pre-allocating sampling points...")
    first_feature_samples = torch.empty(n_features, num_points_first, device=device, dtype=torch.float16)
    for i in range(n_features):
        first_feature_samples[i] = torch.linspace(min_vals[i], max_vals[i], num_points_first, device=device, dtype=torch.float16)
    
    second_feature_samples = torch.empty(n_second_features, num_points_second, device=device, dtype=torch.float16)
    for i, idx in enumerate(second_feature_indices):
        second_feature_samples[i] = torch.linspace(min_vals[idx], max_vals[idx], num_points_second, device=device, dtype=torch.float16)
    
    # Get initial predictions
    print("Computing initial predictions...")
    initial_predictions = torch.zeros(n_samples, device=device, dtype=torch.float32)
    
    with torch.no_grad():
        for i in range(0, n_samples, prediction_batch_size):
            end_i = min(i + prediction_batch_size, n_samples)
            if use_mixed_precision:
                with autocast():
                    batch_pred = mlp_model(X_eval_tensor[i:end_i]).squeeze()
            else:
                batch_pred = mlp_model(X_eval_tensor[i:end_i]).squeeze()
            
            if batch_pred.dim() == 0:
                batch_pred = batch_pred.unsqueeze(0)
            initial_predictions[i:end_i] = batch_pred
    
    # Initialize progress tracker
    progress_tracker = ProgressTracker(
        total_samples=n_samples,
        total_features=n_features,
        n_second_features=n_second_features,
        num_points_first=num_points_first,
        num_points_second=num_points_second
    )
    
    # MAIN PROCESSING LOOP
    print("Starting counterfactual generation...\n")
    
    L1_NORM_MEAN = 6471.3016
    L1_NORM_STD = 1596.4753
    
    best_counterfactuals = []
    feature_importance_data = []
    
    # Process each sample
    for sample_idx in range(n_samples):
        # Keep sample in fp32 during processing for numerical stability
        sample = X_eval_tensor[sample_idx]
        initial_pred = initial_predictions[sample_idx]
        
        best_sample_score = float('inf')
        best_sample_classification = initial_pred.item()
        best_sample_candidate = sample.clone()
        sample_feature_importance = torch.zeros(n_features, device=device, dtype=torch.float32)
        
        # Process each first feature
        for first_feat_idx in range(n_features):
            first_samples = first_feature_samples[first_feat_idx].float()
            
            best_feature_score = float('inf')
            best_feature_pred = initial_pred.item()
            best_feature_candidate = sample.clone()
            
            # Process each second feature
            for second_feat_idx, second_feat_global_idx in enumerate(second_feature_indices):
                second_samples = second_feature_samples[second_feat_idx].float()
                
                # Use chunked processing or vectorized approach
                if use_chunked_processing:
                    score, pred, candidate = process_candidates_chunked(
                        sample, first_feat_idx, second_feat_global_idx,
                        first_samples, second_samples, mlp_model, osvm_model,
                        classification_weight, distance_weight, osvm_weight,
                        L1_NORM_MEAN, L1_NORM_STD, use_mixed_precision,
                        chunk_size
                    )
                else:
                    # Vectorized approach (faster but uses more memory)
                    grid_first, grid_second = torch.meshgrid(first_samples, second_samples, indexing='ij')
                    n_candidates = grid_first.numel()
                    
                    candidates = sample.unsqueeze(0).expand(n_candidates, -1).contiguous()
                    candidates[:, first_feat_idx] = grid_first.flatten()
                    candidates[:, second_feat_global_idx] = grid_second.flatten()
                    
                    with torch.no_grad():
                        if use_mixed_precision:
                            with autocast():
                                mlp_predictions = mlp_model(candidates).squeeze()
                        else:
                            mlp_predictions = mlp_model(candidates).squeeze()
                        
                        if mlp_predictions.dim() == 0:
                            mlp_predictions = mlp_predictions.unsqueeze(0)
                        
                        osvm_scores = osvm_model.decision_function(candidates)
                        l1_distances = torch.sum(torch.abs(candidates - sample), dim=1)
                        
                        classification_component = mlp_predictions * classification_weight
                        distance_component = torch.abs(l1_distances - L1_NORM_MEAN) / L1_NORM_STD * distance_weight
                        osvm_component = osvm_scores * osvm_weight
                        
                        total_scores = - classification_component - distance_component + osvm_component
                        
                        best_idx = torch.argmax(total_scores)
                        score = total_scores[best_idx].item()
                        pred = mlp_predictions[best_idx].item()
                        candidate = candidates[best_idx]
                    
                    # Memory cleanup
                    del candidates, mlp_predictions, osvm_scores, l1_distances
                    del classification_component, distance_component, osvm_component, total_scores
                
                # Update best for this first feature
                if score < best_feature_score:
                    best_feature_score = score
                    best_feature_pred = pred
                    best_feature_candidate = candidate
            
            # Update feature importance
            sample_feature_importance[first_feat_idx] = initial_pred - best_feature_pred
            
            # Update global best for this sample
            if best_feature_score < best_sample_score:
                best_sample_score = best_feature_score
                best_sample_classification = best_feature_pred
                best_sample_candidate = best_feature_candidate
            
            # Update progress tracker
            progress_tracker.update(sample_idx, first_feat_idx)
            
            # Periodic cleanup
            if (first_feat_idx + 1) % 200 == 0:
                torch.cuda.empty_cache()
        
        # Store results after processing all features for this sample
        cf_row = [best_sample_classification] + best_sample_candidate.cpu().numpy().tolist()
        best_counterfactuals.append(cf_row)
        
        importance_row = sample_feature_importance.cpu().numpy().tolist()
        feature_importance_data.append(importance_row)
        
        # Update progress tracker (sample complete)
        progress_tracker.update(sample_idx)
        
        # Memory cleanup after each sample
        if (sample_idx + 1) % 20 == 0:
            gc.collect()
            torch.cuda.empty_cache()
    
    # Finish progress tracking
    progress_tracker.finish()
    
    # Save results
    print("Saving results...")
    cf_columns = ['classification'] + [f'feature_{i}' for i in range(n_features)]
    importance_columns = [f'feature_{i}_diff' for i in range(n_features)]
    
    best_cf_df = pd.DataFrame(best_counterfactuals, columns=cf_columns)
    importance_df = pd.DataFrame(feature_importance_data, columns=importance_columns)
    
    best_cf_df.to_parquet(output_cf_path, engine='pyarrow', compression='snappy', index=False)
    importance_df.to_parquet(output_importance_path, engine='pyarrow', compression='snappy', index=False)
    
    print(f"Saved counterfactuals: {output_cf_path}")
    print(f"Saved feature importance: {output_importance_path}")
    
    torch.cuda.empty_cache()
    gc.collect()
    
    return best_cf_df, importance_df



best_cfs2dim, feature_importance2dim = generate_2feature_counterfactuals_osvm_optimized(
    data_eval_path='data_eval_norm100.parquet',
    model_path='mlp_model.pth',
    osvm_model_path='OSVM_TCGA_class0_20kfeatures_gpu2.pkl',
    second_feature_json_path='top_2_percent_feature_indexes.json',
    num_points_first=4,
    num_points_second=6,
    range_percent=-15,
    classification_weight=20.0,
    distance_weight=5.0,
    osvm_weight=10.0,
    gpu_id=0,
    use_mixed_precision=True,
    prediction_batch_size=2048,
    use_chunked_processing=True,  # Set to True for maximum memory efficiency
    chunk_size=1024
)

if best_cfs is not None:
    print(f"\nSuccess! Generated {len(best_cfs)} counterfactuals")
    print(f"Feature importance shape: {feature_importance.shape}")
else:
    print("Generation failed - check error messages above")