In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
# import tsne
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

In [15]:
# ------------------------------
# VaDE Model Definition
# ------------------------------

class VaDE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=500, latent_dim=10, n_clusters=10):
        super(VaDE, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.n_clusters = n_clusters
        
        # Encoder network
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder network
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # For MNIST, pixel values in [0,1]
        )
        
        # GMM Prior Parameters:
        # 1. Cluster prior logits; softmax gives pi (mixing coefficients)
        self.pi_logits = nn.Parameter(torch.zeros(n_clusters))
        # 2. Cluster means: shape (n_clusters, latent_dim)
        self.mu_c = nn.Parameter(torch.zeros(n_clusters, latent_dim))
        # 3. Cluster log variances: shape (n_clusters, latent_dim)
        self.logvar_c = nn.Parameter(torch.zeros(n_clusters, latent_dim))  # Init 0 => variance 1
        
    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar, z
    
    def gmm_params(self):
        # Return the cluster prior probabilities (pi), means, and log variances.
        pi = F.softmax(self.pi_logits, dim=0)  # (n_clusters,)
        return pi, self.mu_c, self.logvar_c


In [16]:
# ------------------------------
# Helper Functions
# ------------------------------

def log_gaussian(z, mu, logvar):
    """
    Compute the log probability of z under a Gaussian with parameters (mu, logvar).
    z: (batch_size, latent_dim)
    mu, logvar: (n_clusters, latent_dim)
    Returns: log_prob of shape (batch_size, n_clusters)
    """
    batch_size = z.size(0)
    n_clusters = mu.size(0)
    latent_dim = z.size(1)
    
    # Expand z to (batch_size, n_clusters, latent_dim)
    z_expanded = z.unsqueeze(1).expand(batch_size, n_clusters, latent_dim)
    mu_expanded = mu.unsqueeze(0).expand(batch_size, n_clusters, latent_dim)
    logvar_expanded = logvar.unsqueeze(0).expand(batch_size, n_clusters, latent_dim)
    
    quadratic = ((z_expanded - mu_expanded) ** 2) / torch.exp(logvar_expanded)
    log_prob = -0.5 * (latent_dim * np.log(2 * np.pi) + torch.sum(logvar_expanded, dim=2) + torch.sum(quadratic, dim=2))
    return log_prob  # (batch_size, n_clusters)

def gaussian_kl(mu, logvar, mu_c, logvar_c):
    """
    Compute the KL divergence between q(z|x) = N(mu, exp(logvar))
    and each Gaussian N(mu_c, exp(logvar_c)) for every cluster.
    Returns: (batch_size, n_clusters) tensor of KL values.
    """
    batch_size = mu.size(0)
    n_clusters = mu_c.size(0)
    latent_dim = mu.size(1)
    
    mu_expanded = mu.unsqueeze(1).expand(batch_size, n_clusters, latent_dim)
    logvar_expanded = logvar.unsqueeze(1).expand(batch_size, n_clusters, latent_dim)
    mu_c_expanded = mu_c.unsqueeze(0).expand(batch_size, n_clusters, latent_dim)
    logvar_c_expanded = logvar_c.unsqueeze(0).expand(batch_size, n_clusters, latent_dim)
    
    kl_element = 0.5 * (
        logvar_c_expanded - logvar_expanded +
        (torch.exp(logvar_expanded) + (mu_expanded - mu_c_expanded) ** 2) / torch.exp(logvar_c_expanded) - 1
    )
    kl = torch.sum(kl_element, dim=2)  # Sum over latent dimensions
    return kl  # (batch_size, n_clusters)

def vae_loss(x, x_recon, mu, logvar, z, model):
    """
    Compute the VaDE loss which includes:
      - Reconstruction loss: binary cross-entropy between x and its reconstruction.
      - KL divergence: measures the difference between the encoder's distribution and the GMM prior.
    """
    batch_size = x.size(0)
    # Reconstruction loss (averaged over batch)
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum') / batch_size
    
    # Get GMM parameters: pi, mu_c, logvar_c
    pi, mu_c, logvar_c = model.gmm_params()  # pi: (n_clusters,), mu_c & logvar_c: (n_clusters, latent_dim)
    
    # For each sample, compute log probability under each cluster's Gaussian:
    log_p_z_c = log_gaussian(z, mu_c, logvar_c)  # (batch_size, n_clusters)
    log_pi = torch.log(pi + 1e-10)  # (n_clusters,)
    log_p_zc = log_p_z_c + log_pi  # (batch_size, n_clusters)
    
    # Compute soft assignment: p(c|z) = softmax(log(pi * N(z|mu_c, sigma_c)))
    p_c_z = F.softmax(log_p_zc, dim=1)  # (batch_size, n_clusters)
    
    # KL divergence between q(z|x) and each p(z|c)
    kl_z = gaussian_kl(mu, logvar, mu_c, logvar_c)  # (batch_size, n_clusters)
    
    # Two components of the KL term:
    # 1. Cluster assignment KL: sum_c p(c|z) log(p(c|z)/pi_c)
    kl_cluster = torch.sum(p_c_z * (torch.log(p_c_z + 1e-10) - log_pi), dim=1)
    # 2. Latent KL: sum_c p(c|z) * KL(q(z|x)||N(z; mu_c, sigma_c))
    kl_latent = torch.sum(p_c_z * kl_z, dim=1)
    
    kl_term = torch.mean(kl_cluster + kl_latent)
    
    total_loss = recon_loss + kl_term
    return total_loss, recon_loss, kl_term


