### Clustering in feature space is already good, but can it be better in latent space?

<div style="font-size: 12px;">

- consider (variational) auto encoder:
    - we have noisy input, but we dont have access to clean signals
        - better use VAE (with Gaussian Prior regularization) than AE
    - AE is better suited for clean signal's representation learning
        - add noise, train model on clean signal, which is impossible for our case here

</div>

<div style="font-size: 12px;">

| **Main Type**                       | **Sub-variant Name(s)**                                         | **Core Idea**                                           | **Loss / Regularization**                   | **Applications**                     |
| ----------------------------------- | --------------------------------------------------------------- | ------------------------------------------------------- | ------------------------------------------- | ------------------------------------ |
| **1. Basic Autoencoder**            | -                                                               | Learn to reconstruct input                              | MSE, BCE                                    | Dimensionality reduction             |
| **2. Undercomplete AE**             | -                                                               | Latent dim < input dim                                  | MSE                                         | Compression, denoising               |
| **3. Overcomplete AE**              | -                                                               | Latent dim > input dim                                  | L1, L2 regularization                       | Feature extraction                   |
| **4. Denoising AE**                 | -                                                               | Corrupt input → recover original                        | MSE, Gaussian/Bernoulli noise               | Robust encoding                      |
| **5. Sparse AE**                    | -                                                               | Enforce sparsity on latent units                        | L1 norm, KL divergence                      | Part-based features                  |
| **6. Contractive AE**               | -                                                               | Penalize sensitivity of encoding to input               | \$\lambda \|\nabla\_x f(x)\|^2\$            | Robust manifold learning             |
| **7. Variational AE (VAE)**         | β-VAE, CVAE, VQ-VAE, InfoVAE, FactorVAE, DIP-VAE, VampPrior-VAE | Probabilistic latent space; KL-regularized              | ELBO loss                                   | Generative modeling, disentanglement |
| **8. Adversarial AE (AAE)**         | Semi-supervised AAE, Domain AAE                                 | Match latent posterior with prior via discriminator     | Reconstruction + adversarial loss           | Domain adaptation, generative        |
| **9. Wasserstein AE (WAE)**         | WAE-MMD, WAE-GAN                                                | Wasserstein distance instead of KL in latent            | \$\mathcal{L}\_{rec} + \lambda D\_Z\$       | Stable generative models             |
| **10. Robust AE**                   | Robust DAE, L21-AE, Huber AE                                    | Handle noisy or outlier data                            | L1, L21 norm, Huber loss                    | Anomaly detection                    |
| **11. Convolutional AE (CAE)**      | -                                                               | CNN-based encoder/decoder                               | MSE                                         | Image compression, vision tasks      |
| **12. Sequence AE**                 | RNN-AE, LSTM-AE, GRU-AE, BiLSTM-AE                              | Encode and decode sequential data                       | Seq MSE, teacher forcing                    | Time-series, NLP                     |
| **13. Attention-based AE**          | Transformer-AE, MAE (Masked AE), Perceiver-AE                   | Use self-attention to model input                       | Masked MSE or BCE                           | NLP, Vision, Speech                  |
| **14. Structured AE**               | Graph AE, GCN-AE, Tree AE, Set AE                               | Works on structured inputs like graphs or trees         | Graph recon loss, node-wise loss            | Graph learning, structure modeling   |
| **15. Energy-based AE**             | Energy AE, EBM-AE                                               | Latent code assigned energy; contrastive regularization | Energy loss, contrastive divergence         | Out-of-distribution detection        |
| **16. Latent-space Regularized AE** | Latent Consistency AE, Triplet AE, Center AE                    | Regularize relationships in latent space                | Latent MSE, Triplet loss                    | Discriminative embeddings            |
| **17. Self-supervised AE**          | MAE, SimMIM, BYOL-AE, DINO-AE                                   | Predict missing parts / self-predict                    | Masked loss, contrastive loss               | Representation learning              |
| **18. Hybrid AE Models**            | VAE-GAN, AAE-VAE, AE-GAN, VAE-AAE                               | Combine strengths of multiple AE types                  | Hybrid loss (ELBO + GAN)                    | Generative, synthesis                |
| **19. Disentangled AE**             | TC-VAE, FactorVAE, Ada-GVAE, StyleVAE                           | Latent variables encode distinct generative factors     | Total correlation regularization            | Causal and interpretable models      |
| **20. Koopman AE / Dynamics AE**    | LinearDynamicalAE, LatentODE-AE, Neural ODE-AE                  | Impose linear or known dynamics in latent space         | Latent prediction error                     | Dynamical systems, control           |
| **21. Metric AE**                   | Siamese AE, Triplet AE, Contrastive AE                          | Supervised AE using pairwise or triplet relationships   | Contrastive / Triplet loss + Recon          | Face ID, retrieval                   |
| **22. Bayesian AE**                 | Bayes-AE, Probabilistic AE, Stochastic-AE                       | Bayesian treatment of encoder/decoder weights           | VI loss, ELBO                               | Uncertainty estimation               |
| **23. Physics-Informed AE**         | PI-AE, PDE-AE, PINN-AE                                          | Encode physics via constraints in loss                  | PDE residual loss + MSE                     | Scientific ML                        |
| **24. Federated AE**                | FedAE, SplitAE, Differentially Private AE                       | Train across clients; preserve data privacy             | FedAvg, noise injection                     | Federated, edge ML                   |
| **25. Diffusion AE**                | Diff-AE, D3PM-AE, Denoise-AE                                    | Train AE as part of diffusion process                   | Denoising score matching, variational bound | Generation, denoising                |
| **26. Quantized AE**                | VQ-AE, VQ-VAE, Discrete AE                                      | Latent is discrete via quantization                     | Commitment loss + Recon loss                | Speech, compression                  |
| **27. Equivariant AE**              | G-AE, SE(3)-AE, E(2)-AE                                         | Preserve group symmetries (rotation, etc.)              | Group-equivariant loss                      | Molecular modeling, physics          |
| **28. Cross-modal AE**              | Audio-Visual AE, Text-Image AE, Multimodal AE                   | Encode multiple modalities to shared latent             | Multi-modal recon + alignment loss          | Cross-modal retrieval                |
| **29. Swapped / Mixup AE**          | Mixup-AE, Swapped-AE, Latent Interpolation AE                   | Swap or interpolate latent codes                        | Mixup loss + AE loss                        | Regularization, generalization       |
| **30. Contrastive AE**              | SimCLR-AE, MoCo-AE, CL-AE                                       | Contrastive loss in latent + reconstruction             | Contrastive + MSE                           | Self-supervised learning             |


