In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import time
import pandas as pd
import os
import json
from datetime import datetime
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, mean_absolute_error, mean_squared_error, r2_score
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
from torch.multiprocessing import Process, Queue, set_start_method, Manager
from sklearn.decomposition import PCA, NMF, FastICA, KernelPCA, SparsePCA, FactorAnalysis
from sklearn.feature_selection import mutual_info_regression
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import ElasticNet
from sklearn.cluster import KMeans
from sklearn.decomposition import DictionaryLearning
from sklearn.model_selection import train_test_split, KFold
from sklearn.neighbors import NearestNeighbors
try:
    import umap
    UMAP_AVAILABLE = True
except ImportError:
    UMAP_AVAILABLE = False
    print("UMAP not available. UMAP initialization method will fall back to PCA.")
import pandas as pd
import os
import time
import uuid
import warnings
from scipy.stats import entropy
from scipy.linalg import svd
import json
from datetime import datetime

# Suppress warnings
warnings.filterwarnings('ignore')

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

# ===== SETUP AND UTILITY FUNCTIONS =====

def setup_device(gpu_id=None):
    """Set up and return the appropriate device (GPU or CPU)."""
    if not torch.cuda.is_available():
        #print("CUDA not available. Using CPU.")
        return torch.device('cpu')

    num_gpus = torch.cuda.device_count()
    if num_gpus == 0:
        #print("No GPUs detected. Using CPU.")
        return torch.device('cpu')

    if gpu_id is not None:
        gpu_id = gpu_id % num_gpus
        device = torch.device(f'cuda:{gpu_id}')
        torch.cuda.set_device(device)
        print(f"Using device: {device}")
        return device
    else:
        device = torch.device('cuda:0')
        print(f"Using default device: {device}")
        return device

def clean_gpu_memory(device=None):
    """Clean GPU memory to avoid fragmentation."""
    if device is not None and device.type == 'cuda':
        torch.cuda.empty_cache()
        torch.cuda.synchronize(device)
    elif torch.cuda.is_available():
        torch.cuda.empty_cache()


def get_available_datasets():
    """Check available datasets in the current directory."""
    datasets = {}
    for prefix in ['fashion_mnist', 'mnist', 'diabetes', 'cifar10', 'dsprites', 'wine']:
        x_train_path = f'{prefix}_x_train.npy'
        if os.path.exists(x_train_path):
            try:
                x_sample = np.load(x_train_path, mmap_mode='r')
                datasets[prefix] = {'available': True, 'x_shape': x_sample.shape, 'x_path': x_train_path}
                print(f"Dataset {prefix} available: {x_sample.shape[0]} samples, {x_sample.shape[1]} features")
            except Exception as e:
                print(f"Error loading {prefix}: {e}")
                datasets[prefix] = {'available': False, 'error': str(e)}
        else:
            print(f"Dataset {prefix} files not found")
            datasets[prefix] = {'available': False, 'error': 'Files not found'}
    return datasets

def sanitize_metrics(metrics):
    """Ensure metrics have valid numerical values."""
    sanitized = {}
    fallback_defaults = {
        'sparsity': 0.5,
        'modularity': 0.5,
        'factor_vae_score': 0.5,
        'sap_score': 0.5,
        'variance_ratio': 0.5,
        'mi_ksg': 0.5,
        'total_correlation': 0.5,
        'recon_error': 1.0
    }

    for key, value in metrics.items():
        if isinstance(value, torch.Tensor):
            value = value.item()

        if np.isnan(value) or np.isinf(value) or not (0 <= value <= 1):
            sanitized[key] = fallback_defaults.get(key, 0.5)
            print(f"Warning: {key} invalid ({value}), using fallback: {sanitized[key]}")
        else:
            sanitized[key] = value
    return sanitized

def create_factor_analysis_k_matrix(x_data, num_factors, latent_dim, device):
    """Initialize K matrix using Factor Analysis."""
    try:
        # Move to CPU for sklearn
        x_np = x_data.cpu().numpy().reshape(x_data.shape[0], -1)

        # Sample for efficiency
        if x_np.shape[0] > 10000:
            indices = np.random.choice(x_np.shape[0], 10000, replace=False)
            x_np = x_np[indices]

        # Determine number of components
        total_components = min(num_factors * latent_dim, min(x_np.shape))

        # Run Factor Analysis
        fa = FactorAnalysis(
            n_components=total_components,
            random_state=42
        )
        fa.fit(x_np)

        # Create K matrices
        k_matrices = []
        for i in range(num_factors):
            # Extract components for this factor
            start_idx = i * latent_dim
            end_idx = min(start_idx + latent_dim, total_components)

            if end_idx > start_idx:
                # Use Factor Analysis components
                k = torch.tensor(fa.components_[start_idx:end_idx].T, dtype=torch.float32)
            else:
                # If we run out of components, use random initialization
                k = torch.randn(x_np.shape[1], latent_dim)

            # If we don't have enough components, pad with random values
            if k.shape[1] < latent_dim:
                padding = torch.randn(x_np.shape[1], latent_dim - k.shape[1])
                k = torch.cat([k, padding], dim=1)

            # Normalize columns
            k = k / (torch.norm(k, dim=0, keepdim=True) + 1e-6)
            k_matrices.append(k)

        # Stack and move to device
        result = torch.stack(k_matrices).to(device)
        return result
    except Exception as e:
        print(f"FactorAnalysis initialization error: {e}")
        return create_random_k_matrix(x_data, num_factors, latent_dim, device)


def create_pca_k_matrix(x_data, num_factors, latent_dim, device):
    """Initialize K matrix using PCA."""
    try:
        # Always move to CPU for sklearn operations
        x_np = x_data.cpu().numpy().reshape(x_data.shape[0], -1)

        # Determine number of components
        total_components = min(num_factors * latent_dim, min(x_np.shape))

        # Calculate PCA
        pca = PCA(n_components=total_components, random_state=42)
        pca.fit(x_np)

        # Create K matrices
        k_matrices = []
        for i in range(num_factors):
            # Get components for this factor
            start_idx = i * latent_dim
            end_idx = min(start_idx + latent_dim, total_components)

            if end_idx > start_idx:
                # Use PCA components
                k = torch.tensor(pca.components_[start_idx:end_idx].T, dtype=torch.float32)
            else:
                # If we run out of components, use random initialization
                k = torch.randn(x_np.shape[1], latent_dim)

            # If we don't have enough components, pad with random values
            if k.shape[1] < latent_dim:
                padding = torch.randn(x_np.shape[1], latent_dim - k.shape[1])
                k = torch.cat([k, padding], dim=1)

            # Normalize columns
            k = k / (torch.norm(k, dim=0, keepdim=True) + 1e-6)
            k_matrices.append(k)

        # Stack and move to specified device
        result = torch.stack(k_matrices).to(device)
        return result
    except Exception as e:
        print(f"PCA initialization error: {e}")
        return create_random_k_matrix(x_data, num_factors, latent_dim, device)


def load_or_create_dataset(dataset_name, available_datasets):
    """
    Load or create dataset with proper preprocessing.

    Args:
        dataset_name: Name of the dataset
        available_datasets: Dictionary of available datasets

    Returns:
        tuple of (x_data, y_data, is_classification)
    """
    try:
        # Check if dataset is available
        if not available_datasets.get(dataset_name, {}).get('available', False):
            raise ValueError(f"Dataset {dataset_name} is not available")

        # Load x_data
        x_path = available_datasets[dataset_name]['x_path']
        x_data = torch.tensor(np.load(x_path), dtype=torch.float32)

        # Try to load y_data with common naming patterns
        y_patterns = [
            x_path.replace('_x_train', '_y_train'),
            x_path.replace('_x_', '_y_'),
            os.path.join(os.path.dirname(x_path), f'{dataset_name}_y_train.npy'),
            os.path.join(os.path.dirname(x_path), f'{dataset_name}_labels.npy')
        ]

        y_data = None
        for y_path in y_patterns:
            if os.path.exists(y_path):
                try:
                    y_data = torch.tensor(np.load(y_path))
                    print(f"Loaded labels from {y_path}")
                    break
                except Exception as e:
                    print(f"Error loading {y_path}: {e}")

        # If no labels found, create dummy values
        if y_data is None:
            print(f"No labels found for {dataset_name}, creating dummy values")
            y_data = torch.zeros(x_data.shape[0], 1)

        # Determine if classification task
        is_classification = (y_data.dtype == torch.long or
                           (len(torch.unique(y_data)) < 100 and len(torch.unique(y_data)) > 1))

        # Standardize data
        x_mean = x_mean = x_data.mean(dim=0, keepdim=True)
        x_std = x_data.std(dim=0, keepdim=True) + 1e-6
        x_data = (x_data - x_mean) / x_std

        # Handle NaN or Inf values
        x_data = torch.nan_to_num(x_data, nan=0.0, posinf=0.0, neginf=0.0)

        return x_data, y_data, is_classification

    except Exception as e:
        print(f"Error loading dataset {dataset_name}: {e}")
        raise


def compare_universal_k_methods(results):
    """
    Compare performance of Universal K method variations.

    Args:
        results: Dictionary of results

    Returns:
        DataFrame with comparisons
    """
    comparison_rows = []

    for dataset_name, dataset_results in results.items():
        # Check if k_methods exists and has results
        if 'k_methods' not in dataset_results or not dataset_results['k_methods']:
            continue

        # For each factor count, find the best method
        for factors in [3, 5]:
            for dims in [8, 16]:
                factor_config = f"f{factors}_d{dims}"

                # Find best method for this configuration
                best_score = -float('inf')
                best_method = None

                for method_name, method_results in dataset_results['k_methods'].items():
                    for result in method_results:
                        if result['num_factors'] == factors and result['latent_dim'] == dims:
                            if result['combined_score'] > best_score:
                                best_score = result['combined_score']
                                best_method = method_name

                # Add comparison row
                if best_method:
                    row = {
                        'dataset': dataset_name,
                        'factors': factors,
                        'dims': dims,
                        'config': factor_config,
                        'best_method': best_method,
                        'combined_score': best_score
                    }

                    # Add per-method scores for this configuration
                    for method_name in dataset_results['k_methods'].keys():
                        method_score = -float('inf')
                        for result in dataset_results['k_methods'][method_name]:
                            if result['num_factors'] == factors and result['latent_dim'] == dims:
                                method_score = result['combined_score']

                                # Also add individual metrics
                                for metric_name, metric_value in result['metrics'].items():
                                    row[f"{method_name}_{metric_name}"] = metric_value

                        row[f"{method_name}_score"] = method_score

                    comparison_rows.append(row)

    # Create DataFrame
    if comparison_rows:
        comparison_df = pd.DataFrame(comparison_rows)
        return comparison_df
    else:
        # Return empty DataFrame with columns
        return pd.DataFrame(columns=['dataset', 'factors', 'dims', 'config',
                                   'best_method', 'combined_score'])

def evaluate_k_matrix(x_data, k_matrix, num_factors, latent_dim, device):
    """Evaluate K matrix with comprehensive metrics."""
    try:
        # Ensure consistent device placement
        x_data = x_data.to(device)
        k_matrix = k_matrix.to(device)

        # Initialize metrics dictionary
        metrics = {
            'recon_error': 1.0,
            'mi_ksg': 0.5,
            'sparsity': 0.5,
            'total_correlation': 0.5,
            'modularity': 0.5,
            'factor_vae_score': 0.5,
            'sap_score': 0.5,
            'variance_ratio': 0.5
        }

        # Special case for num_factors=1
        if num_factors <= 1:
            print("Evaluation: Single factor case, special handling")
            metrics['total_correlation'] = 0.0  # No correlation with self
            metrics['mi_ksg'] = 0.0  # No mutual information with self
            metrics['modularity'] = 1.0  # Fully modular with self

        # Check for NaNs or Infs in input
        if torch.isnan(k_matrix).any() or torch.isinf(k_matrix).any():
            print("Evaluation Debug: K matrix contains NaN or Inf values!")
            k_matrix = torch.nan_to_num(k_matrix, nan=0.0, posinf=1.0, neginf=-1.0)

        # Normalize k_matrix
        k_norm = torch.norm(k_matrix.view(num_factors, -1, latent_dim), dim=1, keepdim=True)
        k_matrix = k_matrix / (k_norm + 1e-8)

        # Encode data
        z = encode_data(x_data, k_matrix)

        # Reconstruction
        batch_size = 1024
        all_recon = []

        with torch.no_grad():
            for i in range(0, len(x_data), batch_size):
                batch_x = x_data[i:i + batch_size]
                batch_indices = slice(i, min(i + batch_size, len(x_data)))
                batch_z = z[batch_indices]

                # Reconstruct incrementally
                batch_recon = torch.zeros_like(batch_x)

                for j in range(num_factors):
                    # Get the z values for this factor
                    z_j = batch_z[:, j]
                    # Add the reconstruction for this factor
                    batch_recon += torch.matmul(z_j, k_matrix[j].T)

                all_recon.append(batch_recon)

        # Concatenate all batches
        reconstructed = torch.cat(all_recon, dim=0)

        # Calculate reconstruction error
        recon_error = F.mse_loss(reconstructed, x_data)
        data_var = torch.var(x_data)
        metrics['recon_error'] = min(1.0, recon_error.item() / (data_var.item() + 1e-8))

        # Sample a subset for metric computation
        sample_size = min(2000, z.shape[0])
        sample_indices = torch.randperm(z.shape[0], device=device)[:sample_size]
        z_sampled = z[sample_indices]
        x_sampled = x_data[sample_indices]

        # If single factor, just compute remaining metrics
        if num_factors <= 1:
            metrics['sparsity'] = sparsity_score(k_matrix).item()

            try:
                # Move to CPU for SVD
                x_flat = x_sampled.reshape(sample_size, -1).detach().cpu().numpy()
                z_flat = z_sampled.reshape(sample_size, -1).detach().cpu().numpy()

                # Handle NaNs
                if np.isnan(x_flat).any() or np.isinf(x_flat).any():
                    x_flat = np.nan_to_num(x_flat, nan=0.0)
                if np.isnan(z_flat).any() or np.isinf(z_flat).any():
                    z_flat = np.nan_to_num(z_flat, nan=0.0)

                # Add noise for stability
                x_flat = x_flat + 1e-8 * np.random.randn(*x_flat.shape)
                z_flat = z_flat + 1e-8 * np.random.randn(*z_flat.shape)

                # Compute SVD for data
                u_x, s_x, vt_x = svd(x_flat, full_matrices=False)
                total_variance_x = np.sum(s_x ** 2)

                # Compute SVD for latent
                u_z, s_z, vt_z = svd(z_flat, full_matrices=False)
                total_variance_z = np.sum(s_z ** 2)

                # Variance ratio
                metrics['variance_ratio'] = min(0.95, total_variance_z / (total_variance_x + 1e-6))
            except Exception as e:
                print(f"Variance ratio calculation error: {e}")
                metrics['variance_ratio'] = 0.5

            return sanitize_metrics(metrics)

        # For multi-factor case, compute all metrics
        try:
            # Compute mutual information between factors using KSG estimator
            # Move to CPU for KSG computation
            z_cpu = z_sampled.detach().cpu()
            mi_scores = []

            for i in range(num_factors):
                for j in range(i + 1, num_factors):
                    # Take the first dimension of each factor for simplicity
                    mi_scores.append(safe_mi_ksg_estimator(
                        z_cpu[:, i, 0].flatten(),
                        z_cpu[:, j, 0].flatten()
                    ))

            metrics['mi_ksg'] = np.mean(mi_scores) if mi_scores else 0.5
        except Exception as e:
            print(f"MI KSG calculation error: {e}")
            metrics['mi_ksg'] = 0.5

        try:
            # Compute sparsity
            metrics['sparsity'] = sparsity_score(k_matrix).item()
        except Exception as e:
            print(f"Sparsity calculation error: {e}")
            metrics['sparsity'] = 0.5

        try:
            # Compute total correlation
            tc_result = robust_total_correlation(z_sampled, num_factors, latent_dim)
            metrics['total_correlation'] = tc_result.item() if isinstance(tc_result, torch.Tensor) else tc_result
        except Exception as e:
            print(f"Total correlation calculation error: {e}")
            metrics['total_correlation'] = 0.5

        try:
            # Compute modularity
            mod_result = robust_modularity_score(z_sampled, num_factors, latent_dim)
            metrics['modularity'] = mod_result.item() if isinstance(mod_result, torch.Tensor) else mod_result
        except Exception as e:
            print(f"Modularity calculation error: {e}")
            metrics['modularity'] = 0.5

        try:
            # Compute FactorVAE score (needs CPU)
            z_cpu = z_sampled.detach().cpu()
            metrics['factor_vae_score'] = robust_factor_vae_score(z_cpu, num_factors, latent_dim)
        except Exception as e:
            print(f"FactorVAE score calculation error: {e}")
            metrics['factor_vae_score'] = 0.5

        try:
            # Compute SAP score (needs CPU)
            z_cpu = z_sampled.detach().cpu()
            x_cpu = x_sampled.detach().cpu()
            metrics['sap_score'] = robust_sap_score(z_cpu, x_cpu, num_factors, latent_dim)
        except Exception as e:
            print(f"SAP score calculation error: {e}")
            metrics['sap_score'] = 0.5

        try:
            # Compute variance ratio
            x_flat = x_sampled.reshape(sample_size, -1).detach().cpu().numpy()
            z_flat = z_sampled.reshape(sample_size, -1).detach().cpu().numpy()

            # Handle NaNs
            if np.isnan(x_flat).any() or np.isinf(x_flat).any():
                x_flat = np.nan_to_num(x_flat, nan=0.0)
            if np.isnan(z_flat).any() or np.isinf(z_flat).any():
                z_flat = np.nan_to_num(z_flat, nan=0.0)

            # Add noise for stability
            x_flat = x_flat + 1e-8 * np.random.randn(*x_flat.shape)
            z_flat = z_flat + 1e-8 * np.random.randn(*z_flat.shape)

            # Compute SVD
            u_x, s_x, vt_x = svd(x_flat, full_matrices=False)
            total_variance_x = np.sum(s_x ** 2)

            u_z, s_z, vt_z = svd(z_flat, full_matrices=False)
            total_variance_z = np.sum(s_z ** 2)

            metrics['variance_ratio'] = min(0.95, total_variance_z / (total_variance_x + 1e-6))
        except Exception as e:
            print(f"Variance ratio calculation error: {e}")
            metrics['variance_ratio'] = 0.5

        return sanitize_metrics(metrics)

    except Exception as e:
        print(f"Evaluation error: {e}")
        return sanitize_metrics(metrics)



