In [0]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.datasets import utils



if torch.cuda.is_available():
    device = torch.device("cuda")
    use_cuda = True
else:
    device = torch.device("cpu")
    use_cuda = False
    
print(device)

print(f"Your version of Pytorch is {torch.__version__}. You should use a version >0.4.") #Check

cuda
Your version of Pytorch is 1.0.1.post2. You should use a version >0.4.


In [0]:
### Importing the data with the given loader
def get_data_loader(dataset_location, batch_size):
    URL = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/"
    # start processing
    def lines_to_np_array(lines):
        return np.array([[int(i) for i in line.split()] for line in lines])
    splitdata = []
    for splitname in ["train", "valid", "test"]:
        filename = "binarized_mnist_%s.amat" % splitname
        filepath = os.path.join(dataset_location, filename)
        utils.download_url(URL + filename, dataset_location)
        with open(filepath) as f:
            lines = f.readlines()
        x = lines_to_np_array(lines).astype('float32')
        x = x.reshape(x.shape[0], 1, 28, 28)
        # pytorch data loader
        dataset = torch.utils.data.TensorDataset(torch.from_numpy(x))
        print(splitname+" : "+str(x.shape))
        dataset_loader = torch.utils.data.DataLoader(x, batch_size=batch_size, shuffle=splitname == "train")
        splitdata.append(dataset_loader)
    return splitdata
  
train, valid, test = get_data_loader("binarized_mnist", 64)

  0%|          | 0/78400000 [00:00<?, ?it/s]

Downloading http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_train.amat to binarized_mnist/binarized_mnist_train.amat


78405632it [00:04, 16697557.34it/s]                              
  0%|          | 16384/15680000 [00:00<01:37, 160597.80it/s]

train : (50000, 1, 28, 28)
Downloading http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_valid.amat to binarized_mnist/binarized_mnist_valid.amat


15687680it [00:01, 9217571.81it/s]                              
  0%|          | 16384/15680000 [00:00<01:39, 157575.53it/s]

valid : (10000, 1, 28, 28)
Downloading http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_test.amat to binarized_mnist/binarized_mnist_test.amat


15687680it [00:03, 4601117.93it/s]                              


test : (10000, 1, 28, 28)


In [0]:
class VAE_100dim(nn.Module):
    def __init__(self):
        super(VAE_100dim, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3),
            nn.ELU(),
            nn.AvgPool2d(2, stride=2),
            nn.Conv2d(32, 64, 3),
            nn.ELU(),
            nn.AvgPool2d(2, stride=2),
            nn.Conv2d(64, 256, 5),
            nn.ELU())
        self.h1=nn.Linear(256,200)                   #output mu and logvar
        
        self.h2=nn.Linear(100,256)
        self.decoder = nn.Sequential(
            nn.ELU(),
            nn.Conv2d(256, 64, 5, padding=4),
            nn.ELU(),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(64, 32, 3, padding=2),
            nn.ELU(),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(32, 16, 3, padding=2),
            nn.ELU(),
            nn.Conv2d(16, 1, 3, padding=2))
        
    def forward(self,x):
      x = self.encoder(x)
      q = self.h1(x.view(-1,256))
      mu,logvar = torch.split(q,100,dim=1)
      
      std = (logvar*0.5).exp()
      eps = torch.randn_like(mu).to(device)
      z = mu + (eps*std)
      
      x_ = self.h2(z)
      x_ = self.decoder(x_.view(-1,256,1,1))
      
      return x_, mu, logvar
      
    def generate_new_data(self,z):
      """ The following function take the z sampled from distribution and generate new examples
              
      """
      with torch.no_grad():    #no gradients to accumulate here, i.e. faster
        return self.decoder(self.h2(z).view(-1,256,1,1))
      
