In [6]:
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as transforms

In [None]:
class autoencoder(nn.Module):
  def __init__(self, x_dim, hidden_dim, z_dim=10):
    super(autoencoder,self).__init__()
    self.enc_layer1 = nn.Linear(x_dim, hidden_dim)
    self.enc_layer2 = nn.Linear(hidden_dim, z_dim)

    self.dec_layer1 = nn.Linear(z_dim, hidden_dim)
    self.dec_layer2 = nn.Linear(hidden_dim, x_dim)

  def encode(self, x):
    x = F.relu(self.enc_layer1(x))
    z = F.relu(self.enc_layer2(x))
    return z

  def decode(self, z):
    output = F.relu(self.dec_layer1(z))
    output = F.relu(self.dec_layer2(output))
    return output

  def forward(self, x):
    z = self.encode(x)
    output = self.decode(z)
    return output 

def loss_function(output, x):
  recon_loss = F.mse_loss(output, x, reduction='sum')
  return recon_loss


In [9]:
autoencoder = autoencoder(256,32)

In [None]:
class VAE(nn.Module):
  def __init__(self, x_dim, hidden_dim, z_dim=10):
    super(VAE, self).__init__()
    self.enc_layer1 = nn.Linear(x_dim, hidden_dim)
    self.enc_layer2_mu = nn.Linear(hidden_dim, z_dim)
    self.enc_layer2_logvar = nn.Linear(hidden_dim, z_dim)

    self.dec_layer1 = nn.Linear(z_dim, hidden_dim)
    self.dec_layer2 = nn.Linear(hidden_dim, x_dim)

  def encode(self, x):
    x = F.relu(self.enc_layer1(x))
    mu = F.relu(self.enc_layer2_mu(x))
    logvar = F.relu(self.enc_layer2_logvar(x))
    return mu,logvar

  def reparameterize(self, mu, logvar):
    std = torch.exp(logvar/2)
    eps = torch.randn_like(std)
    z = mu + std*eps
    return z

  def decoder(self, z):
    output = F.relu(self.dec_layer1(z))
    output = F.relu(self.dec_layer2(output))
    return output

  def forward(self, x):
    mu, logvar = self.encode(x)
    z = self.reparameterize(mu, logvar)
    output = self.decode(z)
    return output, z, mu, logvar

def loss_function(output, x, mu, logvar):
  renc_loss = F.mse_loss(output, x, reduction='sum') / batch_size
  kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
  return renc_loss + 0.002 * kl_loss



    