In [1]:
from torch import nn
import torch.nn.functional as F
import torch
import numpy as np

PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.0e-5

def log_categorical(x, p, num_classes=256, reduction=None, dim=None):
    x_one_hot = F.one_hot(x.long(), num_classes=num_classes)
    log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1.0 - EPS))
    if reduction == "avg":
        return torch.mean(log_p, dim)
    elif reduction == "sum":
        return torch.sum(log_p, dim)
    else:
        return log_p


def log_bernoulli(x, p, reduction=None, dim=None):
    pp = torch.clamp(p, EPS, 1.0 - EPS)
    log_p = x * torch.log(pp) + (1.0 - x) * torch.log(1.0 - pp)
    if reduction == "avg":
        return torch.mean(log_p, dim)
    elif reduction == "sum":
        return torch.sum(log_p, dim)
    else:
        return log_p

In [4]:
class HierarchicalVAE(nn.Module):

    def __init__(
        self,
        nn_r_1,
        nn_r_2,
        nn_delta_1,
        nn_delta_2,
        nn_z_1,
        nn_x,
        D,
        L,
        num_vals,
        likelihood_type,
    ):
        super().__init__()
        # Bottom-up
        self.nn_r_1 = nn_r_1
        self.nn_r_2 = nn_r_2

        self.nn_delta_1 = nn_delta_1
        self.nn_delta_2 = nn_delta_2

        self.nn_z_1 = nn_z_1
        self.nn_x = nn_x

        self.D = D # input dimensionality

        self.L = L # Second layer dimensionality

        self.num_vals = num_vals

        self.likelihood_type = likelihood_type

    def reparameterization(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + std * eps

    def forward(self, x, reduction='avg'):
        r_1 = self.nn_r_1(x)
        r_2 = self.nn_r_2(r_1)

        delta_1 = self.nn_delta_1(r_1)
        delta_mu_1, delta_log_var_1 = torch.chunk(delta_1, 2, dim=1)

        delta_log_var_1 = F.hardtanh(delta_log_var_1, -7., 2.)

        delta_2 = self.nn_delta_2(r_2)
        delta_mu_2, delta_log_var_2 = torch.chunk(delta_2, 2, dim=1)

        delta_log_var_2 = F.hardtanh(delta_log_var_2, -7.0, 2.0)

        z_2 = self.reparameterization(delta_mu_2, delta_log_var_2)

        h_1 = self.nn_z_1(z_2)
        mu_1, log_var_1 = torch.chunk(h_1, 2, dim=1)

        z_1 = self.reparameterization(mu_1 + delta_mu_1, log_var_1 + delta_log_var_1)

        h_d = self.nn_x(z_1)

        if self.likelihood_type == 'categorical':
            b = h_d.shape[0]
            d = h_d.shape[1]//self.num_vals
            h_d = h_d.view(b, d, self.num_vals)
            mu_d = torch.softmax(h_d, 2)

        elif self.likelihood_type == 'bernoulli':
            mu_d = torch.sigmoid(h_d)

        if self.likelihood_type == "categorical":
            RE = log_categorical(x, mu_d, num_classes=self.num_vals, reduction='sum', dim=-1).sum(-1)

        elif self.likelihood_type == "bernoulli":
            RE = log_bernoulli(
                x, mu_d, reduction="sum", dim=-1
            ).sum(-1)
        KL_z_2 = 0.5 * (delta_mu_2 ** 2 + torch.exp(delta_log_var_2) - delta_log_var_2 - 1).sum(-1)
        KL_z_1 = 0.5 * (
            delta_mu_1**2 + torch.exp(delta_log_var_1) - delta_log_var_1 - 1
        ).sum(-1)

        KL = KL_z_1 + KL_z_2

        if reduction == 'sum':
            loss = -(RE - KL).sum()
        else:
            loss = -(RE - KL).mean()
        return loss

    def sample(self, batch_size=64):
        z_2 = torch.randn(batch_size, self.L)
        h_1 = self.nn_z_1(z_2)
        mu_1, log_var_1 = torch.chunk(h_1, 2, dim=1)
        z_1 = self.reparameterization(mu_1, log_var_1)
        
        h_d = self.nn_x(z_1)
        if self.likelihood_type == "categorical":
            b = batch_size
            d = h_d.shape[1]//self.num_vals
            h_d = h_d.view(b, d, self.num_vals)
            mu_d = torch.softmax(h_d, 2)
            
            p = mu_d.view(-1, self.num_vals)
            x_new = torch.multinomial(p, num_samples=1).view(b, d)

        elif self.likelihood_type == "bernoulli":
            mu_d = torch.sigmoid(h_d)
            x_new = torch.bernoulli(mu_d)
            
        return x_new
        