# Fuck VAE

<img src="assets/vae.png"  width="600" />

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
import torch
import torch.nn as nn
import torch.nn.functional as F

## Fuck VAE Loss

$$
\begin{aligned}
\mathcal{L} & =-\frac{1}{n} \sum_{i=1}^{n} \ell\left(p_{\theta}, q_{\phi}\right) \\
& =\frac{1}{n} \sum_{i=1}^{n} D_{K L}\left(q_{\phi}, p\right)-\frac{1}{n} \sum_{i=1}^{n} \mathbb{E}_{q_{\phi}}\left[\log p_{\theta}\left(x_{i} \mid z\right)\right] \\
& =\frac{1}{n} \sum_{i=1}^{n} D_{K L}\left(q_{\phi}, p\right)-\frac{1}{n m} \sum_{i=1}^{n} \sum_{j=1}^{m} \log p_{\theta}\left(x_{i} \mid z_{j}\right)
\end{aligned}
$$

In [None]:
def vae_loss(recons, input, mu, log_var, kld_weight):
    recons_loss = F.mse_loss(recons, input)
    kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)
    kld_loss = torch.mean(kld_loss, dim=0)
    loss = recons_loss + kld_weight * kld_loss
    return loss, recons_loss, kld_loss

In [None]:
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dims, device='cuda', **kwargs):
        super(VAE, self).__init__()
        
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.device = device
        
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]
        
        # build encoder
        modules = []
        in_channels = input_dim
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(in_channels, h_dim, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU()
                )
            )
            in_channels = h_dim
            
        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
        
        # build decoder
        modules = []
        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
        
        hidden_dims.reverse()
        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU()
                )
            )
        
        self.decoder = nn.Sequential(*modules)
        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[-1], input_dim, kernel_size=3, stride=2),
            nn.BatchNorm2d(hidden_dims[-1]),
            nn.LeakyReLU(),
            nn.Conv2d(hidden_dims[-1], input_dim, kernel_size=3, padding=1),
            nn.Tanh())
    
    def encode(self, input):
        # input: (batch_size, input_dim, height, width)
        x = self.encoder(input)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        
        return mu, log_var
    
    def decode(self, z):
        # z: (batch_size, latent_dim)
        x = self.decoder_input(z)
        x = x.view(-1, x.size(0), 2, 2)
        x = self.decoder(x)
        x = self.final_layer(x)
        return x
    
    # reparameterization trick
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        
        return eps * std + mu
    
    def forward(self, input):
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var
    
    def loss_function(self, input, recon_x, mu, log_var, **kwargs):
        loss, recon_loss, kld_loss = vae_loss(recon_x, input, mu, log_var, kwargs['M_N'])
        return {'loss': loss, 'Reconstruction_Loss':recon_loss.detach(), 'KLD':-kld_loss.detach()}
        
    
    def sample(self, num_samples):
        z = torch.randn(num_samples, self.latent_dim).to(self.device)
        samples = self.decode(z)
        return samples
    
    def generate(self, x):
        # x: (batch_size, input_dim, height, width)
        return self.forward(x)[0]