In [1]:
# Imports
import torch
cuda = torch.cuda.is_available()
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import sys
sys.path.append("../../semi-supervised")

# Auxiliary Deep Generative Model

The Auxiliary Deep Generative Model [[Maaløe, 2016]](https://arxiv.org/abs/1602.05473) posits a model that with an auxiliary latent variable $a$ that infers the variables $z$ and $y$. This helps in terms of semi-supervised learning by delegating causality to their respective variables. This model was state-of-the-art in semi-supervised until 2017, and is still very powerful with an MNIST accuracy of *99.4%* using just 10 labelled examples per class.

<img src="../images/adgm.png" width="400px"/>


In [2]:
from models import AuxiliaryDeepGenerativeModel

y_dim = 10
z_dim = 32
a_dim = 32
h_dim = [256, 128]

model = AuxiliaryDeepGenerativeModel([784, y_dim, z_dim, a_dim, h_dim])
model

AuxiliaryDeepGenerativeModel(
  (encoder): Encoder(
    (hidden): ModuleList(
      (0): Linear(in_features=826, out_features=256, bias=True)
      (1): Linear(in_features=256, out_features=128, bias=True)
    )
    (sample): GaussianSample(
      (mu): Linear(in_features=128, out_features=32, bias=True)
      (log_var): Linear(in_features=128, out_features=32, bias=True)
    )
  )
  (decoder): Decoder(
    (hidden): ModuleList(
      (0): Linear(in_features=42, out_features=128, bias=True)
      (1): Linear(in_features=128, out_features=256, bias=True)
    )
    (reconstruction): Linear(in_features=256, out_features=784, bias=True)
    (output_activation): Sigmoid()
  )
  (classifier): Classifier(
    (dense): Linear(in_features=816, out_features=256, bias=True)
    (logits): Linear(in_features=256, out_features=10, bias=True)
  )
  (aux_encoder): Encoder(
    (hidden): ModuleList(
      (0): Linear(in_features=784, out_features=256, bias=True)
      (1): Linear(in_features=256, out_f

## Training

The lower bound we derived in the notebook for the **deep generative model** is similar to the one for the ADGM. Here, we also need to integrate over a continuous auxiliary variable $a$.

For labelled data, the lower bound is given by.
\begin{align}
\log p(x,y) &= \log \int \int p(x, y, a, z) \ dz \ da\\
&\geq \mathbb{E}_{q(a,z|x,y)} \bigg [\log \frac{p(x,y,a,z)}{q(a,z|x,y)} \bigg ] = - \mathcal{L}(x,y)
\end{align}

Again when no label information is available we sum out all of the labels.

\begin{align}
\log p(x) &= \log \int \sum_{y} \int p(x, y, a, z) \ dz \ da\\
&\geq \mathbb{E}_{q(a,y,z|x)} \bigg [\log \frac{p(x,y,a,z)}{q(a,y,z |x)} \bigg ] = - \mathcal{U}(x)
\end{align}

Where we decompose the q-distribution into its constituent parts. $q(a, y, z|x) = q(z|a,y,x)q(y|a,x)q(a|x)$, which is also what can be seen in the figure.

The distribution over $a$ is similar to $z$ in the sense that it is also a diagonal Gaussian distribution. However by introducing the auxiliary variable we allow for $z$ to become arbitrarily complex - something we can also see when using normalizing flows.

In [3]:
from datautils import get_mnist

# Only use 10 labelled examples per class
# The rest of the data is unlabelled.
labelled, unlabelled, validation = get_mnist(location="./", batch_size=64, labels_per_class=10)
alpha = 0.1 * (len(unlabelled) + len(labelled)) / len(labelled)

def binary_cross_entropy(r, x):
    return -torch.sum(x * torch.log(r + 1e-8) + (1 - x) * torch.log(1 - r + 1e-8), dim=-1)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.9, 0.999))



In [4]:
from itertools import cycle
from inference import SVI, DeterministicWarmup

# We will need to use warm-up in order to achieve good performance.
# Over 200 calls to SVI we change the autoencoder from
# deterministic to stochastic.
beta = DeterministicWarmup(n=200)


if cuda: model = model.cuda()
elbo = SVI(model, likelihood=binary_cross_entropy, beta=beta)

The library is conventially packed with the `SVI` method that does all of the work of calculating the lower bound for both labelled and unlabelled data depending on whether the label is given. It also manages to perform the enumeration of all the labels.

Remember that the labels have to be in a *one-hot encoded* format in order to work with SVI.

In [5]:
from torch.autograd import Variable

for epoch in range(10):
    model.train()
    total_loss, accuracy = (0, 0)
    for (x, y), (u, _) in zip(cycle(labelled), unlabelled):
        # Wrap in variables
        x, y, u = Variable(x), Variable(y), Variable(u)
        if cuda:
            # They need to be on the same device and be synchronized.
            x, y = x.cuda(device=0), y.cuda(device=0)
            u = u.cuda(device=0)

        L = -elbo(x, y)
        print(L)
        U = -elbo(u)

        # Add auxiliary classification loss q(y|x)
        logits = model.classify(x)
        
        # Regular cross entropy
        classication_loss = torch.sum(y * torch.log(logits + 1e-8), dim=1).mean()

        J_alpha = L - alpha * classication_loss + U

        J_alpha.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += J_alpha.item()
        accuracy += torch.mean((torch.max(logits, 1)[1].data == torch.max(y, 1)[1].data).float())
        
    if epoch % 1 == 0:
        model.eval()
        m = len(unlabelled)
        print("Epoch: {}".format(epoch))
        print("[Train]\t\t J_a: {:.2f}, accuracy: {:.2f}".format(total_loss / m, accuracy / m))

        total_loss, accuracy = (0, 0)
        for x, y in validation:
            x, y = Variable(x), Variable(y)

            if cuda:
                x, y = x.cuda(device=0), y.cuda(device=0)

            L = -elbo(x, y)
            U = -elbo(x)

            logits = model.classify(x)
            classication_loss = -torch.sum(y * torch.log(logits + 1e-8), dim=1).mean()

            J_alpha = L + alpha * classication_loss + U

            total_loss += J_alpha.data[0]

            _, pred_idx = torch.max(logits, 1)
            _, lab_idx = torch.max(y, 1)
            accuracy += torch.mean((torch.max(logits, 1)[1].data == torch.max(y, 1)[1].data).float())

        m = len(validation)
        print("[Validation]\t J_a: {:.2f}, accuracy: {:.2f}".format(total_loss / m, accuracy / m))

elbo before torch.Size([64])
elbo after torch.Size([64])
tensor(547.0621, grad_fn=<NegBackward>)
elbo before torch.Size([640])
elbo after torch.Size([640])
elbo before torch.Size([36])
elbo after torch.Size([36])
tensor(544.2576, grad_fn=<NegBackward>)
elbo before torch.Size([640])
elbo after torch.Size([640])
elbo before torch.Size([64])
elbo after torch.Size([64])
tensor(540.6196, grad_fn=<NegBackward>)
elbo before torch.Size([640])
elbo after torch.Size([640])
elbo before torch.Size([36])
elbo after torch.Size([36])
tensor(538.0085, grad_fn=<NegBackward>)
elbo before torch.Size([640])
elbo after torch.Size([640])
elbo before torch.Size([64])
elbo after torch.Size([64])
tensor(534.1075, grad_fn=<NegBackward>)
elbo before torch.Size([640])
elbo after torch.Size([640])
elbo before torch.Size([36])
elbo after torch.Size([36])
tensor(531.1522, grad_fn=<NegBackward>)
elbo before torch.Size([640])
elbo after torch.Size([640])
elbo before torch.Size([64])
elbo after torch.Size([64])
tensor(

elbo before torch.Size([640])
elbo after torch.Size([640])


EOFError: 

KeyboardInterrupt: 

## Conditional generation

When the model is done training you can generate samples conditionally given some normal distributed noise $z$ and a label $y$.

*The model below has only trained for 10 iterations, so the perfomance is not representative*.

In [None]:
from utils import onehot
model.eval()

z = Variable(torch.randn(16, 32))

# Generate a batch of 5s
y = Variable(onehot(10)(5).repeat(16, 1))

x_mu = model.sample(z, y)

In [None]:
f, axarr = plt.subplots(1, 16, figsize=(18, 12))

samples = x_mu.data.view(-1, 28, 28).numpy()

for i, ax in enumerate(axarr.flat):
    ax.imshow(samples[i])
    ax.axis("off")

In [None]:
import numpy as np
import torch

In [None]:
array_2d = torch.Tensor(np.array([[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]]))
array_len = torch.Tensor([6,5,3])

m = ~(torch.ones(array_2d.size()).cumsum(dim=1).t() > array_len).t()
array_2d[m]

In [None]:
len(array_len)

In [None]:
def sequence_mask(lengths, maxlen, dtype=torch.bool):
    if maxlen is None:
        maxlen = lengths.max()
    mask = ~(torch.ones((len(lengths), maxlen)).cumsum(dim=1).t() > lengths).t()
    mask.type(dtype)
    return mask

sequence_mask(array_len, 9)

In [None]:
def log_poisson_loss(targets, log_input, compute_full_loss=False):
    """Computes log Poisson loss given `log_input`.
    Gives the log-likelihood loss between the prediction and the target under the
    assumption that the target has a Poisson distribution.
    Caveat: By default, this is not the exact loss, but the loss minus a
    constant term [log(z!)]. That has no effect for optimization, but
    does not play well with relative loss comparisons. To compute an
    approximation of the log factorial term, specify
    compute_full_loss=True to enable Stirling's Approximation.
    For brevity, let `c = log(x) = log_input`, `z = targets`.  The log Poisson
    loss is
        -log(exp(-x) * (x^z) / z!)
      = -log(exp(-x) * (x^z)) + log(z!)
      ~ -log(exp(-x)) - log(x^z) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
          [ Note the second term is the Stirling's Approximation for log(z!).
            It is invariant to x and does not affect optimization, though
            important for correct relative loss comparisons. It is only
            computed when compute_full_loss == True. ]
      = x - z * log(x) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
      = exp(c) - z * c [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
    Args:
    targets: A `Tensor` of the same type and shape as `log_input`.
    log_input: A `Tensor` of type `float32` or `float64`.
    compute_full_loss: whether to compute the full loss. If false, a constant
      term is dropped in favor of more efficient optimization.
    name: A name for the operation (optional).
    Returns:
    A `Tensor` of the same shape as `log_input` with the componentwise
    logistic losses.
    Raises:
    ValueError: If `log_input` and `targets` do not have the same shape.
    """
    if targets.size() != log_input.size():
        raise ValueError(
            "log_input and targets must have the same shape (%s vs %s)" %
            (log_input.size(), targets.size()))

    result = torch.exp(log_input) - log_input * targets
    if compute_full_loss:
        # need to create constant tensors here so that their dtypes can be matched
        # to that of the targets.
        point_five = 0.5  # constant_op.constant(0.5, dtype=targets.dtype)
        two_pi = 2 * math.pi  # constant_op.constant(2 * math.pi, dtype=targets.dtype)

        stirling_approx = (targets * torch.log(targets)) - targets + (
                point_five * torch.log(two_pi * targets))
        zeros = torch.zeros_like(targets, dtype=targets.dtype)
        ones = torch.ones_like(targets, dtype=targets.dtype)
        cond = (targets >= zeros) & (targets <= ones)  # math_ops.logical_and(targets >= zeros, targets <= ones)
        result += torch.where(cond, zeros, stirling_approx)
    return result

In [None]:
targets = torch.Tensor([[1,2,3,4], [5,6,7,8]])
log_input = torch.Tensor([[2,3,4,5,5], [6,7,8,9,5]])
log_poisson_loss(targets, log_input)

In [None]:
import tensorflow as tf

In [None]:
tf.__version__

In [None]:
tf.enable_eager_execution()

In [None]:
tf.nn.log_poisson_loss(targets.numpy(), log_input.numpy())

In [None]:
A = np.random.randn(6,4) + np.random.uniform(6,4) + np.random.gamma(6,4) / 5

In [None]:
A.mean(), A.var()

In [None]:
Ap = torch.Tensor(A)
At = tf.convert_to_tensor(A)

In [None]:
torch.nn.BatchNorm1d(4)(Ap)

In [None]:
m, v = tf.nn.moments(At, 0)
tf.nn.batch_normalization(At, m, v, offset=0.0, scale=1.0, variance_epsilon=1e-6)