def loss_function(x_,x,mu,logvar):
  """ The Loss is a combination of the reconstruction loss (here binary cross entropy)
      and the KLD.
  """
  BCE = F.binary_cross_entropy_with_logits(x_.view(-1, 784), x.view(-1, 784), reduction='sum')
  
  KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
  
  return (BCE+KLD)/x.shape[0]
        

In [0]:
clf = VAE_100dim()
cuda_available = torch.cuda.is_available()
print(cuda_available)
if cuda_available:
    clf = clf.cuda()
optimizer = optim.Adam(clf.parameters(),lr=3e-4)
criterion = loss_function

True


In [0]:
## Let's train for 20 epochs

tracking_loss=[]
for epoch in range(20):
  
  training_loss=0
  for batch_idx, x in enumerate(train):
    x = x.to(device)
    optimizer.zero_grad()
    
    #Proper training
    x_, mu, logvar = clf(x)

    loss = criterion(x_, x, mu, logvar)
    loss.backward()
    training_loss += loss.item()
    optimizer.step()
    
    if (batch_idx+1) % 78 == 0:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tELBO: {:.6f}'.format(
            epoch, batch_idx * len(x), len(train.dataset),
            100. * batch_idx / len(train),
            -1*loss.item()))
  train_batches = batch_idx

  tracking_loss.append(training_loss)
  clf.eval()
  valid_loss=0
  for batch_idx, x in enumerate(valid):
    x = x.to(device)
    x_, mu, logvar = clf(x)
    loss = criterion(x_, x, mu, logvar)
    valid_loss += loss.item()
  valid_batches = batch_idx
  clf.train()
  
  print('====> Epoch: {} Avg ELBO on Train: {:.4f} || Avg ELBO on Valid: {:.4f}'.format(
      epoch, -1*training_loss/train_batches, -1*valid_loss/valid_batches ))
  



====> Epoch: 0 Avg ELBO on Train: -99.0257 || Avg ELBO on Valid: -99.5495
====> Epoch: 1 Avg ELBO on Train: -98.1449 || Avg ELBO on Valid: -98.5810
====> Epoch: 2 Avg ELBO on Train: -97.4599 || Avg ELBO on Valid: -98.2793
====> Epoch: 3 Avg ELBO on Train: -96.8526 || Avg ELBO on Valid: -97.4045
====> Epoch: 4 Avg ELBO on Train: -96.3483 || Avg ELBO on Valid: -96.8523
====> Epoch: 5 Avg ELBO on Train: -95.9207 || Avg ELBO on Valid: -96.7931
====> Epoch: 6 Avg ELBO on Train: -95.5146 || Avg ELBO on Valid: -96.0883
====> Epoch: 7 Avg ELBO on Train: -95.0871 || Avg ELBO on Valid: -96.3194
====> Epoch: 8 Avg ELBO on Train: -94.7804 || Avg ELBO on Valid: -95.5461
====> Epoch: 9 Avg ELBO on Train: -94.4900 || Avg ELBO on Valid: -95.4545
====> Epoch: 10 Avg ELBO on Train: -94.2002 || Avg ELBO on Valid: -94.9767
====> Epoch: 11 Avg ELBO on Train: -93.9941 || Avg ELBO on Valid: -94.7306
====> Epoch: 12 Avg ELBO on Train: -93.6922 || Avg ELBO on Valid: -94.8787
====> Epoch: 13 Avg ELBO on Train: 

In [0]:
torch.save(clf.state_dict(),"../Vae.pt")

In [0]:
#Loading the trained model to recuparate param:
clf.load_state_dict(torch.load("../content/2_Vae.pt"))
clf.eval()


#Evaluating log-likelihood VAE

## Here, we want to compute log(p(x|z)), log(p(z)) and log(q(z|x)) in order to use the LogSumExp trick

##### Note: We have log(p(x|z)) from BCE calculation. We need to compute log(p(z)/q(z|x)), modelised with multivariate gaussians.
##### Some terms canceled out, such as sqrt(2*pi)^100

