# AGNFinder CVAE Framework Test (2)

**Author:** Maxime Robeyns (2021) <maximerobeyns@gmail.com>

**Digit Classification**

In this notebook, we use the AGNFinder CVAE as a 'classifier' or discriminative model, by conditioning on pixel data, and generating the most likely digit label. This complements the first test where we conditioned on the one-hot encoded labels rather than pixel data.

This is merely intended to test the CVAE implementation and demonstrate that it works correctly; not reach SOTA on MNIST classification which is awefully *passé*...

In [None]:
import os
import torch as t
import torch.nn as nn

from IPython import display
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, Subset

import agnfinder.inference as inf

from agnfinder import config as cfg
from agnfinder.types import arch_t
from agnfinder.inference.utils import _load_mnist, _onehot, Squareplus
from agnfinder.inference.base import CVAE, CVAEParams

In [None]:
try: # One-time setup
    assert(_SETUP)
except NameError:
    import warnings
    warnings.filterwarnings('ignore', category=UserWarning)  # see torchvision pr #4184
    cfg.configure_logging()
    while not '.git' in os.listdir():
        os.chdir("../")
    dtype = t.float64
    device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
    if device == t.device("cuda"):
        print('Training on GPU')
        # !nvidia-smi
    else:
        print('CUDA is unavailable; training on CPU.')
    _SETUP = True

This ugly class creates a pretty visualisation of the test accuracy and some classification examples during training.

In [None]:
class LogVisual():
    def __init__(self, fig, display, test_loader):
        self.fig = fig
        self.gs = self.fig.add_gridspec(2, 10)
        self.testset = test_loader.dataset
        self.disp = display
        self.evals = []  # TODO include uncertainty estimates

    def __call__(self, cvae: CVAE):
        self.fig.clear()
        with t.no_grad():
            bs = 1000
            ts = Subset(self.testset, t.randint(len(self.testset), (bs,)))
            ax1 = self.fig.add_subplot(self.gs[0, :])
            ax1.set_xlabel('Logging Iteration')
            ax1.set_ylabel('Classification Accuracy')
            ax1.set_ylim(0.5, 1.)
            for x, y in DataLoader(ts, batch_size=bs):
                x = x.view(-1, 28*28).to(cvae.device, cvae.dtype)
                y = y.to(cvae.device, cvae.dtype)
                z = cvae.prior.get_dist(None).sample((bs,))
                ys = cvae.decoder.forward(t.cat((z, x), -1)).argmax(1)
                acc = (t.sum(y == ys) / bs).item()
                self.evals.append(acc)
                ax1.set_title(f'{bs} sample test accuracy: {acc*100:.2f}%', color='black')
                ax1.plot(self.evals)
                for i in range(10):
                    ax_tmp = self.fig.add_subplot(self.gs[1, i])
                    ax_tmp.axis('off')
                    ax_tmp.imshow(-x[i].view(28, 28).cpu().data.numpy(), 
                                  cmap=plt.get_cmap('PuBu'))
                    ax_tmp.set_title(f'prediction: {ys[i]}', fontsize=15, color='black')
            self.fig.tight_layout()
            self.disp.update(self.fig)

## CVAE Setup

The ELBO is

\begin{align*}
    \mathcal{L}_{\text{CVAE}}(\theta, \phi; x, y) &= 
    \mathbb{E}_{q_{\phi}(z \vert y, x)}\left[\log p_{\theta}(y \vert z, x)\right]
     - D_{\text{KL}}\left[q_{\phi}(z \vert y, x) \Vert p_{\theta}(z \vert x)\right] \\
       &= \mathbb{E}_{q_{\phi}(z \vert y, x)}\big[\log p_{\theta}(y \vert z, x) + \log p_{\theta}(z \vert x) - \log q_{\phi}(z \vert y, x)\big] \\
       &\doteq \mathbb{E}\big[\mathcal{L}_{\text{logpy}} +
       \mathcal{L}_{\text{logpz}} - \mathcal{L}_{\text{logqz}} \big].
\end{align*}

Here we use a standard Gaussian prior which we do not condition on anything; $p_{\theta}(z \vert x) = \mathcal{N}(z; 0, \mathbf{I})$. We also use a factorised Gaussian for the recognition model $q_{\phi}(z \vert y, x)$.

For the generator model $p_{\theta}(y \vert z, x)$, we use a multinomial distribution, which is implemented in `agnfinder.inference.distributions` as a thin wrapper around the native PyTorch `Multinomial` distribution.

The parameters of this CVAE are fairly similar to the first test, however now the `cond_dim` is 784 since we are conditioning on the pixel data to obtain our 'discriminative model'. The `data_dim` is 10 for one-hot encoded image labels.

We use a softmax activation funcitnon on the output of the decoder network, since this is used to parametrise the multinomial distribution.

In [None]:
class MNIST_label_params(CVAEParams):
    cond_dim = 28*28  # x; dimension of MNIST image pixel data
    data_dim = 10  # y; size of one-hot encoded digit labels
    latent_dim = 2  # z

    prior = inf.StandardGaussianPrior
    prior_arch = None

    encoder = inf.FactorisedGaussianEncoder
    enc_arch = arch_t([data_dim + cond_dim, 256], [latent_dim, latent_dim],
                      nn.SiLU(), batch_norm=True)

    decoder = inf.MultinomialDecoder
    dec_arch = arch_t([latent_dim + cond_dim, 256], [data_dim],
                      nn.SiLU(), [nn.Softmax()], batch_norm=True)

In [None]:
class MNIST_label_cvae(CVAE):

    def preprocess(self, x: t.Tensor, y: t.Tensor) -> tuple[t.Tensor, t.Tensor]:
        if x.dim() > 2:
            x = x.view(-1, 28*28)
        x = x.to(self.device, self.dtype)
        y = _onehot(y, 10).to(self.device, self.dtype)
        return x, y
    
    # trainmodel and ELBO methods kept as in base class

In [None]:
f1 = plt.figure(figsize=(20, 5), dpi=200, facecolor=(1, 1., 1., .8))
disp = display.display("", display_id=True)

train_loader, test_loader = _load_mnist()
cvae = MNIST_label_cvae(MNIST_label_params(), device=device, dtype=dtype, 
                        logging_callbacks=[LogVisual(f1, disp, test_loader)])

cvae.trainmodel(train_loader, epochs=1)
plt.close()