# ===== MODEL IMPLEMENTATIONS =====

class VIB(nn.Module):
    """Variational Information Bottleneck Autoencoder."""
    def __init__(self, input_dim, hidden_dim, latent_dim, beta=1.0):
        super(VIB, self).__init__()

        self.beta = beta

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2)
        )

        # Latent parameters
        self.mu = nn.Linear(hidden_dim // 2, latent_dim)
        self.log_var = nn.Linear(hidden_dim // 2, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, input_dim)
        )

        # Task predictor
        self.predictor = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 2, 1)  # For regression, will be adjusted for classification
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.mu(h), self.log_var(h)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def predict(self, z):
        return self.predictor(z)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decode(z)
        pred = self.predict(z)
        return x_recon, mu, log_var, z, pred

    def loss_function(self, x, x_recon, mu, log_var, y, pred):
        # Reconstruction loss
        recon_loss = F.mse_loss(x_recon, x)

        # KL divergence
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        # Task loss (MSE for regression, will be adjusted for classification)
        task_loss = F.mse_loss(pred, y)

        # Total loss
        total_loss = task_loss + self.beta * kl_loss + recon_loss

        return total_loss, recon_loss, kl_loss, task_loss


class BetaVAE(nn.Module):
    """Beta-VAE for disentangled representation learning."""
    def __init__(self, input_dim, hidden_dim, latent_dim, beta=1.0):
        super(BetaVAE, self).__init__()

        self.beta = beta

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2)
        )

        # Latent parameters
        self.mu = nn.Linear(hidden_dim // 2, latent_dim)
        self.log_var = nn.Linear(hidden_dim // 2, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, input_dim)
        )

        # Task predictor
        self.predictor = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 2, 1)  # For regression, will be adjusted for classification
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.mu(h), self.log_var(h)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def predict(self, z):
        return self.predictor(z)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decode(z)
        pred = self.predict(z)
        return x_recon, mu, log_var, z, pred

    def loss_function(self, x, x_recon, mu, log_var, y, pred):
        # Reconstruction loss
        recon_loss = F.mse_loss(x_recon, x)

        # KL divergence with beta scaling
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        # Task loss
        task_loss = F.mse_loss(pred, y)

        # Total loss with beta weighting
        total_loss = task_loss + self.beta * kl_loss + recon_loss

        return total_loss, recon_loss, kl_loss, task_loss


class DropoutRegularizedModel(nn.Module):
    """Neural network with dropout and weight decay for regularization."""
    def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate=0.5):
        super(DropoutRegularizedModel, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim // 2, output_dim)
        )

    def forward(self, x):
        return self.model(x)


class SparseAutoencoder(nn.Module):
    """Autoencoder with L1 sparsity regularization."""
    def __init__(self, input_dim, hidden_dim, latent_dim, sparsity_weight=0.01):
        super(SparseAutoencoder, self).__init__()

        self.sparsity_weight = sparsity_weight

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, latent_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, input_dim)
        )

        # Task predictor
        self.predictor = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 2, 1)  # For regression, will be adjusted for classification
        )

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        pred = self.predictor(z)
        return x_recon, z, pred

    def loss_function(self, x, x_recon, z, y, pred):
        # Reconstruction loss
        recon_loss = F.mse_loss(x_recon, x)

        # L1 sparsity regularization
        sparsity_loss = torch.mean(torch.abs(z))

        # Task loss
        task_loss = F.mse_loss(pred, y)

        # Total loss
        total_loss = task_loss + recon_loss + self.sparsity_weight * sparsity_loss

        return total_loss, recon_loss, sparsity_loss, task_loss


# ===== DISTILLATION METHODS =====

def knowledge_distillation_loss(student_logits, teacher_logits, y_true, temperature=4.0, alpha=0.5):
    """
    Compute the knowledge distillation loss.

    Args:
        student_logits: Logits from the student model
        teacher_logits: Logits from the teacher model
        y_true: Ground truth labels
        temperature: Softmax temperature for distillation
        alpha: Weight for distillation loss vs. task loss

    Returns:
        The combined loss
    """
    # For classification tasks
    if len(y_true.shape) == 1:  # Labels are indices
        # Task loss
        task_loss = F.cross_entropy(student_logits, y_true)

        # Distillation loss
        # Scale logits by temperature and compute soft targets
        soft_targets = F.softmax(teacher_logits / temperature, dim=1)
        log_probs = F.log_softmax(student_logits / temperature, dim=1)
        distillation_loss = -(soft_targets * log_probs).sum(dim=1).mean() * (temperature ** 2)

        # Combined loss
        loss = alpha * distillation_loss + (1 - alpha) * task_loss

    # For regression tasks
    else:
        # Task loss
        task_loss = F.mse_loss(student_logits, y_true)

        # Distillation loss - MSE between student and teacher outputs
        distillation_loss = F.mse_loss(student_logits, teacher_logits)

        # Combined loss
        loss = alpha * distillation_loss + (1 - alpha) * task_loss

    return loss

def attention_transfer_loss(student_features, teacher_features, y_true, model_output, beta=0.5):
    """
    Compute the attention transfer loss.

    Args:
        student_features: List of feature maps from student model
        teacher_features: List of feature maps from teacher model
        y_true: Ground truth labels
        model_output: Output from the student model
        beta: Weight for attention loss vs. task loss

    Returns:
        The combined loss
    """
    # Task loss (classification or regression)
    if len(y_true.shape) == 1:  # Classification with class indices
        task_loss = F.cross_entropy(model_output, y_true)
    else:  # Regression
        task_loss = F.mse_loss(model_output, y_true)

    # Attention transfer loss
    attention_loss = 0.0
    for student_feat, teacher_feat in zip(student_features, teacher_features):
        # Compute normalized attention maps
        student_attention = F.normalize(student_feat.pow(2).mean(1).view(student_feat.size(0), -1), p=2, dim=1)
        teacher_attention = F.normalize(teacher_feat.pow(2).mean(1).view(teacher_feat.size(0), -1), p=2, dim=1)

        # L2 distance between attention maps
        attention_loss += F.mse_loss(student_attention, teacher_attention)

    # Combined loss
    loss = task_loss + beta * attention_loss

    return loss


# ===== EVALUATION FUNCTIONS =====

def evaluate_model_metrics(model, dataloader, device, is_classification=True):
    """
    Evaluate a model and return comprehensive metrics.

    Args:
        model: The trained model to evaluate
        dataloader: DataLoader for evaluation data
        device: Device to run evaluation on
        is_classification: Whether this is a classification task

    Returns:
        Dictionary of metrics
    """
    model.eval()
    all_targets = []
    all_predictions = []
    all_probs = []  # For AUC calculation in classification

    with torch.no_grad():
        for batch in dataloader:
            x, y = batch
            x, y = x.to(device), y.to(device)

            # Forward pass depends on model type
            if isinstance(model, (VIB, BetaVAE)):
                _, _, _, _, pred = model(x)
            elif isinstance(model, SparseAutoencoder):
                _, _, pred = model(x)
            else:  # Standard models like DropoutRegularizedModel
                pred = model(x)

            # Store predictions and targets
            if is_classification:
                probs = F.softmax(pred, dim=1) if pred.size(1) > 1 else torch.sigmoid(pred)
                _, predicted = torch.max(pred, 1) if pred.size(1) > 1 else (pred > 0.5).long()
                all_probs.append(probs.cpu())
                all_predictions.append(predicted.cpu())
                all_targets.append(y.cpu())
            else:
                all_predictions.append(pred.cpu())
                all_targets.append(y.cpu())

    # Concatenate results
    if is_classification:
        all_targets = torch.cat(all_targets).numpy()
        all_predictions = torch.cat(all_predictions).numpy()
        all_probs = torch.cat(all_probs).numpy()

        # Calculate classification metrics
        accuracy = (all_predictions == all_targets).mean()

        # For multi-class, calculate macro averages
        precision = precision_score(all_targets, all_predictions, average='macro', zero_division=0)
        recall = recall_score(all_targets, all_predictions, average='macro', zero_division=0)
        f1 = f1_score(all_targets, all_predictions, average='macro', zero_division=0)

        # AUC calculation (handle multi-class)
        if all_probs.shape[1] > 2:  # Multi-class
            # One-hot encode targets for multi-class AUC
            from sklearn.preprocessing import label_binarize
            classes = list(range(all_probs.shape[1]))
            all_targets_binary = label_binarize(all_targets, classes=classes)
            auc = roc_auc_score(all_targets_binary, all_probs, multi_class='ovr')
        else:  # Binary classification
            auc = roc_auc_score(all_targets, all_probs[:, 1] if all_probs.shape[1] > 1 else all_probs)

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'auc': auc
        }
    else:
        all_targets = torch.cat(all_targets).numpy()
        all_predictions = torch.cat(all_predictions).numpy()

        # Calculate regression metrics
        mse = mean_squared_error(all_targets, all_predictions)
        mae = mean_absolute_error(all_targets, all_predictions)
        r2 = r2_score(all_targets, all_predictions)

        return {
            'mse': mse,
            'mae': mae,
            'rmse': np.sqrt(mse),
            'r2_score': r2
        }


def evaluate_baseline_method(method_name, model, test_loader, device, is_classification=True, k_matrix=None, num_factors=None, latent_dim=None):
    """
    Evaluate a baseline method and compute comprehensive metrics.

    Args:
        method_name: Name of the method being evaluated
        model: The trained model to evaluate
        test_loader: DataLoader for test data
        device: Device to run evaluation on
        is_classification: Whether this is a classification task
        k_matrix: Optional K matrix for universal K methods
        num_factors: Number of factors for universal K methods
        latent_dim: Latent dimension per factor

    Returns:
        Dictionary of metrics
    """
    task_metrics = evaluate_model_metrics(model, test_loader, device, is_classification)

    # Add method-specific metrics
    metrics = {
        'method': method_name,
        'task_metrics': task_metrics
    }

    # For autoencoder-based methods, add reconstruction metrics
    if isinstance(model, (VIB, BetaVAE, SparseAutoencoder)):
        model.eval()
        recon_loss = 0.0

        with torch.no_grad():
            for batch in test_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)

                if isinstance(model, (VIB, BetaVAE)):
                    x_recon, mu, log_var, z, _ = model(x)
                    recon_error = F.mse_loss(x_recon, x).item()

                    # Add KL divergence for VIB/VAE models
                    kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()).item()
                    metrics['kl_divergence'] = kl_div / len(test_loader)

                    # Add disentanglement metrics if applicable
                    if num_factors is not None and latent_dim is not None:
                        try:
                            z_reshaped = z.view(-1, num_factors, latent_dim)
                            metrics['disentanglement'] = {
                                'mi_ksg': safe_mi_ksg_estimator(z_reshaped),
                                'total_correlation': robust_total_correlation(z_reshaped),
                                'modularity': robust_modularity_score(z_reshaped),
                                'factor_vae_score': robust_factor_vae_score(z_reshaped),
                                'sap_score': robust_sap_score(z_reshaped, x)
                            }
                        except Exception as e:
                            print(f"Error calculating disentanglement metrics: {e}")

                elif isinstance(model, SparseAutoencoder):
                    x_recon, z, _ = model(x)
                    recon_error = F.mse_loss(x_recon, x).item()

                    # Add sparsity measure
                    sparsity = torch.mean(torch.abs(z)).item()
                    metrics['sparsity'] = sparsity

                recon_loss += recon_error

        metrics['recon_error'] = recon_loss / len(test_loader)

    # For K-matrix methods, use the evaluate_k_matrix function if provided
    if k_matrix is not None and 'Universal_K' in method_name:
        try:
            k_metrics = evaluate_k_matrix(next(iter(test_loader))[0].to(device), k_matrix, num_factors, latent_dim, device)
            metrics['k_metrics'] = k_metrics
        except Exception as e:
            print(f"Error evaluating K matrix metrics: {e}")

    return metrics


