In [1]:
import torch
from torch import nn

In [None]:
class VarietionalAutoEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VarietionalAutoEncoder, self).__init__()
        self.img_2hid = nn.Linear(input_dim, hidden_dim)
        self.hid_2mu = nn.Linear(hidden_dim, latent_dim)
        self.hid_2sigma = nn.Linear(hidden_dim, latent_dim)

        self.z_2hid = nn.Linear(latent_dim, hidden_dim)
        self.hid_2img = nn.Linear(hidden_dim, input_dim)

        self.relu = nn.ReLu()
    
    def encode(self, x):
        x = self.relu(self.img_2hid(x))
        return self.hid_2mu(x), self.hid_2sigma(x)

    def decode(self, z):
        x_hat = self.relu(self.z_2hid(z))
        return torch.sigmoid(self.hid_2img(x_hat))

    def forward(self, x):
        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma)
        z_reparametrized = mu + sigma * epsilon

        x_hat = self.decode(z_reparametrized)

        return x_hat, mu, sigma

In [None]:
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_dim = 28*28
h_dim = 252
z_dim = 20
epochs = 10
batch_size = 32
lr = 3e-4 # karpathy constant

model = VarietionalAutoEncoder(input_dim, h_dim, z_dim).to(device)

dataset = datasets.mnist(root='dataset/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(),lr=lr)
loss_fn = nn.BCELoss(reduction="sum")

In [None]:
for epoch in range(epochs):
    for _, (x, _) in enumerate(train_loader):
        
        x = x.to(device).view(x.shape[0], input_dim)
        x_reconstructed, mu, sigma = model(x)

        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_div = torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

        loss = reconstruction_loss - kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()