In [1]:
from torch.utils.data import DataLoader

import torch
import torch.nn.functional as F


from utils_mnist import (
    test,
    visualize_gen_image,
    visualize_gmm_latent_representation,
    non_iid_train_iid_test_6789,
    subset_alignment_dataloader,
    train_align,
    eval_reconstrution,
)
import os
import numpy as np
import matplotlib.pyplot as plt
from utils_mnist import VAE

NUM_CLIENTS = 2
NUM_CLASSES = 4
samples_per_class = 200
alignment_dataloader = subset_alignment_dataloader(
    samples_per_class=samples_per_class,
    batch_size=samples_per_class * NUM_CLASSES,
)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class CVAE(nn.Module):
    def __init__(
        self, x_dim=784, h_dim1=512, h_dim2=256, h_dim3=128, z_dim=10, num_classes=10
    ):
        super(CVAE, self).__init__()
        self.num_classes = num_classes

        # Encoder part
        self.fc1 = nn.Linear(x_dim + self.num_classes, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc3 = nn.Linear(h_dim2, h_dim3)
        self.fc41 = nn.Linear(h_dim3, z_dim)  # mu
        self.fc42 = nn.Linear(h_dim3, z_dim)  # log_var

        # Decoder part
        self.label_projection = nn.Linear(num_classes, z_dim)
        self.fc5 = nn.Linear(z_dim, h_dim3)
        self.fc6 = nn.Linear(h_dim3, h_dim2)
        self.fc7 = nn.Linear(h_dim2, h_dim1)
        self.fc8 = nn.Linear(h_dim1, x_dim)

    def encoder(self, x, y):
        input_cat = torch.cat((x, y), dim=1)
        h = F.relu(self.fc1(input_cat))
        h = F.relu(self.fc2(h))
        h = F.relu(self.fc3(h))
        return self.fc41(h), self.fc42(h)  # mu, log_var

    def sampling(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)  # return z sample

    def decoder(self, z):
        h = F.relu(self.fc5(z))
        h = F.relu(self.fc6(h))
        h = F.relu(self.fc7(h))
        return torch.sigmoid(self.fc8(h))

    def forward(self, x, y):
        y_onehot = F.one_hot(y, self.num_classes).float()
        mu, log_var = self.encoder(x.view(-1, 784), y_onehot)
        z = self.sampling(mu, log_var)

        # Project one-hot encoded conditional information to latent space
        y_proj = self.label_projection(y_onehot)

        # Combine latent vector z with projected conditional information y
        z_combined = z + y_proj

        output = self.decoder(z_combined)
        return output, mu, log_var

In [21]:
def vae_loss(recon_img, img, mu, logvar):
    # Reconstruction loss using binary cross-entropy
    condition = (recon_img >= 0.0) & (recon_img <= 1.0)
    # assert torch.all(condition), "Values should be between 0 and 1"
    # if not torch.all(condition):
    #     ValueError("Values should be between 0 and 1")
    #     recon_img = torch.clamp(recon_img, 0.0, 1.0)
    recon_loss = F.binary_cross_entropy(recon_img, img.view(-1, 784), reduction='sum')
    # KL divergence loss
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    # Total VAE loss
    total_loss = recon_loss + kld_loss*0.01
    return total_loss

In [22]:
ref_model = CVAE(z_dim=2).to(DEVICE)
opt_ref = torch.optim.Adam(ref_model.parameters(), lr=1e-3)
for ep in range(5000):
    for images, labels in alignment_dataloader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        opt_ref.zero_grad()
        recon_images, mu, logvar = ref_model(images, labels)
        vae_loss1 = vae_loss(recon_images, images, mu, logvar)
        vae_loss1.backward()
        opt_ref.step()
        
    if ep % 100 == 0:
        print(f"Epoch {ep}, Loss {vae_loss1.item()}")
    
        print("--------------------------------------------------")


Epoch 0, Loss 434892.8125
--------------------------------------------------
Epoch 100, Loss 134671.4375
--------------------------------------------------
Epoch 200, Loss 112528.421875
--------------------------------------------------
Epoch 300, Loss 105083.7265625
--------------------------------------------------
Epoch 400, Loss 101315.2265625
--------------------------------------------------
Epoch 500, Loss 98174.7734375
--------------------------------------------------
Epoch 600, Loss 95701.7421875
--------------------------------------------------
Epoch 700, Loss 94276.34375
--------------------------------------------------
Epoch 800, Loss 92183.3203125
--------------------------------------------------
Epoch 900, Loss 90803.0234375
--------------------------------------------------
Epoch 1000, Loss 89651.8515625
--------------------------------------------------
Epoch 1100, Loss 88031.140625
--------------------------------------------------
Epoch 1200, Loss 87063.4921875
--

In [24]:
for ep in range(5000):
    for images, labels in alignment_dataloader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        opt_ref.zero_grad()
        recon_images, mu, logvar = ref_model(images, labels)
        vae_loss1 = vae_loss(recon_images, images, mu, logvar)
        vae_loss1.backward()
        opt_ref.step()

    if ep % 100 == 0:
        print(f"Epoch {ep}, Loss {vae_loss1.item()}")

        print("--------------------------------------------------")

Epoch 0, Loss 70338.3203125
--------------------------------------------------
Epoch 100, Loss 70393.8046875
--------------------------------------------------
Epoch 200, Loss 70176.6796875
--------------------------------------------------
Epoch 300, Loss 70327.703125
--------------------------------------------------
Epoch 400, Loss 69082.7265625
--------------------------------------------------
Epoch 500, Loss 69001.875
--------------------------------------------------
Epoch 600, Loss 68813.0625
--------------------------------------------------
Epoch 700, Loss 69031.7578125
--------------------------------------------------
Epoch 800, Loss 68271.25
--------------------------------------------------
Epoch 900, Loss 68457.6328125
--------------------------------------------------
Epoch 1000, Loss 68207.9453125
--------------------------------------------------
Epoch 1100, Loss 68065.2265625
--------------------------------------------------
Epoch 1200, Loss 68406.3125
-------------

In [23]:
# Test VAE on test set
ref_model.eval()
with torch.no_grad():
    test_latents = []
    test_labels = []  # Store corresponding labels
    for x, labels in alignment_dataloader:  # Retrieve labels from test_loader
        x = x.to(DEVICE)
        labels = labels.to(DEVICE)
        recon_images, mu, logvar = ref_model(x, labels)
        test_latents.append(mu.cpu().numpy())
        test_labels.append(labels.cpu().numpy())  # Store labels

# Visualize latent representations in 2D
test_latents = np.concatenate(test_latents, axis=0)
test_labels = np.concatenate(test_labels, axis=0)  # Concatenate all labels
from sklearn.decomposition import PCA

# Assuming test_latents is a numpy array with shape (num_samples, 16)
pca = PCA(n_components=2)  # Reduce to 2 dimensions for visualization
test_latents_pca = pca.fit_transform(test_latents)
plt.figure(figsize=(8, 6))
for label in np.unique(test_labels):
    indices = np.where(test_labels == label)
    # print(indices)
    # print(indices[0])  # Use [0] to access indices
    plt.scatter(
        test_latents[indices, 0],
        test_latents[indices, 1],
        label=label,
        alpha=1,
    )
plt.xlabel("Latent Dimension 1")
plt.ylabel("Latent Dimension 2")
plt.title("Latent Space Visualization of Test Set with Class Highlight")
plt.legend()
plt.savefig("latent_space_visualization.png")  # Save the figure as PNG

plt.show()