In [36]:
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm

In [37]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [38]:
# Input img -> Hidden dim -> (mean, std) -> parametrization trick -> Decoder -> Ouput img
class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, h_dim=200, z_dim=20):
        super().__init__()
        # encoder
        self.img2hid = nn.Linear(input_dim, h_dim)
        self.hid2mu = nn.Linear(h_dim, z_dim)
        self.hid2sigma = nn.Linear(h_dim, z_dim)
        
        # decoder
        self.z2hid = nn.Linear(z_dim, h_dim)
        self.hid2img = nn.Linear(h_dim, input_dim)
        
    
    def encode(self, x):
        h = F.relu(self.img2hid(x))
        mu, sigma = self.hid2mu(h), self.hid2sigma(h)
        return mu, sigma

    def decode(self, z):
        h = F.relu(self.z2hid(z))
        img = F.sigmoid(self.hid2img(h))
        return img

    def forward(self, x):
        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma)
        z_new = mu + sigma*epsilon
        x_reconstructed = self.decode(z_new)
        return x_reconstructed, mu, sigma

In [39]:
x = torch.randn(4,28*28)
vae = VariationalAutoEncoder(input_dim=784)
x_reconstracted, mu, sigma = vae(x)
x_reconstracted.shape, mu.shape, sigma.shape

(torch.Size([4, 784]), torch.Size([4, 20]), torch.Size([4, 20]))

In [40]:
INPUT_DIM = 784
H_DIM = 200
Z_DIM = 20
NUM_EPOCHS = 100
BATCH_SIZE = 32
LR_RATE = 3e-4 #Karpathy constant

In [41]:
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=False)

In [43]:
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(device)
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")

In [44]:
for epoch in range(NUM_EPOCHS):
    total_loss = 0
    for i, (x,_) in enumerate(train_loader):
        x = x.to(device)
        # Forward pass
        x = x.view(x.shape[0], INPUT_DIM)
        x_reconstracted, mu, sigma = model(x)
        
        # Compute loss
        reconstruction_loss = loss_fn(x_reconstracted, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
        
        # Backporp
        loss = reconstruction_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch: {epoch}: {total_loss}")



Epoch: 0: 12345927.324707031
Epoch: 1: 9492996.387207031
Epoch: 2: 8801365.610839844
Epoch: 3: 8436847.585449219
Epoch: 4: 8223414.5185546875
Epoch: 5: 8080029.681884766
Epoch: 6: 7979298.382568359
Epoch: 7: 7906166.759033203
Epoch: 8: 7850454.579833984
Epoch: 9: 7802331.801025391
Epoch: 10: 7764921.876464844
Epoch: 11: 7733058.186279297
Epoch: 12: 7709990.785644531
Epoch: 13: 7688750.896484375
Epoch: 14: 7671211.533203125
Epoch: 15: 7653422.94921875
Epoch: 16: 7639411.596923828
Epoch: 17: 7621540.289794922
Epoch: 18: 7613981.107421875
Epoch: 19: 7598936.225830078
Epoch: 20: 7591659.594482422
Epoch: 21: 7581853.657714844
Epoch: 22: 7573038.9990234375
Epoch: 23: 7565286.30859375
Epoch: 24: 7556259.333251953
Epoch: 25: 7550420.234863281
Epoch: 26: 7542877.338134766
Epoch: 27: 7536575.951904297
Epoch: 28: 7532053.425048828
Epoch: 29: 7527390.541259766
Epoch: 30: 7520139.095458984
Epoch: 31: 7511694.109619141
Epoch: 32: 7506814.652587891
Epoch: 33: 7502936.3701171875
Epoch: 34: 7498597.356

In [45]:
model = model.to("cpu")
def inference(digit, num_examples=1):
    """
    Generates (num_examples) of a particular digit.
    Specifically we extract an example of each digit,
    then after we have the mu, sigma representation for
    each digit we can sample from that.

    After we sample we can run the decoder part of the VAE
    and generate examples.
    """
    images = []
    idx = 0
    for x, y in dataset:
        if y == idx:
            images.append(x)
            idx += 1
        if idx == 10:
            break

    encodings_digit = []
    for d in range(10):
        with torch.no_grad():
            mu, sigma = model.encode(images[d].view(1, 784))
        encodings_digit.append((mu, sigma))

    mu, sigma = encodings_digit[digit]
    for example in range(num_examples):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        out = model.decode(z)
        out = out.view(-1, 1, 28, 28)
        save_image(out, f"generated_{digit}_ex{example}.png")


In [46]:
for idx in range(10):
    inference(idx, num_examples=5)

In [52]:
images = []
idx = 0
for x, y in dataset:
    if y == idx:
        images.append(x)
        idx += 1
    if idx == 10:
        break

In [53]:
images

[tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 