def train_model(model, train_loader, val_loader, device, is_classification=True, epochs=50, patience=5):
    """
    Train a model with early stopping.

    Args:
        model: The model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        device: Device to run training on
        is_classification: Whether this is a classification task
        epochs: Maximum number of epochs to train
        patience: Number of epochs to wait for improvement before stopping

    Returns:
        Trained model
    """
    # Set up optimizer
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Set up criterion for task loss
    if is_classification:
        task_criterion = nn.CrossEntropyLoss()
    else:
        task_criterion = nn.MSELoss()

    # Train for specified epochs with early stopping
    best_val_loss = float('inf')
    early_stop_counter = 0

    for epoch in range(epochs):
        # Training
        model.train()
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            optimizer.zero_grad()

            # Forward pass depends on model type
            if isinstance(model, VIB):
                x_recon, mu, log_var, z, pred = model(batch_x)
                loss, _, _, _ = model.loss_function(batch_x, x_recon, mu, log_var, batch_y, pred)
            elif isinstance(model, BetaVAE):
                x_recon, mu, log_var, z, pred = model(batch_x)
                loss, _, _, _ = model.loss_function(batch_x, x_recon, mu, log_var, batch_y, pred)
            elif isinstance(model, SparseAutoencoder):
                x_recon, z, pred = model(batch_x)
                loss, _, _, _ = model.loss_function(batch_x, x_recon, z, batch_y, pred)
            else:  # Standard models like DropoutRegularizedModel
                pred = model(batch_x)
                loss = task_criterion(pred, batch_y)

            loss.backward()
            optimizer.step()

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)

                # Forward pass depends on model type
                if isinstance(model, VIB):
                    x_recon, mu, log_var, z, pred = model(batch_x)
                    _, _, _, task_loss = model.loss_function(batch_x, x_recon, mu, log_var, batch_y, pred)
                elif isinstance(model, BetaVAE):
                    x_recon, mu, log_var, z, pred = model(batch_x)
                    _, _, _, task_loss = model.loss_function(batch_x, x_recon, mu, log_var, batch_y, pred)
                elif isinstance(model, SparseAutoencoder):
                    x_recon, z, pred = model(batch_x)
                    _, _, _, task_loss = model.loss_function(batch_x, x_recon, z, batch_y, pred)
                else:  # Standard models
                    pred = model(batch_x)
                    task_loss = task_criterion(pred, batch_y)

                val_loss += task_loss.item()

        val_loss /= len(val_loader)

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stop_counter = 0
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    return model


def refine_k_matrix(x_data, k_matrix, num_factors, latent_dim, device, epochs=100):
    """Refine K matrix with robust optimization."""
    try:
        #print(f"Refining K matrix on {device}...")

        # Make a copy for training and ensure it's on the correct device
        k_matrix = k_matrix.clone().to(device).requires_grad_(True)

        # Setup optimizer with conservative learning rate
        optimizer = optim.Adam([k_matrix], lr=5e-5, weight_decay=1e-6)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=20)

        # Setup data loading
        batch_size = min(1024, len(x_data))
        dataset = TensorDataset(x_data)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

        # Track best matrix seen so far
        best_k_matrix = k_matrix.clone().detach()
        best_loss = float('inf')
        patience, patience_counter = 30, 0

        for epoch in range(epochs):
            recon_loss_epoch = 0.0
            tc_loss_epoch = 0.0
            ortho_loss_epoch = 0.0
            batches_processed = 0

            for batch in loader:
                x = batch[0].to(device, non_blocking=True)
                optimizer.zero_grad()

                # Check input shape
                if x.dim() == 1:  # If 1D tensor, reshape to 2D
                    x = x.unsqueeze(0)

                # Reshape k_matrix for computation
                try:
                    k_reshaped = k_matrix.view(num_factors, -1, latent_dim)
                except RuntimeError as e:
                    #print(f"Refinement error: shape issue with k_matrix: {e}")
                    #print(f"k_matrix shape: {k_matrix.shape}, num_factors: {num_factors}, latent_dim: {latent_dim}")
                    return k_matrix.detach()

                # Compute latent representations
                z_factors = []
                for j in range(num_factors):
                    try:
                        z_factor = torch.matmul(x, k_reshaped[j])
                        z_factors.append(z_factor)
                    except RuntimeError as e:
                        #print(f"Error in factor {j} computation: {e}")
                        #print(f"x shape: {x.shape}, k_reshaped[{j}] shape: {k_reshaped[j].shape}")
                        return k_matrix.detach()

                z = torch.stack(z_factors, dim=1)

                # Compute reconstruction
                recon = torch.zeros_like(x)
                for j in range(num_factors):
                    recon += torch.matmul(z_factors[j], k_reshaped[j].T)

                # Basic reconstruction loss
                recon_loss = F.mse_loss(recon, x)

                # Additional losses for better disentanglement

                # Variance penalty - encourage each latent dimension to have variance
                z_var = torch.var(z, dim=0).mean()
                variance_penalty = 0.1 * torch.clamp(1.0 - z_var, min=0.0)

                # Sparsity loss - encourage sparse k_matrix
                sparsity_loss = 0.01 * torch.mean(torch.abs(k_matrix))

                # Orthogonality loss - encourage factors to be orthogonal
                ortho_loss = 0.0
                for i in range(num_factors):
                    for j in range(i + 1, num_factors):
                        ortho_loss += torch.norm(torch.mm(k_reshaped[i].t(), k_reshaped[j]))
                ortho_loss = 0.01 * ortho_loss

                # Total correlation loss - encourage independence within factors
                tc_loss = 0.0
                z_reshaped = z.view(-1, num_factors, latent_dim)

                for i in range(num_factors):
                    # Calculate correlation matrix for each factor's dimensions
                    z_factor = z_reshaped[:, i, :]
                    z_centered = z_factor - z_factor.mean(0, keepdim=True)
                    cov = torch.mm(z_centered.t(), z_centered) / (z_centered.shape[0] - 1)
                    # Normalize to get correlation
                    var = torch.diag(cov).view(-1, 1)
                    corr = cov / torch.sqrt(var * var.t() + 1e-8)
                    # Sum absolute off-diagonal elements
                    tc_loss += torch.sum(torch.abs(corr * (1 - torch.eye(latent_dim, device=device))))

                tc_loss = 0.01 * tc_loss

                # Modularity loss - encourage between-factor independence
                modularity_loss = 0.0
                for i in range(num_factors):
                    for j in range(i + 1, num_factors):
                        z_i = z_reshaped[:, i, :]
                        z_j = z_reshaped[:, j, :]
                        z_i_centered = z_i - z_i.mean(0, keepdim=True)
                        z_j_centered = z_j - z_j.mean(0, keepdim=True)
                        cross_corr = torch.mm(z_i_centered.t(), z_j_centered) / (z_i_centered.shape[0] - 1)
                        modularity_loss += torch.norm(cross_corr)

                modularity_loss = 0.01 * modularity_loss

                # Total loss
                total_loss = (3.0 * recon_loss +
                             variance_penalty +
                             sparsity_loss +
                             ortho_loss +
                             tc_loss +
                             modularity_loss)

                # Check for numerical stability
                if torch.isnan(total_loss) or torch.isinf(total_loss):
                    print(f"NaN or Inf loss detected at epoch {epoch}, batch {batches_processed}")
                    continue

                # Backward and optimize
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_([k_matrix], max_norm=1.0)
                optimizer.step()

                # Track metrics
                recon_loss_epoch += recon_loss.item()
                tc_loss_epoch += tc_loss.item() if isinstance(tc_loss, torch.Tensor) else tc_loss
                ortho_loss_epoch += ortho_loss.item() if isinstance(ortho_loss, torch.Tensor) else ortho_loss
                batches_processed += 1

                # Clean up to prevent memory fragmentation
                del x, z, z_factors, recon
                clean_gpu_memory(device)

            # Average losses
            if batches_processed > 0:
                recon_loss_epoch /= batches_processed
                tc_loss_epoch /= batches_processed
                ortho_loss_epoch /= batches_processed

                # Update learning rate
                scheduler.step(recon_loss_epoch)

                # Track best model
                if recon_loss_epoch < best_loss:
                    best_loss = recon_loss_epoch
                    best_k_matrix = k_matrix.clone().detach()
                    patience_counter = 0
                    #print(f"Epoch {epoch}: New best loss: {best_loss:.6f}")
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        #print(f"Early stopping at epoch {epoch}")
                        break


        # Normalize the best matrix
        k_norm = torch.norm(best_k_matrix.view(num_factors, -1, latent_dim), dim=1, keepdim=True)
        k_matrix_normalized = best_k_matrix / (k_norm + 1e-6)

        # Clean up
        clean_gpu_memory(device)

        return k_matrix_normalized.detach()

    except Exception as e:
        #print(f"Refinement error: {e}")
        # Return the original matrix as fallback
        return k_matrix.detach()

def encode_data(x_data, k_matrix):
    """Encode data with k_matrix."""
    device = x_data.device

    # Ensure k_matrix is on the same device as x_data
    if k_matrix.device != device:
        k_matrix = k_matrix.to(device)

    num_factors = k_matrix.shape[0]
    batch_size = 1024
    all_z = []

    with torch.no_grad():
        for i in range(0, len(x_data), batch_size):
            batch_x = x_data[i:i + batch_size]

            # Process each factor separately for memory efficiency
            batch_z_factors = []
            for j in range(num_factors):
                z_factor = torch.matmul(batch_x, k_matrix[j])
                batch_z_factors.append(z_factor)

            batch_z = torch.stack(batch_z_factors, dim=1)
            all_z.append(batch_z)

    # Concatenate all batches
    z = torch.cat(all_z, dim=0)
    return z


