<a href="https://colab.research.google.com/github/SonnetSaif/VAE-from-scratch_PyTorch/blob/main/VAE_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
import torchvision.datasets as datasets
from torchvision.utils import save_image
from torch.utils.data import DataLoader

In [19]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 20
h_dim = 200
input_dim = 784
batch_size = 128
num_epochs = 20

In [20]:
class VariationalAutoEncoder(nn.Module):
  def __init__(self, input_dim, h_dim, z_dim):
    super().__init__()

    self.img_dim_to_h_dim = nn.Linear(input_dim, h_dim)
    self.h_dim_to_mean = nn.Linear(h_dim, z_dim)
    self.h_dim_to_deviation = nn.Linear(h_dim, z_dim)

    self.z_dim_to_hid_dim = nn.Linear(z_dim, h_dim)
    self.hid_dim_to_img_dim = nn.Linear(h_dim, input_dim)

    self.ReLU = nn.ReLU()

  def encoder(self, x):
    h = self.ReLU(self.img_dim_to_h_dim(x))
    mean = self.h_dim_to_mean(h)
    deviation = self.h_dim_to_deviation(h)
    return mean, deviation

  def decoder(self, z):
    h = self.ReLU(self.z_dim_to_hid_dim(z))
    img = torch.sigmoid(self.hid_dim_to_img_dim(h))
    return img

  def forward(self, x):
    mean, deviation = self.encoder(x)
    z = mean + deviation * (torch.randn_like(deviation))
    x_new = self.decoder(z)
    return x_new, mean, deviation

In [21]:
transforms = transforms.Compose([
    transforms.ToTensor()
])
dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms, download=True)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
model = VariationalAutoEncoder(input_dim, h_dim, z_dim).to(device)
optim = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCELoss()

In [None]:
# if __name__ == "__main__":
#   x = torch.randn(4, 28*28)
#   vae = VariationalAutoEncoder(input_dim, h_dim, z_dim)
#   x_mod, mean, deviation = vae(x)

In [22]:
for epoch in range(num_epochs):
  model.train()
  running_loss = 0.0
  # loop = tqdm(enumerate(train_loader))
  for idx, (img, _) in enumerate(train_loader):
    batch_size = img.shape[0]
    img = img.to(device).view(batch_size, input_dim)
    img_mod, mean, deviation = model(img)

    loss = criterion(img_mod, img)
    kl_divergence = - torch.sum(1 + torch.log(deviation.pow(2)) - mean.pow(2) - deviation.pow(2))

    loss = loss + kl_divergence
    optim.zero_grad()
    loss.backward()
    optim.step()
    running_loss += loss.item()

    # if (idx + 1) % 10 == 0:
    #         print(f"Epoch [{epoch + 1}/{num_epochs}], Iteration [{idx + 1}/{len(train_loader)}], Loss: {loss.item():.4f}")

  print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}")


Epoch [1/20], Loss: 1862.3969
Epoch [2/20], Loss: 343.8366
Epoch [3/20], Loss: 246.7746
Epoch [4/20], Loss: 163.0211
Epoch [5/20], Loss: 88.2698
Epoch [6/20], Loss: 38.2365
Epoch [7/20], Loss: 14.9988
Epoch [8/20], Loss: 7.6251
Epoch [9/20], Loss: 5.1603
Epoch [10/20], Loss: 3.9766
Epoch [11/20], Loss: 3.2346
Epoch [12/20], Loss: 2.7067
Epoch [13/20], Loss: 2.2946
Epoch [14/20], Loss: 1.9640
Epoch [15/20], Loss: 1.6810
Epoch [16/20], Loss: 1.4466
Epoch [17/20], Loss: 1.2532
Epoch [18/20], Loss: 1.0687
Epoch [19/20], Loss: 0.9285
Epoch [20/20], Loss: 0.7886


In [23]:
model = model.to("cpu")
def inference(digit, num_examples):
  images = []
  counter = 0
  # taking digits from dataset
  for i,j in dataset:
    if j == counter:
      images.append(i)
      counter += 1
    if counter == 10:
      break

  # calculating mu, sigma through encoder for those digits from dataset
  encoding_digit = []
  for idx in range(10):
    with torch.no_grad():
      mean, deviation = model.encoder(images[idx].view(1, 784))
    encoding_digit.append((mean, deviation))

  # generating new digit through decoder from mu, sigma and epsilon
  mean, deviation = encoding_digit[digit]
  for idx in range(num_examples):
    z = mean + deviation * (torch.randn_like(deviation))
    x_new = model.decoder(z)
    x_new = x_new.view(-1, 1, 28, 28)
    save_image(x_new, f"generated_{digit}_ex{idx}.png")

In [26]:
for idx in range(5):
  inference(idx, 5)