In [1]:
import torch
from torch import nn

INPUT_DIM = 784
H_DIM = 200
Z_DIM = 2

class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, h_dim=200, z_dim=20):
        super().__init__()
        # encoder
        self.img_2hid = nn.Linear(input_dim, h_dim)
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)

        # decoder
        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, input_dim)

        self.relu = nn.ReLU()

    def encode(self, x):
        h = self.relu(self.img_2hid(x))
        mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
        return mu, sigma

    def decode(self, z):
        h = self.relu(self.z_2hid(z))
        return torch.sigmoid(self.hid_2img(h))

    def forward(self, x):
        mu, sigma = self.encode(x)
        # reparametrization trick
        epsilon = torch.randn_like(sigma)
        z_new = mu + sigma*epsilon
        x_reconstructed = self.decode(z_new)
        return x_reconstructed, mu, sigma


In [2]:
import torch
import torchvision.datasets as datasets  # Standard datasets
from tqdm import tqdm
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

# Set random seed for reproducibility
seed = 42
torch.manual_seed(seed)

# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS = 30
BATCH_SIZE = 32
LR_RATE = 3e-4  

# Dataset Loading
dataset = datasets.FashionMNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")

In [3]:
# Training
for epoch in range(NUM_EPOCHS):
    loop = tqdm(enumerate(train_loader))
    for i, (x, _) in loop:
        # forward pass
        x = x.to(DEVICE).view(x.shape[0], INPUT_DIM)
        x_reconstructed, mu, sigma = model(x)

        # compute loss
        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

        # backprop
        loss = reconstruction_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())

0it [00:00, ?it/s]

1875it [00:08, 230.78it/s, loss=1.04e+4]
1875it [00:08, 231.83it/s, loss=1.01e+4]
1875it [00:08, 233.33it/s, loss=9.13e+3]
1875it [00:08, 232.52it/s, loss=9.98e+3]
1875it [00:08, 232.89it/s, loss=9.5e+3] 
1875it [00:08, 233.51it/s, loss=9.46e+3]
1875it [00:08, 234.33it/s, loss=9.22e+3]
1875it [00:08, 232.83it/s, loss=9.37e+3]
1875it [00:08, 232.80it/s, loss=1.05e+4]
1875it [00:08, 232.67it/s, loss=9.18e+3]
1875it [00:08, 234.21it/s, loss=8.62e+3]
1875it [00:08, 229.89it/s, loss=9.08e+3]
1875it [00:08, 231.28it/s, loss=1.37e+4]
1875it [00:08, 233.22it/s, loss=9.68e+3]
1875it [00:08, 232.57it/s, loss=1.28e+4]
1875it [00:07, 234.92it/s, loss=9.2e+3] 
1875it [00:07, 242.55it/s, loss=9.64e+3]
1875it [00:07, 244.58it/s, loss=1.02e+4]
1875it [00:07, 242.19it/s, loss=1.28e+4]
1875it [00:07, 241.84it/s, loss=1.03e+4]
1875it [00:07, 242.35it/s, loss=8.4e+3] 
1875it [00:07, 237.47it/s, loss=1.32e+4]
1875it [00:07, 235.12it/s, loss=9.68e+3]
1875it [00:07, 245.66it/s, loss=1.1e+4] 
1875it [00:07, 2

In [5]:
torch.save(model.state_dict(), f'output/FashionMNIST/model_{Z_DIM}.pth')

In [9]:
# not useful

def inference(label, num_examples=1):
    """
    Generates (num_examples) of a particular digit.
    Specifically we extract an example of each digit,
    then after we have the mu, sigma representation for
    each digit we can sample from that.

    After we sample we can run the decoder part of the VAE
    and generate examples.
    """
    for x, y in dataset:
        if y == label:
            images = x
            break
    
    mu, sigma = model.encode(images.view(1, 784))

    for example in range(num_examples):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        out = model.decode(z)
        out = out.view(-1, 1, 28, 28)
        save_image(out, f"output/FashionMNIST/generated_{label}_ex{example}.png")

for label in range(10):
    inference(label, num_examples = 4)