def create_clustered_k_matrix(x_data, num_factors, latent_dim, device):
    """Initialize K matrix using feature clustering."""
    try:
        # Move to CPU for sklearn
        x_np = x_data.cpu().numpy().reshape(x_data.shape[0], -1)

        # Sample for efficiency
        if x_np.shape[0] > 10000:
            indices = np.random.choice(x_np.shape[0], 10000, replace=False)
            x_np = x_np[indices]

        # Calculate correlation matrix
        corr_matrix = np.corrcoef(x_np.T)
        dist_matrix = 1 - np.abs(np.nan_to_num(corr_matrix))

        # Determine number of clusters
        n_clusters = min(num_factors, x_np.shape[1])

        # Run KMeans
        clustering = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        cluster_labels = clustering.fit_predict(dist_matrix)

        # Create K matrices
        k_matrices = []
        for i in range(num_factors):
            # Get indices for this cluster
            if i < n_clusters:
                cluster_indices = np.where(cluster_labels == i)[0]
            else:
                cluster_indices = np.array([])

            # Initialize k matrix
            k = torch.zeros(x_np.shape[1], latent_dim)

            if len(cluster_indices) > 0:
                # Set values for features in this cluster
                for j in range(latent_dim):
                    # Select a subset of the cluster features for each latent dimension
                    indices = cluster_indices[np.random.choice(
                        len(cluster_indices),
                        max(1, len(cluster_indices) // latent_dim),
                        replace=False
                    )]
                    k[indices, j] = 1.0

                # Add small noise for stability
                k = k + torch.randn_like(k) * 0.01
            else:
                # If no features in cluster, use random initialization
                k = torch.randn(x_np.shape[1], latent_dim)

            # Normalize columns
            k = k / (torch.norm(k, dim=0, keepdim=True) + 1e-6)
            k_matrices.append(k)

        # Stack and move to device
        result = torch.stack(k_matrices).to(device)
        return result
    except Exception as e:
        print(f"Clustered initialization error: {e}")
        return create_random_k_matrix(x_data, num_factors, latent_dim, device)


def safe_mi_ksg_estimator(x, y, k=3):
    """Robust KSG mutual information estimator."""
    try:
        # Ensure inputs are NumPy arrays on CPU
        if isinstance(x, torch.Tensor):
            x = x.detach().cpu().numpy()
        if isinstance(y, torch.Tensor):
            y = y.detach().cpu().numpy()

        x, y = x.flatten(), y.flatten()
        if len(x) < k + 1 or x.shape != y.shape:
            print("MI KSG: Invalid input shapes")
            return 0.5

        # Add small noise for numerical stability
        x = x + np.random.normal(0, 1e-10, x.shape)
        y = y + np.random.normal(0, 1e-10, y.shape)

        n_samples = x.shape[0]
        xy = np.column_stack([x, y])

        # Find nearest neighbors in joint space
        nn_joint = NearestNeighbors(metric='chebyshev').fit(xy)
        dist_joint = nn_joint.kneighbors(xy, k + 1)[0][:, k]

        # Find points within epsilon radius in marginal spaces
        nn_x = NearestNeighbors(metric='chebyshev').fit(x.reshape(-1, 1))
        nn_y = NearestNeighbors(metric='chebyshev').fit(y.reshape(-1, 1))

        nx = np.array([len(nn_x.radius_neighbors(x[i].reshape(1, -1), radius=dist_joint[i])[0]) for i in range(n_samples)])
        ny = np.array([len(nn_y.radius_neighbors(y[i].reshape(1, -1), radius=dist_joint[i])[0]) for i in range(n_samples)])

        # Ensure counts are at least 1
        nx = np.maximum(nx, 1)
        ny = np.maximum(ny, 1)

        # Calculate MI
        mi = np.mean(np.log(n_samples) + np.log(k) - np.log(nx) - np.log(ny))

        # Normalize and clamp
        return max(0.0, min(1.0, mi / np.log(n_samples)))
    except Exception as e:
        print(f"MI KSG error: {e}")
        return 0.5

def robust_total_correlation(z, num_factors, latent_dim, max_samples=5000):
    """Calculate total correlation between latent factors."""
    try:
        # Get the device of the input tensor
        device = z.device

        # Check for NaNs or Infs
        if torch.isnan(z).any() or torch.isinf(z).any():
            print("TC Debug: Input contains NaN or Inf values!")
            z = torch.nan_to_num(z, nan=0.5, posinf=1.0, neginf=0.0)

        # Reshape correctly
        z_reshaped = z.view(-1, num_factors, latent_dim)
        n_samples = z_reshaped.shape[0]

        # Skip computation for num_factors=1 (TC is always 0 for single factor)
        if num_factors <= 1:
            print("TC Debug: Single factor, TC is 0")
            return torch.tensor(0.0, device=device)

        if n_samples < num_factors * 10:
            print(f"Total Correlation: Insufficient samples ({n_samples} < {num_factors * 10})")
            return torch.tensor(0.5, device=device)

        if n_samples > max_samples:
            indices = torch.randperm(n_samples, device=device)[:max_samples]
            z_reshaped = z_reshaped[indices]
            n_samples = max_samples

        # Min-max normalization (keeping on original device)
        z_min, _ = z_reshaped.min(dim=0, keepdim=True)
        z_max, _ = z_reshaped.max(dim=0, keepdim=True)
        z_range = z_max - z_min + 1e-6
        z_reshaped = (z_reshaped - z_min) / z_range

        # Add small noise to prevent exact zeros
        z_reshaped = z_reshaped + torch.randn_like(z_reshaped) * 1e-5

        # Move to CPU for histogram computation
        z_cpu = z_reshaped.detach().cpu()

        # Pairwise TC computation
        tc_scores = []
        # Use adaptive bin count based on sample size
        bin_count = min(30, max(10, int(np.sqrt(n_samples / 5))))
        #print(f"TC Debug: Using {bin_count} bins for histograms")

        for i in range(num_factors):
            for j in range(i + 1, num_factors):
                try:
                    z_i = z_cpu[:, i, :].flatten()
                    z_j = z_cpu[:, j, :].flatten()

                    # Calculate histograms on CPU
                    hist_i, bin_edges_i = np.histogram(z_i.numpy(), bins=bin_count, range=(0.0, 1.0), density=True)
                    hist_j, bin_edges_j = np.histogram(z_j.numpy(), bins=bin_count, range=(0.0, 1.0), density=True)

                    # Remove zeros for log stability
                    hist_i = hist_i + 1e-10
                    hist_j = hist_j + 1e-10

                    # Normalize
                    hist_i = hist_i / np.sum(hist_i)
                    hist_j = hist_j / np.sum(hist_j)

                    # Calculate entropies
                    entropy_i = -np.sum(hist_i * np.log2(hist_i))
                    entropy_j = -np.sum(hist_j * np.log2(hist_j))

                    # Joint histogram - simple 2D binning
                    joint_hist, _, _ = np.histogram2d(
                        z_i.numpy(), z_j.numpy(),
                        bins=bin_count,
                        range=[[0, 1], [0, 1]]
                    )

                    # Normalize and handle zeros
                    joint_hist = joint_hist / np.sum(joint_hist) + 1e-10

                    # Joint entropy
                    joint_entropy = -np.sum(joint_hist * np.log2(joint_hist))

                    # MI calculation
                    mi = entropy_i + entropy_j - joint_entropy

                    # Normalize to [0, 1]
                    max_entropy = np.log2(bin_count)
                    mi_normalized = mi / max_entropy

                    tc_pair = max(0.0, min(0.95, mi_normalized))
                    tc_scores.append(tc_pair)

                except Exception as e:
                    print(f"TC Debug: Error in pair ({i},{j}): {e}")
                    continue

        if tc_scores:
            tc = np.mean(tc_scores)
            return torch.tensor(tc, device=device)  # Return on the original device

        # Gaussian approximation fallback
        #print("TC Debug: Using Gaussian approximation fallback")

        # Compute on GPU if possible
        try:
            z_flat = z_reshaped.reshape(n_samples, -1)

            # Add regularization for numerical stability
            eps = 1e-3 * torch.eye(z_flat.shape[1], device=device)

            # Compute covariance with explicit formula
            z_centered = z_flat - z_flat.mean(dim=0, keepdim=True)
            cov_matrix = (z_centered.T @ z_centered) / (n_samples - 1) + eps

            # Compute log determinant
            log_det_cov = torch.logdet(cov_matrix)

            # Compute marginal variances
            marginal_vars = torch.var(z_reshaped, dim=0, unbiased=True).flatten()
            marginal_vars = torch.clamp(marginal_vars, min=1e-5)
            log_det_marginals = torch.sum(torch.log(marginal_vars))

            # Calculate TC
            tc = 0.5 * (log_det_marginals - log_det_cov)

            # Scale to [0,1]
            tc_scaled = 0.95 * torch.tanh(tc / np.log(n_samples))
            tc_value = max(0.0, min(0.95, tc_scaled.item()))

            return torch.tensor(tc_value, device=device)

        except Exception as e:
            #print(f"TC Debug Gaussian fallback error: {e}")
            return torch.tensor(0.5, device=device)

    except Exception as e:
        print(f"TC error: {e}")
        return torch.tensor(0.5, device=device)


def robust_modularity_score(z, num_factors, latent_dim):
    """Modularity score for latent factors."""
    try:
        device = z.device
        z_reshaped = z.view(-1, num_factors, latent_dim)

        # Special case for single factor
        if num_factors <= 1:
            return torch.tensor(1.0, device=device)

        # Check for NaNs or Infs
        if torch.isnan(z_reshaped).any() or torch.isinf(z_reshaped).any():
            print("Modularity Debug: Input contains NaN or Inf values!")
            z_reshaped = torch.nan_to_num(z_reshaped, nan=0.5, posinf=1.0, neginf=0.0)

        modularity = 0.0
        count = 0

        for i in range(num_factors):
            for j in range(i + 1, num_factors):
                z_i = z_reshaped[:, i, :].flatten()
                z_j = z_reshaped[:, j, :].flatten()

                # Robust normalization
                z_i_mean, z_i_std = z_i.mean(), z_i.std() + 1e-8
                z_j_mean, z_j_std = z_j.mean(), z_j.std() + 1e-8

                z_i = (z_i - z_i_mean) / z_i_std
                z_j = (z_j - z_j_mean) / z_j_std

                # Clip to prevent extreme values
                z_i = torch.clamp(z_i, -10.0, 10.0)
                z_j = torch.clamp(z_j, -10.0, 10.0)

                # Compute correlation
                corr = torch.abs(torch.mean(z_i * z_j))

                if not torch.isnan(corr) and not torch.isinf(corr):
                    corr_val = corr.item()
                    # Correlation indicates dependence, so we take (1 - correlation) as modularity
                    modularity += 1.0 - min(1.0, max(0.0, corr_val))
                    count += 1

        if count == 0:
            print("Modularity Debug: No valid correlations computed")
            return torch.tensor(0.5, device=device)

        result = modularity / count

        # Final sanity check
        if result < 0 or result > 1 or np.isnan(result) or np.isinf(result):
            print(f"Modularity Debug: Invalid final result: {result}")
            return torch.tensor(0.5, device=device)

        return torch.tensor(result, device=device)

    except Exception as e:
        print(f"Modularity error: {e}")
        return torch.tensor(0.5, device=device)

def robust_factor_vae_score(z, num_factors, latent_dim, n_samples=2000):
    """Improved Factor VAE score using ElasticNet."""
    try:
        # Ensure z is on CPU before NumPy conversion
        if isinstance(z, torch.Tensor):
            if z.device.type != 'cpu':
                z = z.detach().cpu()

        # Special case for single factor
        if num_factors <= 1:
            return 0.5  # Default value for single factor

        z_reshaped = z.view(-1, num_factors, latent_dim).detach().numpy()

        if z_reshaped.shape[0] < num_factors or z_reshaped.shape[2] != latent_dim:
            print("FactorVAE: Invalid shape")
            return 0.5

        if z_reshaped.shape[0] > n_samples:
            indices = np.random.choice(z_reshaped.shape[0], n_samples, replace=False)
            z_reshaped = z_reshaped[indices]

        # Standardize the data
        scaler = StandardScaler()
        z_reshaped = scaler.fit_transform(z_reshaped.reshape(z_reshaped.shape[0], -1)).reshape(z_reshaped.shape)

        scores = []
        for j in range(num_factors):
            for k in range(latent_dim):
                target = z_reshaped[:, j, k]
                # Get all data from other factors
                other = z_reshaped[:, [i for i in range(num_factors) if i != j], :].reshape(z_reshaped.shape[0], -1)

                if other.size == 0:
                    continue

                # Split data for training and testing
                X_train, X_test, y_train, y_test = train_test_split(other, target, test_size=0.2, random_state=42)

                # Train ElasticNet model
                model = ElasticNet(alpha=0.1, l1_ratio=0.5, max_iter=1000)
                model.fit(X_train, y_train)

                # Score is predictability (R²)
                score = model.score(X_test, y_test)

                # Higher independence (lower predictability) is better
                scores.append(max(0.0, min(1.0, 1.0 - score)))

        return np.mean(scores) if scores else 0.5
    except Exception as e:
        print(f"FactorVAE error: {e}")
        return 0.5

def robust_sap_score(z, x_data, num_factors, latent_dim):
    """Improved SAP score using mutual information."""
    try:
        # Ensure inputs are on CPU
        if isinstance(z, torch.Tensor):
            if z.device.type != 'cpu':
                z = z.detach().cpu()
        if isinstance(x_data, torch.Tensor):
            if x_data.device.type != 'cpu':
                x_data = x_data.detach().cpu()

        # Special case for single factor
        if num_factors <= 1:
            return 0.5  # Default value for single factor

        z_reshaped = z.view(-1, num_factors, latent_dim).numpy()
        x_np = x_data.numpy().reshape(x_data.shape[0], -1)

        # Use a subset of features as proxies for true factors
        n_proxy = min(50, x_np.shape[1])
        proxy_indices = np.random.choice(x_np.shape[1], n_proxy, replace=False)

        sap_scores = []
        for j in range(num_factors):
            for k in range(latent_dim):
                latent = z_reshaped[:, j, k]

                # Calculate mutual information with each proxy
                mi_scores = []
                for p in proxy_indices:
                    if x_np[:, p].var() > 1e-6:  # Skip if variance is too low
                        mi = mutual_info_regression(latent.reshape(-1, 1), x_np[:, p])[0]
                        mi_scores.append(mi)

                if len(mi_scores) > 1:
                    # Sort MI scores
                    sorted_mi = sorted(mi_scores, reverse=True)
                    # Gap between highest and second highest MI score indicates disentanglement
                    gap = (sorted_mi[0] - sorted_mi[1]) / (sorted_mi[0] + 1e-8)
                    sap_scores.append(max(0.0, min(1.0, gap)))

        return np.mean(sap_scores) if sap_scores else 0.5
    except Exception as e:
        print(f"SAP score error: {e}")
        return 0.5


# ===== MAIN EXPERIMENT FUNCTIONS =====

def run_comparison_experiment(dataset_names=None, output_dir='results'):
    """
    Run the Universal K Matrix comparison experiment.

    Args:
        dataset_names: List of datasets to test
        output_dir: Directory to save results

    Returns:
        Dictionary of results
    """
    print("Starting Universal K Matrix Comparison Experiment...")

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Set random seeds for reproducibility
    np.random.seed(42)
    torch.manual_seed(42)

    # Set up device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Get available datasets
    available_datasets = get_available_datasets()
    if not dataset_names:
        dataset_names = [name for name, info in available_datasets.items() if info['available']]

    # Define methods to test
    k_methods = [
        ('Clustered', create_clustered_k_matrix),
        ('PCA', create_pca_k_matrix),
        ('FactorAnalysis', create_factor_analysis_k_matrix)
    ]

    # Define baseline methods
    baseline_methods = [
        'VIB',
        'BetaVAE',
        'SparseAutoencoder',
        'DropoutRegularizedModel'
    ]

    # Define hyperparameters
    hyperparams = {
        'VIB': {'beta': [0.1, 1.0, 10.0]},
        'BetaVAE': {'beta': [0.1, 1.0, 10.0]},
        'SparseAutoencoder': {'sparsity_weight': [0.01, 0.1, 1.0]},
        'DropoutRegularizedModel': {'dropout_rate': [0.2, 0.5], 'weight_decay': [1e-4, 1e-5]},
        'KnowledgeDistillation': {'temperature': [2, 4], 'alpha': [0.3, 0.5, 0.7]}
    }

    # Factors and dimensions to try
    factors_to_try = [3, 5]
    dims_to_try = [8, 16]

    # Store all results
    results = {}

    # Process each dataset
    for dataset_name in dataset_names:
        print(f"\nProcessing dataset: {dataset_name}")

        # Load or create dataset
        x_data, y_data, is_classification = load_or_create_dataset(dataset_name, available_datasets)
        print(f"Dataset shape: {x_data.shape}, Classification: {is_classification}")

        # Create data loaders
        # Split into train/val/test
        train_size = int(0.7 * len(x_data))
        val_size = int(0.15 * len(x_data))
        test_size = len(x_data) - train_size - val_size

        indices = torch.randperm(len(x_data))
        train_indices = indices[:train_size]
        val_indices = indices[train_size:train_size+val_size]
        test_indices = indices[train_size+val_size:]

        # Create datasets
        train_x, train_y = x_data[train_indices], y_data[train_indices]
        val_x, val_y = x_data[val_indices], y_data[val_indices]
        test_x, test_y = x_data[test_indices], y_data[test_indices]

        train_dataset = TensorDataset(train_x, train_y)
        val_dataset = TensorDataset(val_x, val_y)
        test_dataset = TensorDataset(test_x, test_y)

        batch_size = 128
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)
        test_loader = DataLoader(test_dataset, batch_size=batch_size)

        # Initialize results for this dataset
        dataset_results = {
            'k_methods': {},
            'baseline_methods': {}
        }

        # 1. Universal K Matrix Methods
        for method_name, method_func in k_methods:
            print(f"\nEvaluating Universal K method: {method_name}")
            method_results = []

            for num_factors in factors_to_try:
                for latent_dim in dims_to_try:
                    config_info = f"[{method_name}, Factors={num_factors}, Dims={latent_dim}]"
                    print(f"Processing {config_info}")

                    try:
                        # Initialize K matrix
                        start_time = time.time()
                        k_matrix = method_func(train_x.to(device), num_factors, latent_dim, device)
                        init_time = time.time() - start_time

                        # Refine K matrix
                        start_time = time.time()
                        k_refined = refine_k_matrix(train_x.to(device), k_matrix, num_factors, latent_dim, device, epochs=100)
                        refine_time = time.time() - start_time

                        # Evaluate K matrix
                        metrics = evaluate_k_matrix(val_x.to(device), k_refined, num_factors, latent_dim, device)

                        # Calculate combined score
                        combined_score = (
                            (1.0 - metrics['mi_ksg']) * 0.2 +
                            metrics['modularity'] * 0.2 +
                            (1.0 - metrics['total_correlation']) * 0.2 +
                            metrics['factor_vae_score'] * 0.2 +
                            metrics['sap_score'] * 0.2
                        )

                        # Encode data with K matrix
                        z_train = encode_data(train_x.to(device), k_refined)
                        z_val = encode_data(val_x.to(device), k_refined)
                        z_test = encode_data(test_x.to(device), k_refined)

                        # Create datasets with latent codes
                        latent_train_dataset = TensorDataset(z_train.reshape(z_train.shape[0], -1), train_y.to(device))
                        latent_val_dataset = TensorDataset(z_val.reshape(z_val.shape[0], -1), val_y.to(device))
                        latent_test_dataset = TensorDataset(z_test.reshape(z_test.shape[0], -1), test_y.to(device))

                        latent_train_loader = DataLoader(latent_train_dataset, batch_size=batch_size, shuffle=True)
                        latent_val_loader = DataLoader(latent_val_dataset, batch_size=batch_size)
                        latent_test_loader = DataLoader(latent_test_dataset, batch_size=batch_size)

                        # Train task model on latent codes
                        input_dim = z_train.reshape(z_train.shape[0], -1).shape[1]

                        if is_classification:
                            # For multi-class classification
                            num_classes = len(torch.unique(train_y))
                            model = nn.Sequential(
                                nn.Linear(input_dim, 128),
                                nn.ReLU(),
                                nn.Linear(128, 64),
                                nn.ReLU(),
                                nn.Linear(64, num_classes)
                            ).to(device)
                        else:
                            # For regression
                            model = nn.Sequential(
                                nn.Linear(input_dim, 128),
                                nn.ReLU(),
                                nn.Linear(128, 64),
                                nn.ReLU(),
                                nn.Linear(64, 1)
                            ).to(device)

                        # Train model
                        model = train_model(model, latent_train_loader, latent_val_loader, device, is_classification)

                        # Evaluate model
                        task_metrics = evaluate_model_metrics(model, latent_test_loader, device, is_classification)

                        # Store the performance metric
                        if is_classification:
                            performance = task_metrics['accuracy']
                        else:
                            performance = task_metrics['mse']

                        # Add results
                        result = {
                            'method': method_name,
                            'num_factors': num_factors,
                            'latent_dim': latent_dim,
                            'metrics': metrics,
                            'combined_score': combined_score,
                            'init_time': init_time,
                            'refine_time': refine_time,
                            'teacher_performance': performance,
                            'task_metrics': task_metrics,
                            'k_matrix': k_refined.detach().cpu()
                        }

                        method_results.append(result)

                        # Log results
                        print(f"{config_info} - Combined score: {combined_score:.4f}")
                        if is_classification:
                            print(f"{config_info} - Accuracy: {performance:.4f}")
                        else:
                            print(f"{config_info} - MSE: {performance:.4f}")

                    except Exception as e:
                        print(f"Error processing {config_info}: {e}")

            # Store results for this method
            dataset_results['k_methods'][method_name] = method_results

        # 2. Baseline Methods
        for method_name in baseline_methods:
            print(f"\nEvaluating baseline method: {method_name}")
            method_results = []

            # Get hyperparameter combinations
            method_hyperparams = hyperparams.get(method_name, {})
            hyperparam_keys = list(method_hyperparams.keys())
            hyperparam_values = list(method_hyperparams.values())

            # Generate all combinations
            hyperparam_combinations = []
            if hyperparam_keys:
                import itertools
                hyperparam_combinations = list(itertools.product(*hyperparam_values))
            else:
                hyperparam_combinations = [()]

            for params in hyperparam_combinations:
                # Create parameter dictionary
                param_dict = {key: value for key, value in zip(hyperparam_keys, params)}

                # For models with latent space, test with different latent dimensions
                if method_name in ['VIB', 'BetaVAE', 'SparseAutoencoder']:
                    for num_factors in factors_to_try:
                        for latent_dim in dims_to_try:
                            total_latent_dim = num_factors * latent_dim

                            config_info = f"[{method_name}, Factors={num_factors}, Dims={latent_dim}, Params={param_dict}]"
                            print(f"Processing {config_info}")

                            try:
                                # Initialize model
                                input_dim = x_data.shape[1]
                                hidden_dim = 256

                                if method_name == 'VIB':
                                    if is_classification:
                                        num_classes = len(torch.unique(train_y))
                                        model = VIB(input_dim, hidden_dim, total_latent_dim, beta=param_dict.get('beta', 1.0))
                                        # Adjust predictor for classification
                                        model.predictor = nn.Sequential(
                                            nn.Linear(total_latent_dim, hidden_dim // 2),
                                            nn.LeakyReLU(0.2),
                                            nn.Linear(hidden_dim // 2, num_classes)
                                        )
                                    else:
                                        model = VIB(input_dim, hidden_dim, total_latent_dim, beta=param_dict.get('beta', 1.0))

                                elif method_name == 'BetaVAE':
                                    if is_classification:
                                        num_classes = len(torch.unique(train_y))
                                        model = BetaVAE(input_dim, hidden_dim, total_latent_dim, beta=param_dict.get('beta', 1.0))
                                        # Adjust predictor for classification
                                        model.predictor = nn.Sequential(
                                            nn.Linear(total_latent_dim, hidden_dim // 2),
                                            nn.LeakyReLU(0.2),
                                            nn.Linear(hidden_dim // 2, num_classes)
                                        )
                                    else:
                                        model = BetaVAE(input_dim, hidden_dim, total_latent_dim, beta=param_dict.get('beta', 1.0))

                                elif method_name == 'SparseAutoencoder':
                                    if is_classification:
                                        num_classes = len(torch.unique(train_y))
                                        model = SparseAutoencoder(input_dim, hidden_dim, total_latent_dim, sparsity_weight=param_dict.get('sparsity_weight', 0.01))
                                        # Adjust predictor for classification
                                        model.predictor = nn.Sequential(
                                            nn.Linear(total_latent_dim, hidden_dim // 2),
                                            nn.LeakyReLU(0.2),
                                            nn.Linear(hidden_dim // 2, num_classes)
                                        )
                                    else:
                                        model = SparseAutoencoder(input_dim, hidden_dim, total_latent_dim, sparsity_weight=param_dict.get('sparsity_weight', 0.01))

                                # Move model to device
                                model = model.to(device)

                                # Train model
                                model = train_model(model, train_loader, val_loader, device, is_classification)

                                # Evaluate model
                                metrics = evaluate_baseline_method(method_name, model, test_loader, device, is_classification, None, num_factors, latent_dim)

                                # Add performance metric
                                if is_classification:
                                    metrics['performance'] = metrics['task_metrics']['accuracy']
                                else:
                                    metrics['performance'] = metrics['task_metrics']['mse']

                                # Add results
                                result = {
                                    'method': method_name,
                                    'params': param_dict,
                                    'num_factors': num_factors,
                                    'latent_dim': latent_dim,
                                    'metrics': metrics,
                                    'performance': metrics['performance']
                                }

                                method_results.append(result)

                                # Log results
                                if is_classification:
                                    print(f"{config_info} - Accuracy: {metrics['performance']:.4f}")
                                else:
                                    print(f"{config_info} - MSE: {metrics['performance']:.4f}")

                            except Exception as e:
                                print(f"Error processing {config_info}: {e}")
                else:
                    # For models without latent space
                    config_info = f"[{method_name}, Params={param_dict}]"
                    print(f"Processing {config_info}")

                    try:
                        # Initialize model
                        input_dim = x_data.shape[1]
                        hidden_dim = 256

                        if method_name == 'DropoutRegularizedModel':
                            if is_classification:
                                num_classes = len(torch.unique(train_y))
                                model = DropoutRegularizedModel(input_dim, hidden_dim, num_classes, dropout_rate=param_dict.get('dropout_rate', 0.5))
                            else:
                                model = DropoutRegularizedModel(input_dim, hidden_dim, 1, dropout_rate=param_dict.get('dropout_rate', 0.5))

                        # Move model to device
                        model = model.to(device)

                        # Create optimizer with weight decay
                        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=param_dict.get('weight_decay', 0.0))

                        # Train model
                        criterion = nn.CrossEntropyLoss() if is_classification else nn.MSELoss()

                        # Train for specified epochs with early stopping
                        best_val_loss = float('inf')
                        early_stop_counter = 0

                        for epoch in range(50):  # 50 epochs max
                            # Training
                            model.train()
                            for batch_x, batch_y in train_loader:
                                batch_x, batch_y = batch_x.to(device), batch_y.to(device)

                                optimizer.zero_grad()
                                outputs = model(batch_x)
                                loss = criterion(outputs, batch_y)
                                loss.backward()
                                optimizer.step()

                            # Validation
                            model.eval()
                            val_loss = 0.0
                            with torch.no_grad():
                                for batch_x, batch_y in val_loader:
                                    batch_x, batch_y = batch_x.to(device), batch_y.to(device)

                                    outputs = model(batch_x)
                                    loss = criterion(outputs, batch_y)
                                    val_loss += loss.item()

                            val_loss /= len(val_loader)

                            # Early stopping
                            if val_loss < best_val_loss:
                                best_val_loss = val_loss
                                early_stop_counter = 0
                            else:
                                early_stop_counter += 1
                                if early_stop_counter >= 5:  # Stop after 5 epochs without improvement
                                    break

                        # Evaluate model
                        task_metrics = evaluate_model_metrics(model, test_loader, device, is_classification)

                        # Add results
                        result = {
                            'method': method_name,
                            'params': param_dict,
                            'metrics': {'task_metrics': task_metrics},
                            'performance': task_metrics['accuracy'] if is_classification else task_metrics['mse']
                        }

                        method_results.append(result)

                        # Log results
                        if is_classification:
                            print(f"{config_info} - Accuracy: {task_metrics['accuracy']:.4f}")
                        else:
                            print(f"{config_info} - MSE: {task_metrics['mse']:.4f}")

                    except Exception as e:
                        print(f"Error processing {config_info}: {e}")

            # Store results for this method
            dataset_results['baseline_methods'][method_name] = method_results

        # 3. Knowledge Distillation (if applicable)
        if dataset_results['k_methods']:
            print("\nEvaluating Knowledge Distillation")

            # Find best K matrix method
            best_k_method = None
            best_k_score = -float('inf')
            best_k_config = None

            for method_name, method_results in dataset_results['k_methods'].items():
                for result in method_results:
                    if result['combined_score'] > best_k_score:
                        best_k_score = result['combined_score']
                        best_k_method = method_name
                        best_k_config = result

            if best_k_config:
                print(f"Using best K matrix method for distillation: {best_k_method}")

                # Extract best K matrix
                k_matrix = best_k_config['k_matrix'].to(device)
                num_factors = best_k_config['num_factors']
                latent_dim = best_k_config['latent_dim']

                # Encode data with K matrix
                z_train = encode_data(train_x.to(device), k_matrix)
                z_val = encode_data(val_x.to(device), k_matrix)
                z_test = encode_data(test_x.to(device), k_matrix)

                # Create latent datasets
                latent_train_dataset = TensorDataset(z_train.reshape(z_train.shape[0], -1), train_y.to(device))
                latent_val_dataset = TensorDataset(z_val.reshape(z_val.shape[0], -1), val_y.to(device))
                latent_test_dataset = TensorDataset(z_test.reshape(z_test.shape[0], -1), test_y.to(device))

                latent_train_loader = DataLoader(latent_train_dataset, batch_size=batch_size, shuffle=True)
                latent_val_loader = DataLoader(latent_val_dataset, batch_size=batch_size)
                latent_test_loader = DataLoader(latent_test_dataset, batch_size=batch_size)

                # Create teacher model
                input_dim = z_train.reshape(z_train.shape[0], -1).shape[1]

                if is_classification:
                    num_classes = len(torch.unique(train_y))
                    teacher_model = nn.Sequential(
                        nn.Linear(input_dim, 128),
                        nn.ReLU(),
                        nn.Linear(128, 64),
                        nn.ReLU(),
                        nn.Linear(64, num_classes)
                    ).to(device)
                else:
                    teacher_model = nn.Sequential(
                        nn.Linear(input_dim, 128),
                        nn.ReLU(),
                        nn.Linear(128, 64),
                        nn.ReLU(),
                        nn.Linear(64, 1)
                    ).to(device)

                # Train teacher model
                teacher_model = train_model(teacher_model, latent_train_loader, latent_val_loader, device, is_classification)

                # Evaluate teacher model
                teacher_metrics = evaluate_model_metrics(teacher_model, latent_test_loader, device, is_classification)

                # Create student model (smaller than teacher)
                if is_classification:
                    student_model = nn.Sequential(
                        nn.Linear(input_dim, 64),
                        nn.ReLU(),
                        nn.Linear(64, num_classes)
                    ).to(device)
                else:
                    student_model = nn.Sequential(
                        nn.Linear(input_dim, 64),
                        nn.ReLU(),
                        nn.Linear(64, 1)
                    ).to(device)

                # Run knowledge distillation experiments
                kd_results = []

                for temp in hyperparams['KnowledgeDistillation']['temperature']:
                    for alpha in hyperparams['KnowledgeDistillation']['alpha']:
                        config_info = f"[KD, Temp={temp}, Alpha={alpha}]"
                        print(f"Processing {config_info}")

                        try:
                            # Optimizer for student
                            optimizer = optim.Adam(student_model.parameters(), lr=0.001)

                            # Train with knowledge distillation
                            best_val_loss = float('inf')
                            early_stop_counter = 0

                            for epoch in range(50):
                                # Training
                                student_model.train()
                                teacher_model.eval()

                                for batch_z, batch_y in latent_train_loader:
                                    optimizer.zero_grad()

                                    # Forward passes
                                    with torch.no_grad():
                                        teacher_outputs = teacher_model(batch_z)

                                    student_outputs = student_model(batch_z)

                                    # Knowledge distillation loss
                                    loss = knowledge_distillation_loss(
                                        student_outputs, teacher_outputs, batch_y,
                                        temperature=temp, alpha=alpha
                                    )

                                    loss.backward()
                                    optimizer.step()

                                # Validation
                                student_model.eval()
                                val_loss = 0.0

                                with torch.no_grad():
                                    for batch_z, batch_y in latent_val_loader:
                                        student_outputs = student_model(batch_z)

                                        if is_classification:
                                            loss = F.cross_entropy(student_outputs, batch_y)
                                        else:
                                            loss = F.mse_loss(student_outputs, batch_y)

                                        val_loss += loss.item()

                                val_loss /= len(latent_val_loader)

                                # Early stopping
                                if val_loss < best_val_loss:
                                    best_val_loss = val_loss
                                    early_stop_counter = 0
                                else:
                                    early_stop_counter += 1
                                    if early_stop_counter >= 5:  # Stop after 5 epochs without improvement
                                        break

                            # Evaluate student model
                            student_metrics = evaluate_model_metrics(student_model, latent_test_loader, device, is_classification)

                            # Add results
                            result = {
                                'method': 'KnowledgeDistillation',
                                'temperature': temp,
                                'alpha': alpha,
                                'metrics': student_metrics,
                                'performance': student_metrics['accuracy'] if is_classification else student_metrics['mse'],
                                'teacher_method': best_k_method,
                                'num_factors': num_factors,
                                'latent_dim': latent_dim
                            }

                            kd_results.append(result)

                            # Log results
                            if is_classification:
                                print(f"{config_info} - Accuracy: {student_metrics['accuracy']:.4f}")
                            else:
                                print(f"{config_info} - MSE: {student_metrics['mse']:.4f}")

                        except Exception as e:
                            print(f"Error processing {config_info}: {e}")

                dataset_results['KnowledgeDistillation'] = kd_results

        # Store results for this dataset
        results[dataset_name] = dataset_results

        # Save intermediate results to CSV
        save_metrics_to_csv({dataset_name: dataset_results}, output_dir)

    # Save final results to CSV
    save_metrics_to_csv(results, output_dir)

    return results

def compare_universal_k_with_baselines(results):
    """
    Compare performance of Universal K approach with baseline methods.

    Args:
        results: Dictionary of results

    Returns:
        DataFrame with comparisons
    """
    comparison_rows = []

    for dataset_name, dataset_results in results.items():
        # Determine if classification or regression task
        is_classification = False
        task_metric = 'mse'  # Default to regression metric

        # Check in k_methods first
        if 'k_methods' in dataset_results and dataset_results['k_methods']:
            method_name = next(iter(dataset_results['k_methods']))
            method_results = dataset_results['k_methods'][method_name]
            if method_results and 'teacher_performance' in method_results[0]:
                # If performance is between 0 and 1, likely accuracy
                if 0 <= method_results[0]['teacher_performance'] <= 1:
                    is_classification = True
                    task_metric = 'accuracy'

        # Get best Universal K result
        best_k_performance = float('-inf') if is_classification else float('inf')
        best_k_method = None
        best_k_config = None

        if 'k_methods' in dataset_results:
            for method_name, method_results in dataset_results['k_methods'].items():
                for result in method_results:
                    perf = result.get('teacher_performance', None)
                    if perf is not None:
                        better = (is_classification and perf > best_k_performance) or (not is_classification and perf < best_k_performance)
                        if better:
                            best_k_performance = perf
                            best_k_method = method_name
                            best_k_config = result

        # Get best baseline result
        best_baseline_performance = float('-inf') if is_classification else float('inf')
        best_baseline_method = None
        best_baseline_config = None

        if 'baseline_methods' in dataset_results:
            for method_name, method_results in dataset_results['baseline_methods'].items():
                for result in method_results:
                    perf = result.get('performance', None)
                    if perf is not None:
                        better = (is_classification and perf > best_baseline_performance) or (not is_classification and perf < best_baseline_performance)
                        if better:
                            best_baseline_performance = perf
                            best_baseline_method = method_name
                            best_baseline_config = result

        # Get best distillation result
        best_kd_performance = float('-inf') if is_classification else float('inf')
        best_kd_config = None

        if 'KnowledgeDistillation' in dataset_results:
            for result in dataset_results['KnowledgeDistillation']:
                perf = result.get('performance', None)
                if perf is not None:
                    better = (is_classification and perf > best_kd_performance) or (not is_classification and perf < best_kd_performance)
                    if better:
                        best_kd_performance = perf
                        best_kd_config = result

        # Add row to comparison
        row = {
            'dataset': dataset_name,
            'is_classification': is_classification,
            'metric': task_metric
        }

        # Add Universal K details
        if best_k_method:
            row['best_k_method'] = best_k_method
            row['best_k_num_factors'] = best_k_config['num_factors']
            row['best_k_latent_dim'] = best_k_config['latent_dim']
            row['best_k_performance'] = best_k_performance
            row['best_k_combined_score'] = best_k_config['combined_score']

        # Add baseline details
        if best_baseline_method:
            row['best_baseline_method'] = best_baseline_method

            # Add hyperparameters if available
            if 'params' in best_baseline_config:
                for param_name, param_value in best_baseline_config['params'].items():
                    row[f'best_baseline_{param_name}'] = param_value

            row['best_baseline_performance'] = best_baseline_performance

        # Add distillation details
        if best_kd_config:
            row['best_kd_temperature'] = best_kd_config['temperature']
            row['best_kd_alpha'] = best_kd_config['alpha']
            row['best_kd_performance'] = best_kd_performance

        # Calculate performance differences
        if best_k_method and best_baseline_method:
            if is_classification:
                row['k_vs_baseline_diff'] = best_k_performance - best_baseline_performance
                row['k_better_than_baseline'] = best_k_performance > best_baseline_performance
            else:
                row['k_vs_baseline_diff'] = best_baseline_performance - best_k_performance  # Lower is better for regression
                row['k_better_than_baseline'] = best_k_performance < best_baseline_performance

        if best_k_method and best_kd_config:
            if is_classification:
                row['k_vs_kd_diff'] = best_k_performance - best_kd_performance
                row['k_better_than_kd'] = best_k_performance > best_kd_performance
            else:
                row['k_vs_kd_diff'] = best_kd_performance - best_k_performance  # Lower is better for regression
                row['k_better_than_kd'] = best_k_performance < best_kd_performance

        comparison_rows.append(row)

    # Create DataFrame
    comparison_df = pd.DataFrame(comparison_rows)
    return comparison_df

def create_random_k_matrix(x_data, num_factors, latent_dim, device):
    """Create random orthogonal K matrices."""
    try:
        # Get input feature dimension
        n_features = x_data.shape[1]

        # Initialize on the correct device
        k_matrices = []

        for i in range(num_factors):
            # Create random matrix
            k = torch.randn(n_features, latent_dim, device=device)

            # Make orthogonal to previous factors
            for prev_k in k_matrices:
                k = k - torch.mm(prev_k, torch.mm(prev_k.t(), k))

            # QR decomposition for orthogonalization
            if torch.linalg.matrix_rank(k) > 0:  # Check if matrix is not all zeros
                q, r = torch.linalg.qr(k)
                k = q[:, :latent_dim]
            else:
                # If rank is 0, just use random normalized matrix
                k = torch.randn(n_features, latent_dim, device=device)
                k = k / (torch.norm(k, dim=0, keepdim=True) + 1e-8)

            k_matrices.append(k)

        return torch.stack(k_matrices)
    except Exception as e:
        print(f"Random initialization error: {e}")
        # Ultimate fallback
        return torch.randn(num_factors, x_data.shape[1], latent_dim, device=device)

def sparsity_score(k_matrix):
    """Improved sparsity score with numerical stability."""
    try:
        # Get device from input tensor
        device = k_matrix.device

        # Add small noise for numerical stability
        k_matrix = k_matrix + 1e-6 * torch.randn_like(k_matrix)

        l1_norm = torch.sum(torch.abs(k_matrix))
        l2_norm = torch.sqrt(torch.sum(k_matrix ** 2) + 1e-6)
        n_elements = float(torch.numel(k_matrix))

        # Create tensor on the same device
        n_elements_tensor = torch.tensor(n_elements, dtype=torch.float, device=device)

        sparsity = 1.0 - (l1_norm / (l2_norm * torch.sqrt(n_elements_tensor) + 1e-6))

        # Create min/max tensors on the same device
        min_val = torch.tensor(0.1, device=device)
        max_val = torch.tensor(0.9, device=device)
        sparsity_val = torch.clamp(sparsity, min_val, max_val)

        if torch.isnan(sparsity_val) or torch.isinf(sparsity_val):
            print(f"Sparsity is NaN or Inf")
            return torch.tensor(0.5, device=device)

        return sparsity_val

    except Exception as e:
        print(f"Sparsity error: {e}")
        # Ensure return value is on the same device
        return torch.tensor(0.5, device=device if device is not None else 'cpu')



# ===== MODEL IMPLEMENTATIONS =====

class VIB(nn.Module):
    """Variational Information Bottleneck Autoencoder."""
    def __init__(self, input_dim, hidden_dim, latent_dim, beta=1.0):
        super(VIB, self).__init__()

        self.beta = beta

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2)
        )

        # Latent parameters
        self.mu = nn.Linear(hidden_dim // 2, latent_dim)
        self.log_var = nn.Linear(hidden_dim // 2, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, input_dim)
        )

        # Task predictor
        self.predictor = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 2, 1)  # For regression, will be adjusted for classification
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.mu(h), self.log_var(h)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def predict(self, z):
        return self.predictor(z)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decode(z)
        pred = self.predict(z)
        return x_recon, mu, log_var, z, pred

    def loss_function(self, x, x_recon, mu, log_var, y, pred):
        # Reconstruction loss
        recon_loss = F.mse_loss(x_recon, x)

        # KL divergence
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        # Task loss (MSE for regression, will be adjusted for classification)
        task_loss = F.mse_loss(pred, y)

        # Total loss
        total_loss = task_loss + self.beta * kl_loss + recon_loss

        return total_loss, recon_loss, kl_loss, task_loss


class BetaVAE(nn.Module):
    """Beta-VAE for disentangled representation learning."""
    def __init__(self, input_dim, hidden_dim, latent_dim, beta=1.0):
        super(BetaVAE, self).__init__()

        self.beta = beta

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2)
        )

        # Latent parameters
        self.mu = nn.Linear(hidden_dim // 2, latent_dim)
        self.log_var = nn.Linear(hidden_dim // 2, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, input_dim)
        )

        # Task predictor
        self.predictor = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 2, 1)  # For regression, will be adjusted for classification
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.mu(h), self.log_var(h)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def predict(self, z):
        return self.predictor(z)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decode(z)
        pred = self.predict(z)
        return x_recon, mu, log_var, z, pred

    def loss_function(self, x, x_recon, mu, log_var, y, pred):
        # Reconstruction loss
        recon_loss = F.mse_loss(x_recon, x)

        # KL divergence with beta scaling
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        # Task loss
        task_loss = F.mse_loss(pred, y)

        # Total loss with beta weighting
        total_loss = task_loss + self.beta * kl_loss + recon_loss

        return total_loss, recon_loss, kl_loss, task_loss


class DropoutRegularizedModel(nn.Module):
    """Neural network with dropout and weight decay for regularization."""
    def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate=0.5):
        super(DropoutRegularizedModel, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim // 2, output_dim)
        )

    def forward(self, x):
        return self.model(x)


