In [22]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [31]:
class VariationalAutoencoder(nn.Module):
  def __init__(self, input_dim, latent_dim):
    super().__init__()
    self.layer1 = nn.Linear(input_dim, 256)
    self.layer2 = nn.Linear(256, 128)
    self.layer3 = nn.Linear(128, 64)
    self.layer4_mu = nn.Linear(64, latent_dim)
    self.layer4_sigma = nn.Linear(64, latent_dim)
    self.layer5 = nn.Linear(latent_dim, 64)
    self.layer6 = nn.Linear(64, 128)
    self.layer7 = nn.Linear(128, 256)
    self.layer8 = nn.Linear(256, input_dim)
    self.relu = nn.ReLU()
    self.sigmoid = nn.Sigmoid()

  def encode(self, x):
    h1 = self.relu(self.layer1(x))
    h2 = self.relu(self.layer2(h1))
    h3 = self.relu(self.layer3(h2))
    mu = self.layer4_mu(h3)
    sigma = self.layer4_sigma(h3)
    return mu, sigma

  def decode(self, mu, sigma):
    epsilon = torch.randn_like(sigma)
    z = mu + sigma*epsilon
    g1 = self.relu(self.layer5(z))
    g2 = self.relu(self.layer6(g1))
    g3 = self.relu(self.layer7(g2))
    x_reconstructed = self.sigmoid(self.layer8(g3))

    return x_reconstructed

  def forward(self, x):
    mu, sigma = self.encode(x)
    x_reconstructed = self.decode(mu, sigma)
    return mu, sigma, x_reconstructed

In [32]:
input_dim = 784
latent_dim_t = 16
batch_size = 32
num_epochs = 5
lr = 3e-4

In [33]:
my_transforms = transforms.Compose(
    [transforms.ToTensor()]
)

dataset = datasets.MNIST(root = 'dataset1/', train = True, transform=my_transforms, download = True)
loader = DataLoader(dataset, batch_size = batch_size, shuffle=True)

In [34]:
model = VariationalAutoencoder(input_dim, latent_dim_t)
optimizer = optim.Adam(model.parameters(), lr = lr)
loss_criterion = nn.BCELoss()

In [35]:
from tqdm import tqdm

for epoch in range(num_epochs):
  for index, (img, label) in tqdm(enumerate(loader)):
    img = img.view(img.shape[0], input_dim)
    mu, sigma, x_reconstructed = model(img)
    reconstruction_loss = loss_criterion(x_reconstructed, img)
    kl_div_loss = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

    loss = reconstruction_loss + kl_div_loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

1875it [00:33, 55.47it/s]
1875it [00:35, 53.57it/s]
1875it [00:35, 52.28it/s]
1875it [00:36, 51.26it/s]
1875it [00:36, 51.64it/s]


In [None]:
x = torch.randn_like()