In [1]:
from tensorflow.keras import datasets

(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
print(train_images.shape)
print(train_labels.shape)
print(test_images.shape)
print(test_labels.shape)

2025-03-25 23:28:07.938101: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-25 23:28:08.044368: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742925488.095050   15081 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742925488.109160   15081 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-25 23:28:08.211533: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 0us/step
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)


In [2]:
import torch
from torch.utils.data import DataLoader, TensorDataset

train_images = torch.tensor(train_images / 255, dtype=torch.float32)
test_images = torch.tensor(test_images / 255, dtype=torch.float32)

train_images = train_images.unsqueeze(1)
test_images = test_images.unsqueeze(1)

train_loader = DataLoader(train_images, batch_size=1024, shuffle=True)
test_loader = DataLoader(test_images, batch_size=1024, shuffle=False)

In [3]:
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 1, stride=2, padding=1)
        self.linear1 = nn.Linear(64 * 5 * 5, latent_dim)
        self.linear2 = nn.Linear(64 * 5 * 5, latent_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(-1, 64 * 5 * 5)
        mu = self.linear1(x)
        logvar = self.linear2(x)
        return mu, logvar

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.linear = nn.Linear(latent_dim, 64 * 5 * 5)
        self.deconv1 = nn.ConvTranspose2d(64, 32, 1, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1)

    def forward(self, z):
        x = self.linear(z)
        x = x.view(-1, 64, 5, 5)
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = self.deconv3(x)
        x = torch.sigmoid(x)
        return x


In [5]:
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)
        self.latent_dim = latent_dim

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decoder(z)
        return x_hat, mu, logvar

In [14]:
vae = VAE(latent_dim=16).to('cuda')

In [15]:
class VAELoss(nn.Module):
    def __init__(self):
        super(VAELoss, self).__init__()

    def forward(self, x_hat, x, mu, logvar):
        recon_loss = F.binary_cross_entropy(x_hat, x, reduction='sum')
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return recon_loss + kl_loss, recon_loss, kl_loss

In [24]:
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

def train(model, train_loader, epochs, lr, device):
    writer = SummaryWriter()
    model.to(device)
    criterion = VAELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, desc="Training", leave=False)
        for batch_idx, data in enumerate(train_bar):
            data = data.to(device)
            optimizer.zero_grad()
            x_hat, mu, logvar = model(data)
            loss, recon, kl = criterion(x_hat, data, mu, logvar)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_bar.set_postfix(loss=running_loss / (batch_idx + 1))
        print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {running_loss / len(train_loader)}, Recon Loss: {recon}, KL Loss: {kl}")
        writer.add_scalar("Loss/Train", running_loss / len(train_loader), epoch)
        writer.add_scalar("Loss/TrainRecon", recon, epoch)
        writer.add_scalar("Loss/TrainKL", kl, epoch)
        model.eval()
        running_loss = 0.0
        val_bar = tqdm(test_loader, desc="Validation", leave=False)
        for batch_idx, data in enumerate(val_bar):
            data = data.to(device)
            x_hat, mu, logvar = model(data)
            loss, recon, kl = criterion(x_hat, data, mu, logvar)
            running_loss += loss.item()
            val_bar.set_postfix(loss=running_loss / (batch_idx + 1))
        print(f"Epoch {epoch + 1}/{epochs}, Val Loss: {running_loss / len(test_loader)}, Recon Loss: {recon}, KL Loss: {kl}")
        writer.add_scalar("Loss/Val", running_loss / len(test_loader), epoch)
        writer.add_scalar("Loss/ValRecon", recon, epoch)
        writer.add_scalar("Loss/ValKL", kl, epoch)
    writer.close()

In [None]:
train(vae, train_loader, 100, 0.01, "cuda")