class SparseAutoencoder(nn.Module):
    """Autoencoder with L1 sparsity regularization."""
    def __init__(self, input_dim, hidden_dim, latent_dim, sparsity_weight=0.01):
        super(SparseAutoencoder, self).__init__()

        self.sparsity_weight = sparsity_weight

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, latent_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, input_dim)
        )

        # Task predictor
        self.predictor = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 2, 1)  # For regression, will be adjusted for classification
        )

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        pred = self.predictor(z)
        return x_recon, z, pred

    def loss_function(self, x, x_recon, z, y, pred):
        # Reconstruction loss
        recon_loss = F.mse_loss(x_recon, x)

        # L1 sparsity regularization
        sparsity_loss = torch.mean(torch.abs(z))

        # Task loss
        task_loss = F.mse_loss(pred, y)

        # Total loss
        total_loss = task_loss + recon_loss + self.sparsity_weight * sparsity_loss

        return total_loss, recon_loss, sparsity_loss, task_loss


# ===== DISTILLATION METHODS =====

def knowledge_distillation_loss(student_logits, teacher_logits, y_true, temperature=4.0, alpha=0.5):
    """
    Compute the knowledge distillation loss.

    Args:
        student_logits: Logits from the student model
        teacher_logits: Logits from the teacher model
        y_true: Ground truth labels
        temperature: Softmax temperature for distillation
        alpha: Weight for distillation loss vs. task loss

    Returns:
        The combined loss
    """
    # For classification tasks
    if len(y_true.shape) == 1:  # Labels are indices
        # Task loss
        task_loss = F.cross_entropy(student_logits, y_true)

        # Distillation loss
        # Scale logits by temperature and compute soft targets
        soft_targets = F.softmax(teacher_logits / temperature, dim=1)
        log_probs = F.log_softmax(student_logits / temperature, dim=1)
        distillation_loss = -(soft_targets * log_probs).sum(dim=1).mean() * (temperature ** 2)

        # Combined loss
        loss = alpha * distillation_loss + (1 - alpha) * task_loss

    # For regression tasks
    else:
        # Task loss
        task_loss = F.mse_loss(student_logits, y_true)

        # Distillation loss - MSE between student and teacher outputs
        distillation_loss = F.mse_loss(student_logits, teacher_logits)

        # Combined loss
        loss = alpha * distillation_loss + (1 - alpha) * task_loss

    return loss

