<a href="https://colab.research.google.com/github/angzhifan/Auto-Encoding_Variational_Bayes/blob/main/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
#VAE
#Author: Angzhi Fan fana@uchicago.edu
#Oct 11, 2020

import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import math
import scipy.io
import random

def stdmtx(X):
    means = X.mean(axis =1)
    stds = X.std(axis= 1, ddof=1)
    X= X - means[:, np.newaxis]
    #X= X / stds[:, np.newaxis]
    return np.nan_to_num(X)


In [None]:
mnist = scipy.io.loadmat('/content/drive/My Drive/dataset/MNIST/mnist_all.mat')
mnist_train = np.concatenate((mnist['train0'],mnist['train1'],mnist['train2'],
                              mnist['train3'],mnist['train4'],mnist['train5'],mnist['train6'],
                              mnist['train7'],mnist['train8'],mnist['train9']), axis=0)
mnist_test = np.concatenate((mnist['test0'],mnist['test1'],mnist['test2'],
                              mnist['test3'],mnist['test4'],mnist['test5'],mnist['test6'],
                              mnist['test7'],mnist['test8'],mnist['test9']), axis=0)
print(mnist_train.shape)
print(mnist_test.shape)
#mnist_mean = np.concatenate((mnist_train,mnist_test), axis=0).mean()
#mnist_sd = np.sqrt(np.concatenate((mnist_train,mnist_test), axis=0).var())
mnist_train = (mnist_train-mnist_mean)/mnist_sd
mnist_test = (mnist_test-mnist_mean)/mnist_sd
print(mnist_train.var(0))


In [3]:
def load_data():
    train_file = '/content/drive/My Drive/dataset/BinaryMNIST/binarized_mnist_train.amat'
    valid_file = '/content/drive/My Drive/dataset/BinaryMNIST/binarized_mnist_valid.amat'
    test_file = '/content/drive/My Drive/dataset/BinaryMNIST/binarized_mnist_test.amat'
    mnist_train = np.concatenate([np.loadtxt(train_file),np.loadtxt(valid_file)])
    mnist_test = np.loadtxt(test_file)
    return mnist_train, mnist_test

mnist_train, mnist_test = load_data()
print(mnist_train.shape)
print(mnist_test.shape)

(60000, 784)
(10000, 784)


In [None]:
ff = scipy.io.loadmat('/content/drive/My Drive/dataset/Frey_Face/frey_rawface.mat')['ff'].transpose()/256
test_index = random.sample([i for i in range(1965)],281)
train_index = list(set([i for i in range(1965)])-set(test_index))
ff_train = ff[train_index, :]
ff_test = ff[test_index, :]
print(ff_train.shape)
print(ff_test.shape)

(1684, 560)
(281, 560)


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# VAE with one stochastic layer z
class VAE(nn.Module):
    
    def __init__(self, t, d, h_num, out_type):
        super(VAE, self).__init__()
        self.dim = d
        self.Nz = t
        self.hid_num = h_num
        self.output_type = out_type
        self.fc1 = nn.Linear(d, h_num)
        self.fc2_mu = nn.Linear(h_num, t)
        self.fc2_sigma = nn.Linear(h_num, t)
        self.fc3 = nn.Linear(t, h_num)
        if out_type == 'gaussian':
          self.fc4_mu = nn.Linear(h_num, d)
          self.fc4_sigma = nn.Linear(h_num, d)
        else:
          self.fc4 = nn.Linear(h_num, d)

            
    def forward(self, x):
        x = x.view(-1,self.dim)
        x = torch.tanh(self.fc1(x))
        mu_z = self.fc2_mu(x)
        log_sigma2_z = self.fc2_sigma(x)
        eps = torch.randn_like(mu_z)
        x = mu_z + torch.exp(log_sigma2_z)*eps
        x = torch.tanh(self.fc3(x))
        if self.output_type =='gaussian':
          if self.dim == 560:
            mu = torch.sigmoid(self.fc4_mu(x))
          else: 
            mu = self.fc4_mu(x)
          log_sigma2 = self.fc4_sigma(x)
          return mu, mu_z, log_sigma2, log_sigma2_z
        else:
          x = self.fc4(x)
          return x, mu_z, _, log_sigma2_z

    


cpu


In [5]:
def test_function(net, test_n, dataset, out_type):
    if dataset == 'ff':
        testloader = torch.utils.data.DataLoader(ff_test, batch_size=100,shuffle=False)
    else:
        testloader = torch.utils.data.DataLoader(mnist_test, batch_size=100,shuffle=False)
    ll = 0.0
    for i,data in enumerate(testloader, 0):
      with torch.no_grad():
        test = data.to(device)
        output = net(test.float())

        # the negative KL term
        negtive_KL = (torch.ones_like(output[3])+output[3]-output[1]*output[1]-torch.exp(output[3])).sum(1)/2


        # the log conditional prob term
        if out_type =='gaussian':
          test_minus_mu = test-output[0]
          log_p_x_given_z = -torch.ones_like(test).sum(1)*np.log(2*math.pi)/2-output[2].sum(1)/2-(test_minus_mu*test_minus_mu/(2*torch.exp(output[2]))).sum(1) 
        else:
          log_p_x_given_z = torch.sum(output[0]*test-torch.log(1+torch.exp(output[0])), 1)


        # sum of the lower bound
        L = negtive_KL.sum()+log_p_x_given_z.sum()

        ll += L.item()
    return ll/test_n

