In [3]:
import os

import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.datasets import load_digits
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F


In [4]:
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.0e-5


def log_categorical(x, p, num_classes=256, reduction=None, dim=None):
    x_one_hot = F.one_hot(x.long(), num_classes=num_classes)
    log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1.0 - EPS))
    if reduction == "avg":
        return torch.mean(log_p, dim)
    elif reduction == "sum":
        return torch.sum(log_p, dim)
    else:
        return log_p


def log_bernoulli(x, p, reduction=None, dim=None):
    pp = torch.clamp(p, EPS, 1.0 - EPS)
    log_p = x * torch.log(pp) + (1.0 - x) * torch.log(1.0 - pp)
    if reduction == "avg":
        return torch.mean(log_p, dim)
    elif reduction == "sum":
        return torch.sum(log_p, dim)
    else:
        return log_p


def log_normal_diag(x, mu, log_var, reduction=None, dim=None):
    D = x.shape[1]
    log_p = (
        -0.5 * D * torch.log(2.0 * PI)
        - 0.5 * log_var
        - 0.5 * torch.exp(-log_var) * (x - mu) ** 2.0
    )
    if reduction == "avg":
        return torch.mean(log_p, dim)
    elif reduction == "sum":
        return torch.sum(log_p, dim)
    else:
        return log_p


def log_standard_normal(x, reduction=None, dim=None):
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2.0 * PI) - 0.5 * x**2.0
    if reduction == "avg":
        return torch.mean(log_p, dim)
    elif reduction == "sum":
        return torch.sum(log_p, dim)
    else:
        return log_p

In [None]:
class Digits(Dataset):
    """Scikit-Learn Digits dataset."""

    def __init__(self, mode="train", transforms=None):
        digits = load_digits()
        if mode == "train":
            self.data = digits.data[:1000].astype(np.float32)
        elif mode == "val":
            self.data = digits.data[1000:1350].astype(np.float32)
        else:
            self.data = digits.data[1350:].astype(np.float32)

        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transforms:
            sample = self.transforms(sample)
        return sample

In [5]:
class Encoder(nn.Module):
    def __init__(self, encoder_net):
        super().__init__()
        self.encoder = encoder_net

    @staticmethod
    def reparameterizing(mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)

        return mu + std * eps

    def encode(self, x):
        h_e = self.encoder(x)

        mu_e, log_var_e = torch.chunk(h_e, 2, dim=1)
        return mu_e, log_var_e

    def sample(self, x=None, mu_e=None, log_var_e=None):
        if (mu_e is None) and (log_var_e is None):
            mu_e, log_var_e = self.encode(x)
        return self.reparameterizing(mu_e, log_var_e)

    def log_prob(self, x):
        mu_e, log_var_e = self.encode(x)
        z = self.sample(x, mu_e, log_var_e)
        
        return log_normal_diag(z, mu_e, log_var_e)
    
    def forward(self, x, type='log_prob'):
        assert type in ['encode', 'log_prob'], 'Type could be either encode or log_prob'
        
        if type == 'log_prob':
            return self.log_prob(x)
        else:
            return self.sample(x)

In [6]:
class Decoder(nn.Module):

    def __init__(self, decoder_net, distribution, num_vals):
        super().__init__()
        self.decoder = decoder_net
        self.distribution = distribution
        self.num_vals = num_vals

    def decode(self, z):
        h_d = self.decoder(z)
        if self.distribution == 'categorical':
            b = h_d.shape[0] # Batch size
            d = h_d.shape[1] // self.num_vals # Dimensionality of x

            h_d = h_d.view(b, d, self.num_vals)
            mu_d = torch.softmax(h_d, 2)
            return [mu_d]

        elif self.distribution == 'bernoulli':
            mu_d = torch.sigmoid(h_d)
            return [mu_d]

        else:
            raise ValueError('Only: categorical or bernulli')

    def sample(self, z):
        outs = self.decode(z)

        if self.distribution == "categorical":
            mu_d = outs[0]
            b = mu_d.shape[0]  # Batch size
            m = mu_d.shape[1]  # Dimensionality of x

            mu_d = mu_d.view(mu_d.shape[0], -1, self.num_vals)
            p = mu_d.view(-1, self.num_vals)

            x_new = torch.multinomial(p, num_samples=1).view(b, m)

        elif self.distribution == "bernoulli":
            mu_d = outs[0]
            x_new = torch.bernoulli(mu_d)

        return x_new

    def log_prob(self, x, z):
        outs = self.decode(z)

        if self.distribution == "categorical":
            mu_d = outs[0]
            log_p = log_categorical(x, mu_d, num_classes=self.num_vals, reduction='sum', dim=-1).sum(-1)
        elif self.distribution == "bernoulli":
            mu_d = outs[0]
            log_p = log_bernoulli(x, mu_d, reduction="sum", dim=-1).sum(-1)
        return log_p
    
    def forward(self, z, x=None, type='log_prob'):
        assert type in ['decoder', 'log_prob'], 'Type could be either decode or log_prob'
        
        if type == 'log_prob':
            return self.log_prob(x, z)
        else:
            return self.sample(z)