In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING']="1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision
import math

In [3]:
# Dataset
traindata = torchvision.datasets.MNIST('./', train=True, download=False, transform=torchvision.transforms.ToTensor())
testdata = torchvision.datasets.MNIST('./', train=False, download=False, transform=torchvision.transforms.ToTensor())


In [4]:
train_loader = DataLoader(traindata, batch_size=64, shuffle=True)
test_loader = DataLoader(testdata, batch_size=64, shuffle=False)

In [None]:
class GMMVAE(nn.Module):
    def __init__(self, input_dim=28*28, latent_dim=25, num_clusters=10):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_clusters = num_clusters

        # Encoder
        self.fc_enc = nn.Linear(input_dim, 250)
        self.logits = nn.Linear(250, num_clusters)
        self.means = nn.ModuleList([nn.Linear(250, latent_dim) for _ in range(num_clusters)])
        self.logvars = nn.ModuleList([nn.Linear(250, latent_dim) for _ in range(num_clusters)])

        # Decoder
        self.fc_dec = nn.Linear(latent_dim, 250)
        self.output_layer = nn.Linear(250, input_dim)

        # Priors
        self.mean_priors = nn.Parameter(torch.randn(num_clusters, latent_dim))
        self.logvar_priors = nn.Parameter(torch.zeros(num_clusters, latent_dim))  

    def decoder(self, z):
        h = F.relu(self.fc_dec(z))
        return torch.sigmoid(self.output_layer(h))

    def encoder(self, x):
        h = F.relu(self.fc_enc(x))
        logits = self.logits(h)
        mean_list = [mean(h) for mean in self.means]
        logvar_list = [logvar(h) for logvar in self.logvars]
        q_c = F.softmax(logits, dim=-1)
        return logits, q_c, mean_list, logvar_list

    def gumbel_softmax(self, logits, tau, train=True):
        return F.gumbel_softmax(logits, tau, hard=not train)

    def kl_categorical(self, q_c):
        # KL[q(c|x) || p(c)] where p(c) is uniform
        log_q = torch.log(q_c + 1e-10)
        log_uniform = math.log(1.0 / self.num_clusters)
        kl = torch.sum(q_c * (log_q - log_uniform), dim=1)  # [batch]
        return kl.mean()

    def kl_gaussian(self, m_q, logvar_q, m_p, logvar_p):
        return 0.5 * torch.sum(
            (torch.exp(logvar_q) + (m_q - m_p).pow(2)) / torch.exp(logvar_p)
            - 1 + logvar_p - logvar_q, dim=1
        )  # shape [batch]

    def forward(self, x, train=True):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)

        logits, q_c, mean_list, logvar_list = self.encoder(x)
        c = self.gumbel_softmax(logits, tau=0.5, train=train)

        # KL divergence losses
        kl_c_loss = self.kl_categorical(q_c)
        kl_z_loss = torch.zeros(batch_size, device=x.device)

        z_samples = []
        for i in range(self.num_clusters):
            m_q = mean_list[i]
            logvar_q = logvar_list[i]
            std_q = torch.exp(0.5 * logvar_q)
            eps = torch.randn_like(std_q)
            z_i = m_q + std_q * eps
            z_samples.append(z_i)

            # KL(q || p)
            m_p = self.mean_priors[i].unsqueeze(0).expand_as(m_q)
            logvar_p = self.logvar_priors[i].unsqueeze(0).expand_as(logvar_q)
            kl_i = self.kl_gaussian(m_q, logvar_q, m_p, logvar_p)
            kl_z_loss += c[:, i] * kl_i

        # Combine latent variables
        z_stack = torch.stack(z_samples, dim=2)  # [batch, latent_dim, K]
        z = torch.bmm(z_stack, c.unsqueeze(2)).squeeze(2)  # [batch, latent_dim]

        # Decode
        x_recon = self.decoder(z)

        if x.shape!= x_recon.shape:
            print(x_recon.shape)
            print(x.shape)
        # Reconstruction loss
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum') / batch_size
        kl_z_loss = kl_z_loss.mean()

        total_loss = recon_loss + kl_c_loss + kl_z_loss
        #print(total_loss)
        return total_loss, x_recon


In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [27]:
model = GMMVAE().to(device)

In [28]:
model

