In [1]:
import pandas as pd
import numpy as np
import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms

In [2]:
mnist_trainset = torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ]))
mnist_testset = torchvision.datasets.MNIST(root="./data",train=False,download=True,transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ]))

### Dataset
$28\times28$ images with labels 0-9

image $i$: mnist_trainset[$i$][0][0]

label $i$: mnist_trainset[$i$][1]

In [3]:
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.e-5

def log_categorical(x, p, num_classes=10):
    x_one_hot = F.one_hot(x.long(), num_classes=num_classes)
    log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1. - EPS))
    return log_p

def log_normal_diag(x, mu, log_var):
    D = x.shape[0]
    log_p = -0.5 * D * torch.log(2. * PI) - 0.5 * log_var - 0.5 * torch.exp(-log_var) * (x - mu)**2.
    return log_p

def log_standard_normal(x):
    D = x.shape[0]
    log_p = -0.5 * D * torch.log(2. * PI) - 0.5 * x**2.
    return log_p

### Encoder

In [99]:
class Encoder(nn.Module):
    def __init__(self, encoder_net):
        super(Encoder,self).__init__()
        self.encoder = encoder_net

    def reparameterization(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu+std*eps
    
    def encode(self,x):
        h_e = self.encoder(x.view(x.shape[0],28**2))
        mu_e, log_var_e = torch.chunk(h_e,chunks=2,dim=1)
        return mu_e, log_var_e

    def sample(self, mu_e, log_var_e):
        z = self.reparameterization(mu_e,log_var_e)
        return z
    
    def log_prob(self, x=None, mu_e=None, log_var_e=None, z=None):
        if x is not None:
            mu_e, log_var_e = self.encode(x)
            z = self.sample(mu_e=mu_e, log_var_e=log_var_e)
        else:
            if (mu_e is None) or (log_var_e is None) or (z is None):
                raise ValueError('mu, log-var and z can`t be None!')

        return log_normal_diag(z, mu_e, log_var_e)
    
    def forward(self, x):
        return self.log_prob(x) 

### Decoder

In [100]:
class Decoder(nn.Module):
    def __init__(self,decoder_net,num_vals):
        super(Decoder,self).__init__()
        self.num_vals = num_vals
        self.decoder = decoder_net
    
    def decode(self,z):
        h_d = self.decoder(z)
        h_d = h_d.view(z.shape[0],28,28,self.num_vals)
        mu_d = torch.softmax(h_d,dim=2)
        return mu_d

    def sample(self,z):
        mu_d = self.decode(z)
        mu_d = mu_d.view(28, -1, self.num_vals)
        p = mu_d.view(-1, self.num_vals)
        x_new = torch.multinomial(p, num_samples = 1).view(28, 28)
        return x_new
        
    def log_prob(self, x, z):
        mu_d = self.decode(z)
        log_p = torch.sum(log_categorical(x, mu_d, num_classes=self.num_vals),dim=-1)
        return log_p
    
    def forward(self, z, x=None):
        return self.log_prob(x, z)        
        
   

### Prior

In [101]:
class Prior(nn.Module):
    def __init__(self,L):
        super(Prior, self).__init__()
        self.L = L
    
    def sample(self, batchsize=1):
        z = torch.randn((batchsize,self.L))
        return z
    
    def log_prob(self, z):
        return log_standard_normal(z)

### VAE

In [102]:
class VAE(nn.Module):
    def __init__(self,encoder_net,decoder_net,L,num_vals):
        super(VAE,self).__init__()
        self.encoder = Encoder(encoder_net)
        self.decoder = Decoder(decoder_net,num_vals)
        self.Prior = Prior(L)
    
    def forward(self, x):
        mu_e, log_var_e = self.encoder.encode(x)
        z = self.encoder.sample(mu_e,log_var_e)
        RE = self.decoder.log_prob(x,z)
        KL = self.Prior.log_prob(z)
        KL -= self.encoder.log_prob(mu_e=mu_e,log_var_e=log_var_e,z=z)
        ELBO = -(RE+KL.sum(-1)).sum()
        return ELBO.mean()
    
    def sample(self, batchsize = 1):
        z = self.Prior.sample(batchsize)
        return self.decoder.sample(z)

In [125]:
def evaluation(test_loader, model):
    loss = 0; N = 0
    for batch, num in test_loader:
            loss_it = model.forward(batch)
            loss += loss_it
            N += batch.shape[0]
    return loss / N

In [126]:
def training(num_epochs, model, optimizer, training_loader, val_loader):
    nll_val = []
    for e in range(num_epochs):
        model.train()
        for indx_batch, (batch,target) in enumerate(training_loader):
            loss = model.forward(batch)
            optimizer.zero_grad()
            loss.backward(retain_graph = True)
            optimizer.step()
            if indx_batch % 1000 == 0:
                print(indx_batch)
                loss_val = evaluation(val_loader,model)
                print("Epoch: " + str(e) + " Training loss: " + str(loss) + " Validation loss: "+ str(loss_val))
                nll_val.append(loss_val)

    return nll_val

In [127]:
from torch.utils.data import DataLoader, random_split
train_set, val_set, _ = random_split(mnist_trainset,[9000,1000,50000])
training_loader = DataLoader(train_set,batch_size = 1,shuffle=True)
validation_loader = DataLoader(val_set,batch_size = 1,shuffle=True)
test_loader = DataLoader(mnist_testset,batch_size = 1,shuffle=False)

In [128]:
D = 28**2
hidden_layer = 128
L = 10
num_values = 10
encoder = nn.Sequential(nn.Linear(D,hidden_layer),nn.ReLU(),nn.Linear(hidden_layer,2*L))
decoder = nn.Sequential(nn.Linear(L,hidden_layer),nn.ReLU(),nn.Linear(hidden_layer,D*L))
model = VAE(encoder, decoder, L, num_values)
lr = 1e-3
optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)
nll_val = training(num_epochs=1,model=model,training_loader=training_loader,val_loader=validation_loader,optimizer=optimizer)

0
Epoch: 0loss: tensor(2754.2993, grad_fn=<MeanBackward0>)val loss: tensor(3032.0784, grad_fn=<DivBackward0>)
1000
Epoch: 0loss: tensor(2197.4233, grad_fn=<MeanBackward0>)val loss: tensor(2602.1475, grad_fn=<DivBackward0>)
2000
Epoch: 0loss: tensor(2596.1167, grad_fn=<MeanBackward0>)val loss: tensor(2609.1353, grad_fn=<DivBackward0>)
3000
Epoch: 0loss: tensor(2412.6475, grad_fn=<MeanBackward0>)val loss: tensor(2560.6807, grad_fn=<DivBackward0>)
4000
Epoch: 0loss: tensor(2486.8999, grad_fn=<MeanBackward0>)val loss: tensor(2564.7078, grad_fn=<DivBackward0>)
5000
Epoch: 0loss: tensor(2485.1074, grad_fn=<MeanBackward0>)val loss: tensor(2545.8787, grad_fn=<DivBackward0>)
6000
Epoch: 0loss: tensor(2733.9297, grad_fn=<MeanBackward0>)val loss: tensor(2550.6716, grad_fn=<DivBackward0>)
7000
Epoch: 0loss: tensor(2369.0459, grad_fn=<MeanBackward0>)val loss: tensor(2552.0242, grad_fn=<DivBackward0>)
8000
Epoch: 0loss: tensor(2624.0723, grad_fn=<MeanBackward0>)val loss: tensor(2550.5723, grad_fn=<D