In [17]:
def train_vade(model, device, train_loader, optimizer, epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
        data = data.to(device)  # data is already flattened by the transform
        optimizer.zero_grad()
        x_recon, mu, logvar, z = model(data)
        loss, recon_loss, kl_term = vae_loss(data, x_recon, mu, logvar, z, model)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        # if batch_idx % 100 == 0:
            # print(f"Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] "
                #   f"Loss: {loss.item():.4f}, Recon: {recon_loss.item():.4f}, KL: {kl_term.item():.4f}")
    avg_loss = train_loss / len(train_loader.dataset)
    # print(f"====> Epoch: {epoch} Average loss: {avg_loss:.4f}")


In [18]:
def pretrain(model, device, train_loader, optimizer, epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
        data = data.to(device)  # data is already flattened by the transform
        optimizer.zero_grad()
        z, _ = model.encode(data)
        x_recon = model.decode(z)
        # loss, recon_loss, kl_term = vae_loss(data, x_recon, mu, logvar, z, model)
        loss = F.binary_cross_entropy(x_recon, data)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        # if batch_idx % 100 == 0:
            # print(f"Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] "
                #   f"Loss: {loss.item():.4f}, Recon: {recon_loss.item():.4f}, KL: {kl_term.item():.4f}")
    avg_loss = train_loss / len(train_loader.dataset)
    # print(f"====> Epoch: {epoch} Average loss: {avg_loss:.4f}")

In [19]:
def t_sne_on_test(model, test_loader, device):
    latent_vectors = []
    labels = []
    for data, target in test_loader:
        data = data.to(device)
        with torch.no_grad():
            mu, logvar = model.encode(data)
            z = model.reparameterize(mu, logvar)
        latent_vectors.append(z)
        labels.append(target)
        break  # Only need one batch
    latent_vectors = torch.cat(latent_vectors, dim=0).cpu().numpy()
    labels = torch.cat(labels, dim=0).cpu().numpy()
    # tsne
    tsne = TSNE(n_components=2)
    # plot 
    latent_vectors = tsne.fit_transform(latent_vectors)
    plt.figure(figsize=(10, 10))
    plt.scatter(latent_vectors[:, 0], latent_vectors[:, 1], c=labels, cmap='tab10')
    plt.colorbar()
    # save 
    plt.savefig('vade.png')
    plt.close()



In [20]:
# Hyperparameters
batch_size = 128
epochs = 200
lr = 2e-3
latent_dim = 10
n_clusters = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST dataset: transform flattens the 28x28 image into a 784-dim vector.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

# Model and optimizer
model = VaDE(input_dim=784, hidden_dim=500, latent_dim=latent_dim, n_clusters=n_clusters).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
for pretrain_epoch in range(10):
    pretrain(model, device, train_loader, optimizer, pretrain_epoch)
    t_sne_on_test(model, test_loader, device)


  6%|▌         | 26/469 [00:00<00:05, 81.40it/s]

In [None]:
# Training loop
for epoch in range(1, epochs + 1):
    train_vade(model, device, train_loader, optimizer, epoch)
    t_sne_on_test(model, test_loader, device)
    # test_vade(model, device, test_loader)

100%|██████████| 469/469 [00:06<00:00, 69.57it/s]
100%|██████████| 469/469 [00:06<00:00, 69.20it/s]
100%|██████████| 469/469 [00:06<00:00, 68.71it/s]
100%|██████████| 469/469 [00:06<00:00, 68.18it/s]
100%|██████████| 469/469 [00:06<00:00, 68.46it/s]
100%|██████████| 469/469 [00:07<00:00, 63.58it/s]
100%|██████████| 469/469 [00:07<00:00, 66.88it/s]
100%|██████████| 469/469 [00:06<00:00, 68.05it/s]
100%|██████████| 469/469 [00:07<00:00, 61.46it/s]
100%|██████████| 469/469 [00:06<00:00, 70.34it/s]
100%|██████████| 469/469 [00:07<00:00, 64.87it/s]
 71%|███████   | 331/469 [00:05<00:02, 64.49it/s]


KeyboardInterrupt: 