<a href="https://colab.research.google.com/github/SonnetSaif/VAE-from-scratch_PyTorch/blob/main/VAE_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

In [14]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 20
h_dim = 200
input_dim = 784
batch_size = 32
num_epochs = 10

In [15]:
class VariationalAutoEncoder(nn.Module):
  def __init__(self, input_dim, h_dim, z_dim):
    super().__init__()

    self.img_dim_to_h_dim = nn.Linear(input_dim, h_dim)
    self.h_dim_to_mean = nn.Linear(h_dim, z_dim)
    self.h_dim_to_deviation = nn.Linear(h_dim, z_dim)

    self.z_dim_to_hid_dim = nn.Linear(z_dim, h_dim)
    self.hid_dim_to_img_dim = nn.Linear(h_dim, input_dim)

    self.ReLU = nn.ReLU()

  def encoder(self, x):
    h = self.ReLU(self.img_dim_to_h_dim(x))
    mean = self.h_dim_to_mean(h)
    deviation = self.h_dim_to_deviation(h)
    return mean, deviation

  def decoder(self, z):
    h = self.ReLU(self.z_dim_to_hid_dim(z))
    img = torch.sigmoid(self.hid_dim_to_img_dim(h))
    return img

  def forward(self, x):
    mean, deviation = self.encoder(x)
    z = mean + deviation * (torch.randn_like(deviation))
    x_new = self.decoder(z)
    return x_new, mean, deviation

In [16]:
transforms = transforms.Compose([
    transforms.ToTensor()
])
dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms, download=True)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
model = VariationalAutoEncoder(input_dim, h_dim, z_dim).to(device)
optim = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCELoss()

In [None]:
# if __name__ == "__main__":
#   x = torch.randn(4, 28*28)
#   vae = VariationalAutoEncoder(input_dim, h_dim, z_dim)
#   x_mod, mean, deviation = vae(x)

In [17]:
for epoch in range(num_epochs):
  for idx, (img, _) in enumerate(train_loader):
    batch_size = img.shape[0]
    img = img.to(device).view(batch_size, input_dim)
    img_mod, mean, deviation = model(img)

    loss = criterion(img_mod, img)
    kl_divergence = - torch.sum(1 + torch.log(deviation.pow(2)) - mean.pow(2) - deviation.pow(2))

    loss = loss + kl_divergence
    optim.zero_grad()
    loss.backward()
    optim.step()

KeyboardInterrupt: ignored