In [None]:
# Import Lib
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms

from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision.datasets import MNIST

In [None]:
# Setting hyper-params
device = "cuda"
batch_size = 128
lr = 1e-03
epochs = 20

in_channels = 1
z_dim = 128
hidden_dims = [32, 64]

### Step 1. Define Dataset

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = MNIST(root = ".", train = True, transform = transform, download = True)
test_dataset = MNIST(root = ".", train = False, transform = transform, download = True)

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

In [None]:
images, labels = next(iter(train_loader))

print("Images shape: {}".format(images.shape))
print("labels shape: {}".format(labels.shape))

grid = make_grid(images[:64], nrow = 8, normalize=True)

plt.figure(figsize = (12, 12))
plt.imshow(grid[0])
plt.axis("off")
plt.show()

### Step 2. Defien Model

In [None]:
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2, mode="nearest"):
        super().__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=scale_factor, mode=mode),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1))

    def forward(self, x):
        return self.up(x)


class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1):
        super().__init__()
        self.down = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)

    def forward(self, x):
        return self.down(x)


class VAE(nn.Module):
    def __init__(self, in_channels=1, z_dim=128, hidden_dims=[32, 64]):
        super().__init__()
        # Encoder
        self.encoder = nn.Sequential(
            Downsample(in_channels, hidden_dims[0]),
            nn.BatchNorm2d(hidden_dims[0]),
            nn.ReLU(),
            Downsample(hidden_dims[0], hidden_dims[1]),
            nn.BatchNorm2d(hidden_dims[1]),
            nn.ReLU()
        )
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(64*7*7, z_dim)
        self.fc_logvar = nn.Linear(64*7*7, z_dim)

        # Decoder
        self.decoder_input = nn.Linear(z_dim, 64*7*7)

        self.decoder = nn.Sequential(
            Upsample(hidden_dims[1], hidden_dims[0]),
            nn.BatchNorm2d(hidden_dims[0]),
            nn.ReLU(),
            Upsample(hidden_dims[0], in_channels),
            nn.Sigmoid() # value to [0, 1]
        )

    def reparameterize(self, mu, logvar):
        sigma = torch.exp(0.5 * logvar)
        eps = torch.randn_like(sigma).to(logvar.device)
        return mu + sigma * eps

    def encode(self, x):
        x = self.encoder(x)
        x = self.flatten(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def decode(self, z):
        z = self.decoder_input(z)
        z = z.view(-1, 64, 7, 7)
        out = self.decoder(z)
        return out

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        out = self.decode(z)
        return out, mu, logvar

x = torch.randn(4, 1, 28, 28)
model = VAE(in_channels=in_channels, z_dim=z_dim)
recon_x, _, _ = model(x)
print("x:", x.shape, "recon_x:", recon_x.shape)

### Step 3. Define Loss function and Model & Optimizer

In [None]:
# loss function
def loss_function(pred, target, mu, logvar):
    recon_loss = F.binary_cross_entropy(pred, target, reduction="mean")
    kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss
    
model = VAE(in_channels=in_channels, z_dim=z_dim, hidden_dims=hidden_dims).to(device) # model
optimizer = optim.Adam(model.parameters(), lr=lr) # optimizer

### Step 4. Train VAE

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, epoch, device):
    total = 0
    total_loss = 0.

    model.train()
    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.to(device)
        total += data.size(0)

        optimizer.zero_grad()
        out, mu, logvar = model(data)
        loss = criterion(out, data, mu, logvar)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f"[Train] Epoch {epoch} [{total}/{len(dataloader.dataset)}]\tAvg Loss: {total_loss/(batch_idx+1):.5f}")

    # Sample plot
    plt.figure(figsize=(6,6))
    sample = out[0].permute(1, 2, 0).detach().cpu().numpy()
    plt.imshow(sample)
    plt.axis("off")
    plt.show()
    plt.close()

    total_loss /= len(dataloader)
    return total_loss

for epoch in range(1, epochs+1):
    train_loss = train_epoch(model, train_loader, loss_function, optimizer, epoch, device)
    print(f"[Train] Epoch: {epoch}\tTotal Avg Loss: {train_loss:.5f}")

### Step 5. Generate Images from Test Dataset

In [None]:
images, label = next(iter(test_loader))

model.eval()
with torch.no_grad():
    images = images.to(device)
    recon_images, _, _ = model(images)

recon_grid = make_grid(recon_images[:64].detach().cpu(), nrow = 8, normalize=True)
gt_grid = make_grid(images[:64].detach().cpu(), nrow = 8, normalize=True)

fig, ax = plt.subplots(nrows=1, ncols=2, figsize = (18, 9))
ax[0].imshow(recon_grid[0])
ax[0].set_title("Generation")
ax[0].set_axis_off()

ax[1].imshow(gt_grid[0])
ax[1].set_title("Ground Truth")
ax[1].set_axis_off()
plt.show()

### Step 6. Generate Images from Noise vectors

In [None]:
noise = torch.randn(batch_size, z_dim).to(device)

model.eval()
with torch.no_grad():
    generated_images = model.decode(noise)

grid = make_grid(generated_images[:64].detach().cpu(), nrow = 8, normalize=True)
plt.figure(figsize = (12, 12))
plt.imshow(grid[0])
plt.axis("off")
plt.show()