In [1]:
# -*- coding: utf-8 -*-
# VAE (Variational Autoencoder) based on MNIST

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

image_size = [1, 28, 28]  # Size of each image
latent_dim = 20
batch_size = 64
use_gpu = torch.cuda.is_available()

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(np.prod(image_size), 400),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(400, latent_dim)
        self.fc_logvar = nn.Linear(400, latent_dim)
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 400),
            nn.ReLU(),
            nn.Linear(400, np.prod(image_size)),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        x_hat = self.decoder(z)
        x_hat = x_hat.view(-1, *image_size)  # Reshape decoded output
        return x_hat

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar

# Data loading
dataset = torchvision.datasets.MNIST("mnist_data", train=True, download=True,
                                     transform=torchvision.transforms.Compose(
                                         [torchvision.transforms.ToTensor()]
                                     )
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

vae = VAE()
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

if use_gpu:
    print("Using GPU for training")
    vae = vae.cuda()

num_epochs = 10
for epoch in range(num_epochs):
    vae.train()
    train_loss = 0
    for i, (data, _) in enumerate(dataloader):
        if use_gpu:
            data = data.cuda()
        optimizer.zero_grad()
        x_hat, mu, logvar = vae(data)
        # Calculate loss
        recon_loss = F.binary_cross_entropy(x_hat, data, reduction='sum')
        kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        loss = recon_loss + kld_loss
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if i % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item() / batch_size:.4f}")

    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {train_loss / len(dataloader.dataset):.4f}")

    # Generate and save images after each epoch
    with torch.no_grad():
        sample = torch.randn(16, latent_dim)  # Generate 16 samples
        if use_gpu:
            sample = sample.cuda()
        sample = vae.decode(sample).cpu()

        # Save as a 2x8 image grid
        torchvision.utils.save_image(sample, f'sample_epoch_{epoch+1}.png', nrow=8, padding=2)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:11<00:00, 886kB/s] 


Extracting mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 126kB/s]


Extracting mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:01<00:00, 1.20MB/s]


Extracting mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 8.52MB/s]


Extracting mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist_data/MNIST/raw

Using GPU for training
Epoch [1/10], Step [0/937], Loss: 549.3596
Epoch [1/10], Step [100/937], Loss: 180.3835
Epoch [1/10], Step [200/937], Loss: 166.5256
Epoch [1/10], Step [300/937], Loss: 143.1723
Epoch [1/10], Step [400/937], Loss: 137.7251
Epoch [1/10], Step [500/937], Loss: 131.4553
Epoch [1/10], Step [600/937], Loss: 125.1710
Epoch [1/10], Step [700/937], Loss: 120.7971
Epoch [1/10], Step [800/937], Loss: 121.4058
Epoch [1/10], Step [900/937], Loss: 118.9433
Epoch [1/10], Average Loss: 147.3761
Epoch [2/10], Step [0/937], Loss: 121.3193
Epoch [2/10], Step [100/937], Loss: 115.8706
Epoch [2/10], Step [200/937], Loss: 121.2374
Epoch [2/10], Step [300/937], Loss: 114.1101
Epoch [2/10], Step [400/937], Loss: 115.2225
Epoch [2/10], Step [500/937], Loss: 116.6241
Epoch [2/10], Step [600/937], Loss: 109.2436
Epoch [2/10], Step [700/937], Loss: 121.8980
Epoch [2/10], Step [800/937], Loss: 114.1192
Epoch 