In [1]:
import pandas as pd
import numpy as np
import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms

#### Categorical distribution

In [2]:
EPS = 1.e-5
PI = torch.from_numpy(np.asarray(np.pi))

def log_categorical(x,p,num_classes=256):
    x_one_hot = F.one_hot(x.long(),num_classes=num_classes)
    log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1. - EPS))
    return log_p

def log_normal_diag(x, mu, log_var, reduction=None, dim=None):
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2. * PI) - 0.5 * log_var - 0.5 * torch.exp(-log_var) * (x - mu)**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p
    
def log_standard_normal(x, reduction=None, dim=None):
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2. * PI) - 0.5 * x**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p


#### Encoder

In [3]:
class Encoder(torch.nn.Module):
    def __init__(self,encoder_net):
        super(Encoder,self).__init__()
        self.encoder = encoder_net
    
    def encode(self, x):
        h_e = self.encoder(x)
        mu_e, log_var_e = torch.chunk(h_e, 2, dim=1)
        return mu_e, log_var_e

    def sample(self, x=None, mu_e=None, log_var_e=None):
        if (mu_e is None) and (log_var_e is None):
            mu_e, log_var_e = self.encode(x)
        else:
            if (mu_e is None) or (log_var_e is None):
                raise ValueError('mu and log-var can`t be None!')
        z = self.reparameterization(mu_e, log_var_e)
        return z

    def log_prob(self, x=None, mu_e=None, log_var_e=None, z=None):
        if x is not None:
            mu_e, log_var_e = self.encode(x)
            z = self.sample(mu_e=mu_e, log_var_e=log_var_e)
        else:
            if (mu_e is None) or (log_var_e is None) or (z is None):
                raise ValueError('mu, log-var and z can`t be None!')

        return log_normal_diag(z, mu_e, log_var_e)

    def forward(self, x, type='log_prob'):
        assert type in ['encode', 'log_prob'], 'Type could be either encode or log_prob'
        if type == 'log_prob':
            return self.log_prob(x)
        else:
            return self.sample(x)

#### Decoder

In [4]:
class Decoder(torch.nn.Module):
    def __init__(self, decoder_net, num_vals=None):
        super(Decoder, self).__init__()

        self.decoder = decoder_net
        self.num_vals=num_vals

    def decode(self, z):
        h_d = self.decoder(z)

        b = h_d.shape[0]
        d = h_d.shape[1]//self.num_vals
        h_d = h_d.view(b, d, self.num_vals)
        mu_d = torch.softmax(h_d, 2)
        return [mu_d]
    
    def sample(self, z):
        outs = self.decode(z)
        mu_d = outs[0]
        b = mu_d.shape[0]
        m = mu_d.shape[1]
        mu_d = mu_d.view(mu_d.shape[0], -1, self.num_vals)
        p = mu_d.view(-1, self.num_vals)
        x_new = torch.multinomial(p, num_samples=1).view(b, m)
        return x_new

    def log_prob(self, x, z):
        outs = self.decode(z)
        mu_d = outs[0]
        log_p = log_categorical(x, mu_d, num_classes=self.num_vals, reduction='sum', dim=-1).sum(-1)
        return log_p
    
    def forward(self, z, x=None, type='log_prob'):
        assert type in ['decoder', 'log_prob'], 'Type could be either decode or log_prob'
        if type == 'log_prob':
            return self.log_prob(x, z)
        else:
            return self.sample(x)

In [5]:
class Prior(torch.nn.Module):
    def __init__(self, L):
        super(Prior, self).__init__()
        self.L = L

    def sample(self, batch_size):
        z = torch.randn((batch_size, self.L))
        return z

    def log_prob(self, z):
        return log_standard_normal(z)

In [6]:
class VAE(torch.nn.Module):
    def __init__(self, encoder_net, decoder_net, num_vals=256, L=16, likelihood_type='categorical'):
        super(VAE, self).__init__()
        self.encoder = Encoder(encoder_net=encoder_net)
        self.decoder = Decoder(decoder_net=decoder_net, num_vals=num_vals)
        self.prior = Prior(L=L)

        self.num_vals = num_vals

        self.likelihood_type = likelihood_type

    def forward(self, x, reduction='avg'):
        # encoder
        mu_e, log_var_e = self.encoder.encode(x)
        z = self.encoder.sample(mu_e=mu_e, log_var_e=log_var_e)

        # ELBO
        RE = self.decoder.log_prob(x, z)
        KL = (self.prior.log_prob(z) - self.encoder.log_prob(mu_e=mu_e, log_var_e=log_var_e, z=z)).sum(-1)

        if reduction == 'sum':
            return -(RE + KL).sum()
        else:
            return -(RE + KL).mean()

    def sample(self, batch_size=64):
        z = self.prior.sample(batch_size=batch_size)
        return self.decoder.sample(z)

