In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [23]:
transform = transforms.ToTensor()

In [24]:
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [25]:
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [26]:
img_height, img_width = train_dataset.data.shape[1], train_dataset.data.shape[2]
num_channels = 1 # Grayscale


In [27]:
class VAE(nn.Module):
    def __init__(self, latent_dim=2):
        super().__init__()
        self.latent_dim = latent_dim

        self.encoder_conv = nn.Sequential(
            # Input (N, 1, 28, 28)
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # -> (N, 64, 14, 14)
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten() # -> (N, 64 * 14 * 14) = (N, 12544)
        )

        self.conv_shape = (64, 14, 14)


        self.encoder_dense = nn.Sequential(
            nn.Linear(self.conv_shape[0] * self.conv_shape[1] * self.conv_shape[2], 32),
            nn.ReLU()
        )

        self.fc_mu = nn.Linear(32, latent_dim)
        self.fc_log_var = nn.Linear(32, latent_dim)
        self.decoder_dense = nn.Sequential(
            nn.Linear(latent_dim, self.conv_shape[0] * self.conv_shape[1] * self.conv_shape[2]),
            nn.ReLU()
        )

        self.decoder_conv = nn.Sequential(
            nn.Unflatten(1, self.conv_shape),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # -> (N, 32, 28, 28)
            nn.ReLU(),
            nn.ConvTranspose2d(32, num_channels, kernel_size=3, padding=1), # -> (N, 1, 28, 28)
            nn.Sigmoid()
        )

    def reparameterize(self, mu, log_var):

        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):

        x = self.encoder_conv(x)
        x = self.encoder_dense(x)
        mu = self.fc_mu(x)
        log_var = self.fc_log_var(x)

        # Reparameterization
        z = self.reparameterize(mu, log_var)

        # Decoder path
        x_recon = self.decoder_dense(z)
        x_recon = self.decoder_conv(x_recon)

        return x_recon, mu, log_var


In [28]:
def vae_loss(recon_x, x, mu, log_var):

    recon_loss = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')

    kld_loss = -0.0005 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    return (recon_loss + kld_loss) / x.size(0)


In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
EPOCHS = 10

print(f"\n--- Training VAE on {device} ---")
for epoch in range(1, EPOCHS + 1):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()

        recon_batch, mu, log_var = model(data)
        loss = vae_loss(recon_batch, data, mu, log_var)

        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    avg_loss = train_loss / len(train_loader)
    print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f}')

print("Training complete.")



--- Training VAE on cpu ---
====> Epoch: 1 Average loss: 177.6869


KeyboardInterrupt: 

In [None]:
all_mu = []
all_labels = []
with torch.no_grad():
    for data, labels in test_loader:
        data = data.to(device)
        # We only need mu for this plot
        recon, mu, log_var = model(data)
        all_mu.append(mu.cpu())
        all_labels.append(labels.cpu())

all_mu = torch.cat(all_mu, dim=0)
all_labels = torch.cat(all_labels, dim=0)

plt.figure(figsize=(10, 10))
plt.scatter(all_mu[:, 0], all_mu[:, 1], c=all_labels, cmap='brg')
plt.xlabel('Latent Dim 1 (mu)')
plt.ylabel('Latent Dim 2 (mu)')
plt.colorbar()
plt.title('Test Set Images Mapped to Latent Space')
plt.show()



In [None]:
n = 20  # Number of images per side of the grid
figure = np.zeros((img_width * n, img_height * n, num_channels))

# Create a grid of latent variables
grid_x = np.linspace(-5, 5, n)
grid_y = np.linspace(-5, 5, n)[::-1]

with torch.no_grad():
    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = torch.tensor([[xi, yi]], device=device, dtype=torch.float32)

            # The Keras code uses a separate decoder model. In PyTorch, we can
            # just call the relevant parts of our VAE model.
            x_decoded_flat = model.decoder_dense(z_sample)
            x_decoded = model.decoder_conv(x_decoded_flat)

            digit = x_decoded[0].cpu().numpy().reshape(img_height, img_width, num_channels)
            figure[i * img_height: (i + 1) * img_height,
                   j * img_width: (j + 1) * img_width] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure.squeeze(), cmap='gnuplot2')
plt.title('Manifold of Generated Digits from Latent Space')
plt.axis('off')
plt.show()