def attention_transfer_loss(student_features, teacher_features, y_true, model_output, beta=0.5):
    """
    Compute the attention transfer loss.

    Args:
        student_features: List of feature maps from student model
        teacher_features: List of feature maps from teacher model
        y_true: Ground truth labels
        model_output: Output from the student model
        beta: Weight for attention loss vs. task loss

    Returns:
        The combined loss
    """
    # Task loss (classification or regression)
    if len(y_true.shape) == 1:  # Classification with class indices
        task_loss = F.cross_entropy(model_output, y_true)
    else:  # Regression
        task_loss = F.mse_loss(model_output, y_true)

    # Attention transfer loss
    attention_loss = 0.0
    for student_feat, teacher_feat in zip(student_features, teacher_features):
        # Compute normalized attention maps
        student_attention = F.normalize(student_feat.pow(2).mean(1).view(student_feat.size(0), -1), p=2, dim=1)
        teacher_attention = F.normalize(teacher_feat.pow(2).mean(1).view(teacher_feat.size(0), -1), p=2, dim=1)

        # L2 distance between attention maps
        attention_loss += F.mse_loss(student_attention, teacher_attention)

    # Combined loss
    loss = task_loss + beta * attention_loss

    return loss


# ===== EVALUATION FUNCTIONS =====

def evaluate_model_metrics(model, dataloader, device, is_classification=True):
    """
    Evaluate a model and return comprehensive metrics.

    Args:
        model: The trained model to evaluate
        dataloader: DataLoader for evaluation data
        device: Device to run evaluation on
        is_classification: Whether this is a classification task

    Returns:
        Dictionary of metrics
    """
    model.eval()
    all_targets = []
    all_predictions = []
    all_probs = []  # For AUC calculation in classification

    with torch.no_grad():
        for batch in dataloader:
            x, y = batch
            x, y = x.to(device), y.to(device)

            # Forward pass depends on model type
            if isinstance(model, (VIB, BetaVAE)):
                _, _, _, _, pred = model(x)
            elif isinstance(model, SparseAutoencoder):
                _, _, pred = model(x)
            else:  # Standard models like DropoutRegularizedModel
                pred = model(x)

            # Store predictions and targets
            if is_classification:
                probs = F.softmax(pred, dim=1) if pred.size(1) > 1 else torch.sigmoid(pred)
                _, predicted = torch.max(pred, 1) if pred.size(1) > 1 else (pred > 0.5).long()
                all_probs.append(probs.cpu())
                all_predictions.append(predicted.cpu())
                all_targets.append(y.cpu())
            else:
                all_predictions.append(pred.cpu())
                all_targets.append(y.cpu())

    # Concatenate results
    if is_classification:
        all_targets = torch.cat(all_targets).numpy()
        all_predictions = torch.cat(all_predictions).numpy()
        all_probs = torch.cat(all_probs).numpy()

        # Calculate classification metrics
        accuracy = (all_predictions == all_targets).mean()

        # For multi-class, calculate macro averages
        precision = precision_score(all_targets, all_predictions, average='macro', zero_division=0)
        recall = recall_score(all_targets, all_predictions, average='macro', zero_division=0)
        f1 = f1_score(all_targets, all_predictions, average='macro', zero_division=0)

        # AUC calculation (handle multi-class)
        if all_probs.shape[1] > 2:  # Multi-class
            # One-hot encode targets for multi-class AUC
            from sklearn.preprocessing import label_binarize
            classes = list(range(all_probs.shape[1]))
            all_targets_binary = label_binarize(all_targets, classes=classes)
            auc = roc_auc_score(all_targets_binary, all_probs, multi_class='ovr')
        else:  # Binary classification
            auc = roc_auc_score(all_targets, all_probs[:, 1] if all_probs.shape[1] > 1 else all_probs)

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'auc': auc
        }
    else:
        all_targets = torch.cat(all_targets).numpy()
        all_predictions = torch.cat(all_predictions).numpy()

        # Calculate regression metrics
        mse = mean_squared_error(all_targets, all_predictions)
        mae = mean_absolute_error(all_targets, all_predictions)
        r2 = r2_score(all_targets, all_predictions)

        return {
            'mse': mse,
            'mae': mae,
            'rmse': np.sqrt(mse),
            'r2_score': r2
        }


def evaluate_baseline_method(method_name, model, test_loader, device, is_classification=True, k_matrix=None, num_factors=None, latent_dim=None):
    """
    Evaluate a baseline method and compute comprehensive metrics.

    Args:
        method_name: Name of the method being evaluated
        model: The trained model to evaluate
        test_loader: DataLoader for test data
        device: Device to run evaluation on
        is_classification: Whether this is a classification task
        k_matrix: Optional K matrix for universal K methods
        num_factors: Number of factors for universal K methods
        latent_dim: Latent dimension per factor

    Returns:
        Dictionary of metrics
    """
    task_metrics = evaluate_model_metrics(model, test_loader, device, is_classification)

    # Add method-specific metrics
    metrics = {
        'method': method_name,
        'task_metrics': task_metrics
    }

    # For autoencoder-based methods, add reconstruction metrics
    if isinstance(model, (VIB, BetaVAE, SparseAutoencoder)):
        model.eval()
        recon_loss = 0.0

        with torch.no_grad():
            for batch in test_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)

                if isinstance(model, (VIB, BetaVAE)):
                    x_recon, mu, log_var, z, _ = model(x)
                    recon_error = F.mse_loss(x_recon, x).item()

                    # Add KL divergence for VIB/VAE models
                    kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()).item()
                    metrics['kl_divergence'] = kl_div / len(test_loader)

                    # Add disentanglement metrics if applicable
                    if num_factors is not None and latent_dim is not None:
                        try:
                            z_reshaped = z.view(-1, num_factors, latent_dim)
                            metrics['disentanglement'] = {
                                'mi_ksg': safe_mi_ksg_estimator(z_reshaped),
                                'total_correlation': robust_total_correlation(z_reshaped),
                                'modularity': robust_modularity_score(z_reshaped),
                                'factor_vae_score': robust_factor_vae_score(z_reshaped),
                                'sap_score': robust_sap_score(z_reshaped, x)
                            }
                        except Exception as e:
                            print(f"Error calculating disentanglement metrics: {e}")

                elif isinstance(model, SparseAutoencoder):
                    x_recon, z, _ = model(x)
                    recon_error = F.mse_loss(x_recon, x).item()

                    # Add sparsity measure
                    sparsity = torch.mean(torch.abs(z)).item()
                    metrics['sparsity'] = sparsity

                recon_loss += recon_error

        metrics['recon_error'] = recon_loss / len(test_loader)

    # For K-matrix methods, use the evaluate_k_matrix function if provided
    if k_matrix is not None and 'Universal_K' in method_name:
        try:
            k_metrics = evaluate_k_matrix(next(iter(test_loader))[0].to(device), k_matrix, num_factors, latent_dim, device)
            metrics['k_metrics'] = k_metrics
        except Exception as e:
            print(f"Error evaluating K matrix metrics: {e}")

    return metrics


