In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.distributions
from torchvision.datasets import MNIST
import torchvision
import numpy as np

In [49]:
!rm -rf data

In [57]:
img_transform = torchvision.transforms.Compose({
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5,),(0.5,))
})

In [58]:
dataset = MNIST('./data',download=True,transform=img_transform)

In [59]:
dataloader = DataLoader(dataset,batch_size=256,shuffle=True)

In [41]:
class VAE_encoder(nn.Module):
  def __init__(self, in_dim, latent_dim):
    super().__init__()
    self.linear1 = nn.Linear(in_dim, 512)
    self.linear2 = nn.Linear(512, latent_dim)
    self.linear3 = nn.Linear(512, latent_dim)

    self.N = torch.distributions.Normal(0,1)
    self.N.loc = self.N.loc.cuda()
    self.N.scale = self.N.scale.cuda()
    self.kl = 0

  def forward(self, x):
    x = F.relu(self.linear1(x))
    mu = self.linear2(x)
    sigma = torch.exp(self.linear3(x))
    z = mu + sigma*self.N.sample(mu.shape)
    self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 0.5).sum()
    return z

In [42]:
class Decoder(nn.Module):
  def __init__(self, latent_dim, out_dim):
    super().__init__()
    self.linear1 = nn.Linear(latent_dim, 512)
    self.linear2 = nn.Linear(512, out_dim)

  def forward(self, z):
    z = F.relu(self.linear1(z))
    z = torch.sigmoid(self.linear2(z))
    return z


In [43]:
class VAE(nn.Module):
  def __init__(self, in_dim, latent_dim, out_dim):
    super().__init__()
    self.encoder = VAE_encoder(in_dim, latent_dim)
    self.decoder = Decoder(latent_dim, out_dim)

  def forward(self, x):
    z = self.encoder(x)
    return self.decoder(z)

In [82]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = VAE(784, 2, 784).to(device)
epochs = 30
lr = 3e-3
opt = torch.optim.Adam(model.parameters(), lr=lr)
mse_loss = nn.MSELoss()

In [1]:
for epoch in range(epochs):
  for data in dataloader:
    img, _ = data
    img = img.view(img.size(0),-1).to(device)
    opt.zero_grad()
    output = model(img)
    loss = mse_loss(output, img.data) + model.encoder.kl
    loss.backward()
    opt.step()
  if epoch % 5 == 0:
    print(f"Epoch: {epoch}, loss: {loss.item():3f}")