# VAE for MNIST

In [None]:
from AEML import AEML
import torch
from torch import nn
from VAEBase import VAEBase
import torchvision as tv

First, download our data to `./data`:

In [None]:
data = torch.utils.data.DataLoader(
    tv.datasets.MNIST('./data', download=True,
                      transform=tv.transforms.ToTensor()),
    batch_size=128, shuffle=True, num_workers=8)

`VAENormal` is a base class of all our variational autoencoders. Write your code between the three pairs of `YOUR CODE HERE` comments.

- `VAENormal.forward()` receives an input image `x`, `enc_...()`odes it, etc., and returns the result of `self.dec()`.
- `BernoulliLoss.__call__()` returns the scalar loss associated with the generated Bernoulli parameter array `xz` and the input image `x`.

Both methods actually receive a *batch* of images in the form of a `torch.Tensor` whose first dimension runs over the instances. Except for the computation of the scalar end result of the loss function, your code should hardly have to care; it should mostly read like it receives individual images.

The same solution should work for all exercises.

Finally,

- the `VAE` class must be implemented with a suitable network architecture further down.

In [None]:
class VAENormal(VAEBase):
    def __init__(self, enc_mu, enc_logsigma, dec):
        super(VAENormal, self).__init__()
        self.distrib_pz = torch.distributions.Normal(0, 1) # q(z|x), p(z)
        self.enc_mu = enc_mu
        self.enc_logsigma = enc_logsigma
        self.dec = dec

    def forward(self, x):
        x = nn.Flatten()(x)
        ### BEGIN YOUR CODE HERE
        ### END YOUR CODE HERE

    def encode_mu(self, x):
        return self.enc_mu(nn.Flatten()(x))


class BernoulliLoss: # p(x|z) is Bernoulli; q(z|x) and p(z) are Normal
    def __init__(self, vae):
        self.vae = vae
        self.BCE = nn.BCELoss(reduction='none')

    def __call__(self, xz, x):
        ### BEGIN YOUR CODE HERE
        ### END YOUR CODE HERE

Define your VAE encoder and decoder networks, and generate an instance.

If you want to use a feature extractor that is common to both the $\mu$ and $\log\sigma$ parts of the encoder, you can avoid altering the `VAENormal` base class by overriding its `forward()` (calling `super().forward(self.enc_...(...))`) and `encode_mu()` methods.

In [None]:
zdim = 2

class VAE(VAENormal):
    ### BEGIN YOUR CODE HERE
    ### END YOUR CODE HERE


model = VAE()
# model.load('VAE-MNIST')
ml = AEML(data, model, BernoulliLoss(model),
          torch.optim.Adam(model.parameters(), lr=1e-3))

Now train the model. Call this cell repeatedly if you like.

In [None]:
ml.run(10)
#model.save('VAE-MNIST')

In [None]:
#model.load('VAE-MNIST')
ml.plotEncDataset()
model.plotDecRandom(zdim)
model.plotDecGrid(zdim)