<a href="https://colab.research.google.com/github/alexchen1999/deeplearning-from-scratch/blob/main/vae_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn.functional as F

from torch import nn

In [4]:
class VariationalAutoEncoder(nn.Module):
  def __init__(self, input_dim, h_dim=200, z_dim=20):
    super().__init__()

    # Encoder
    self.img_2hid = nn.Linear(input_dim, h_dim)
    self.hid_2mu = nn.Linear(h_dim, z_dim)
    self.hid_2sigma = nn.Linear(h_dim, z_dim)

    # Decoder
    self.z_2hid = nn.Linear(z_dim, h_dim)
    self.hid_2img = nn.Linear(h_dim, input_dim)

    self.relu = nn.ReLU()

  def encode(self, x):
    h = self.relu(self.img_2hid(x))
    mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
    return mu, sigma

  def decode(self, z):
    h = self.z_2hid(z)
    return torch.sigmoid(self.hid_2img(h)) # Scale from 0 to 1 because MNIST is standardized

  def forward(self, x):
    mu, sigma = self.encode(x)

    # Reparameterization trick
    epsilon = torch.randn_like(sigma) # see https://pytorch.org/docs/stable/generated/torch.randn_like.html: sampling from Gaussian with same tensor shape
    z_reparameterized = mu + sigma * epsilon

    x_reconstructed = self.decode(z_reparameterized)

    return x, mu, sigma

In [5]:
x = torch.randn(4, 28 * 28)
vae = VariationalAutoEncoder(input_dim=784)
x_recon, mu, sigma = vae(x)

print(x_recon.shape)
print(mu.shape)
print(sigma.shape)

torch.Size([4, 784])
torch.Size([4, 20])
torch.Size([4, 20])


In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

INPUT_DIM = 784
H_DIM = 200
Z_DIM = 20

NUM_EPOCHS = 10

BATCH_SIZE = 32

LEARNING_RATE = 0.001

In [8]:
import torchvision.datasets as datasets
from torchvision import transforms

dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
dataset

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 200097755.02it/s]


Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 42728639.80it/s]


Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 189010423.52it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 13559095.21it/s]


Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



Dataset MNIST
    Number of datapoints: 60000
    Root location: dataset/
    Split: Train
    StandardTransform
Transform: ToTensor()

In [10]:
from torch.utils.data import DataLoader
from torch import optim
from tqdm import tqdm

train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

bce_loss = nn.BCELoss(reduction="sum")


for epoch in range(NUM_EPOCHS):
  loop = tqdm(enumerate(train_loader))

  for i, (x, _) in loop:
    x = x.to(device).view(x.shape[0], INPUT_DIM)

    x_recon, mu, sigma = model(x)

    recon_loss = bce_loss(x_recon, x)
    kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

    loss = recon_loss + kl_div
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    loop.set_postfix(loss=loss.item())

1875it [00:27, 67.06it/s, loss=1.41e+3]
1875it [00:29, 62.52it/s, loss=1.5e+3]
1875it [00:28, 65.22it/s, loss=1.41e+3]
1875it [00:28, 64.95it/s, loss=1.45e+3]
1875it [00:28, 65.37it/s, loss=1.42e+3]
1875it [00:29, 64.25it/s, loss=1.48e+3]
1875it [00:29, 62.76it/s, loss=1.44e+3]
1875it [00:30, 62.14it/s, loss=1.39e+3]
1875it [00:30, 61.89it/s, loss=1.42e+3]
1875it [00:30, 60.60it/s, loss=1.5e+3]
