In [None]:
import torch
import torchvision
from  torchvision import transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import random
import math
import numpy as np

## MNIST data prep

In [None]:
batch_size = 100
num_classes = 10
epochs = 500

lat_inp = 100
lat_hid = 200
lat_out = 784

def_img_out = 540
def_img_hid = 50

gen_lr = .0001
dis_lr = .001

batch_size_train = batch_size
batch_size_test = batch_size

In [None]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

## Networks

In [None]:
class encoder(nn.Module):
    def __init__(self,lat_inp,lat_hid,lat_out):
        super(encoder, self).__init__()
        
        #add latent vector to combined state computation
        self.fc1 = nn.Linear(lat_inp, lat_hid, bias=True)
        self.fc2 = nn.Linear(lat_hid, lat_hid, bias=True)
        self.fc3 = nn.Linear(lat_hid, lat_hid, bias=True)
        self.fc4 = nn.Linear(lat_hid, lat_hid, bias=True)
        self.fc5 = nn.Linear(lat_hid, lat_out, bias=True)
        self.fc6 = nn.Linear(lat_hid, lat_out, bias=True)
        
        self.act = nn.LeakyReLU()
        self.act1 = nn.Sigmoid()
        self.act2 = nn.Tanh()
        
        self.batch_size = batch_size
        
    def forward(self,img):
        
        hid = self.act(self.fc1(img))
        hid = self.act(self.fc2(hid))
        hid = self.act(self.fc3(hid))
        hid = self.act(self.fc4(hid))
        
        means = self.act2(self.fc5(hid))
        var = self.act2(self.fc6(hid))

        return means,var
        #out = torch.normal(means,var)
        #return out

In [None]:
class decoder(nn.Module):
    def __init__(self,lat_inp,lat_hid,lat_out):
        super(decoder, self).__init__()
        
        #add latent vector to combined state computation
        self.fc1 = nn.Linear(lat_inp, lat_hid, bias=True)
        self.fc2 = nn.Linear(lat_hid, lat_hid, bias=True)
        self.fc3 = nn.Linear(lat_hid, lat_hid, bias=True)
        self.fc4 = nn.Linear(lat_hid, lat_hid, bias=True)
        self.fc5 = nn.Linear(lat_hid, lat_out, bias=True)
        
        self.act = nn.LeakyReLU()
        self.act1 = nn.Sigmoid()
        self.act2 = nn.Tanh()
        
        self.batch_size = batch_size
        
    def forward(self,latent):
        
        out = self.act(self.fc1(latent))
        out = self.act(self.fc2(out))
        out = self.act(self.fc3(out))
        out = self.act(self.fc4(out))
        out = self.act2(self.fc5(out))
        
        return out

## Initializations

In [None]:
mse = nn.MSELoss()
kldiv = nn.KLDivLoss(reduction='batchmean')
ce = nn.CrossEntropyLoss()
bce = nn.BCELoss()

In [None]:
enc = encoder(lat_out,lat_hid,lat_inp)
dec = decoder(lat_inp,lat_hid,lat_out)

In [None]:
learning_rate = .0002
enc = encoder(lat_out,lat_hid,lat_inp)
dec = decoder(lat_inp,lat_hid,lat_out)
enc_op = optim.Adam(enc.parameters(), lr=learning_rate)
dec_op = optim.Adam(dec.parameters(), lr=learning_rate)

## Training

In [None]:
for epoch in range(500):
    enc.zero_grad()
    dec.zero_grad()
    for batch_idx, (real_imgs, target) in enumerate(train_loader):
        #encoder loss
        enc.zero_grad()
        dec.zero_grad()
        means,var = enc(real_imgs.view((batch_size,-1)))
        epsilon = torch.normal(torch.zeros(var.shape),torch.ones(var.shape))
        z = means + var*epsilon
        pred_imgs = dec(z).view(batch_size,1,28,28)
        vae_loss = mse(pred_imgs,real_imgs)
        vae_loss.backward()
        enc_op.step()
        dec_op.step()
    
    print("epoch {}  and {}".format(epoch,vae_loss))

## Plotting

In [None]:
fig=plt.figure(figsize=(15, 15))
columns = 4
rows = 4
for i in range(1, columns*rows +1):
    for batch_idx, (real_imgs, target) in enumerate(train_loader):
        means,var = enc(real_imgs.view((batch_size,-1)))
        epsilon = torch.normal(torch.zeros(var.shape),torch.ones(var.shape))
        z = means + var*epsilon
        pred_imgs = dec(z).view(batch_size,1,28,28).detach().numpy()[i][0]
        fig.add_subplot(rows, columns, i)
        plt.imshow(pred_imgs,cmap='gray')
        break
plt.show()