print("Finished loading test function")

Finished loading test function


In [None]:
dataset = 'ff'
batch_size = 100
continued = 0
Nz = 20
output_type = 'gaussian'

if dataset == 'ff':
  dim = 560
  hid_num = 200
  train_num = 1684
  test_num = 281
  trainloader = torch.utils.data.DataLoader(ff_train, batch_size=batch_size, 
                                         shuffle=True, num_workers=2)
elif dataset == 'mnist':
  dim = 28*28
  hid_num = 500
  train_num = 60000
  test_num = 10000
  trainloader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, 
                                         shuffle=True, num_workers=2)
else:
  raise Exception("Invalid Dataset")

net = VAE(Nz, dim, hid_num, output_type)
if continued == 0:
  for p in net.parameters():
    torch.nn.init.normal_(p,0,0.1)
elif continued < 10000:
  net.load_state_dict(torch.load('/content/drive/My Drive/VAE/model/vae_net'+'_Nz'+str(Nz)+'_dataset'+dataset+'_'+str(continued)+'.pth'))
else:
  raise Exception("Invalid Starting Epoch")

net.to(device)


with open('/content/drive/My Drive/VAE/outfile_Nz'+str(Nz)+'_dataset'+dataset+'_'+'.txt', 'w') as outfile:
    outfile.write('output of the code '+'\n'+'author:Angzhi Fan fana@uchicago.edu'+'\n')
    
start = time.time()

optimizer = torch.optim.Adagrad(net.parameters())


for epoch in range(continued, 171):
    if epoch in [1, 3, 10, 17, 33, 66, 100, 140, 170]:
        PATH = '/content/drive/My Drive/VAE/model/vae_net'+'_Nz'+str(Nz)+'_dataset'+dataset+'_'+str(epoch)+'.pth'
        torch.save(net.state_dict(), PATH)
        with open('/content/drive/My Drive/VAE/outfile_Nz'+str(Nz)+'_dataset'+dataset+'_'+'.txt', 'a') as outfile:
            elbo = test_function(net, test_num, dataset, output_type)
            print('test average L(x)=', elbo)
            outfile.write('test average L(x)='+str(elbo)+'\n')
    running_loss = 0.0
    for i,data in enumerate(trainloader, 0):
        train = data.to(device)
        optimizer.zero_grad()
        
        output = net(train.float())
        #print(output[0][12,12])

        # the negative KL term
        negtive_KL = (torch.ones_like(output[3])+output[3]-output[1]*output[1]-torch.exp(output[3])).sum(1)/2


        # the log conditional prob term
        if output_type =='gaussian':
            train_minus_mu = train - output[0]
            log_p_x_given_z = -torch.ones_like(train).sum(1)*np.log(2*math.pi)/2-output[2].sum(1)/2-(train_minus_mu*train_minus_mu/(2*torch.exp(output[2]))).sum(1) 
        else:
            log_p_x_given_z = torch.sum(output[0]*train-torch.log(1+torch.exp(output[0])), 1)


        # train the model
        loss = -negtive_KL.mean()-log_p_x_given_z.mean()
        loss.backward()
        optimizer.step()
        running_loss -= negtive_KL.sum().item()
        #print(negtive_KL.sum().item())
        running_loss -= log_p_x_given_z.sum().item()
        #print(log_p_x_given_z.sum().item())
        #print((train_minus_mu*train_minus_mu/(2*torch.exp(output[2]))).sum())

    print('[%d] loss: %.3f' % (epoch+1, running_loss/train_num))
    with open('/content/drive/My Drive/VAE/outfile_Nz'+str(Nz)+'_dataset'+dataset+'_'+'.txt', 'a') as outfile:
        outfile.write('[%d] loss: %.3f' % (epoch+1, running_loss/train_num)+'\n')
        
PATH = '/content/drive/My Drive/VAE/model/vae_net'+'_Nz'+str(Nz)+'_dataset'+dataset+'_'+'.pth'
torch.save(net.state_dict(), PATH)

print('Finished Training')
with open('/content/drive/My Drive/VAE/outfile_Nz'+str(Nz)+'_dataset'+dataset+'_'+'.txt', 'a') as outfile:
    outfile.write('Finished Training'+'\n'+'time cost:'+str(time.time()-start)+'\n')