In [1]:
from torch.utils.data import DataLoader

import torch
import torch.nn.functional as F


from utils_mnist import (
    test,
    visualize_gen_image,
    visualize_gmm_latent_representation,
    non_iid_train_iid_test_6789,
    subset_alignment_dataloader,
    train_align,
    eval_reconstrution,
)
from utils_mnist import VAE
import os
import numpy as np
import matplotlib.pyplot as plt
from utils_mnist import VAE

NUM_CLIENTS = 2
NUM_CLASSES = 4
samples_per_class = 200
alignment_dataloader = subset_alignment_dataloader(
    samples_per_class=samples_per_class,
    batch_size=samples_per_class * NUM_CLASSES,
)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
DEVICE

device(type='cuda', index=0)

In [3]:
def vae_loss(recon_img, img, mu, logvar):
    # Reconstruction loss using binary cross-entropy
    condition = (recon_img >= 0.0) & (recon_img <= 1.0)
    # assert torch.all(condition), "Values should be between 0 and 1"
    # if not torch.all(condition):
    #     ValueError("Values should be between 0 and 1")
    #     recon_img = torch.clamp(recon_img, 0.0, 1.0)
    recon_loss = F.binary_cross_entropy(
        recon_img, img.view(-1, img.shape[2] * img.shape[3]), reduction="sum"
    )
    # KL divergence loss
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    # Total VAE loss
    total_loss = recon_loss + kld_loss*5
    return total_loss

In [4]:
ref_model = VAE(z_dim=2).to(DEVICE)
opt_ref = torch.optim.Adam(ref_model.parameters(), lr=1e-3)
for ep in range(5000):
    for images, _ in alignment_dataloader:
        images = images.to(DEVICE)
        opt_ref.zero_grad()
        recon_images, mu, logvar = ref_model(images)
        vae_loss1 = vae_loss(recon_images, images, mu, logvar)
        vae_loss1.backward()
        opt_ref.step()
        
    if ep % 100 == 0:
        print(f"Epoch {ep}, Loss {vae_loss1.item()}")
    
        print("--------------------------------------------------")




Epoch 0, Loss 434872.9375
--------------------------------------------------
Epoch 100, Loss 146453.640625
--------------------------------------------------
Epoch 200, Loss 139193.765625
--------------------------------------------------
Epoch 300, Loss 131410.9375
--------------------------------------------------
Epoch 400, Loss 128147.421875
--------------------------------------------------
Epoch 500, Loss 124872.140625
--------------------------------------------------
Epoch 600, Loss 123540.96875
--------------------------------------------------
Epoch 700, Loss 121272.7265625
--------------------------------------------------
Epoch 800, Loss 120235.328125
--------------------------------------------------
Epoch 900, Loss 119476.59375
--------------------------------------------------
Epoch 1000, Loss 118929.9375
--------------------------------------------------
Epoch 1100, Loss 118031.484375
--------------------------------------------------
Epoch 1200, Loss 117590.2421875
---

In [6]:
# Test VAE on test set
ref_model.eval()
with torch.no_grad():
    test_latents = []
    test_labels = []  # Store corresponding labels
    for x, labels in alignment_dataloader:  # Retrieve labels from test_loader
        x = x.to(DEVICE)
        labels = labels.to(DEVICE)
        recon_images, mu, logvar = ref_model(x)
        test_latents.append(mu.cpu().numpy())
        test_labels.append(labels.cpu().numpy())  # Store labels

# Visualize latent representations in 2D
test_latents = np.concatenate(test_latents, axis=0)
test_labels = np.concatenate(test_labels, axis=0)  # Concatenate all labels
from sklearn.decomposition import PCA

# Assuming test_latents is a numpy array with shape (num_samples, 16)
pca = PCA(n_components=2)  # Reduce to 2 dimensions for visualization
test_latents_pca = pca.fit_transform(test_latents)
plt.figure(figsize=(8, 6))
for label in np.unique(test_labels):
    indices = np.where(test_labels == label)
    # print(indices)
    # print(indices[0])  # Use [0] to access indices
    plt.scatter(
        test_latents[indices, 0],
        test_latents[indices, 1],
        label=label,
        alpha=1,
    )
plt.xlabel("Latent Dimension 1")
plt.ylabel("Latent Dimension 2")
plt.title("Latent Space Visualization of Test Set with Class Highlight")
plt.legend()
plt.savefig("latent_space_visualization.png")  # Save the figure as PNG

plt.show()