# Generalizations of the normal distriution and Hyperbolic VAEs

Welcome to our fifth notebook for the ECCV 2022 Tutorial "[Hyperbolic Representation Learning for Computer Vision](https://sites.google.com/view/hyperbolic-tutorial-eccv22)"!

**Open notebook:**
[![View on Github](https://img.shields.io/static/v1.svg?logo=github&label=Repo&message=View%20On%20Github&color=lightgrey)](https://github.com/MinaGhadimiAtigh/hyperbolic_representation_learning/blob/main/notebooks/5_Hyperbolic_VAEs.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/MinaGhadimiAtigh/hyperbolic_representation_learning/blob/main/notebooks/2_Hyperbolic_VAEs.ipynb) 

**Author:** Jeffrey Gu

In this tutorial, we will go through [A Wrapped Normal Distribution on Hyperbolic Space for Gradient-Based
Learning](https://proceedings.mlr.press/v97/nagano19a.html) (Nagano et al. 2019), an ICML 2019 paper, and [Continuous Hierarchical Representations with
Poincaré Variational Auto-Encoders](https://proceedings.neurips.cc/paper/2019/hash/0ec04cb3912c4f08874dd03716f80df1-Abstract.html) (Mathieu et al. 2019), a NeurIPS 2019 paper. The first paper introduces the wrapped normal distribution, a generalization of the Euclidean normal distribution to hyperbolic space. The first paper then uses wrapped normal distribution to create a hyperbolic variational autoencoder (VAE). The second paper builds on the first paper by proposing a new max-entropy generalization of the normal distribution, which they call the Riemannian normal, and propose reparametrizable sampling schemes and algorithms to calculate the probability density function for both generalizations. 

Let's start with installing and importing libraries. Also, we set a manual seed using `set_seed`.

In [None]:
## standard libraries
import numpy as np
import math
import warnings
from IPython.display import clear_output

## Imports for plotting
import matplotlib.pyplot as plt

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

## PyTorch Torchvision
import torchvision

warnings.filterwarnings('ignore')

In [None]:
# Function for setting the seed
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
set_seed(42)

# Ensure that all operations are deterministic on GPU (if using) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

# Fetching the device that will be used throughout this notebook
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
print("Using device", device)

For the Hyperbolic layers and functions, we're going to use geoopt library in this notebook.

In [None]:
!pip install -q git+https://github.com/geoopt/geoopt.git

Here, we define the paths which will be used in this notebook.

In [None]:
DATA_PATH = './data'

Let's start with setting up the dataset. In this notebook, you will work with `MNIST` dataset. MNIST consists of 70000 tiny (28*28) gray scale images of handwritten digits, from zero to nine. The goal is to achieve a good log-likelihood.

In [None]:
tx=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda p: p.clamp(1e-5, 1 - 1e-5))
])

# Train dataset - downloading the training dataset. The training dataset is split into train and val parts.
train_dataset = torchvision.datasets.MNIST(root=DATA_PATH, train=True, download=True, transform=tx)
test_dataset = torchvision.datasets.MNIST(root=DATA_PATH, train=False, download=True, transform=tx)

## Notation

We will follow the notation of Mathieu et al. 2019. Unless otherwise stated, you may assume that all operations, such as the norm $||\cdot||$, are the usual Euclidean operations. We denote the Poincare ball of curvature $c$ in dimension $d$ by $\mathbb{B}_c^d$. Recall that the Poincare ball is one model of hyperbolic geometry where all elements lie in an open ball of radius $1/\sqrt{c}$. The distance measure on the hyperbolic ball is
\begin{align}
    d_p^c(z, y) = \frac{1}{\sqrt{c}} \cosh^{-1} \left(1 + 2c\frac{||z - y||^2}{(1 - c||z||^2)(1 - c||y||^2)} \right)
\end{align}
We will also introduce the following useful term $\lambda_z^c$, which is the factor from which the metric of the Poincare ball differs from the Euclidean metric:
\begin{align}
    \lambda_z^c = \frac{2}{1 - c||z||^2}
\end{align}
Finally, we will define the exponential and logarithmic maps in terms of gyrovector space addition. Gyrovector addition (also called Mobius addition) wass introduced by Ungar 2008 and is a type of hyperbolic translation. It is defined as
\begin{align}
    z \oplus_c y = \frac{(1 + 2c \langle z, y \rangle + (1 - c||z||^2))y}{1 + 2c \langle z, y \rangle + c^2||z||^2||y||^2}
