In [2]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

import numpy as np
from mpl_toolkits.axes_grid1 import ImageGrid
from torchvision.utils import save_image, make_grid

In [3]:
tensor_transform = transforms.ToTensor()
dataset = datasets.CIFAR10(
    root="../../data", train=True, download=True, transform=tensor_transform
)
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=32, shuffle=True)

In [None]:
def vae_loss(recon, x, mu, log_var):
    recon_loss = F.binary_cross_entropy(recon, x, reduction='sum')
    D_kl = (-0.5) * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss + D_kl, recon_loss, D_kl

In [None]:
class vae(nn.Module):
    def __init__(self):
        super(vae, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )

        self.mu_encoder = nn.Conv2d()
        self.var_encoder = nn.Conv2d()

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )

    def encode(self, x):
        x = self.encoder(x)
        mean, var = self.mu_encoder(x), self.var_encoder(x)
        return mean, var

    def decode(self, z):
        return self.decoder(z)

    def reparameter(self, mu, var):
        epsilon = torch.randn_like(var)
        latent = mu + var * epsilon
        return latent

    def forward(self, x):
        mean, var = self.encode(x)
        std = torch.exp(0.5 * var)
        z = self.reparameter(mean, std)
        x_hat = self.decode(z)
        return x, x_hat, z, mean, var