In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), lambda x: (x > 0.5).float()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

In [None]:
# create dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# architecture
class CVAE(nn.Module):
    def __init__(self, latent_dim):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=1, padding=1), nn.ReLU(),
            nn.Conv2d(16, 16, 3, stride=2, padding=1), nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 32, 3, stride=1, padding=1), nn.ReLU(),
            nn.Conv2d(32, 32, 3, stride=2, padding=1), nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(7*7*64, latent_dim * 2)  # Mean and log-variance
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 7*7*32), nn.ReLU(),
            nn.Unflatten(1, (32, 7, 7)),
            nn.ConvTranspose2d(32, 64, 3, stride=1, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(64, 64, 3, stride=2, padding=1, output_padding=1), nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 32, 3, stride=1, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(32, 32, 3, stride=2, padding=1, output_padding=1), nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32, 1, 3, stride=1, padding=1), nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mean, logvar = h[:, :self.latent_dim], h[:, self.latent_dim:]
        return mean, logvar

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar) # == torch.sqrt(torch.exp(logvar))
        eps = torch.randn_like(std)
        return mean + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        x_recon = self.decode(z)
        return x_recon, mean, logvar

In [None]:
# Define the loss function
def loss_function(x_recon, x, mean, logvar):
    BCE = nn.functional.binary_cross_entropy(x_recon, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    return BCE + KLD

In [None]:
# create model
latent_dim = 2
model = CVAE(latent_dim).to('cuda')

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
# Training loop
epochs = 230
for epoch in range(epochs):
    model.train()
    train_loss = 0
    for batch in train_loader:
        x, _ = batch
        x = x.to('cuda')
        optimizer.zero_grad()
        x_recon, mean, logvar = model(x)
        loss = loss_function(x_recon, x, mean, logvar)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {train_loss/len(train_loader.dataset)}')

In [None]:
# save the trained model
torch.save(model.state_dict(), 'mnist_cvae_encoder.pth')

In [None]:
# save the decoder
torch.save(model.decoder,'mnist_cvae_decoder.pth')

In [None]:
# how to use the decoder in pytorch
dummy_input = torch.tensor([[0,0]],dtype=torch.float32).to('cuda')
print(model.decoder(dummy_input).shape) # [1, 1, 28, 28]

In [None]:
# Plot latent space
def plot_latent_space(model, n=20, figsize=15):
    global grid_x, grid_y
    norm = torch.distributions.Normal(0, 1)
    grid_x = norm.icdf(torch.linspace(0.05, 0.95, n-1))
    grid_y = norm.icdf(torch.linspace(0.05, 0.95, n-1))
    figure = np.zeros((28 * (n-1), 28 * (n-1)))

    model.eval()
    with torch.no_grad():
        for i, yi in enumerate(grid_x):
            for j, xi in enumerate(grid_y):
                z = torch.tensor([[xi, yi]]).float().to('cuda')
                x_decoded = model.decode(z)
                digit = x_decoded[0].reshape(28, 28).cpu().numpy()
                figure[i * 28: (i + 1) * 28, j * 28: (j + 1) * 28] = digit

    plt.figure(figsize=(figsize, figsize))
    plt.imshow(figure, cmap='Greys_r')
    plt.axis('off')
    plt.savefig('latent_space.png')

plot_latent_space(model)