\end{align}
As one might expect, one covers Euclidean vector addition as curvature $c \to 0$ (recall that Euclidean space has curvature 0). Ganea et al 2018 then derived formulas for the exponential map (maps from Euclidean space to hyperbolic space) and the logarithm map ("inverse" of exponential map, maps hyperbolic space to Euclidean space):
\begin{align}
    \exp_z^c(v) &= z \oplus_c \left( \tanh \left( \sqrt{c} \frac{\lambda_z^c||v||}{2} \right) \frac{v}{\sqrt{c}||v||} \right) \\
    \log_z^c(y) &= \frac{2}{\sqrt{c} \lambda_z^c} \tanh^{-1} (\sqrt{c}||-z \oplus_c y||) \frac{-z \oplus_c y }{||-z \oplus_c y||}
\end{align}

Now, we will define our Poincare ball manifold, using code adapted from Mathieu et al. 2019's official Github repository:

In [None]:
from geoopt.manifolds import PoincareBall as PoincareBallParent
from geoopt.manifolds.stereographic.math import _lambda_x, arsinh, tanh

In [None]:
class PoincareBall(PoincareBallParent):

    def __init__(self, dim, c=1.0):
        super().__init__(c)
        self.register_buffer("dim", torch.as_tensor(dim, dtype=torch.int))
        
    @property
    def coord_dim(self):
        return int(self.dim)
    
    @property
    def zero(self):
        return torch.zeros(1, self.dim).to(self.device)

    def logdetexp(self, x, y, is_vector=False, keepdim=False):
        """
        The log-determinant of the exponential map. This is used for calculating the 
        log-probability (PDF) of the wrapped normal. 
        """
        d = self.norm(x, y, keepdim=keepdim) if is_vector else self.dist(x, y, keepdim=keepdim)
        return (self.dim - 1) * (torch.sinh(self.c.sqrt()*d) / self.c.sqrt() / d).log()
    
    def normdist2plane(self, x, a, p, keepdim: bool = False, signed: bool = False, dim: int = -1, norm: bool = False):
        """
        Finds the distance of a point to a plane in hyperbolic space. Used to implement the gyroplane layer.
        """
        c = self.c
        sqrt_c = c ** 0.5
        diff = self.mobius_add(-p, x, dim=dim)
        diff_norm2 = diff.pow(2).sum(dim=dim, keepdim=keepdim).clamp_min(1e-15)
        sc_diff_a = (diff * a).sum(dim=dim, keepdim=keepdim)
        if not signed:
            sc_diff_a = sc_diff_a.abs()
        a_norm = a.norm(dim=dim, keepdim=keepdim, p=2).clamp_min(1e-15)
        # computing the numerator (see below)
        num = 2 * sqrt_c * sc_diff_a
        # computing the denominator (see below)
        denom = (1 - c * diff_norm2) * a_norm
        res = arsinh(num / denom.clamp_min(1e-15)) / sqrt_c
        if norm:
            res = res * a_norm # * self.lambda_x(a, dim=dim, keepdim=keepdim)
        return res

### The Wrapped Normal Distribution

