In [22]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import torch

In [23]:
class Encoder(nn.Module):
    def __init__(self,img_channels, latent_size):
        super(Encoder,self).__init__()
        self.img_channels = img_channels
        self.latent_size = latent_size
        self.conv1 = nn.Conv2d(in_channels = img_channels, out_channels = 32, kernel_size = 4, stride = 2)
        self.conv2 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 4, stride = 2)
        self.conv3 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 4, stride = 2)
        self.conv4 = nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 4, stride = 2)
        
        self.mu = nn.Linear(in_features = 2*2*256, out_features = latent_size)
        self.logsigma = nn.Linear(in_features = 2*2*256, out_features = latent_size)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = x.view(x.size(0), -1)
        
        mu = self.mu(x)
        logsigma = self.logsigma(x)
        sigma = logsigma.exp()
        epsilon = torch.randn_like(sigma)
        
        z = epsilon.mul(sigma).add_(mu)        
        return mu, logsigma

In [24]:
class Decoder(nn.Module):
    def __init__(self,img_channels, latent_size):
        super(Decoder,self).__init__()
        self.img_channels = img_channels
        self.latent_size = latent_size
        self.linear1 = nn.Linear(latent_size,1024)
        self.deconv1 = nn.ConvTranspose2d(in_channels = 1024, out_channels = 128, kernel_size = 5, stride = 2)
        self.deconv2 = nn.ConvTranspose2d(in_channels = 128, out_channels = 64, kernel_size = 5, stride = 2)
        self.deconv3 = nn.ConvTranspose2d(in_channels = 64, out_channels = 32, kernel_size = 6, stride = 2)
        self.deconv4 = nn.ConvTranspose2d(in_channels = 32, out_channels = img_channels, kernel_size = 6, stride = 2)
        
    def forward(self,z):
        z = F.relu(self.linear1(z))
        z = z.unsqueeze(-1).unsqueeze(-1)
        z = F.relu(self.deconv1(z))
        z = F.relu(self.deconv2(z))
        z = F.relu(self.deconv3(z))
        z = F.sigmoid(self.deconv4(z))
        return z

In [25]:
class VAE(nn.Module):
    def __init__(self,img_channels, latent_size):
        super(VAE,self).__init__()
        self.latent_size = latent_size
        self.encoder = Encoder(img_channels, latent_size)
        self.decoder = Decoder(img_channels, latent_size)
        
    def forward(x):
        mu,logsigma = self.encoder(x)
        
        sigma = logsigma.exp()
        epsilon = torch.randn_like(sigma)
        z = epsilon.mul(sigma).add_(mu)   
        
        recon_x = self.decoder(z)
        return recon_x, mu, logsigma

$KL_{loss}=-\frac{1}{2}(2\log(\sigma_1)-\sigma_1^2-\mu_1^2+1)$  if σ is the standard deviation.   
Warning, if σ if the variance, $=-\frac{1}{2}(\log(\sigma_1)-\sigma_1-\mu^2_1+1)$

In [26]:
class ConvVAE():
    def __init__(self,img_channels, latent_size, learning_rate):
        self.vae = VAE(img_channels, latent_size)
        self.learning_rate = learning_rate
        self.optimizer = optim.Adam(self.vae.parameters(), lr = learning_rate)
        self.losses = []
        
    def train(self, batch_img):
        recon_x, mu, logsigma = vae(batch_img)
        
        BCE = F.mse_loss(recon_x, batch_img, size_average=False)
        #If the training is bad, add a threshold to KLD
        KLD = -0.5 * torch.sum(1 + 2 * logsigma - mu.pow(2) - (2 * logsigma).exp())
        loss = BCE + KLD
        losses.append(loss)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()