| **Aspect / Property**            | **SimCLR**                           | **MoCo v2**                 | **BYOL**                      | **SimSiam**               | **Barlow Twins**             | **VICReg**                         | **SwAV**                  | **Triplet Loss**         | **SupCon**                   |
| -------------------------------- | ------------------------------------ | --------------------------- | ----------------------------- | ------------------------- | ---------------------------- | ---------------------------------- | ------------------------- | ------------------------ | ---------------------------- |
| **Requires Negative Samples**    | ✅ Yes                                | ✅ Yes                       | ❌ No                          | ❌ No                      | ❌ No                         | ❌ No                               | ❌ No                      | ✅ Yes                    | ✅ Yes                        |
| **Positive Pair Source**         | Augmented views                      | Augmented views             | Augmented views               | Augmented views           | Augmented views              | Augmented views                    | Cluster assignments       | Triplet sampling         | Same-class views             |
| **Negative Source**              | In-batch                             | Memory queue                | None                          | None                      | None                         | None                               | N/A                       | Explicit sampling        | In-batch                     |
| **Memory Bank / Queue Used**     | ❌ No                                 | ✅ Yes                       | ❌ No                          | ❌ No                      | ❌ No                         | ❌ No                               | ❌ No                      | ❌ No                     | ❌ No                         |
| **Momentum Encoder**             | ❌ No                                 | ✅ Yes                       | ✅ Yes                         | ❌ No                      | ❌ No                         | ❌ No                               | ✅ Optional                | ❌ No                     | ❌ No                         |
| **Asymmetric Architecture**      | ❌ No                                 | ✅ Slight                    | ✅ Yes                         | ✅ Yes                     | ❌ No                         | ❌ No                               | ✅ Yes                     | ❌ No                     | ❌ No                         |
| **Uses Stop-Gradient**           | ❌ No                                 | ❌ No                        | ✅ Yes                         | ✅ Yes                     | ❌ No                         | ❌ No                               | ❌ No                      | ❌ No                     | ❌ No                         |
| **Requires Large Batch**         | ✅ Yes                                | ❌ No                        | ❌ No                          | ❌ No                      | ❌ No                         | ❌ No                               | ❌ No                      | ⚠ Moderate               | ✅ Yes                        |
| **Collapse Risk Without Tricks** | ✅ Yes                                | ❌ No                        | ❌ No                          | ❌ No                      | ❌ No                         | ❌ No                               | ❌ No                      | ✅ Yes                    | ❌ No                         |
| **Batch Size Sensitivity**       | 🔺 High                              | 🔻 Low                      | 🔻 Low                        | 🔻 Low                    | 🔻 Low                       | 🔻 Low                             | 🔻 Low                    | ⚠ Medium                 | 🔺 High                      |
| **Augmentation Sensitivity**     | ✅ Strong                             | ✅ Strong                    | ✅ Strong                      | ✅ Strong                  | ✅ Strong                     | ✅ Medium                           | ✅ Strong                  | ⚠ Varies                 | ✅ Strong                     |
| **Architecture Simplicity**      | ✅ Simple                             | ⚠ Slightly complex          | ✅ Simple                      | ✅ Simple                  | ✅ Simple                     | ✅ Simple                           | ⚠ Mid                     | ⚠ Needs sampler          | ✅ Simple                     |
| **Projection Head Used**         | ✅ Yes                                | ✅ Yes                       | ✅ Yes                         | ✅ Yes                     | ✅ Yes                        | ✅ Yes                              | ✅ Yes                     | ❌ No                     | ✅ Yes                        |
| **Prediction Head (Extra MLP)**  | ❌ No                                 | ❌ No                        | ✅ Yes                         | ✅ Yes                     | ❌ No                         | ❌ No                               | ❌ No                      | ❌ No                     | ❌ No                         |
| **Loss Function Type**           | Contrastive (NT-Xent)                | InfoNCE                     | Asym. regressive              | Asym. regressive          | Redundancy reduction         | Variance & Cov reg                 | Cluster assignment        | Margin-based             | Contrastive (NT-Xent)        |
| **Learning Stability**           | ⚠ Sensitive                          | ✅ Stable                    | ✅ Very stable                 | ✅ Very stable             | ✅ Very stable                | ✅ Very stable                      | ✅ Stable                  | ⚠ Depends on mining      | ✅ Stable                     |
| **Computational Cost**           | 🔺 High                              | ⚠ Medium                    | 🔻 Low                        | 🔻 Low                    | 🔻 Low                       | 🔻 Low                             | ⚠ Medium                  | 🔺 High                  | 🔺 High                      |
| **Supervision Used**             | ❌ Unsupervised                       | ❌ Unsupervised              | ❌ Unsupervised                | ❌ Unsupervised            | ❌ Unsupervised               | ❌ Unsupervised                     | ❌ Unsupervised            | ❌ (optional)             | ✅ (Supervised)               |
| **Use Case Domain**              | Vision, audio                        | Vision, general             | Vision, low-resources         | Vision                    | General representation       | General (simple)                   | Vision (clustering)       | Metric learning          | Semi/fully supervised        |
| **Best For**                     | Large compute setups                 | Small batch efficiency      | Efficient self-supervision    | Minimal contrastive setup | Negative-free training       | Simplified contrastive pretraining | Online clustering         | Face / text similarity   | Label-guided contrastive     |
| **Unique Trait**                 | High performance with many negatives | Queue-based contrast        | No negatives needed           | Simple + no collapse      | Correlation decorrelation    | Explicit variance regularization   | Self-labeling w/ Sinkhorn | Explicit metric learning | Leverages label info         |
| **Common Pitfalls**              | Needs large batch and careful aug    | Needs tuning queue/momentum | Risk of collapse if not tuned | May underperform SimCLR   | Limited gains w/o strong aug | May not scale well to hard tasks   | Sinkhorn instability      | Poor if mining weak      | Needs labels and large batch |

</div>


![image.png](attachment:image.png)

![image.png](attachment:image.png)