One common strategy to generalize the Euclidean normal distribution to an arbitrary manifold is to simply map it onto the manifold using the manifold's exponential map. We will call this wrapped normal distribution (of the manifold) and will be denoted $\mathcal{N}^{\mathrm{W}}$. This induces a probability measure on the manifold, which mathematicians call the pushforward measure. This induced density can be calculated to be
\begin{align}
    \mathcal{N}^{\mathrm{W}}(z|\mu, \Sigma) = \mathcal{N}(\lambda_\mu^c \log_\mu(z)|0, \Sigma) \left(\frac{\sqrt{c} d_p^c(\mu, z)}{\sinh(\sqrt{c} d_p^c(\mu, z)} \right)^{-1}
\end{align}
Mathieu et al. 2019 simplified the sampling scheme of Nagano et al. 2019 to the following reparametrisable sampling scheme (Algorithm 1 of Mathieu et al. 2019):

<img src="hyperbolic_normal_sampling.png" width="300" height="300" align="center"/>

Calculating the PDF (likelihood) can be done using a change-of-variables formula, which allows us to calculate the PDF of an induced distribution given the original distribution (Nagano et al. 2019).

Now, we now show how sampling is implemented for the wrapped normal distribution, using code from Mathieu et al. 2019's official implementation:

In [None]:
from numbers import Number

class WrappedNormal(torch.distributions.Distribution):

    arg_constraints = {'loc': torch.distributions.constraints.real,
                       'scale': torch.distributions.constraints.positive}
    support = torch.distributions.constraints.real
    has_rsample = True
    _mean_carrier_measure = 0

    @property
    def mean(self):
        return self.loc

    @property
    def stddev(self):
        raise NotImplementedError

    @property
    def scale(self):
        return F.softplus(self._scale) if self.softplus else self._scale

    def __init__(self, loc, scale, manifold, validate_args=None, softplus=False):
        self.dtype = loc.dtype
        self.softplus = softplus
        self.loc, self._scale = torch.distributions.utils.broadcast_all(loc, scale)
        self.manifold = manifold
        self.manifold.assert_check_point_on_manifold(self.loc)
        self.device = loc.device
        if isinstance(loc, Number) and isinstance(scale, Number):
            batch_shape, event_shape = torch.Size(), torch.Size()
        else:
            batch_shape = self.loc.shape[:-1]
            event_shape = torch.Size([self.manifold.dim])
        super(WrappedNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args)

    def sample(self, shape=torch.Size()):
        with torch.no_grad():
            return self.rsample(shape)

    def rsample(self, sample_shape=torch.Size()):
        """
        Implementation of the above reparametrizable sampling scheme. 
        """
        shape = self._extended_shape(sample_shape)
        # 1. sample standard normal and multiply by the standard deivation 
        v = self.scale * torch.distributions.utils._standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
        self.manifold.assert_check_vector_on_tangent(self.manifold.zero, v)
        # 2. divide by the factor of lambda as in the algorithm above
        v = v / self.manifold.lambda_x(self.manifold.zero, keepdim=True)
        u = self.manifold.transp(self.manifold.zero, self.loc, v)
        # 3. calculate expmap
        z = self.manifold.expmap(self.loc, u)
        return z

    def log_prob(self, x):
        """
        Calculation of the PDf via calculating the log-probability. The calculation is done 
        using the algorithm of Nagano et al. 2019 (Algorithm 2 of Nagano et al. 2019). For 
        more details, see the paper.
        """
        shape = x.shape
        loc = self.loc.unsqueeze(0).expand(x.shape[0], *self.batch_shape, self.manifold.coord_dim)
        if len(shape) < len(loc.shape): x = x.unsqueeze(1)
        # 1. take the inverse exponential map (log map)
        v = self.manifold.logmap(loc, x)
        # 2. parallel transport to the corret location
        v = self.manifold.transp(loc, self.manifold.zero, v)
        # 3. calculate log-pdf using change of variables (Eqn 7 of Nagano et al. 2019)
        u = v * self.manifold.lambda_x(self.manifold.zero, keepdim=True)
        norm_pdf = torch.distributions.Normal(torch.zeros_like(self.scale), self.scale).log_prob(u).sum(-1, keepdim=True)
        logdetexp = self.manifold.logdetexp(loc, x, keepdim=True)
        result = norm_pdf - logdetexp
        return result

### The Riemannian Normal Distribution

The Riemannian normal, denoted $\mathcal{N}^{\mathrm{R}}$, generalizes the normal distribution by considering the normal distribution as the distribution that maximizes entropy for a given mean and variance. Mathieu et al. 2019  a reparametrizable sampling scheme via acceptance-rejection sampling (see above) and its PDF:
\begin{align}
    \mathcal{N}^{\mathrm{R}}(z|\mu, \sigma^2) = \frac{1}{Z^\mathrm{R}} \exp \left(-\frac{d_p^c(\mu, z)^2}{2\sigma^2} \right)
\end{align}
where $\mu$ is the mean, $\sigma$ is a dispersion parameter analogous to the standard deviation, and $Z^{\mathrm{R}}$ is a normalizing constant (for a derivation, see Appendix B.4.3 of Mathieu et al. 2019). We will not implement the Riemannian normal here, but the code is available on the official Github repository. 

## From Euclidean VAE to Hyperbolic VAE

Recall that a VAE has an encoder-decoder structure, where the encoder produces the parameters of a chosen distribution (the posterior distribution) given an input $x$. We then sample from this latent distribution to get a latent code $z$, which is then fed into the decoder, which reconstructs $x$. The VAE is then trained using a KL-divergence loss between the posterior distribution and a chosen prior distribution. The typical choice of prior and posterior distributions in Euclidean space is the normal distribution for both distributions.

Then, to create a hyperbolic VAE, we need to only choose the prior and posterior distributions to be one of the hyperbolic normal distributions we just defined! The only caveat is that if we only replace the prior and posterior distribution, the encoder and decoder networks are still fully Euclidean networks, but some parameters of our distribution as well as the output of sampling live in hyperbolic space. Nagano et al. 2019 resolves this in a simple way: just apply an exponential map at the end of encoder, and a logarithm map at the start of the decoder. The decoder of Mathieu et al. 2019 is a bit more complicated: there is an additional gyroplane layer as the first layer of the decoder, which the paper argues better handles the geometry of the hyperbolic latent space (not implemented in this notebook). 

Now we define our hyperbolic VAE model. First, we define the encoder, closely adapted from Mathieu et al. 2019's official implementation:

In [None]:
from numpy import prod

In [None]:
def extra_hidden_layer(hidden_dim, non_lin):
    return nn.Sequential(nn.Linear(hidden_dim, hidden_dim), non_lin)

In [None]:
class Enc(nn.Module):
    """
    The usual Euclidean encoder, with an exponential map on the mean head to 
    produce a mean in the correct latent space.
    """
    def __init__(self, manifold, data_size, non_lin, num_hidden_layers, hidden_dim):
        super(Enc, self).__init__()
        self.manifold = manifold
        self.data_size = data_size
        modules = []
        modules.append(nn.Sequential(nn.Linear(prod(data_size), hidden_dim), non_lin))
        modules.extend([extra_hidden_layer(hidden_dim, non_lin) for _ in range(num_hidden_layers - 1)])
        self.enc = nn.Sequential(*modules)
        self.mean_head = nn.Linear(hidden_dim, manifold.coord_dim)
        self.sigma_head = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        e = self.enc(x.view(*x.size()[:-len(self.data_size)], -1))
        mu = self.mean_head(e)
        mu = self.manifold.expmap0(mu)
        return mu, F.softplus(self.sigma_head(e)) + 1e-5,  self.manifold # want to ensure sigma is non-zero

Now we define our decoder, which is just the usual VAE decoder. As described above, there is a log map to map from the hyperbolic latent space to the Euclidean space expected by the linear decoder layers.

In [None]:
class Dec(nn.Module):
    """
    The usual Euclidean decoder, with a logarithm map at the beginning in 
    order to map the latent code to Euclidean space. 
    """
    def __init__(self, manifold, data_size, non_lin, num_hidden_layers, hidden_dim):
        super(Dec, self).__init__()
        self.data_size = data_size
        self.manifold = manifold
        modules = []
        modules.append(nn.Sequential(nn.Linear(manifold.coord_dim, hidden_dim), non_lin))
        modules.extend([extra_hidden_layer(hidden_dim, non_lin) for _ in range(num_hidden_layers - 1)])
        self.dec = nn.Sequential(*modules)
        self.output = nn.Linear(hidden_dim, np.prod(data_size))

    def forward(self, z):
        z = self.manifold.logmap0(z)
        d = self.dec(z)
        mu = self.output(d).view(*z.size()[:-1], *self.data_size)
        return torch.tensor(1.0).to(z.device), mu

One contribution of Mathieu et al. 2019 is to replace the first layer of the decoder with a hyperbolic layer, which better respects the geometry of the hyperbolic latent space. This layer generalizes the Euclidean affine transform, which can be written as
\begin{align}
    f_{a, p}(z) = \langle a, z - p \rangle = \mathrm{sign}(\langle a, z - p \rangle) ||a|| d_E(z, H_{a, p}^c)
\end{align}
where $H_{a, p}^c = \{z \in \mathbb{R}^p \langle a, z - p \rangle = 0 \}$ is the decision hyperplane. The corresponding hyperbolic layer, called the gyroplane layer, is a map from hyperbolic space to Euclidean space that has the formula
\begin{align}
    f_{a, p}^c(z) = \mathrm{sign}(\langle a, \log_p^c(z)\rangle_p) ||a||_p d_p^c(z, H_{a, p}^c)
\end{align}
This operation was first introduced Ganea et al. 2018, which also computed a closed-form formula for $d_p^c(z, H_{a, p}^c)$
\begin{align}
    d_p^c(z, H_{a, p}^c) = \frac{1}{\sqrt{c}} \left( \frac{2\sqrt{c}|\langle -p \oplus_c z, a \rangle|}{(1 -c||-p \oplus_c z ||^2)||a||} \right)
\end{align}

In [None]:
class RiemannianLayer(nn.Module):
    def __init__(self, in_features, out_features, manifold, over_param, weight_norm):
        super(RiemannianLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.manifold = manifold

        self._weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.over_param = over_param
        self.weight_norm = weight_norm
        if self.over_param:
            self._bias = ManifoldParameter(torch.Tensor(out_features, in_features), manifold=manifold)
        else:
            self._bias = nn.Parameter(torch.Tensor(out_features, 1))
        self.reset_parameters()

    @property
    def weight(self):
        return self.manifold.transp0(self.bias, self._weight) # weight \in T_0 => weight \in T_bias

    @property
    def bias(self):
        if self.over_param:
            return self._bias
        else:
            return self.manifold.expmap0(self._weight * self._bias) # reparameterisation of a point on the manifold

    def reset_parameters(self):
        torch.nn.init.kaiming_normal_(self._weight, a=math.sqrt(5))
        fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self._weight)
        bound = 4 / math.sqrt(fan_in)
        torch.nn.init.uniform_(self._bias, -bound, bound)
        if self.over_param:
            with torch.no_grad(): self._bias.set_(self.manifold.expmap0(self._bias))

In [None]:
class GyroplaneLayer(RiemannianLayer):
    def __init__(self, in_features, out_features, manifold, over_param=False, weight_norm=False):
        super(GyroplaneLayer, self).__init__(in_features, out_features, manifold, over_param, weight_norm)

    def forward(self, input):
        input = input.unsqueeze(-2)
        input = input.expand(*input.shape[:-(len(input.shape) - 2)], self.out_features, self.in_features)
        # Compute the gyroplane layer using the distance formula of Ganea et al. 2018
        res = self.manifold.normdist2plane(input, self.bias, self.weight,
                                           signed=True, norm=self.weight_norm)
        return res

In [None]:
class GyroDec(nn.Module):
    """ First layer is a Hypergyroplane followed by usual decoder """
    def __init__(self, manifold, data_size, non_lin, num_hidden_layers, hidden_dim):
        super(GyroDec, self).__init__()
        self.data_size = data_size
        modules = []
        # The decoder is the same except the first layer is replaced with a Gyroplane layer
        modules.append(nn.Sequential(GyroplaneLayer(manifold.coord_dim, hidden_dim, manifold), non_lin))
        modules.extend([extra_hidden_layer(hidden_dim, non_lin) for _ in range(num_hidden_layers - 1)])
        self.dec = nn.Sequential(*modules)
        self.output = nn.Linear(hidden_dim, prod(data_size))

    def forward(self, z):
        d = self.dec(z)
        mu = self.output(d).view(*z.size()[:-1], *self.data_size)  # reshape data
        return torch.tensor(1.0).to(z.device), mu

## Training

### Model

We now create our final hyperbolic VAE model using the encoder and decoder that we just defined.

In [None]:
class VAE(nn.Module):
    def __init__(self, prior_dist, posterior_dist, likelihood_dist, enc, dec, params):
        super(VAE, self).__init__()
        self.pz = prior_dist
        self.px_z = likelihood_dist
        self.qz_x = posterior_dist
        self.enc = enc
        self.dec = dec
        self.modelName = None
        self.params = params
        self.data_size = params.data_size
        self.prior_std = params.prior_std

        if self.px_z == torch.distributions.RelaxedBernoulli:
            self.px_z.log_prob = lambda self, value: \
                -F.binary_cross_entropy_with_logits(
                    self.logits if value.dim() <= self.logits.dim() else self.logits.expand_as(value),
                    value.expand(self.batch_shape) if value.dim() <= self.logits.dim() else value,
                    reduction='none'
                )
            
    def generate(self, N, K):
        self.eval()
        with torch.no_grad():
            mean_pz = get_mean_param(self.pz_params)
            mean = get_mean_param(self.dec(mean_pz))
            px_z_params = self.dec(self.pz(*self.pz_params).sample(torch.Size([N])))
            px_z_tmp, px_z_mu = px_z_params
            means = get_mean_param(px_z_params)
            samples = self.px_z(px_z_tmp, logits=px_z_mu).sample(torch.Size([K]))

        return mean, \
            means.view(-1, *means.size()[2:]), \
            samples.view(-1, *samples.size()[3:])

    def reconstruct(self, data):
        self.eval()
        with torch.no_grad():
            qz_x = self.qz_x(*self.enc(data))
            px_z_params = self.dec(qz_x.rsample(torch.Size([1])).squeeze(0))
            
    def forward(self, x, K=1):
        qz_x = self.qz_x(*self.enc(x))
        zs = qz_x.rsample(torch.Size([K]))
        temp, mu = self.dec(zs)
        px_z = self.px_z(temp, logits=mu)
        return qz_x, px_z, zs
    
    @property
    def pz_params(self):
        return self._pz_mu.mul(1), F.softplus(self._pz_logvar).div(math.log(2)).mul(self.prior_std_scale)

Here, we set our model parameters.

In [None]:
class Config:
    def __init__(self, manifold='PoincareBall', latent_dim=2, c=0.7, prior='WrappedNormal', posterior='WrappedNormal', 
                 prior_std=1., num_hidden_layers=1, hidden_dim=600, nl='ReLU', enc='Enc', dec='GyroDec', beta=1.0, K=1, 
                 epochs=80, batch_size=128, lr=5e-4, beta1=0.9, beta2=0.999, data_size=torch.Size([1, 28, 28])):
        self.manifold = manifold                           # Manifold: Euclidean or Hyperbolic
        self.latent_dim = latent_dim                       # Latent dimension of manifold
        self.c = c                                         # Curvature of manifold
        self.prior = prior                                 # VAE prior distribution
        self.posterior = posterior                         # VAE posterior distribution
        self.prior_std = prior_std                         # Standard dev. of prior distribution
        
        self.num_hidden_layers = num_hidden_layers         # Num hidden layers in encoder and decoder
        self.hidden_dim = hidden_dim                       # Hidden dimension
        self.nl = nl                                       # Non-linearity for encoder and decoder
        self.enc = enc                                     # VAE encoder                         
        self.dec = dec                                     # VAE decoder
        
        self.beta = beta                                   # Beta parameter for beta-VAE
        self.K = K                                         # Number of samples for ELBO
        
        self.epochs = epochs                               # Epochs
        self.batch_size = batch_size                       # Batch size
        self.lr = lr                                       # Learning rate
        self.beta1 = beta1                                 # beta1 for Adam optimizer
        self.beta2 = beta2                                 # beta2 for Adam optimizer
        
        self.data_size = data_size                         # Data size of input data

In [None]:
params = Config()

Here, following Mathieu et al. 2019, we create a MNIST-specific VAE:

In [None]:
class Mnist(VAE):
    def __init__(self, params):
        c = nn.Parameter(params.c * torch.ones(1), requires_grad=False)
        manifold = eval(params.manifold)(params.latent_dim, c)
        super(Mnist, self).__init__(
            eval(params.prior),                       # prior distribution
            eval(params.posterior),                   # posterior distribution
            torch.distributions.RelaxedBernoulli,     # likelihood distribution
            eval(params.enc)(
                manifold, 
                params.data_size, 
                getattr(nn, params.nl)(), 
                params.num_hidden_layers, 
                params.hidden_dim
            ),
            eval(params.dec)(
                manifold, 
                params.data_size, 
                getattr(nn, params.nl)(), 
                params.num_hidden_layers, 
                params.hidden_dim
            ),
            params
        )
        self.manifold = manifold
        self.c = c
        self._pz_mu = nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False)
        self._pz_logvar = nn.Parameter(torch.zeros(1, 1), requires_grad=False)
        self.modelName = 'Mnist'

    def init_last_layer_bias(self, train_loader):
        if not hasattr(self.dec.output, 'bias'): return
        with torch.no_grad():
            p = torch.zeros(prod(params.data_size[1:]), device=self._pz_mu.device)
            N = 0
            for i, (data, _) in enumerate(train_loader):
                data = data.to(self._pz_mu.device)
                B = data.size(0)
                N += B
                p += data.view(-1, prod(params.data_size[1:])).sum(0)
            p /= N
            p += 1e-4
            self.dec.output.bias.set_(p.log() - (1 - p).log())

    @property
    def pz_params(self):
        return self._pz_mu.mul(1), F.softplus(self._pz_logvar).div(math.log(2)).mul(self.prior_std), self.manifold

