In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [4]:

# imput img -> hidden dim -> mean, std -> parametrization trick -> decoder -> output img
class VariantionalAutoEncoder(nn.Module):
    def __init__(self, input_dim, h_dim=200, z_dim=20):
        super().__init__()
        #encoder
        self.img_2_hid = nn.Linear(input_dim, h_dim)
        self.hid_2_mu = nn.Linear(h_dim, z_dim)
        self.hid_2_sigma = nn.Linear(h_dim, z_dim)
        
        #decoder
        self.z_2_hid = nn.Linear(z_dim, h_dim)
        self.hid_2_img = nn.Linear(h_dim, input_dim)
        
        self.relu = nn.ReLU()
        
    def encode(self, x):
        #q_phy(z|x)
        h = self.relu(self.img_2_hid(x))
        
        mu = self.hid_2_mu(h)
        sigma = self.hid_2_sigma(h)
        return mu, sigma
    
    def decode(self, z):
        # p_theta(x|z)
        h = self.z_2_hid(z)
        return torch.sigmoid(self.hid_2_img(h))
    
    def forward(self, x):
        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma)
        z_reparametrized = mu + sigma * epsilon
        x_reconstructed = self.decode(z_reparametrized)
        return x_reconstructed, mu, sigma
    
x = torch.randn(4, 784) # 28x28 -> 784
vae = VariantionalAutoEncoder(input_dim=784)
print(
    vae(x)[0].shape,
    vae(x)[1].shape,
    vae(x)[2].shape
)
    

torch.Size([4, 784]) torch.Size([4, 20]) torch.Size([4, 20])


In [21]:
import torchvision.datasets as datasets
from tqdm import tqdm
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch import optim

# configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
INPUT_DIM = 784
Z_DIM = 20
H_DIM = 200
NUM_EPOCHS = 10
BATCH_SIZE = 128
LR_RATE = 3e-4 # Karpathy constant

# Dataset loading
dataset = datasets.MNIST(root="kaggle/working/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, shuffle=True, batch_size=BATCH_SIZE)
model = VariantionalAutoEncoder(input_dim=INPUT_DIM)
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")

# Training
for epoch in range(NUM_EPOCHS):
    loop = tqdm(enumerate(train_loader))
    for i, (x, _) in loop:
        # forward
        x = x.to(DEVICE).view(-1, INPUT_DIM)
        x_reconstructed, mu, sigma = model(x)
        
        # Loss
        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
        
        # Backprop
        loss = reconstruction_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loop.set_postfix(loss=loss.item())
        

        

469it [00:11, 39.30it/s, loss=2.03e+4]
469it [00:11, 40.43it/s, loss=1.74e+4]
469it [00:11, 39.38it/s, loss=1.62e+4]
469it [00:11, 40.29it/s, loss=1.54e+4]
469it [00:11, 40.23it/s, loss=1.49e+4]
469it [00:11, 39.49it/s, loss=1.41e+4]
469it [00:11, 40.38it/s, loss=1.38e+4]
469it [00:11, 40.24it/s, loss=1.49e+4]
469it [00:11, 39.41it/s, loss=1.43e+4]
469it [00:11, 40.18it/s, loss=1.4e+4] 


In [22]:

def inference(digit, num_examples=1):
    """
    Generates (num_examples) of a particular digit.
    Specifically we extract an example of each digit,
    then after we have the mu, sigma representation for
    each digit we can sample from that.

    After we sample we can run the decoder part of the VAE
    and generate examples.
    """
    images = []
    idx = 0
    for x, y in dataset:
        if y == idx:
            images.append(x)
            idx += 1
        if idx == 10:
            break

    encodings_digit = []
    for d in range(10):
        with torch.no_grad():
            mu, sigma = model.encode(images[d].view(1, 784))

        encodings_digit.append((mu, sigma))

    mu, sigma = encodings_digit[digit]
    for example in range(num_examples):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        out = model.decode(z)
        out = out.view(-1, 1, 28, 28)
        save_image(out, f"generated_{digit}_ex{example}.png")

for idx in range(10):
    inference(idx, num_examples=5)