GMMVAE(
  (fc_enc): Linear(in_features=784, out_features=250, bias=True)
  (logits): Linear(in_features=250, out_features=10, bias=True)
  (means): ModuleList(
    (0-9): 10 x Linear(in_features=250, out_features=25, bias=True)
  )
  (logvars): ModuleList(
    (0-9): 10 x Linear(in_features=250, out_features=25, bias=True)
  )
  (fc_dec): Linear(in_features=25, out_features=250, bias=True)
  (output_layer): Linear(in_features=250, out_features=784, bias=True)
)

In [29]:
optimizer = torch.optim.Adam(model.parameters(), lr=2e-2)

In [23]:
epochs = 10
for i in range(epochs):
    model.train()

    train_loss =0.0
    cosine_sims = []

    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        loss, x_recon = model(data, train=True)
        
        flat_data = data.view(-1, 28*28)
        #recons_loss = F.binary_cross_entropy_with_logits(flat_data, x_recon)
        #loss = -recons_loss + KL_total
        #print(recons_loss)
        #print(KL_total)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        optimizer.step()
        train_loss += loss.item()

        
        cos_sim = F.cosine_similarity(flat_data, x_recon, dim=1)  
        cosine_sims.append(cos_sim)
        
        # Aggregate cosine similarity across all batches
    epoch_cosine_similarity = torch.cat(cosine_sims).mean().item()

    print(f'Epoch {i + 1}, Loss: {train_loss / len(train_loader.dataset):.4f}, Cosine Similarity: {epoch_cosine_similarity:.4f}')

Epoch 1, Loss: 2.3189, Cosine Similarity: 0.8163
Epoch 2, Loss: 2.0956, Cosine Similarity: 0.8531
Epoch 3, Loss: 2.0849, Cosine Similarity: 0.8554
Epoch 4, Loss: 2.0797, Cosine Similarity: 0.8567
Epoch 5, Loss: 2.0823, Cosine Similarity: 0.8565
Epoch 6, Loss: 2.0799, Cosine Similarity: 0.8570
Epoch 7, Loss: 2.0812, Cosine Similarity: 0.8571
Epoch 8, Loss: 2.0892, Cosine Similarity: 0.8562
Epoch 9, Loss: 2.0906, Cosine Similarity: 0.8559
Epoch 10, Loss: 2.0933, Cosine Similarity: 0.8556


In [30]:
epochs = 10
for i in range(epochs):
    model.train()

    train_loss =0.0
    cosine_sims = []

    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        loss, x_recon = model(data, train=True)
        
        flat_data = data.view(-1, 28*28)
        #recons_loss = F.binary_cross_entropy_with_logits(flat_data, x_recon)
        #loss = -recons_loss + KL_total
        #print(recons_loss)
        #print(KL_total)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.25)
        optimizer.step()
        train_loss += loss.item()

        
        cos_sim = F.cosine_similarity(flat_data, x_recon, dim=1)  
        cosine_sims.append(cos_sim)
        
        # Aggregate cosine similarity across all batches
    epoch_cosine_similarity = torch.cat(cosine_sims).mean().item()

    print(f'Epoch {i + 1}, Loss: {train_loss / len(train_loader.dataset):.4f}, Cosine Similarity: {epoch_cosine_similarity:.4f}')

Epoch 1, Loss: 2.2108, Cosine Similarity: 0.8389
Epoch 2, Loss: 2.0364, Cosine Similarity: 0.8683
Epoch 3, Loss: 2.0365, Cosine Similarity: 0.8692
Epoch 4, Loss: 2.0352, Cosine Similarity: 0.8698
Epoch 5, Loss: 2.0371, Cosine Similarity: 0.8698
Epoch 6, Loss: 2.0442, Cosine Similarity: 0.8685
Epoch 7, Loss: 2.0398, Cosine Similarity: 0.8693
Epoch 8, Loss: 2.0446, Cosine Similarity: 0.8694
Epoch 9, Loss: 2.0394, Cosine Similarity: 0.8700
Epoch 10, Loss: 2.0474, Cosine Similarity: 0.8692


In [31]:
model.eval()

test_cosine = []
test_loss = 0.0

for batch_idx, (data, _) in enumerate(test_loader):
    data = data.to(device)

    loss, x_recon = model(data, train=False)
    flat_data = data.view(-1, 28*28)
    test_loss+=loss.item()

    cos_sim = F.cosine_similarity(flat_data, x_recon, dim=1)  
    test_cosine.append(cos_sim)

cosine_similarity = torch.cat(test_cosine).mean().item()
print(f"Test_Cosine Similarity: {round(cosine_similarity, 4)}")

Test_Cosine Similarity: 0.8769