def train_model(model, train_loader, val_loader, device, is_classification=True, epochs=50, patience=5):
    """
    Train a model with early stopping.

    Args:
        model: The model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        device: Device to run training on
        is_classification: Whether this is a classification task
        epochs: Maximum number of epochs to train
        patience: Number of epochs to wait for improvement before stopping

    Returns:
        Trained model
    """
    # Set up optimizer
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Set up criterion for task loss
    if is_classification:
        task_criterion = nn.CrossEntropyLoss()
    else:
        task_criterion = nn.MSELoss()

    # Train for specified epochs with early stopping
    best_val_loss = float('inf')
    early_stop_counter = 0

    for epoch in range(epochs):
        # Training
        model.train()
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            optimizer.zero_grad()

            # Forward pass depends on model type
            if isinstance(model, VIB):
                x_recon, mu, log_var, z, pred = model(batch_x)
                loss, _, _, _ = model.loss_function(batch_x, x_recon, mu, log_var, batch_y, pred)
            elif isinstance(model, BetaVAE):
                x_recon, mu, log_var, z, pred = model(batch_x)
                loss, _, _, _ = model.loss_function(batch_x, x_recon, mu, log_var, batch_y, pred)
            elif isinstance(model, SparseAutoencoder):
                x_recon, z, pred = model(batch_x)
                loss, _, _, _ = model.loss_function(batch_x, x_recon, z, batch_y, pred)
            else:  # Standard models like DropoutRegularizedModel
                pred = model(batch_x)
                loss = task_criterion(pred, batch_y)

            loss.backward()
            optimizer.step()

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)

                # Forward pass depends on model type
                if isinstance(model, VIB):
                    x_recon, mu, log_var, z, pred = model(batch_x)
                    _, _, _, task_loss = model.loss_function(batch_x, x_recon, mu, log_var, batch_y, pred)
                elif isinstance(model, BetaVAE):
                    x_recon, mu, log_var, z, pred = model(batch_x)
                    _, _, _, task_loss = model.loss_function(batch_x, x_recon, mu, log_var, batch_y, pred)
                elif isinstance(model, SparseAutoencoder):
                    x_recon, z, pred = model(batch_x)
                    _, _, _, task_loss = model.loss_function(batch_x, x_recon, z, batch_y, pred)
                else:  # Standard models
                    pred = model(batch_x)
                    task_loss = task_criterion(pred, batch_y)

                val_loss += task_loss.item()

        val_loss /= len(val_loader)

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stop_counter = 0
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    return model


def save_metrics_to_csv(results, output_dir='.'):
    """
    Save experiment results to CSV files.

    Args:
        results: Dictionary of results by dataset and method
        output_dir: Directory to save CSV files
    """
    os.makedirs(output_dir, exist_ok=True)

    # Create consolidated dataframes
    all_task_metrics = []
    all_disent_metrics = []

    # Process each dataset
    for dataset_name, dataset_results in results.items():
        # 1. Task performance

        # K matrix methods
        if 'k_methods' in dataset_results:
            for method_name, method_results in dataset_results['k_methods'].items():
                for result in method_results:
                    row = {
                        'dataset': dataset_name,
                        'method': f"{method_name}",
                        'method_type': 'Universal_K',
                        'num_factors': result['num_factors'],
                        'latent_dim': result['latent_dim']
                    }

                    # Add task metrics
                    if 'teacher_performance' in result:
                        is_classification = isinstance(result['teacher_performance'], float) and 0 <= result['teacher_performance'] <= 1

                        if is_classification:
                            row['accuracy'] = result['teacher_performance']
                        else:
                            row['mse'] = result['teacher_performance']

                    # Add K matrix metrics
                    if 'metrics' in result:
                        metrics = result['metrics']
                        for metric_name, metric_value in metrics.items():
                            row[metric_name] = metric_value

                    # Add other fields
                    if 'combined_score' in result:
                        row['combined_score'] = result['combined_score']
                    if 'init_time' in result:
                        row['init_time'] = result['init_time']
                    if 'refine_time' in result:
                        row['refine_time'] = result['refine_time']

                    all_task_metrics.append(row)

                    # Also add to disentanglement metrics
                    disent_row = row.copy()
                    if 'metrics' in result:
                        metrics = result['metrics']
                        disentanglement_metrics = ['mi_ksg', 'modularity', 'total_correlation',
                                                'factor_vae_score', 'sap_score', 'sparsity', 'variance_ratio']

                        for metric in disentanglement_metrics:
                            if metric in metrics:
                                disent_row[metric] = metrics[metric]

                        all_disent_metrics.append(disent_row)

        # Baseline methods
        if 'baseline_methods' in dataset_results:
            for method_name, method_results in dataset_results['baseline_methods'].items():
                for result in method_results:
                    row = {
                        'dataset': dataset_name,
                        'method': method_name,
                        'method_type': 'Baseline'
                    }

                    # Add hyperparameters if available
                    if 'params' in result:
                        for param_name, param_value in result['params'].items():
                            row[param_name] = param_value

                    # Add num_factors and latent_dim if available
                    if 'num_factors' in result:
                        row['num_factors'] = result['num_factors']
                    if 'latent_dim' in result:
                        row['latent_dim'] = result['latent_dim']

                    # Add metrics
                    if 'metrics' in result:
                        metrics = result['metrics']
                        for metric_name, metric_value in metrics.items():
                            if not isinstance(metric_value, dict):
                                row[metric_name] = metric_value

                    # Add performance if available
                    if 'performance' in result:
                        is_classification = isinstance(result['performance'], float) and 0 <= result['performance'] <= 1

                        if is_classification:
                            row['accuracy'] = result['performance']
                        else:
                            row['mse'] = result['performance']

                    all_task_metrics.append(row)

        # Knowledge Distillation
        if 'KnowledgeDistillation' in dataset_results:
            for result in dataset_results['KnowledgeDistillation']:
                row = {
                    'dataset': dataset_name,
                    'method': 'KnowledgeDistillation',
                    'method_type': 'Distillation',
                    'temperature': result['temperature'],
                    'alpha': result['alpha'],
                    'teacher_method': result['teacher_method'],
                    'num_factors': result['num_factors'],
                    'latent_dim': result['latent_dim']
                }

                # Add metrics
                if 'metrics' in result:
                    metrics = result['metrics']
                    for metric_name, metric_value in metrics.items():
                        row[metric_name] = metric_value

                # Add performance if available
                if 'performance' in result:
                    is_classification = isinstance(result['performance'], float) and 0 <= result['performance'] <= 1

                    if is_classification:
                        row['accuracy'] = result['performance']
                    else:
                        row['mse'] = result['performance']

                all_task_metrics.append(row)

    # Save consolidated dataframes
    if all_task_metrics:
        df_all_task = pd.DataFrame(all_task_metrics)
        df_all_task.to_csv(f"{output_dir}/all_task_metrics.csv", index=False)
        print(f"Saved all task metrics to {output_dir}/all_task_metrics.csv")

    if all_disent_metrics:
        df_all_disent = pd.DataFrame(all_disent_metrics)
        df_all_disent.to_csv(f"{output_dir}/all_disentanglement_metrics.csv", index=False)
        print(f"Saved all disentanglement metrics to {output_dir}/all_disentanglement_metrics.csv")

    # Also save separate files for each dataset for compatibility
    for dataset_name, dataset_results in results.items():
        dataset_task_metrics = [row for row in all_task_metrics if row['dataset'] == dataset_name]
        dataset_disent_metrics = [row for row in all_disent_metrics if row['dataset'] == dataset_name]

        if dataset_task_metrics:
            df_task = pd.DataFrame(dataset_task_metrics)
            df_task.to_csv(f"{output_dir}/{dataset_name}_task_metrics.csv", index=False)
            print(f"Saved {dataset_name} task metrics to {output_dir}/{dataset_name}_task_metrics.csv")

        if dataset_disent_metrics:
            df_disent = pd.DataFrame(dataset_disent_metrics)
            df_disent.to_csv(f"{output_dir}/{dataset_name}_disentanglement_metrics.csv", index=False)
            print(f"Saved {dataset_name} disentanglement metrics to {output_dir}/{dataset_name}_disentanglement_metrics.csv")


def run_distillation_experiment(best_k_method, best_k_matrix, num_factors, latent_dim,
                               x_train, y_train, x_val, y_val, x_test, y_test,
                               is_classification, device, alphas=[0.3, 0.5, 0.7]):
    """
    Run knowledge distillation experiment using the universal K matrix.

    Args:
        best_k_method: Name of the best K matrix method
        best_k_matrix: The best K matrix to use for distillation
        num_factors, latent_dim: Configuration of the K matrix
        x_train, y_train, x_val, y_val, x_test, y_test: Data splits
        is_classification: Whether this is a classification task
        device: Device to run on
        alphas: List of alpha values to try (weighting between distillation and task loss)

    Returns:
        Dictionary of distillation results
    """
    # Create results container
    results = []

    # Move data to device
    best_k_matrix = best_k_matrix.to(device)
    x_train, y_train = x_train.to(device), y_train.to(device)
    x_val, y_val = x_val.to(device), y_val.to(device)
    x_test, y_test = x_test.to(device), y_test.to(device)

    # Encode data with K matrix
    z_train = encode_data(x_train, best_k_matrix)
    z_val = encode_data(x_val, best_k_matrix)
    z_test = encode_data(x_test, best_k_matrix)

    # Flatten z for training
    z_train_flat = z_train.reshape(z_train.shape[0], -1)
    z_val_flat = z_val.reshape(z_val.shape[0], -1)
    z_test_flat = z_test.reshape(z_test.shape[0], -1)

    # Determine input dimension for teacher/student models
    input_dim = z_train_flat.shape[1]

    # Create teacher model (larger)
    if is_classification:
        num_classes = len(torch.unique(y_train))
        teacher_model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        ).to(device)
    else:
        teacher_model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        ).to(device)

    # Create latent datasets
    latent_train_dataset = TensorDataset(z_train_flat, y_train)
    latent_val_dataset = TensorDataset(z_val_flat, y_val)
    latent_test_dataset = TensorDataset(z_test_flat, y_test)

    batch_size = 128
    latent_train_loader = DataLoader(latent_train_dataset, batch_size=batch_size, shuffle=True)
    latent_val_loader = DataLoader(latent_val_dataset, batch_size=batch_size)
    latent_test_loader = DataLoader(latent_test_dataset, batch_size=batch_size)

    # Train teacher model
    teacher_model = train_model(teacher_model, latent_train_loader, latent_val_loader, device, is_classification)

    # Evaluate teacher model
    teacher_metrics = evaluate_model_metrics(teacher_model, latent_test_loader, device, is_classification)

    # For each alpha value, train a student model
    for alpha in alphas:
        # Create student model (smaller)
        if is_classification:
            student_model = nn.Sequential(
                nn.Linear(input_dim, 64),
                nn.ReLU(),
                nn.Linear(64, num_classes)
            ).to(device)
        else:
            student_model = nn.Sequential(
                nn.Linear(input_dim, 64),
                nn.ReLU(),
                nn.Linear(64, 1)
            ).to(device)

        # Train student with knowledge distillation
        optimizer = optim.Adam(student_model.parameters(), lr=0.001)

        # Training with early stopping
        best_val_loss = float('inf')
        patience = 5
        early_stop_counter = 0

        for epoch in range(50):  # 50 epochs max
            # Training
            student_model.train()
            teacher_model.eval()

            for batch_z, batch_y in latent_train_loader:
                optimizer.zero_grad()

                # Get teacher outputs (with no grad)
                with torch.no_grad():
                    teacher_outputs = teacher_model(batch_z)

                # Get student outputs
                student_outputs = student_model(batch_z)

                # Knowledge distillation loss
                distill_loss = knowledge_distillation_loss(
                    student_outputs, teacher_outputs, batch_y,
                    temperature=4.0, alpha=alpha
                )

                distill_loss.backward()
                optimizer.step()

            # Validation
            student_model.eval()
            val_loss = 0.0

            with torch.no_grad():
                for batch_z, batch_y in latent_val_loader:
                    student_outputs = student_model(batch_z)

                    if is_classification:
                        loss = F.cross_entropy(student_outputs, batch_y)
                    else:
                        loss = F.mse_loss(student_outputs, batch_y)

                    val_loss += loss.item()

            val_loss /= len(latent_val_loader)

            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                early_stop_counter = 0
            else:
                early_stop_counter += 1
                if early_stop_counter >= patience:
                    print(f"Early stopping at epoch {epoch}")
                    break

        # Evaluate student model
        student_metrics = evaluate_model_metrics(student_model, latent_test_loader, device, is_classification)

        # Calculate model sizes
        teacher_size = sum(p.numel() for p in teacher_model.parameters())
        student_size = sum(p.numel() for p in student_model.parameters())

        # Add result
        result = {
            'method': 'KnowledgeDistillation',
            'alpha': alpha,
            'k_method': best_k_method,
            'num_factors': num_factors,
            'latent_dim': latent_dim,
            'teacher_metrics': teacher_metrics,
            'student_metrics': student_metrics,
            'teacher_size': teacher_size,
            'student_size': student_size,
            'compression_ratio': teacher_size / student_size
        }

        results.append(result)

    return results

def run_single_dataset_experiment(dataset_name, output_dir, gpu_id, results_queue, k_methods, epochs=100):
    """
    Run experiment for a single dataset using only the top K matrix methods.

    Args:
        dataset_name: Name of the dataset
        output_dir: Directory for output
        gpu_id: GPU ID to use
        results_queue: Queue for results
        k_methods: Methods to evaluate (only Clustered, PCA, FactorAnalysis)
        epochs: Number of epochs for refinement

    Returns:
        None (results sent to queue)
    """
    try:
        # Set up device
        device = setup_device(gpu_id)
        print(f"Processing {dataset_name} on {device}")

        # Load dataset
        available_datasets = get_available_datasets()
        if not available_datasets.get(dataset_name, {}).get('available', False):
            raise ValueError(f"Dataset {dataset_name} is not available")

        # Load dataset
        x_data, y_data, is_classification = load_or_create_dataset(dataset_name, available_datasets)
        print(f"Dataset shape: {x_data.shape}, Classification: {is_classification}")

        # Define factors to try - using only k=3 and k=5 as specified in the paper
        # Fix dimension to 8 as requested
        factors_to_try = [3, 5]
        latent_dim = 8  # Fixed dimension

        # Initialize results for this dataset
        dataset_results = {'k_methods': {}}

        # Process each K method (Clustered, PCA, FactorAnalysis)
        for method_name, method_func in k_methods:
            print(f"\nEvaluating method: {method_name} for {dataset_name}")
            method_results = []

            for num_factors in factors_to_try:
                config_info = f"[{method_name}, Factors={num_factors}, Dims={latent_dim}]"
                print(f"Processing {config_info}")

                try:
                    # Initialize K matrix
                    start_time = time.time()
                    k_matrix = method_func(x_data.to(device), num_factors, latent_dim, device)
                    init_time = time.time() - start_time

                    # Refine K matrix
                    start_time = time.time()
                    k_refined = refine_k_matrix(x_data.to(device), k_matrix, num_factors, latent_dim, device, epochs=epochs)
                    refine_time = time.time() - start_time

                    # Evaluate K matrix
                    metrics = evaluate_k_matrix(x_data.to(device), k_refined, num_factors, latent_dim, device)

                    # Calculate combined score as specified in the paper
                    combined_score = (
                        (1.0 - metrics['mi_ksg']) * 0.2 +
                        metrics['modularity'] * 0.2 +
                        (1.0 - metrics['total_correlation']) * 0.2 +
                        metrics['factor_vae_score'] * 0.2 +
                        metrics['sap_score'] * 0.2
                    )

                    # Add results
                    result = {
                        'method': method_name,
                        'num_factors': num_factors,
                        'latent_dim': latent_dim,
                        'metrics': metrics,
                        'combined_score': combined_score,
                        'init_time': init_time,
                        'refine_time': refine_time,
                        'k_matrix': k_refined.detach().cpu()
                    }

                    # For evaluation, also store teacher performance
                    # Encode data with K matrix
                    z = encode_data(x_data.to(device), k_refined)

                    # Flatten z for prediction
                    z_flat = z.reshape(z.shape[0], -1)

                    # Split data for training
                    train_size = int(0.8 * len(z_flat))
                    train_z, test_z = z_flat[:train_size], z_flat[train_size:]
                    train_y, test_y = y_data[:train_size].to(device), y_data[train_size:].to(device)

                    # Create prediction model
                    input_dim = z_flat.shape[1]
                    if is_classification:
                        num_classes = len(torch.unique(y_data))
                        model = nn.Sequential(
                            nn.Linear(input_dim, 64),
                            nn.ReLU(),
                            nn.Linear(64, num_classes)
                        ).to(device)
                        criterion = nn.CrossEntropyLoss()
                    else:
                        model = nn.Sequential(
                            nn.Linear(input_dim, 64),
                            nn.ReLU(),
                            nn.Linear(64, 1)
                        ).to(device)
                        criterion = nn.MSELoss()

                    # Train simple model to evaluate representation quality
                    optimizer = optim.Adam(model.parameters(), lr=0.001)
                    model.train()

                    # Train for 20 epochs
                    batch_size = 128
                    for _ in range(20):
                        for i in range(0, len(train_z), batch_size):
                            batch_z = train_z[i:i+batch_size]
                            batch_y = train_y[i:i+batch_size]

                            optimizer.zero_grad()
                            outputs = model(batch_z)

                            if is_classification:
                                loss = criterion(outputs, batch_y.long())
                            else:
                                loss = criterion(outputs, batch_y)

                            loss.backward()
                            optimizer.step()

                    # Evaluate model
                    model.eval()
                    with torch.no_grad():
                        outputs = model(test_z)

                        if is_classification:
                            _, predicted = torch.max(outputs, 1)
                            accuracy = (predicted == test_y.long()).float().mean().item()
                            result['teacher_performance'] = accuracy
                        else:
                            mse = criterion(outputs, test_y).item()
                            result['teacher_performance'] = mse

                    method_results.append(result)

                    # Log results
                    print(f"{config_info} - Combined score: {combined_score:.4f}")
                    performance_metric = "Accuracy" if is_classification else "MSE"
                    print(f"{config_info} - {performance_metric}: {result['teacher_performance']:.4f}")

                    # Clean up GPU memory
                    clean_gpu_memory(device)

                except Exception as e:
                    print(f"Error processing {config_info}: {e}")
                    import traceback
                    traceback.print_exc()
                    clean_gpu_memory(device)

            # Store results for this method
            dataset_results['k_methods'][method_name] = method_results

        # Put results in queue
        results_queue.put((dataset_name, dataset_results))
        print(f"Completed processing dataset: {dataset_name}")

    except Exception as e:
        print(f"Error processing dataset {dataset_name}: {e}")
        import traceback
        traceback.print_exc()
        # Return empty results on error
        results_queue.put((dataset_name, {'k_methods': {}}))

    finally:
        # Clean up GPU memory
        if 'device' in locals():
            clean_gpu_memory(device)