In [None]:
model = Mnist(params)
model.to(device)

### Training objective

Both Nagano et al. 2019 and Mathieu et al. 2019 use a $\beta$-VAE (Higgins et al. 2017), whose objective applies a scalar weight of $\beta$ to the KL-divergence term.

In [None]:
def vae_objective(model, x, K=1, beta=1.0, components=False, **kwargs):
    """
    The beta-VAE objective. 
    """
    qz_x, px_z, zs = model(x, K)
    _, B, D = zs.size()
    flat_rest = torch.Size([*px_z.batch_shape[:2], -1])
    lpx_z = px_z.log_prob(x.expand(px_z.batch_shape)).view(flat_rest).sum(-1)

    pz = model.pz(*model.pz_params)
    kld = qz_x.log_prob(zs).sum(-1) - pz.log_prob(zs).sum(-1)

    obj = -lpx_z.mean(0).sum() + beta * kld.mean(0).sum()
    return (qz_x, px_z, lpx_z, kld, obj) if components else obj

In [None]:
loss_function = vae_objective

## Optimizer

We will use the Adam optimizer, which is used by both Nagano et al. 2019 and Mathieu et al. 2019.

In [None]:
optimizer = optim.Adam(model.parameters(), lr=params.lr, amsgrad=True, betas=(params.beta1, params.beta2))