In [7]:
def evaluation(test_loader, name=None, model_best=None, epoch=None):
    # EVALUATION
    if model_best is None:
        # load best performing model
        model_best = torch.load(name + '.model')

    model_best.eval()
    loss = 0.
    N = 0.
    for indx_batch, test_batch in enumerate(test_loader):
        loss_t = model_best.forward(test_batch, reduction='sum')
        loss = loss + loss_t.item()
        N = N + test_batch.shape[0]
    loss = loss / N

    if epoch is None:
        print(f'FINAL LOSS: nll={loss}')
    else:
        print(f'Epoch: {epoch}, val nll={loss}')

    return loss

def training(num_epochs, model, optimizer, training_loader):
    # Main loop
    for e in range(num_epochs):
        # TRAINING
        model.train()
        for indx_batch, batch in enumerate(training_loader):
            batch = batch[0]
            if hasattr(model, 'dequantization'):
                if model.dequantization:
                    batch = batch + torch.rand(batch.shape)
            loss = model.forward(batch)

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

In [8]:
D = 28*28
L = 16
M = 50
lr = 1e-3
num_epochs = 100

In [9]:
mnist_trainset = torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ]))
mnist_testset = torchvision.datasets.MNIST(root="./data",train=False,download=True,transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ]))

In [10]:
train_loader = torch.utils.data.DataLoader(mnist_trainset,batch_size=10,shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset,batch_size=10,shuffle=True)

In [11]:
num_vals = 256
encoder = nn.Sequential(nn.Linear(D, M), nn.LeakyReLU(),
                        nn.Linear(M, M), nn.LeakyReLU(),
                        nn.Linear(M, 2 * L))

decoder = nn.Sequential(nn.Linear(L, M), nn.LeakyReLU(),
                        nn.Linear(M, M), nn.LeakyReLU(),
                        nn.Linear(M, num_vals * D))

prior = torch.distributions.MultivariateNormal(torch.zeros(L), torch.eye(L))
model = VAE(encoder_net=encoder, decoder_net=decoder, num_vals=num_vals, L=L)


VAE by JT.


In [12]:
optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)

In [13]:
training(num_epochs=num_epochs, model=model, optimizer=optimizer,training_loader=train_loader)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (280x28 and 784x50)

In [81]:
def test(training_loader):
        for indx_batch, batch in enumerate(training_loader):
            print(len(batch))
            print(batch[0].size())
            print(batch[1])

In [82]:
test(train_loader)

2
torch.Size([10, 1, 28, 28])
tensor([8, 7, 4, 9, 7, 4, 1, 2, 8, 1])
2
torch.Size([10, 1, 28, 28])
tensor([7, 2, 7, 8, 4, 4, 6, 9, 6, 9])
2
torch.Size([10, 1, 28, 28])
tensor([6, 8, 1, 1, 9, 4, 2, 6, 1, 5])
2
torch.Size([10, 1, 28, 28])
tensor([5, 4, 8, 4, 9, 4, 7, 2, 7, 1])
2
torch.Size([10, 1, 28, 28])
tensor([8, 1, 9, 1, 0, 8, 4, 3, 2, 2])
2
torch.Size([10, 1, 28, 28])
tensor([8, 2, 4, 3, 6, 6, 2, 9, 4, 5])
2
torch.Size([10, 1, 28, 28])
tensor([1, 0, 2, 9, 7, 1, 8, 8, 1, 7])
2
torch.Size([10, 1, 28, 28])
tensor([1, 9, 0, 4, 7, 2, 3, 5, 9, 1])
2
torch.Size([10, 1, 28, 28])
tensor([8, 5, 3, 3, 1, 6, 0, 1, 8, 1])
2
torch.Size([10, 1, 28, 28])
tensor([7, 4, 3, 4, 7, 9, 2, 9, 4, 0])
2
torch.Size([10, 1, 28, 28])
tensor([6, 0, 9, 6, 4, 6, 6, 6, 4, 8])
2
torch.Size([10, 1, 28, 28])
tensor([1, 4, 2, 2, 1, 3, 3, 6, 1, 7])
2
torch.Size([10, 1, 28, 28])
tensor([2, 9, 9, 4, 7, 4, 1, 5, 1, 9])
2
torch.Size([10, 1, 28, 28])
tensor([9, 2, 6, 3, 2, 6, 7, 2, 9, 2])
2
torch.Size([10, 1, 28, 28])
tens

KeyboardInterrupt: 