def save_metrics_for_sota_comparison(results, output_dir='.'):
    """
    Save experiment results to CSV files specifically formatted for SOTA comparison.

    Args:
        results: Dictionary of results by dataset and method
        output_dir: Directory to save CSV files
    """
    os.makedirs(output_dir, exist_ok=True)

    # Create consolidated dataframes
    all_task_metrics = []
    all_disent_metrics = []
    sota_comparison = []

    # Process each dataset
    for dataset_name, dataset_results in results.items():
        # 1. K matrix methods
        if 'k_methods' in dataset_results:
            for method_name, method_results in dataset_results['k_methods'].items():
                for result in method_results:
                    # Basic information
                    row = {
                        'dataset': dataset_name,
                        'method': method_name,
                        'num_factors': result['num_factors'],
                        'latent_dim': result['latent_dim'],
                        'config': f"f{result['num_factors']}_d{result['latent_dim']}"
                    }

                    # Performance metrics
                    if 'teacher_performance' in result:
                        is_classification = isinstance(result['teacher_performance'], float) and 0 <= result['teacher_performance'] <= 1
                        if is_classification:
                            row['accuracy'] = result['teacher_performance']
                        else:
                            row['mse'] = result['teacher_performance']

                    # Add K matrix metrics
                    if 'metrics' in result:
                        metrics = result['metrics']
                        for metric_name, metric_value in metrics.items():
                            row[metric_name] = metric_value

                    # Add combined score
                    if 'combined_score' in result:
                        row['combined_score'] = result['combined_score']

                    # Timing information
                    if 'init_time' in result:
                        row['init_time'] = result['init_time']
                    if 'refine_time' in result:
                        row['refine_time'] = result['refine_time']

                    # Add to all metrics
                    all_task_metrics.append(row)

                    # Add to SOTA comparison format
                    sota_row = row.copy()
                    sota_row['method_type'] = 'Universal_K'
                    sota_comparison.append(sota_row)

                    # Add disentanglement metrics
                    disent_row = row.copy()
                    if 'metrics' in result:
                        metrics = result['metrics']
                        disentanglement_metrics = ['mi_ksg', 'modularity', 'total_correlation',
                                                'factor_vae_score', 'sap_score', 'sparsity', 'variance_ratio']

                        for metric in disentanglement_metrics:
                            if metric in metrics:
                                disent_row[metric] = metrics[metric]

                        all_disent_metrics.append(disent_row)

    # Save consolidated dataframes
    if all_task_metrics:
        df_all_task = pd.DataFrame(all_task_metrics)
        df_all_task.to_csv(f"{output_dir}/all_task_metrics.csv", index=False)
        print(f"Saved all task metrics to {output_dir}/all_task_metrics.csv")

    if all_disent_metrics:
        df_all_disent = pd.DataFrame(all_disent_metrics)
        df_all_disent.to_csv(f"{output_dir}/all_disentanglement_metrics.csv", index=False)
        print(f"Saved all disentanglement metrics to {output_dir}/all_disentanglement_metrics.csv")

    # Save SOTA comparison
    if sota_comparison:
        df_sota = pd.DataFrame(sota_comparison)
        df_sota.to_csv(f"{output_dir}/sota_comparison.csv", index=False)
        print(f"Saved SOTA comparison to {output_dir}/sota_comparison.csv")

    # Save metrics by dataset and method
    if all_task_metrics:
        # Group by dataset and method
        for dataset_name in set(row['dataset'] for row in all_task_metrics):
            dataset_metrics = [row for row in all_task_metrics if row['dataset'] == dataset_name]

            # Create a DataFrame and save
            df_dataset = pd.DataFrame(dataset_metrics)
            df_dataset.to_csv(f"{output_dir}/{dataset_name}_metrics.csv", index=False)
            print(f"Saved {dataset_name} metrics to {output_dir}/{dataset_name}_metrics.csv")

            # Group by method
            for method_name in set(row['method'] for row in dataset_metrics):
                method_metrics = [row for row in dataset_metrics if row['method'] == method_name]

                # Create a DataFrame and save
                df_method = pd.DataFrame(method_metrics)
                df_method.to_csv(f"{output_dir}/{dataset_name}_{method_name}_metrics.csv", index=False)
                print(f"Saved {dataset_name}_{method_name} metrics to {output_dir}/{dataset_name}_{method_name}_metrics.csv")

def generate_stats_table(results, output_dir='.'):
    """
    Generate and save statistical tables for comparing methods.

    Args:
        results: Dictionary of results by dataset and method
        output_dir: Directory to save tables
    """
    # Create tables for different hyperparameter configurations
    for factors in [3, 5]:
        for dims in [8, 16]:
            rows = []

            # Process each dataset
            for dataset_name, dataset_results in results.items():
                if 'k_methods' not in dataset_results:
                    continue

                # Find best method for this dataset and config
                best_score = -float('inf')
                best_method = None
                method_scores = {}
                method_metrics = {}

                for method_name, method_results in dataset_results['k_methods'].items():
                    for result in method_results:
                        if result['num_factors'] == factors and result['latent_dim'] == dims:
                            score = result['combined_score']
                            method_scores[method_name] = score

                            # Store metrics
                            if 'metrics' in result:
                                method_metrics[method_name] = result['metrics']

                            # Update best method
                            if score > best_score:
                                best_score = score
                                best_method = method_name

                # Create row
                if best_method:
                    row = {
                        'dataset': dataset_name,
                        'best_method': best_method,
                        'best_score': best_score
                    }

                    # Add scores for each method
                    for method_name in ['Clustered', 'PCA', 'FactorAnalysis']:
                        if method_name in method_scores:
                            row[f"{method_name}_score"] = method_scores[method_name]
                        else:
                            row[f"{method_name}_score"] = None

                    # Add key metrics for best method
                    if best_method in method_metrics:
                        metrics = method_metrics[best_method]
                        for metric_name in ['mi_ksg', 'modularity', 'sparsity', 'recon_error']:
                            if metric_name in metrics:
                                row[metric_name] = metrics[metric_name]

                    rows.append(row)

            # Create and save table
            if rows:
                df = pd.DataFrame(rows)
                filename = f"{output_dir}/stats_f{factors}_d{dims}.csv"
                df.to_csv(filename, index=False)
                print(f"Saved statistical table for f{factors}_d{dims} to {filename}")

    # Create overall best method table
    create_best_method_table(results, output_dir)

def create_best_method_table(results, output_dir='.'):
    """
    Create a table showing the best method for each dataset across all configurations.

    Args:
        results: Dictionary of results by dataset and method
        output_dir: Directory to save table
    """
    rows = []

    for dataset_name, dataset_results in results.items():
        if 'k_methods' not in dataset_results:
            continue

        # Find best method and configuration
        best_score = -float('inf')
        best_config = None

        for method_name, method_results in dataset_results['k_methods'].items():
            for result in method_results:
                if 'combined_score' in result and result['combined_score'] > best_score:
                    best_score = result['combined_score']
                    best_config = {
                        'dataset': dataset_name,
                        'method': method_name,
                        'factors': result['num_factors'],
                        'dims': result['latent_dim'],
                        'combined_score': result['combined_score']
                    }

                    # Add metrics
                    if 'metrics' in result:
                        metrics = result['metrics']
                        for metric_name, metric_value in metrics.items():
                            best_config[metric_name] = metric_value

        if best_config:
            rows.append(best_config)

    # Create and save table
    if rows:
        df = pd.DataFrame(rows)
        filename = f"{output_dir}/best_methods_overall.csv"
        df.to_csv(filename, index=False)
        print(f"Saved best methods table to {filename}")

        # Also create a formatted table for the paper
        paper_rows = []
        for row in rows:
            paper_row = {
                'Dataset': row['dataset'],
                'Best Method': row['method'],
                'Combined Score': f"{row['combined_score']:.4f}",
                'Recon. Error': f"{row.get('recon_error', 'N/A'):.4f}" if isinstance(row.get('recon_error'), (int, float)) else 'N/A',
                'MI (MINE)': f"{row.get('mi_ksg', 'N/A'):.4f}" if isinstance(row.get('mi_ksg'), (int, float)) else 'N/A',
                'Sparsity': f"{row.get('sparsity', 'N/A'):.4f}" if isinstance(row.get('sparsity'), (int, float)) else 'N/A'
            }
            paper_rows.append(paper_row)

        # Create and save paper table
        df_paper = pd.DataFrame(paper_rows)
        paper_filename = f"{output_dir}/paper_table.csv"
        df_paper.to_csv(paper_filename, index=False)
        print(f"Saved formatted paper table to {paper_filename}")

def run_parallel_experiments(dataset_names, output_dir, max_processes, results_queue=None, k_methods=None, epochs=100):
    """
    Run experiments in parallel across multiple GPUs/processes, focusing on top K-matrix methods.

    Args:
        dataset_names: List of datasets to process
        output_dir: Directory for saving results
        max_processes: Maximum number of parallel processes
        results_queue: Optional Queue for collecting results
        k_methods: List of K matrix methods to evaluate
        epochs: Number of epochs for refining K matrices

    Returns:
        Combined results dictionary
    """
    if results_queue is None:
        manager = Manager()
        results_queue = manager.Queue()

    # Define only the top methods (Clustered, PCA, FactorAnalysis)
    if k_methods is None:
        k_methods = [
            ('Clustered', create_clustered_k_matrix),
            ('PCA', create_pca_k_matrix),
            ('FactorAnalysis', create_factor_analysis_k_matrix)
        ]

    # Create process pool
    all_results = {}

    # Process datasets in batches
    for i in range(0, len(dataset_names), max_processes):
        batch = dataset_names[i:i + max_processes]
        processes = []
        active_datasets = []

        # Start processes for this batch
        for j, dataset_name in enumerate(batch):
            # Assign GPU
            gpu_id = j % max(1, torch.cuda.device_count()) if torch.cuda.is_available() else None

            # Create and start process
            p = Process(
                target=run_single_dataset_experiment,
                args=(dataset_name, output_dir, gpu_id, results_queue, k_methods, epochs)
            )
            processes.append(p)
            active_datasets.append(dataset_name)
            p.start()
            print(f"Started processing {dataset_name} on GPU {gpu_id if gpu_id is not None else 'N/A'}")

        # Wait for all processes in this batch to complete
        for p in processes:
            p.join()

        # Collect results from this batch
        for _ in range(len(active_datasets)):
            try:
                dataset_name, dataset_results = results_queue.get(timeout=10)
                all_results[dataset_name] = dataset_results
                print(f"Collected results for {dataset_name}")
            except Exception as e:
                print(f"Error collecting results: {e}")

        # Save intermediate results
        if all_results:
            save_metrics_for_sota_comparison(all_results, output_dir)

    return all_results

# ===== MAIN ENTRY POINT =====
def main(dataset_names=None, output_dir='results', max_processes=None, epochs=100):
    """
    Modified main function to run the Universal K Matrix experiment focused on top methods.

    Args:
        dataset_names: List of datasets to test (None for all available)
        output_dir: Directory to save results
        max_processes: Maximum number of parallel processes (None uses all available GPUs)
        epochs: Number of epochs for refining K matrices

    Returns:
        Dictionary of results and comparison DataFrame
    """
    try:
        print("Starting Universal K Matrix Analysis with Top Methods")

        # Create output directory
        os.makedirs(output_dir, exist_ok=True)

        # Record start time
        start_time = time.time()

        # Set up multiprocessing with 'spawn' method for CUDA compatibility
        try:
            set_start_method('spawn', force=True)
        except RuntimeError:
            print("Context already set, continuing...")

        # Set up GPU management
        num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
        print(f"Found {num_gpus} GPUs")

        # Determine maximum parallel processes
        if max_processes is None:
            max_processes = max(1, num_gpus)
        print(f"Using maximum of {max_processes} parallel processes")

        # Get available datasets
        available_datasets = get_available_datasets()
        if dataset_names is None:
            dataset_names = [name for name, info in available_datasets.items() if info['available']]

        if not dataset_names:
            print("No datasets available for processing")
            return {}

        print(f"Processing datasets: {', '.join(dataset_names)}")

        # Define ONLY the top methods to test based on the paper results
        k_methods = [
            ('Clustered', create_clustered_k_matrix),
            ('PCA', create_pca_k_matrix),
            ('FactorAnalysis', create_factor_analysis_k_matrix)
        ]

        # Create Manager for shared resources
        manager = Manager()
        results_queue = manager.Queue()

        # Run experiment using parallel processes
        results = run_parallel_experiments(dataset_names, output_dir, max_processes,
                                          results_queue, k_methods, epochs=epochs)

        # Save results to CSV specifically formatted for SOTA comparison
        save_metrics_for_sota_comparison(results, output_dir)

        # Generate comparison between methods
        comparison_df = compare_universal_k_methods(results)

        # Save comparison to CSV
        comparison_path = os.path.join(output_dir, 'method_comparison.csv')
        comparison_df.to_csv(comparison_path, index=False)
        print(f"Saved method comparison to {comparison_path}")

        # Generate statistics table for the paper
        generate_stats_table(results, output_dir)

        # Print elapsed time
        elapsed_time = time.time() - start_time
        print(f"Experiment completed in {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)")

        return results, comparison_df

    except Exception as e:
        print(f"Error in main: {e}")
        import traceback
        traceback.print_exc()
        return None, None


if __name__ == "__main__":
    main()