## Prepare dataloaders

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=params.batch_size, 
    shuffle=True, 
    num_workers=1, 
    pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, 
    batch_size=params.batch_size, 
    shuffle=True, 
    num_workers=1, 
    pin_memory=True
)

## Training

Here, we train the hyperbolic VAE. Note that we do not need Riemannian SGD (as required by some hyperbolic models) since the parameters of our model living in hyperbolic space are parametrized via the exponential map. 

In [None]:
def probe_infnan(v, name, extras={}):
    nps = torch.isnan(v)
    s = nps.sum().item()
    if s > 0:
        print('>>> {} >>>'.format(name))
        print(name, s)
        print(v[nps])
        for k, val in extras.items():
            print(k, val, val.sum().item())
        quit()

In [None]:
from collections import defaultdict

# The training loop
model.init_last_layer_bias(train_loader)
agg = defaultdict(list)
for epoch in range(1, params.epochs + 1):
    model.train()
    b_loss, b_recon, b_kl = 0., 0., 0.
    for i, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        qz_x, px_z, lik, kl, loss = loss_function(model, data, K=params.K, beta=params.beta, components=True)
        # The Poincare ball model can have numerical instability close to the boundary of the ball
        probe_infnan(loss, "Training loss:") 
        loss.backward()
        optimizer.step()

        b_loss += loss.item()
        b_recon += -lik.mean(0).sum().item()
        b_kl += kl.sum(-1).mean(0).sum().item()

    agg['train_loss'].append(b_loss / len(train_loader.dataset))
    agg['train_recon'].append(b_recon / len(train_loader.dataset))
    agg['train_kl'].append(b_kl / len(train_loader.dataset))
    if epoch % 1 == 0:
        print('====> Epoch: {:03d} Loss: {:.2f} Recon: {:.2f} KL: {:.2f}'.format(
            epoch, agg['train_loss'][-1], agg['train_recon'][-1], agg['train_kl'][-1])
        )

## Testing

Here, we evaluate the likelihood of the test dataset using our trained model.

In [None]:
model.eval()
b_loss = 0.
with torch.no_grad():
    for i, (data, labels) in enumerate(test_loader):
        data = data.to(device)
        qz_x, px_z, lik, kl, loss = loss_function(model, data, K=params.K, beta=params.beta, components=True)
        b_loss += loss.item()

agg['test_loss'].append(b_loss / len(test_loader.dataset))
print('====>             Test loss: {:.4f}'.format(agg['test_loss'][-1]))

## Visualization

Here, we generate sample MNIST digits.

In [None]:
def get_mean_param(params):
    """Return the parameter used to show reconstructions or generations.
    For example, the mean for Normal, or probs for Bernoulli.
    For Bernoulli, skip first parameter, as that's (scalar) temperature
    """
    if params[0].dim() == 0:
        return params[1]
    # elif len(params) == 3:
    #     return params[1]
    else:
        return params[0]

In [None]:
mean, means, samples = model.generate(64, 9)

In [None]:
vis = []
for i in range(8):
    vis.append(means[i].squeeze().cpu())
vis = torch.cat(vis, dim=1)
plt.axis('off')
plt.imshow(vis.numpy())
plt.show()