<a href="https://colab.research.google.com/github/MathBorgess/into_pytorch/blob/master/generative_models/vaes_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn

In [2]:
class VarietionalAutoEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VarietionalAutoEncoder, self).__init__()
        self.img_2hid = nn.Linear(input_dim, hidden_dim)
        self.hid_2mu = nn.Linear(hidden_dim, latent_dim)
        self.hid_2sigma = nn.Linear(hidden_dim, latent_dim)

        self.z_2hid = nn.Linear(latent_dim, hidden_dim)
        self.hid_2img = nn.Linear(hidden_dim, input_dim)

        self.relu = nn.ReLU()

    def encode(self, x):
        x = self.relu(self.img_2hid(x))
        return self.hid_2mu(x), self.hid_2sigma(x)

    def decode(self, z):
        x_hat = self.relu(self.z_2hid(z))
        return torch.sigmoid(self.hid_2img(x_hat))

    def forward(self, x):
        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma)
        z_reparametrized = mu + sigma * epsilon

        x_hat = self.decode(z_reparametrized)

        return x_hat, mu, sigma

In [3]:
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_dim = 28*28
h_dim = 252
z_dim = 20
epochs = 10
batch_size = 32
lr = 3e-4 # karpathy constant

model = VarietionalAutoEncoder(input_dim, h_dim, z_dim).to(device)

dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(),lr=lr)
loss_fn = nn.BCELoss(reduction="sum")

In [5]:
for epoch in range(epochs):
    for _, (x, _) in enumerate(train_loader):

        x = x.to(device).view(x.shape[0], input_dim)
        x_reconstructed, mu, sigma = model(x)

        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_div = torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

        loss = reconstruction_loss - kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'epoch: {epoch}/{epochs} loss: {loss}')

epoch: 0/10 loss: 4709.1533203125
epoch: 1/10 loss: 4646.25048828125
epoch: 2/10 loss: 4240.48828125
epoch: 3/10 loss: 4204.97412109375
epoch: 4/10 loss: 4413.064453125
epoch: 5/10 loss: 4084.40478515625
epoch: 6/10 loss: 4326.85205078125
epoch: 7/10 loss: 4229.59326171875
epoch: 8/10 loss: 4042.04931640625
epoch: 9/10 loss: 4367.39453125


In [10]:
def inference(digit, samples=1):
    images = []
    for x, y in dataset:
        if y == digit:
            images.append(x.to(device))
    mu, sigma = (0,0)
    with torch.no_grad():
        mu, sigma = model.encode(images[0].view(1, input_dim))

    for i in range(samples):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        out = model.decode(z)
        out = out.view(-1, 1, 28, 28)
        save_image(out, f"generated_{digit}_ex{i}.png")

In [11]:
inference(3, 3)