In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision

In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [3]:
MNIST = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST('./data', transform=torchvision.transforms.ToTensor(), download=True),
            batch_size=128,
)

In [4]:
class Encoder(nn.Module):
    def __init__(self, latent_dims):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.conv2 = nn.Conv2d(16, 16, 3)
        self.conv3 = nn.Conv2d(16, 10, 3)
        self.linear1 = nn.Linear(4840, latent_dims)
        self.linear2 = nn.Linear(4840, latent_dims)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = torch.flatten(x, start_dim=1)
        mu = self.linear1(x)
        sigma = self.linear2(x)
        return mu, sigma

In [5]:
class Decoder(nn.Module):
    def __init__(self, latent_dims):
        super(Decoder, self).__init__()
        self.latent_dims = latent_dims
        self.conv1 = nn.ConvTranspose2d(10, 16, 3)
        self.conv2 = nn.ConvTranspose2d(16, 16, 3)
        self.conv3 = nn.ConvTranspose2d(16, 1, 3)
        self.linear = nn.Linear(latent_dims, 4840)
        
    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], 10, 22, 22)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return torch.sigmoid(self.conv3(x))

In [6]:
class VAE(nn.Module):
    def __init__(self, latent_dims):
        super(VAE, self).__init__()
        self.enc = Encoder(latent_dims)
        self.dec = Decoder(latent_dims)
        self.N = torch.distributions.Normal(0, 1)
        
    def forward(self, x):
        mu, sigma = self.enc(x)
#         z = mu + sigma*self.N.sample(mu.shape)
#         x_hat = self.dec(z)
        x_hat = self.dec(mu)
        kl_divergence = torch.sum(sigma**2 + mu**2 - torch.log(sigma) - 1/2)
        return x_hat, kl_divergence

In [7]:
def train(data, latent_dims):
    vae = VAE(latent_dims)
    opt = torch.optim.Adam(vae.parameters())
    losses = []
    for x, y in data:
        opt.zero_grad()
        x_hat, kl_divergence = vae(x)
        reconstruction_loss = ((x - x_hat)**2).mean()
#         loss = kl_divergence + reconstruction_loss
        loss = reconstruction_loss
        loss.backward()
        opt.step()
        losses.append(loss)
    return vae

In [None]:
vae = VAE(10)
opt = torch.optim.Adam(vae.parameters())
losses = []
for epoch in range(10):
    for i, (x, y) in enumerate(MNIST):
        opt.zero_grad()
        x_hat, kl_divergence = vae(x)
    #     reconstruction_loss = ((x - x_hat)**2).mean()
        reconstruction_loss = F.binary_cross_entropy(x_hat, x)
    #         loss = kl_divergence + reconstruction_loss
        loss = reconstruction_loss
        loss.backward()
        opt.step()
        losses.append(loss)
        print(f'{i}/{len(MNIST)}: {loss.detach().numpy()}', end='\r')

112/469: 0.25742208957672127

In [None]:
plt.plot(losses)

In [None]:
plt.imshow(x[20].reshape(28, 28))

In [None]:
x_hat, _ = vae(x)

In [None]:
plt.imshow(x_hat.detach().numpy()[20].reshape(28, 28))