In [None]:
# -*- coding: utf-8 -*-
'''
VaDE (Variational Deep Embedding: A Generative Approach to Clustering)
Clean implementation with standard sklearn-like interface
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from sklearn import mixture
from scipy.optimize import linear_sum_assignment
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import os
import json


class VaDE(nn.Module):
    def __init__(self, n_clusters, input_dim, hidden_dims, latent_dim=2,
                 batch_size=100, epochs=100,
                 lr=0.002, lr_decay=0.9, decay_every=10,
                 alpha=1.0, activation='sigmoid', use_dropout=True, use_pretraining=True,
                 plot_every=5):
        """
        VaDE model with clean interface

        Args:
            n_clusters: Number of clusters
            input_dim: Input feature dimension
            hidden_dims: List of hidden layer dimensions
            latent_dim: Latent space dimension (default: 2 for 2D visualization)
            batch_size: Training batch size (smaller -> slower training, more noisy gradient, easier to escape local minima)
            epochs: Number of training epochs
            lr: Learning rate
            lr_decay: Learning rate decay factor
            decay_every: Decay learning rate every N epochs
            alpha: Reconstruction loss weight
            activation: 'sigmoid' or 'linear' for decoder output (if inputs are already normalized in [0,1], use 'sigmoid')
            use_dropout: Whether to use dropout
            use_pretraining: Whether to use pretraining for GMM initialization (recommended)
            plot_every: Save latent space plot every N epochs (default: 5)
        """
        super(VaDE, self).__init__()

        self.n_clusters = n_clusters
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.hidden_dims = hidden_dims
        self.batch_size = batch_size
        self.max_epochs = epochs
        self.lr = lr
        self.lr_decay = lr_decay
        self.decay_every = decay_every
        self.alpha = alpha
        self.activation = activation
        self.use_dropout = use_dropout
        self.use_pretraining = use_pretraining
        self.plot_every = plot_every

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)

        # Create latent space folder (clear old one first)
        self.latent_space_dir = "latent_space"
        if os.path.exists(self.latent_space_dir):
            import shutil
            shutil.rmtree(self.latent_space_dir)
        os.makedirs(self.latent_space_dir, exist_ok=True)

        # Build networks
        self._build_encoder()
        self._build_decoder()

        # GMM parameters - following paper's notation:
        # π (theta_p): prior cluster probabilities
        # μ (u_p): cluster means in latent space
        # σ² (lambda_p): cluster variances in latent space (make them elliptical)
        self.theta_p = nn.Parameter(torch.ones(n_clusters) / n_clusters)
        self.u_p = nn.Parameter(torch.zeros(latent_dim, n_clusters))
        self.lambda_p = nn.Parameter(torch.ones(latent_dim, n_clusters))

        # # If you want, you can use dual optimizers as per paper methodology:
        # # 1. Neural network optimizer: for encoder/decoder parameters
        # # 2. GMM optimizer: for GMM parameters (π, μ, σ²)
        # # This separation allows different learning rates and schedules
        # self.nn_optimizer = None
        # self.gmm_optimizer = None

        self.optimizer = None

        # Training history
        self.history = {
            'loss': [],
            'accuracy': [],
            'lr': []
        }

        # Current epoch tracking for external progress monitoring
        self.epochs = 0

    def _build_encoder(self):
        """Build encoder network"""
        layers = []
        prev_dim = self.input_dim

        # batchnorm is essential here, as pretrained encoder and encoder jointly with GMM
        # have completely different latent space, need to make sure the initial GMM centers
        # matches with the samples in joint space, at least on scale
        for hidden_dim in self.hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(hidden_dim))
            if self.use_dropout:
                layers.append(nn.Dropout(0.2))
            prev_dim = hidden_dim

        self.encoder = nn.Sequential(*layers)
        self.z_mean = nn.Linear(prev_dim, self.latent_dim)
        self.z_log_var = nn.Linear(prev_dim, self.latent_dim)

    def _build_decoder(self):
        """Build decoder network"""
        layers = []
        prev_dim = self.latent_dim

        for hidden_dim in reversed(self.hidden_dims):
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.ReLU())
            prev_dim = hidden_dim

        self.decoder = nn.Sequential(*layers)

        if self.activation == 'sigmoid':
            self.output_layer = nn.Sequential(
                nn.Linear(prev_dim, self.input_dim),
                nn.Sigmoid()
            )
        else:
            self.output_layer = nn.Linear(prev_dim, self.input_dim)

    def encode(self, x):
        """Encode input to latent parameters"""
        h = self.encoder(x)
        z_mean = self.z_mean(h)
        z_log_var = self.z_log_var(h)
        return z_mean, z_log_var

    def reparameterize(self, z_mean, z_log_var):
        """Reparameterization trick"""
        std = torch.exp(0.5 * z_log_var)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def decode(self, z):
        """Decode latent representation

        Note: The decoder takes the sampled z (which incorporates both mean and variance 
        through the reparameterization trick)
        """
        h = self.decoder(z)
        return self.output_layer(h)

    def forward(self, x):
        """Forward pass"""
        z_mean, z_log_var = self.encode(x)
        z = self.reparameterize(z_mean, z_log_var)
        x_recon = self.decode(z)
        return x_recon, z_mean, z_log_var, z

    def get_gamma(self, z_mean):
        """Compute cluster assignment probabilities q(c|x) = E_{q(z|x)}[p(c|z)]"""
        batch_size = z_mean.shape[0]
        # Expand dimensions for broadcasting [batch, latent, clusters]
        Z = z_mean.unsqueeze(2).expand(-1, -1, self.n_clusters)
        u_tensor3 = self.u_p.unsqueeze(0).expand(batch_size, -1, -1)
        lambda_tensor3 = self.lambda_p.unsqueeze(0).expand(batch_size, -1, -1)

        # Compute log p(z|c) (Appendix B, Lemma 1)
        log_two_pi = np.log(2 * np.pi)
        log_p_z_given_c = -0.5 * (
            self.latent_dim * log_two_pi +
            torch.sum(torch.log(lambda_tensor3), dim=1) +
            torch.sum((Z - u_tensor3).pow(2) / lambda_tensor3, dim=1)
        )

        # Add log p(c) and normalize (Appendix A)
        log_p_c = torch.log(self.theta_p + 1e-10)
        log_p = log_p_z_given_c + log_p_c.unsqueeze(0)
        gamma = torch.nn.functional.softmax(log_p, dim=-1)

        return gamma

    def compute_loss(self, x, x_recon, z_mean, z_log_var, z):
        """Compute VaDE loss per Appendix C"""
        batch_size = x.shape[0]

        # Get gamma using variational mean (Appendix A)
        gamma = self.get_gamma(z_mean)

        # Reconstruction loss (Eq 20, first term)
        if self.activation == 'sigmoid':
            recon_loss = self.alpha * self.input_dim * F.binary_cross_entropy(x_recon, x, reduction='none')
        else:
            recon_loss = self.alpha * self.input_dim * F.mse_loss(x_recon, x, reduction='none')
        recon_loss = torch.sum(recon_loss, dim=1)

        # KL divergences (Eq 20)
        # Expand dimensions for broadcasting
        u_tensor3 = self.u_p.unsqueeze(0).expand(batch_size, -1, -1)  # [batch_size, latent_dim, n_clusters]
        lambda_tensor3 = self.lambda_p.unsqueeze(0).expand(batch_size, -1, -1)  # [batch_size, latent_dim, n_clusters]
        z_mean_expanded = z_mean.unsqueeze(2).expand(-1, -1, self.n_clusters)  # [batch_size, latent_dim, n_clusters]
        z_var = torch.exp(z_log_var)  # [batch_size, latent_dim]
        z_var_expanded = z_var.unsqueeze(2).expand(-1, -1, self.n_clusters)  # [batch_size, latent_dim, n_clusters]

        # Term 1: E[log p(z|c)] (Appendix C) Gaussian mixture expectation
        log_lambda = torch.sum(torch.log(lambda_tensor3), dim=1)  # [batch_size, n_clusters]
        kl1 = 0.5 * torch.sum(gamma * (
            self.latent_dim * np.log(2 * np.pi) +
            log_lambda +
            torch.sum(z_var_expanded / lambda_tensor3, dim=1) +
            torch.sum((z_mean_expanded - u_tensor3).pow(2) / lambda_tensor3, dim=1)
        ), dim=1)

        # Term 2: -E[log q(z|x)] (Appendix C) entropy calculation
        kl2 = -0.5 * torch.sum(1 + z_log_var, dim=1)

        # Term 3: E[log p(c)] (Appendix C) prior probability term
        log_theta = torch.log(self.theta_p + 1e-10)
        kl3 = -torch.sum(gamma * log_theta, dim=1)

        # Term 4: -E[log q(c|x)] (Appendix C) entropy term with numerical stability
        kl4 = torch.sum(gamma * torch.log(gamma + 1e-10), dim=1)

        total_loss = recon_loss + kl1 + kl2 + kl3 + kl4
        return torch.mean(total_loss)

    def _initialize_gmm_params(self, X, use_pretraining=False):
        """Initialize GMM parameters using sklearn with optional pretraining"""
        if use_pretraining:
            # Pretrain autoencoder first
            print("Pretraining autoencoder...")
            X_tensor = torch.FloatTensor(X).to(self.device) if not isinstance(X, torch.Tensor) else X.to(self.device)
            n_samples = X_tensor.shape[0]

            # Simple pretraining with reconstruction loss only
            optimizer = optim.Adam(self.parameters(), lr=0.0001)

            # Early stopping parameters
            patience = 4
            min_delta = 1 + 1e-4
            best_loss = float('inf')
            patience_counter = 0
            epoch = 0

            while epoch < self.max_epochs:
                self.train()  # Enable dropout during pretraining
                indices = torch.randperm(n_samples)
                total_loss = 0
                num_batches = 0

                for i in range(0, n_samples, self.batch_size):
                    end_idx = min(i + self.batch_size, n_samples)
                    batch_x = X_tensor[indices[i:end_idx]]

                    x_recon, z_mean, z_log_var, z = self.forward(batch_x)

                    # Simple reconstruction loss (no KL divergence during pretraining)
                    if self.activation == 'sigmoid':
                        loss = F.binary_cross_entropy(x_recon, batch_x, reduction='mean')
                    else:
                        loss = F.mse_loss(x_recon, batch_x)

                    # Backward pass
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    total_loss += loss.item()
                    num_batches += 1

                avg_loss = total_loss / num_batches
                print(f"Pretrain Epoch {epoch:3d} | Loss: {avg_loss:.6f}")

                # Early stopping check
                if avg_loss < best_loss * min_delta:
                    best_loss = avg_loss
                    patience_counter = 0
                else:
                    patience_counter += 1

                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch}")
                    break

                epoch += 1

            # Get pretrained representations using sampled z
            # Use train() mode to match real training behavior (important when BatchNorm is removed)
            self.train()  # Keep dropout enabled to match real training
            with torch.no_grad():
                # Get sampled z for GMM initialization
                z_mean, z_log_var = self.encode(X_tensor)
                z = self.reparameterize(z_mean, z_log_var)
                sample = z.cpu().numpy()
        else:
            # Use current encoder (no pretraining)
            # Use train() mode to match real training behavior (important when BatchNorm is removed)
            self.train()  # Keep dropout enabled to match real training
            with torch.no_grad():
                X_tensor = torch.FloatTensor(X).to(self.device) if not isinstance(X, torch.Tensor) else X.to(self.device)
                # Get sampled z for GMM initialization
                z_mean, z_log_var = self.encode(X_tensor)
                z = self.reparameterize(z_mean, z_log_var)
                sample = z.cpu().numpy()

        # Fit GMM and initialize parameters
        gmm = mixture.GaussianMixture(n_components=self.n_clusters, covariance_type='diag')
        gmm.fit(sample)

        # Initialize GMM parameters
        assert gmm.means_ is not None, "GMM means should not be None"
        assert gmm.covariances_ is not None, "GMM covariances should not be None"

        self.u_p.data = torch.FloatTensor(gmm.means_.T).to(self.device)
        self.lambda_p.data = torch.FloatTensor(gmm.covariances_.T).to(self.device)
        self.theta_p.data = torch.FloatTensor(gmm.weights_).to(self.device)

        print('GMM parameters initialized!')

    def cluster_accuracy(self, y_pred, y_true):
        """Compute clustering accuracy using Hungarian algorithm"""
        assert y_pred.size == y_true.size
        D = max(y_pred.max(), y_true.max()) + 1
        w = np.zeros((D, D), dtype=np.int64)

        for i in range(y_pred.size):
            w[y_pred[i], y_true[i]] += 1

        row_ind, col_ind = linear_sum_assignment(w.max() - w)
        accuracy = sum([w[row_ind[i], col_ind[i]] for i in range(len(row_ind))]) / y_pred.size
        return accuracy

    def fit(self, X, y=None):
        """
        Train the VaDE model using dual optimizers as per paper methodology.

        The paper uses separate optimizers for:
        1. Neural network parameters (encoder/decoder)
        2. GMM parameters (π, μ, σ²)

        This allows different learning rates and schedules for different parameter groups.

        Args:
            X: Training data of shape (n_samples, n_features)
            y: Ground truth labels (optional, used for accuracy computation)
        """
        # Convert to numpy if needed
        if isinstance(X, torch.Tensor):
            X = X.cpu().numpy()
        if isinstance(y, torch.Tensor):
            y = y.cpu().numpy()

        n_samples = X.shape[0]
        X_tensor = torch.FloatTensor(X).to(self.device)

        # Initialize GMM parameters
        self._initialize_gmm_params(X, use_pretraining=self.use_pretraining)

        # Initialize optimizers
        optimizer = optim.Adam(self.parameters(), lr=self.lr, eps=1e-4)

        # Early stopping parameters
        patience = 10
        min_delta = 1 + 1e-10
        best_loss = float('inf')
        patience_counter = 0
        self.epochs = 0

        print(f"Training VaDE model...")
        print("=" * 60)

        while self.epochs < self.max_epochs:
            self.train()

            # Shuffle data
            indices = torch.randperm(n_samples)
            total_loss = 0
            num_batches = 0

            # Training loop
            for i in range(0, n_samples, self.batch_size):
                end_idx = min(i + self.batch_size, n_samples)
                batch_indices = indices[i:end_idx]
                batch_x = X_tensor[batch_indices]

                # Forward pass
                x_recon, z_mean, z_log_var, z = self.forward(batch_x)
                loss = self.compute_loss(batch_x, x_recon, z_mean, z_log_var, z)

                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                num_batches += 1

            avg_loss = total_loss / num_batches
            current_lr = optimizer.param_groups[0]['lr']

            # Compute accuracy if labels provided
            accuracy = 0.0
            if y is not None:
                self.eval()
                with torch.no_grad():
                    z_mean, _ = self.encode(X_tensor)
                    gamma = self.get_gamma(z_mean)
                    y_pred = np.argmax(gamma.cpu().numpy(), axis=1)
                    accuracy = self.cluster_accuracy(y_pred, y)

            # Store history
            if self.epochs > 5:
                self.history['loss'].append(avg_loss)
                self.history['accuracy'].append(accuracy)
                self.history['lr'].append(current_lr)

            # Learning rate decay
            if self.epochs > 0 and self.epochs % self.decay_every == 0:
                new_lr = max(current_lr * self.lr_decay, 0.0002)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = new_lr

            # Early stopping check
            if avg_loss < best_loss * min_delta:
                best_loss = avg_loss
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"Early stopping at epoch {self.epochs}")
                break

            # Print progress
            if y is not None:
                print(f"Epoch {self.epochs:3d} | Loss: {avg_loss:.6f} | Accuracy: {accuracy:.6f} | LR: {current_lr:.6f}")
            else:
                print(f"Epoch {self.epochs:3d} | Loss: {avg_loss:.6f} | LR: {current_lr:.6f}")

            # Save latent space plot every N epochs
            if self.epochs % self.plot_every == 0:
                self.save_latent_space(X, y, self.epochs)

            # Increment epoch counter
            self.epochs += 1

        print("=" * 60)
        print("Training completed!")

        return self

    def predict(self, X):
        """
        Predict cluster assignments

        Args:
            X: Input data of shape (n_samples, n_features)

        Returns:
            Cluster assignments of shape (n_samples,)
        """
        if isinstance(X, torch.Tensor):
            X = X.cpu().numpy()

        X_tensor = torch.FloatTensor(X).to(self.device)

        self.eval()
        with torch.no_grad():
            z_mean, _ = self.encode(X_tensor)
            gamma = self.get_gamma(z_mean)

        return np.argmax(gamma.cpu().numpy(), axis=1)

    def predict_proba(self, X):
        """
        Predict cluster assignment probabilities

        Args:
            X: Input data of shape (n_samples, n_features)

        Returns:
            Cluster probabilities of shape (n_samples, n_clusters)
        """
        if isinstance(X, torch.Tensor):
            X = X.cpu().numpy()

        X_tensor = torch.FloatTensor(X).to(self.device)

        self.eval()
        with torch.no_grad():
            z_mean, _ = self.encode(X_tensor)
            gamma = self.get_gamma(z_mean)

        return gamma.cpu().numpy()

    def transform(self, X):
        """
        Transform data to latent space

        Args:
            X: Input data of shape (n_samples, n_features)

        Returns:
            Latent representations of shape (n_samples, latent_dim)
        """
        if isinstance(X, torch.Tensor):
            X = X.cpu().numpy()

        X_tensor = torch.FloatTensor(X).to(self.device)

        self.eval()
        with torch.no_grad():
            z_mean, _ = self.encode(X_tensor)

        return z_mean.cpu().numpy()

    def plot_training_history(self):
        """Plot training history using plotly"""
        if not self.history['loss']:
            print("No training history available. Run fit() first.")
            return

        fig = make_subplots(
            rows=1, cols=3,
            subplot_titles=('Training Loss', 'Clustering Accuracy', 'Learning Rate'),
            specs=[[{"secondary_y": False}, {"secondary_y": False}, {"secondary_y": False}]]
        )

        # Plot loss
        fig.add_trace(
            go.Scatter(y=self.history['loss'], mode='lines', name='Loss'),
            row=1, col=1
        )

        # Plot accuracy if available
        if any(acc > 0 for acc in self.history['accuracy']):
            fig.add_trace(
                go.Scatter(y=self.history['accuracy'], mode='lines', name='Accuracy'),
                row=1, col=2
            )

        # Plot learning rate
        fig.add_trace(
            go.Scatter(y=self.history['lr'], mode='lines', name='Learning Rate'),
            row=1, col=3
        )

        fig.update_layout(
            title="Training History",
            width=1200,
            height=400
        )

        fig.update_xaxes(title_text="Epoch", row=1, col=1)
        fig.update_xaxes(title_text="Epoch", row=1, col=2)
        fig.update_xaxes(title_text="Epoch", row=1, col=3)

        fig.update_yaxes(title_text="Loss", row=1, col=1)
        fig.update_yaxes(title_text="Accuracy", row=1, col=2)
        fig.update_yaxes(title_text="Learning Rate", row=1, col=3, type="log")

        fig.show()

    def save_latent_space(self, X, y=None, epoch=None):
        """Save latent space plot with true labels"""
        if isinstance(X, torch.Tensor):
            X = X.cpu().numpy()

        X_tensor = torch.FloatTensor(X).to(self.device)

        self.eval()
        with torch.no_grad():
            z_mean, _ = self.encode(X_tensor)
            z_mean_np = z_mean.cpu().numpy()

        # Create plot
        fig = go.Figure()

        if y is not None:
            # Plot with true labels
            unique_labels = np.unique(y)
            # Use a more robust color scheme that can handle any number of labels
            colors = px.colors.qualitative.Set1 + px.colors.qualitative.Set2 + px.colors.qualitative.Set3 + px.colors.qualitative.Pastel1 + px.colors.qualitative.Pastel2

            for i, label in enumerate(unique_labels):
                mask = y == label
                color_idx = i % len(colors)  # Use modulo to avoid index out of range
                fig.add_trace(go.Scatter(
                    x=z_mean_np[mask, 0],
                    y=z_mean_np[mask, 1],
                    mode='markers',
                    marker=dict(size=5, color=colors[color_idx]),
                    name=f'Class {label}',
                    showlegend=True
                ))
        else:
            # Plot without labels
            fig.add_trace(go.Scatter(
                x=z_mean_np[:, 0],
                y=z_mean_np[:, 1],
                mode='markers',
                marker=dict(size=5, color='blue'),
                name='Data Points'
            ))

        # Add cluster centers
        if self.latent_dim == 2:
            fig.add_trace(go.Scatter(
                x=self.u_p[0, :].detach().cpu().numpy(),
                y=self.u_p[1, :].detach().cpu().numpy(),
                mode='markers',
                marker=dict(size=12, color='red', symbol='x', line=dict(width=2, color='black')),
                name='Cluster Centers'
            ))

        # Add GMM elliptical contours
        if self.latent_dim == 2:
            # Get GMM parameters
            u_p_np = self.u_p.detach().cpu().numpy()  # [latent_dim, n_clusters]
            lambda_p_np = self.lambda_p.detach().cpu().numpy()  # [latent_dim, n_clusters]

            # Create confidence ellipses for each cluster
            for k in range(self.n_clusters):
                # Get cluster center and variances
                center = u_p_np[:, k]  # [latent_dim]
                variances = lambda_p_np[:, k]  # [latent_dim]

                # Create ellipse points for 95% confidence interval
                # For 2D Gaussian, 95% confidence ellipse corresponds to chi-square with 2 degrees of freedom
                # chi2.ppf(0.95, 2) ≈ 5.991
                confidence_level = 5.991

                # Generate ellipse points
                t = np.linspace(0, 2*np.pi, 100)
                x_ellipse = center[0] + np.sqrt(confidence_level * variances[0]) * np.cos(t)
                y_ellipse = center[1] + np.sqrt(confidence_level * variances[1]) * np.sin(t)

                # Add ellipse contour
                fig.add_trace(go.Scatter(
                    x=x_ellipse,
                    y=y_ellipse,
                    mode='lines',
                    line=dict(color='red', width=2, dash='dash'),
                    name=f'Cluster {k} Contour' if k == 0 else None,  # Only show legend for first cluster
                    showlegend=k == 0,
                    hoverinfo='skip'
                ))

        fig.update_layout(
            title=f'Latent Space - Epoch {epoch}' if epoch is not None else 'Latent Space',
            xaxis_title='Latent Dimension 1',
            yaxis_title='Latent Dimension 2',
            width=800,
            height=600
        )

        # Save plot as PNG image
        epoch_str = f"epoch_{epoch:04d}" if epoch is not None else "final"
        filename = f"{self.latent_space_dir}/{epoch_str}_latent_space.png"
        fig.write_image(filename)
        print(f"Saved latent space plot: {filename}")


In [None]:
# -*- coding: utf-8 -*-
'''
Theoretically-Grounded Contrastive Autoencoder for Tabular Data (ContrastiveAE)
- Implements tabular-appropriate contrastive learning with VICReg and Barlow Twins
- Uses feature-type-aware augmentation strategies
- Simplified architecture optimized for tabular data structure
- Theoretically-motivated training strategies
- Designed for unsupervised clustering of heterogeneous tabular features
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import shutil
import math
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.manifold import TSNE
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
from scipy.optimize import linear_sum_assignment
from typing import Optional, List, Tuple, Union
from sklearn.cluster import KMeans