#Given x and z:
def logP_xi(AE,x):
  """ Calculate the log-likelihood for a given batch of data
  
      Call it M time (once per batch) to get
      (log p(x_1),...,log(px_M)) estimates of size (M,)
  """
  
  with torch.no_grad():
    x_, mu, logvar = AE(x)
    
    ### Let's sample the z_i from q_phi:
    mu, logvar = mu.repeat(1,200), logvar.repeat(1,200)        #K=200 samples per x_i
    z = torch.normal(mu,(logvar*0.5).exp())                  #z_i^k sampled from q(z|x_i)!!
    z = z.reshape(-1,100)
    #print(z.size())
    x_= AE.generate_new_data(z)                        #getting g(z) for log(p(x|z))
    x_ = x_.reshape(x.shape[0],200,784)
    
    x = x.view(-1,784).repeat(1,200)                   #Need to compare each x_i to the 200 z_i samples
    x = x.reshape(x.shape[0],200,784)
        
    logP_xi_zi = -(F.binary_cross_entropy_with_logits(x_, x, reduction='none').sum(dim=2))
    
  
  z = z.reshape(x.shape[0],200,100)
  mu = mu.reshape(x.shape[0],200,100)
  logvar = logvar.reshape(x.shape[0],200,100)
  logP_z__Q_z_x = (0.5*((z-mu)**2/logvar.exp() - z**2).sum(dim=-1) + 0.5*logvar.sum(dim=-1))  #log of two gaussians
  
  #LogSumExpTrick:
  pi_max = torch.max(logP_xi_zi + logP_z__Q_z_x, dim=1,keepdim=True)[0]  #Rescaling with pi_max
  
  logP_xi = pi_max.view(-1) + torch.log((logP_xi_zi + logP_z__Q_z_x - pi_max).exp().mean(dim=-1))
  
  return logP_xi

  



In [0]:
## Performing the calculation with the logP_xi function
## on the valid and test set:

logP_valid=[]     #stocking (log p(x_1),...,log(px_M)) estimates of size (M,) for valid
logP_test = []    #stocking (log p(x_1),...,log(px_M)) estimates of size (M,) for test

for batch_idx, x in enumerate(valid):
  
  x = x.to(device)
  logP = logP_xi(clf,x)
  logP_valid.append(logP)
  if (batch_idx+1) % 15 ==0:
    print(batch_idx/157)     # tracking progress
  
for batch_idx, x in enumerate(test):
  
  x = x.to(device)
  logP = logP_xi(clf,x)
  logP_test.append(logP)
  if (batch_idx+1) % 15 == 0:
    print(batch_idx/157)   # tracking progress
  

estimate_valid = 0
estimate_test = 0
   
for estimate in logP_valid:          
  estimate_valid += estimate.sum()  
  
for estimate in logP_test:          
  estimate_test += estimate.sum()
  
estimate_valid /= 10000             #Average on the dataset (N = 10 000)
estimate_test /= 10000

print(estimate_valid,estimate_test)



0.08917197452229299
0.18471337579617833
0.2802547770700637
0.37579617834394907
0.4713375796178344
0.5668789808917197
0.6624203821656051
0.7579617834394905
0.8535031847133758
0.9490445859872612
0.08917197452229299
0.18471337579617833
0.2802547770700637
0.37579617834394907
0.4713375796178344
0.5668789808917197
0.6624203821656051
0.7579617834394905
0.8535031847133758
0.9490445859872612
tensor(-88.2604, device='cuda:0') tensor(-87.6443, device='cuda:0')


In [0]:
clf.eval()
test_loss=0
for batch_idx, x in enumerate(test):
  x = x.to(device)
  x_, mu, logvar = clf(x)
  loss = criterion(x_, x, mu, logvar)
  test_loss += loss.item()
test_batches = batch_idx
clf.train()
  
print('====> ELBO on Test set: {:.4f}'.format(
      -1*test_loss/test_batches ))



====> ELBO on Test set: -93.4580