class TabularBlock(nn.Module):
    """Optimized block for tabular data without spatial assumptions."""

    def __init__(self, in_dim: int, out_dim: int, use_dropout: bool = True, dropout_rate: float = 0.1):
        super(TabularBlock, self).__init__()

        self.linear = nn.Linear(in_dim, out_dim)
        self.norm = nn.LayerNorm(out_dim)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout_rate if use_dropout else 0.0)

        # Optional skip connection only when dimensions match and theoretically justified
        self.use_skip = (in_dim == out_dim) and (out_dim >= 64)  # Only for larger representations

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.linear(x)
        h = self.norm(h)
        h = self.activation(h)
        h = self.dropout(h)

        # Conditional skip connection based on representation size
        if self.use_skip:
            return h + x
        return h


class ContrastiveAE(nn.Module):
    """
    Contrastive Autoencoder for representation clustering.

    - No special structure, general(tabular) architecture
        - No convolutional operations (i.e., no local receptive fields or spatial structure)
        - No positional encodings
        - No inductive priors for spatial locality

    - Recommended config:
        - 64    -> [128, 128, 128] -> [max(128, 64*4) -> 64] -> ([..,..,..])
          input    encoder_backbone   projection_head           decoder(optional)
    """

    def __init__(self,
                 input_dim: int,
                 hidden_dims: List[int],
                 latent_dim: int = 2,
                 use_drop: bool = True,
                 warmup_epochs: int = 20,
                 ):
        """
        Initialize theoretically-grounded Contrastive Autoencoder.

        Args:
            input_dim: Input dimension
            hidden_dims: List of hidden layer dimensions for backbone/decoder
            latent_dim: Latent space dimension (projection head output)
            use_drop: Whether to use Dropout
            warmup_epochs: Number of epochs to exclude uniformity and SwAV losses
        """
        super(ContrastiveAE, self).__init__()

        assert len(hidden_dims) > 1, "hidden_dims must have at least 2 layers"

        self.input_dim = input_dim
        self.hidden_dims = hidden_dims[:-1]
        self.backbone_dim = hidden_dims[-1]
        self.head_dim = latent_dim
        self.use_drop = use_drop
        self.dropout_rate = 0.15  # More conservative for tabular data

        # Warm-up training strategy
        self.warmup_epochs = warmup_epochs

        # Feature type awareness for proper augmentation
        self.categorical_features = []
        self.continuous_features = list(range(input_dim))
        self.sequential_feature_blocks = list(range(input_dim))

        print(f"architecture: {self.input_dim} → [{self.hidden_dims}, {self.backbone_dim}] → [{max(self.backbone_dim, self.head_dim * 4)}, {self.head_dim}]")
        print(f"Feature types - Categorical: {len(self.categorical_features)}, Continuous: {len(self.continuous_features)}")
        print(f"Warm-up epochs: {self.warmup_epochs} (excluding uniformity & SwAV losses)")

        # Theoretically-motivated training parameters
        self.batch_size = 1024  # some loss are sensitive to batch size
        self.max_epochs = 80
        self.lr = 0.005

        # Force CPU usage
        self.device = torch.device('cpu')
        self.to(self.device)

        # Create latent space folder
        self.latent_space_dir = os.path.join(os.path.dirname(__file__), "latent_space")
        if os.path.exists(self.latent_space_dir):
            shutil.rmtree(self.latent_space_dir)
        os.makedirs(self.latent_space_dir, exist_ok=True)

        # Build theoretically-motivated architecture
        self._build_encoder_backbone()
        self._build_projection_head()
        self._build_decoder()

        # Move to device
        self.to(self.device)

        # Training state tracking
        self.epochs = 0

        # --- K-means based clustering (direct optimal center finding) ---
        self.n_prototypes = 15  # number of clusters
        self.temperature = 0.1  # temperature for soft assignment
        self.cluster_on_repre = True  # whether to cluster on representation space (h) or projection space (z)

        # Initialize cluster centers (will be updated by K-means)
        prototype_dim = self.backbone_dim if self.cluster_on_repre else self.head_dim
        self.cluster_centers = torch.randn(self.n_prototypes, prototype_dim) * 2.0

        # Remove learnable prototypes - we find optimal centers directly
        # But keep Sinkhorn-Knopp for SwAV loss with fixed centers
        self.sinkhorn_epsilon = 0.5  # Sinkhorn-Knopp epsilon for balanced assignments

        # Explicit loss weights
        self.L_recon_weight = 1.0
        self.L_var_weight = 1.0
        self.L_diag_weight = 10.0
        self.L_offdiag_weight = 10.0
        self.L_unif_weight = 1.0
        self.L_swav_weight = 0.3

        # Epoch loss accumulators (sums)
        self.epoch_sum_L_total = 0.0
        self.epoch_sum_L_recon = 0.0
        self.epoch_sum_L_var = 0.0
        self.epoch_sum_L_diag = 0.0
        self.epoch_sum_L_offdiag = 0.0
        self.epoch_sum_L_unif = 0.0
        self.epoch_sum_L_swav = 0.0

        # --- Diagnostics & history ---
        self.training_history = {
            'epochs': [],
            'loss_total': [],
            'loss_recon': [],
            'loss_var': [],
            'loss_diag': [],
            'loss_offdiag': [],
            'loss_unif': [],
            'loss_swav': [],
            'cluster_accuracy': [],
            'n_clusters': []
        }

        self.last_validation_results = None

    def _build_encoder_backbone(self) -> None:
        layers = []

        # Progressive dimensionality reduction
        prev_dim = self.input_dim
        for hidden_dim in self.hidden_dims + [self.backbone_dim]:
            layers.append(TabularBlock(prev_dim, hidden_dim,
                                       use_dropout=self.use_drop,
                                       dropout_rate=self.dropout_rate))
            prev_dim = hidden_dim

        self.encoder = nn.Sequential(*layers)

    def _build_projection_head(self) -> None:
        # - Without g(projection head), the encoder itself directly optimizes the contrastive loss, which can:
        #   1. Harm the geometry of h(backbone representation)
        #   2. Force h to be optimized specifically for the contrastive objective only
        #   3. Reduce transferability(downstream tasks performance).
        # - As information bottleneck, it only let features related to the contrastive loss to pass through
        #   while keeping the other features useful to downstream tasks in encoder backbone

        # Single hidden layer projection - sufficient for tabular data
        hidden_proj = max(self.backbone_dim, self.head_dim * 4)

        self.projection_head = nn.Sequential(
            nn.Linear(self.backbone_dim, hidden_proj),
            nn.LayerNorm(hidden_proj),
            nn.GELU(),
            nn.Dropout(self.dropout_rate * 0.5),  # Reduced dropout in projection
            nn.Linear(hidden_proj, self.head_dim)
        )

    def _build_decoder(self) -> None:
        layers = []

        # Symmetric decoder
        prev_dim = self.backbone_dim
        for hidden_dim in reversed(self.hidden_dims):
            layers.append(TabularBlock(prev_dim, hidden_dim,
                                       use_dropout=self.use_drop,
                                       dropout_rate=self.dropout_rate * 0.5))  # Reduced dropout in decoder
            prev_dim = hidden_dim

        # Final output layer
        layers.append(nn.Linear(prev_dim, self.input_dim))

        self.decoder = nn.Sequential(*layers)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """Encode input to backbone representation."""
        return self.encoder(x)

    def project(self, z: torch.Tensor) -> torch.Tensor:
        """Project backbone representation for contrastive learning."""
        return self.projection_head(z)

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """Decode backbone representation to input space."""
        return self.decoder(z)

    def forward_half(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass returning representation h and projection z."""
        h = self.encode(x)               # Representation space (encoder output)
        z = self.project(h)              # Projection space (head output)
        return h, z

    def forward_full(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Forward through encoder, projection head, then decoder."""
        h = self.encode(x)               # Representation
        z = self.project(h)              # Projection
        x_recon = self.decode(h)         # Decode from representation
        return x_recon, h, z

    def get_representation(self, x: torch.Tensor) -> torch.Tensor:
        """Get encoder backbone representation (before projection head) - better for clustering."""
        self.eval()
        with torch.no_grad():
            h = self.encode(x)
            return h

    def get_projection(self, x: torch.Tensor) -> torch.Tensor:
        """Get projected representation for contrastive learning."""
        self.eval()
        with torch.no_grad():
            h = self.encode(x)
            z = self.project(h)
            return z

    def get_cluster_labels(self, x: torch.Tensor) -> torch.Tensor:
        """Return hard cluster ids for each sample based on prototype similarity."""
        self.eval()
        with torch.no_grad():
            h = self.get_representation(x.to(self.device))
            z = self.get_projection(x.to(self.device))
            c = h if self.cluster_on_repre else z
            c = F.normalize(c, p=2, dim=1)
            clusters_norm = F.normalize(self.cluster_centers, p=2, dim=1)
            sims = torch.matmul(c, clusters_norm.T)  # [B, K]
            return torch.argmax(sims, dim=1)

    def _tabular_augment(self, x: torch.Tensor) -> torch.Tensor:
        """
        Theoretically-grounded augmentation for tabular data.

        Key principles:
        - Preserve semantic meaning of features
        - Respect feature types (categorical vs continuous)
        - Minimal distortion while maintaining class-relevant information
        """
        batch_size = x.shape[0]
        augmented = x.clone()

        # Progressive augmentation strength
        progress = min(1.0, self.epochs / (self.max_epochs * 0.7))
        base_strength = 0.05 + 0.1 * progress

        # Continuous feature augmentation: small gaussian noise
        if self.continuous_features:
            continuous_mask = torch.zeros(self.input_dim, dtype=torch.bool)
            continuous_mask[self.continuous_features] = True

            # Adaptive noise based on feature variance
            feature_stds = x[:, continuous_mask].std(dim=0, keepdim=True)
            noise_scale = feature_stds * base_strength
            noise = torch.randn_like(x[:, continuous_mask]) * noise_scale

            augmented[:, continuous_mask] += noise

        # Categorical feature augmentation: minimal dropout only
        if self.categorical_features:
            categorical_mask = torch.zeros(self.input_dim, dtype=torch.bool)
            categorical_mask[self.categorical_features] = True

            # Very conservative dropout for categorical features
            dropout_prob = base_strength * 0.1  # Much smaller than continuous
            dropout_mask = torch.rand_like(x[:, categorical_mask]) > dropout_prob
            augmented[:, categorical_mask] *= dropout_mask.float()

        # Feature mixing: swap features between samples of potentially same cluster
        if batch_size > 1 and torch.rand(1).item() < 0.3:  # 30% chance
            # Random pairwise mixing
            indices = torch.randperm(batch_size)
            mix_ratio = torch.rand(batch_size, 1) * 0.2  # Max 20% mixing
            augmented = (1 - mix_ratio) * augmented + mix_ratio * augmented[indices]

        # Only apply sequential/categorical feature block if categorical_features is not empty
        if self.sequential_feature_blocks:
            seq = augmented[:, self.sequential_feature_blocks]  # shape [B, T]
            # Apply small drift or smooth noise
            seq_noise = torch.randn_like(seq) * base_strength * 0.5
            seq_noise = F.avg_pool1d(seq_noise.unsqueeze(1), kernel_size=3, stride=1, padding=1).squeeze(1)
            augmented[:, self.sequential_feature_blocks] = seq + seq_noise

        return augmented

    def _update_cluster_centers(self, X: torch.Tensor) -> None:
        """Update cluster centers using K-means with K-means++ initialization."""
        self.eval()
        with torch.no_grad():
            h = self.get_representation(X)
            z = self.get_projection(X)
            c = h if self.cluster_on_repre else z
            c_np = c.cpu().numpy()

            # Run K-means with K-means++ initialization for better initial centers
            kmeans = KMeans(n_clusters=self.n_prototypes, init='k-means++',
                            random_state=42, n_init='auto', max_iter=100)
            kmeans.fit(c_np)

            # Update cluster centers
            self.cluster_centers = torch.tensor(kmeans.cluster_centers_,
                                                dtype=torch.float32, device=self.device)

            # Print diagnostics
            center_distances = torch.pdist(self.cluster_centers)
            # print(f"Updated K-means centers - avg distance: {center_distances.mean().item():.3f}, "
            #       f"min distance: {center_distances.min().item():.3f}")

    def _sinkhorn_knopp(self, scores: torch.Tensor, num_iters: int = 3) -> torch.Tensor:
        """Approximate balanced assignments via the Sinkhorn-Knopp algorithm.
        Args:
            scores: similarity matrix [B, K]
            num_iters: number of normalisation iterations
        Returns:
            Doubly-stochastic assignment matrix Q of shape [B, K]
        """
        with torch.no_grad():
            # Clamp exponent to avoid overflow/underflow
            Q = torch.exp(torch.clamp(scores / self.sinkhorn_epsilon, min=-20, max=20))
            Q = Q.T  # [K, B]
            sum_Q = Q.sum()
            Q /= sum_Q
            K, B = Q.shape
            r = torch.ones(K, device=Q.device) / K
            c = torch.ones(B, device=Q.device) / B
            for _ in range(num_iters):
                u = Q.sum(dim=1)
                Q *= (r / u).unsqueeze(1)
                v = Q.sum(dim=0)
                Q *= (c / v)
            return (Q / Q.sum(dim=0, keepdim=True)).T  # [B, K]

    def _loss(self, x: torch.Tensor, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        # Representation(latent) space: h = f(x)
        # Projection space: z = g(h)
        # NOTE: apply loss to representation space maybe detrimental to semantic feature extraction, use with caution

        # VICReg: Variance(prevent collapse) + Diagonal/Invariance(match views) + off-Diagonal/Covariance/De-correlation
        # Barlow Twins: Diagonal + off-Diagonal
        # ours = Reconstruction(x->h->x') + Barlow(z) + Variance(z) + Uniformity(z) + clustering(h)

        # Check if we're in warm-up phase
        in_warmup = self.epochs < self.warmup_epochs

        # Forward pass for contrastive learning
        h1, z1 = self.forward_half(x1)   # h: representation, z: projection
        h2, z2 = self.forward_half(x2)

        # Forward pass for reconstruction
        x_recon, _, _ = self.forward_full(x)

        # =========================================================================================
        # Compute reconstruction loss
        recon_loss = F.mse_loss(x_recon, x, reduction='mean')

        # =========================================================================================
        # Variance loss (VICReg) (projection space)
        eps = 1e-6
        std_z1 = torch.sqrt(z1.var(dim=0, unbiased=False) + eps)
        std_z2 = torch.sqrt(z2.var(dim=0, unbiased=False) + eps)
        variance_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2))

        # =========================================================================================
        # Barlow Twins loss (projection space)
        batch_size, dim = z1.shape
        # Zero-mean, unit-variance normalisation across batch for each dimension
        z1_norm = (z1 - z1.mean(dim=0, keepdim=True)) / (z1.std(dim=0, keepdim=True, unbiased=False) + eps)
        z2_norm = (z2 - z2.mean(dim=0, keepdim=True)) / (z2.std(dim=0, keepdim=True, unbiased=False) + eps)

        # Cross-correlation matrix c_{ij} = E_b[ z1_i * z2_j ]
        cross_corr = torch.matmul(z1_norm.T, z2_norm) / batch_size  # [dim, dim]

        # On-diagonal loss: (c_{ii} − 1)^2 averaged over dimensions
        on_diag = torch.diagonal(cross_corr)
        diagonal_loss = ((on_diag - 1.0) ** 2).mean()

        # Off-diagonal loss: sum_{i≠j} c_{ij}^2, normalised by #elements
        off_diag_sum = cross_corr.pow(2).sum() - on_diag.pow(2).sum()
        off_diagonal_loss = off_diag_sum / (dim * (dim - 1))

        # =========================================================================================
        # Uniformity loss (Wang & Isola 2020)
        # It will steadily approach a non-zero theoretical limit that = e^-c (c ≈ expected squared distance)
        # Uniformity loss (Wang & Isola 2020) - ONLY after warm-up
        if not in_warmup:
            z_all = torch.cat([z1, z2], dim=0)  # [2B, D]
            z_all = F.normalize(z_all, p=2, dim=1)  # Unit-norm embeddings

            if z_all.size(0) <= 1:
                uniformity_loss = torch.tensor(0.0, device=z_all.device)
            else:
                # Pairwise Euclidean distances for all unique pairs (upper-triangular vector)
                pairwise_dist = torch.pdist(z_all, p=2)  # [N*(N−1)/2]
                # Uniformity regulariser: E[ exp( −t · d² ) ] with t = 2, scaled by dimension
                t = 2.0 / dim  # Scale temperature by embedding dimension
                uniformity_loss = torch.exp(-t * pairwise_dist.pow(2)).mean()
        else:
            uniformity_loss = torch.tensor(0.0, device=z1.device)

        # =========================================================================================
        # SwAV clustering loss with fixed K-means centers - ONLY after warm-up
        # SwAV lacks explicit contrastive pull on positive pairs like other methods
        # it does so via cluster assignment consistency → less sharp gradient signals -> longer training time
        # this is complementary so you do not need to wait for it to completely converge
        if not in_warmup:
            c1 = h1 if self.cluster_on_repre else z1
            c2 = h2 if self.cluster_on_repre else z2
            c1_norm = F.normalize(c1, p=2, dim=1)
            c2_norm = F.normalize(c2, p=2, dim=1)

            # Use fixed K-means centers (updated each epoch)
            centers_norm = F.normalize(self.cluster_centers, p=2, dim=1)  # [K, D]

            # Compute similarities/logits between representations and fixed centers
            logits1 = torch.matmul(c1_norm, centers_norm.T) / self.temperature  # [B, K]
            logits2 = torch.matmul(c2_norm, centers_norm.T) / self.temperature

            # Balanced assignments using Sinkhorn-Knopp (prevents cluster collapse)
            with torch.no_grad():
                q1 = self._sinkhorn_knopp(logits1)  # [B, K]
                q2 = self._sinkhorn_knopp(logits2)  # [B, K]

            # SwAV loss: cross-entropy between assignments of one view and logits of the other
            # This enforces consistency between augmented views
            swav_loss = (- torch.mean(torch.sum(q1 * F.log_softmax(logits2, dim=1), dim=1))
                         - torch.mean(torch.sum(q2 * F.log_softmax(logits1, dim=1), dim=1))) * 0.5
        else:
            swav_loss = torch.tensor(0.0, device=z1.device)

        # -----------------------------------------------------------------------------------------
        # Weight & sum losses
        recon_loss = recon_loss * self.L_recon_weight
        variance_loss = variance_loss * self.L_var_weight
        diagonal_loss = diagonal_loss * self.L_diag_weight
        off_diagonal_loss = off_diagonal_loss * self.L_offdiag_weight
        uniformity_loss = uniformity_loss * self.L_unif_weight
        swav_loss = swav_loss * self.L_swav_weight

        total_loss = (recon_loss + variance_loss + diagonal_loss + off_diagonal_loss + uniformity_loss + swav_loss)

        # Accumulate epoch sums
        self.epoch_sum_L_total += total_loss.item()
        self.epoch_sum_L_recon += recon_loss.item()
        self.epoch_sum_L_var += variance_loss.item()
        self.epoch_sum_L_diag += diagonal_loss.item()
        self.epoch_sum_L_offdiag += off_diagonal_loss.item()
        self.epoch_sum_L_unif += uniformity_loss.item()
        self.epoch_sum_L_swav += swav_loss.item()

        return total_loss

    def fit(self, X: torch.Tensor, y: Optional[torch.Tensor] = None) -> None:
        """Train the contrastive autoencoder."""
        assert isinstance(X, torch.Tensor), "X must be a torch.Tensor"
        n_samples = X.shape[0]
        X = X.to(self.device)

        if y is not None:
            if not isinstance(y, torch.Tensor):
                y = torch.tensor(y)
            y = y.to(self.device)

        # Enhanced optimizer with better parameters
        optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=1e-4,
                                betas=(0.9, 0.999), eps=1e-8)

        # Improved learning rate scheduler with gradual warm-up
        warmup_lr_epochs = min(self.warmup_epochs, self.max_epochs // 8)  # Use warmup_epochs for consistency

        def lr_lambda(epoch: int) -> float:
            if epoch < warmup_lr_epochs:
                # Gradual linear increase from 0.1 to 1.0 over warmup period
                return 0.1 + 0.9 * (epoch + 1) / warmup_lr_epochs
            else:
                return 1

        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

        self.epochs = 0

        while self.epochs < self.max_epochs:
            # Check if we're transitioning out of warm-up
            if self.epochs == self.warmup_epochs:
                print(f"\n🔄 Warm-up completed! Resetting optimizer and enabling uniformity & SwAV losses...")

                # Reset optimizer with fresh state
                optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=1e-4,
                                        betas=(0.9, 0.999), eps=1e-8)

                # Reset scheduler with adjusted epoch count
                def lr_lambda_post_warmup(epoch: int) -> float:
                    adjusted_epoch = epoch - self.warmup_epochs
                    if adjusted_epoch < warmup_lr_epochs:
                        # Gradual linear increase from 0.1 to 1.0 over warmup period
                        return 0.1 + 0.9 * (adjusted_epoch + 1) / warmup_lr_epochs
                    else:
                        decay_epochs = max(1, ((self.max_epochs - self.warmup_epochs) - warmup_lr_epochs) // 5)
                        return max(0.1, 0.95 ** ((adjusted_epoch - warmup_lr_epochs) // decay_epochs))

                scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_post_warmup)

            # Update K-means centers ONLY after warm-up (when clustering losses are active)
            if self.epochs >= self.warmup_epochs:
                self._update_cluster_centers(X)

            self.train()

            indices = torch.randperm(n_samples, device=self.device)
            num_batches = 0

            # Reset epoch accumulators
            self.epoch_sum_L_total = 0.0
            self.epoch_sum_L_recon = 0.0
            self.epoch_sum_L_var = 0.0
            self.epoch_sum_L_diag = 0.0
            self.epoch_sum_L_offdiag = 0.0
            self.epoch_sum_L_unif = 0.0
            self.epoch_sum_L_swav = 0.0

            for i in range(0, n_samples, self.batch_size):
                end_idx = min(i + self.batch_size, n_samples)
                batch_indices = indices[i:end_idx]
                x = X[batch_indices]

                # Create two augmented views
                x1 = self._tabular_augment(x)
                x2 = self._tabular_augment(x)

                # Compute loss (updates self.epoch_sum_L_*)
                loss = self._loss(x, x1, x2)

                optimizer.zero_grad()
                loss.backward()
                # torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) # clip gradient norm to prevent exploding gradients
                optimizer.step()

                num_batches += 1

            scheduler.step()

            # Average loss components over batches
            avg_L_total = self.epoch_sum_L_total / num_batches
            avg_L_recon = self.epoch_sum_L_recon / num_batches
            avg_L_var = self.epoch_sum_L_var / num_batches
            avg_L_diag = self.epoch_sum_L_diag / num_batches
            avg_L_offdiag = self.epoch_sum_L_offdiag / num_batches
            avg_L_unif = self.epoch_sum_L_unif / num_batches
            avg_L_swav = self.epoch_sum_L_swav / num_batches

            current_lr = optimizer.param_groups[0]['lr']

            # --------------------------------------------------------------
            # Prototype-based clustering accuracy (only compute after warm-up)
            # --------------------------------------------------------------
            if y is not None and self.epochs >= self.warmup_epochs:
                with torch.no_grad():
                    h = self.get_representation(X)             # [N, D_h]
                    z = self.get_projection(X)                 # [N, D_z]
                    c = h if self.cluster_on_repre else z
                    c_norm = F.normalize(c, p=2, dim=1)
                    proto_norm = F.normalize(self.cluster_centers, p=2, dim=1)
                    # Use cosine similarity (same as get_cluster_labels method)
                    sims = torch.matmul(c_norm, proto_norm.T)       # [N, K]
                    cluster_labels = torch.argmax(sims, dim=1)      # [N] - most similar prototype

                    majority_labels = torch.empty_like(cluster_labels)
                    for c in cluster_labels.unique():
                        mask = cluster_labels == c
                        if mask.sum() == 0:
                            continue
                        majority = torch.bincount(y[mask].to(torch.int64)).argmax()
                        majority_labels[mask] = majority
                    cluster_accuracy = (majority_labels == y).float().mean().item()
            else:
                cluster_accuracy = 0.0

            n_clusters = self.n_prototypes if self.epochs >= self.warmup_epochs else 0

            # Print loss components with warm-up indicator
            warmup_indicator = " [WARM-UP]" if self.epochs < self.warmup_epochs else ""
            print(f"Epoch {self.epochs:3d}{warmup_indicator} | L_total: {avg_L_total:7.4f} | "
                  f"L_recon: {avg_L_recon:7.4f} | "
                  f"L_var: {avg_L_var:7.4f} | "
                  f"L_diag: {avg_L_diag:7.4f} | "
                  f"L_offdiag: {avg_L_offdiag:7.4f} | "
                  f"L_unif: {avg_L_unif:7.4f} | "
                  f"L_swav: {avg_L_swav:7.4f} | "
                  f"LR: {current_lr:.5f} | "
                  f"Acc: {cluster_accuracy:.3f} | Clusters: {n_clusters}")

            # Save latent space every 10 epochs
            if self.epochs % 10 == 0:
                self.save_latent_space(X, y, self.epochs)

            # Record training history
            self.training_history['epochs'].append(self.epochs)
            self.training_history['loss_total'].append(avg_L_total)
            self.training_history['loss_recon'].append(avg_L_recon)
            self.training_history['loss_var'].append(avg_L_var)
            self.training_history['loss_diag'].append(avg_L_diag)
            self.training_history['loss_offdiag'].append(avg_L_offdiag)
            self.training_history['loss_unif'].append(avg_L_unif)
            self.training_history['loss_swav'].append(avg_L_swav)
            self.training_history['cluster_accuracy'].append(cluster_accuracy)
            self.training_history['n_clusters'].append(n_clusters)

            self.epochs += 1

        print("🚀 Training completed!")

        # Final diagnostics after training
        self.run_diagnostics(X, y)

    # --------------------------------------------------
    # Latent space visualisation & diagnostics utilities
    # --------------------------------------------------

    def save_latent_space(self, X: torch.Tensor, y: Optional[torch.Tensor], epoch: Optional[int] = None) -> None:
        """Save t-SNE visualisations of backbone & projection spaces."""
        if X.size(0) > 10000:
            return  # skip very large datasets to save time/memory

        self.eval()
        with torch.no_grad():
            h = self.get_representation(X)
            z = self.get_projection(X)

        h_np = h.cpu().numpy()
        z_np = z.cpu().numpy()
        y_np = None
        if y is not None:
            y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y

        # Cluster labels from learned prototypes
        proto_norm = F.normalize(self.cluster_centers, p=2, dim=1)
        c = h if self.cluster_on_repre else z
        c_norm = F.normalize(c, p=2, dim=1)
        # Use cosine similarity (same as get_cluster_labels method)
        sims = torch.matmul(c_norm, proto_norm.T)
        cluster_labels = torch.argmax(sims, dim=1).cpu().numpy()

        n_clusters = self.n_prototypes

        # Get actual prototype positions in the space used for clustering
        prototypes_np = self.cluster_centers.detach().cpu().numpy()  # [K, D]

        # Combine data points with prototypes for joint t-SNE embedding
        if self.cluster_on_repre:
            # Prototypes are in representation space, combine with h
            combined_data = np.vstack([h_np, prototypes_np])  # [N+K, D_h]
        else:
            # Prototypes are in projection space, combine with z
            combined_data = np.vstack([z_np, prototypes_np])   # [N+K, D_z]

        # Compute t-SNE embeddings for combined data (points + prototypes)
        perplexity = min(30, max(5, X.size(0)//4))
        tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42, max_iter=300, early_exaggeration=12.0)
        combined_tsne = tsne.fit_transform(combined_data)

        # Split back into data points and prototype positions
        h_tsne = combined_tsne[:X.size(0)]          # Data points in t-SNE space
        center_tsne = combined_tsne[X.size(0):]     # Actual prototype positions in t-SNE space

        # Compute separate t-SNE for projection space visualization
        z_tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42, max_iter=300, early_exaggeration=12.0).fit_transform(z_np)

        # Determine correctness of cluster assignments if labels available
        correctly_identified = None
        false_identified = None
        if y_np is not None:
            majority_labels = np.full_like(cluster_labels, -1)
            for cluster_id in np.unique(cluster_labels):
                mask_cl = cluster_labels == cluster_id
                if np.any(mask_cl):
                    true_labels, counts = np.unique(y_np[mask_cl], return_counts=True)
                    majority_labels[mask_cl] = true_labels[np.argmax(counts)]
            correctly_identified = (y_np == majority_labels)
            false_identified = ~correctly_identified

        colors = px.colors.qualitative.Set1 + px.colors.qualitative.Set2 + px.colors.qualitative.Set3

        # Create two sub-plots: (1) representation space, (2) projection space
        fig = make_subplots(rows=1, cols=2,
                            subplot_titles=["Representation space (h)", "Projection space (z)"])

        # Plot samples
        if y_np is not None:
            unique_labels = np.unique(y_np)
            # Correct samples
            for i, label in enumerate(unique_labels):
                mask = (y_np == label) & correctly_identified
                if np.any(mask):
                    color = colors[i % len(colors)]
                    fig.add_trace(go.Scatter(x=h_tsne[mask, 0], y=h_tsne[mask, 1], mode='markers',
                                             marker=dict(size=6, color=color, opacity=0.8, line=dict(width=0)),
                                             name=f"Class {label} (Correct)", legendgroup=f"class_{label}_correct"),
                                  row=1, col=1)

            # False samples: thin black outline circles (no alpha blending)
            for i, label in enumerate(unique_labels):
                mask = (y_np == label) & false_identified
                if np.any(mask):
                    color = colors[i % len(colors)]
                    fig.add_trace(go.Scatter(x=h_tsne[mask, 0], y=h_tsne[mask, 1], mode='markers',
                                             marker=dict(size=6, symbol='circle-open', color='black',
                                                         line=dict(width=1, color='black')),
                                             name=f"Class {label} (False)", legendgroup=f"class_{label}_false"),
                                  row=1, col=1)
        else:
            fig.add_trace(go.Scatter(x=h_tsne[:, 0], y=h_tsne[:, 1], mode='markers',
                                     marker=dict(size=5, color='blue', opacity=0.7), name='Samples'),
                          row=1, col=1)

        # Plot actual learned prototypes (cluster centres) on representation subplot
        fig.add_trace(go.Scatter(x=center_tsne[:, 0], y=center_tsne[:, 1], mode='markers',
                                 marker=dict(symbol='cross', color='red', size=10, line=dict(width=2)),
                                 name='Learned Prototypes'),
                      row=1, col=1)

        # ------------------------- PROJECTION SPACE PLOT -------------------------
        if y_np is not None:
            unique_labels = np.unique(y_np)
            for i, label in enumerate(unique_labels):
                mask = (y_np == label)
                if np.any(mask):
                    color = colors[i % len(colors)]
                    fig.add_trace(go.Scatter(x=z_tsne[mask, 0], y=z_tsne[mask, 1], mode='markers',
                                             marker=dict(size=6, color=color, opacity=0.8, line=dict(width=0)),
                                             name=f"Class {label} (Proj)", legendgroup=f"class_{label}_proj"),
                                  row=1, col=2)
        else:
            fig.add_trace(go.Scatter(x=z_tsne[:, 0], y=z_tsne[:, 1], mode='markers',
                                     marker=dict(size=5, color='blue', opacity=0.7), name='Samples (Proj)'),
                          row=1, col=2)

        # Axis titles for both sub-plots
        fig.update_xaxes(title_text="t-SNE Dim 1", row=1, col=1)
        fig.update_yaxes(title_text="t-SNE Dim 2", row=1, col=1)
        fig.update_xaxes(title_text="t-SNE Dim 1", row=1, col=2)
        fig.update_yaxes(title_text="t-SNE Dim 2", row=1, col=2)

        epoch_str = f"epoch_{epoch:04d}" if epoch is not None else "final"
        fig.update_layout(title=f"Representation vs Projection t-SNE ({epoch_str}) | Clusters: {n_clusters}",
                          width=1200, height=600)

        outfile = os.path.join(self.latent_space_dir, f"{epoch_str}_cluster_space.png")
        fig.write_image(outfile)

    def run_diagnostics(self, X: torch.Tensor, y: Optional[torch.Tensor] = None, use_cached: bool = True) -> None:
        """Compute and print diagnostic metrics after training."""
        if use_cached and self.last_validation_results is not None:
            validation = self.last_validation_results
        else:
            self.eval()
            with torch.no_grad():
                h = self.get_representation(X)
                z = self.get_projection(X)

                x1 = self._tabular_augment(X)
                x2 = self._tabular_augment(X)
                _, z1 = self.forward_half(x1)
                _, z2 = self.forward_half(x2)

            h_np = h.cpu().numpy()
            z_np = z.cpu().numpy()
            z1_np = z1.cpu().numpy()
            z2_np = z2.cpu().numpy()

            validation = {}

            # Representation quality
            validation['representation_quality'] = {
                'backbone_variance': float(h_np.var()),
                'projection_variance(>1)': float(z_np.var())
            }

            # Invariance
            mse_err = np.mean((z1_np - z2_np) ** 2)
            validation['augmentation_invariance'] = {
                'mse_error': float(mse_err)
            }

            # Decorrelations
            corr = np.corrcoef(z_np.T)
            off_diag = corr - np.diag(np.diag(corr))
            validation['feature_decorrelation'] = {
                'max_correlation(<0.5)': float(np.abs(off_diag).max())
            }

            # Dimensionality
            eigvals = np.linalg.eigvals(np.cov(z_np.T))
            eigvals = np.real(eigvals[eigvals > 1e-10])
            pr = (eigvals.sum() ** 2) / (np.square(eigvals).sum()) if eigvals.size > 0 else 0.0
            validation['dimensionality_analysis'] = {
                'effective_dimensions(> 0.5 * proj_dim)': float(pr)
            }

            # Clustering validation based on prototypes
            if y is not None:
                y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y
                c = h if self.cluster_on_repre else z
                c_norm = F.normalize(c, p=2, dim=1)
                proto_norm = F.normalize(self.cluster_centers, p=2, dim=1)
                # Use cosine similarity (same as get_cluster_labels method)
                sims = torch.matmul(c_norm, proto_norm.T)
                cluster_labels = torch.argmax(sims, dim=1).cpu().numpy()

                sil = silhouette_score(z_np, y_np)
                ari = adjusted_rand_score(y_np, cluster_labels)
                nmi = normalized_mutual_info_score(y_np, cluster_labels)

                validation['clustering_validation'] = {
                    'silhouette_projection (> 0.3)': float(sil),
                    'ari_score (> 0.5)': ari,
                    'nmi_score (> 0.5)': nmi
                }

            self.last_validation_results = validation

        # Pretty print
        print("\nDiagnostic Metrics:\n" + "-" * 60)
        for k, v in validation.items():
            print(f"{k}: {v}")
