**Warning: Work in progress.** Please keep in that the following is preliminary. Over the following weeks, we will keep refining the current results, adding further results, and cleaning up the code and presentation up until the camera-ready date, so please check back later. Should the paper be accepted we will include this material in a revision.


# Neural Latent Variable models for image synthesis and in-painting

In this example we consider the problem of training neural latent variable models for image synthesis and in-painting tasks. 

### Dataset description

Our datasets are comprised of $M$ images $y = (y^{m})_{m=1}^M$. We consider two datasets:

- MNIST containing $70,000$ $d_y:=28\times 28$ images of hand-written digits: http://yann.lecun.com/exdb/mnist/

- CelebA containing $202,599$ $d_y:=32\times 32$ images of faces of celebrities: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html

In either case we do not use the entire dataset but a randomly subsampled subset of. In what follows, $M$ denotes the size of this training set. Furthermore, all images' pixel values are normalized so that they lie in $[0,1]$.

## Model description

The model assumes that each image $y^m$ was generated independtly of the others and that it was generated by:

1. drawing a latent variable $x^m$ from a zero-mean unit-variance Gaussian distribution $p(x):=\mathcal{N}(x;0,I_{d_x})$ on a low dimensional latent space ($\mathbb{R}^{d_x}$ with $d_x$ ranging from $5$ to no more than $100$);
2. mapping $x^m$ to the image space via a neural network $f_\theta$ parametrized by some parameters $\theta$ in $\mathbb{R}^{D_\theta}$;
3. adding zero-mean $\sigma^2$-variance Gaussian noise: $y^m=f_\theta(x^m)+\epsilon^m$ where $(\epsilon^m)_{m=1}^M$ is a an i.i.d. sequence with law $\mathcal{N}(0,I_{d_y})$.

In full, the model's density is given by
$$
p_\theta (x,y) = \prod_{m=1}^M p_\theta(x^{m}, y^{m})\qquad\qquad(1)$$
where
$$
p_\theta(x^m,y^m)= p_\theta(y^m|x^m)p(x^m),\quad\textrm{with}\quad p_\theta(y^m|x^m) := \mathcal{N}(y^m|f_\theta(x^m), \sigma^2 I_{d_y}).
$$

For $f_\theta$ we use a convolutional neural network with an architecture emulating that used in \[[1](https://link.springer.com/chapter/10.1007/978-3-030-58539-6_22)\], see below for details. In total, it has $3068$ parameters ($D_\theta=3068$).

### Network architecure

The neural networks has is composed of $4$ basics types of layers:

*   $l_\theta$: fully-connected linear layers,
*   $c_\theta$: convolutional layers,
*   $c_\theta^T$: transpose convolutional layers,
*   $b_\theta:$ batch normalization layers.

These are interweaved with GELU activation functions $\phi$. In particular, they are assembled to $2$ create further types of layers:

*   'projection' layers $\pi_\theta:=\phi \circ b_\theta\circ c_\theta \circ \phi\circ b_\theta\circ l_\theta$;
*   'deterministic' layer $d_\theta=\phi \circ b_\theta \circ c_\theta \circ \phi\circ b_\theta \circ c_\theta + I$ where $I$ denotes the identity operator (in other words, the layer has a skip connection).

The network itself then consists of a projection layer, followed by two deterministic layers, followed by transpose convolutional layer and a $\tanh$ activation:

$$f_\theta =  \tanh\circ c_\theta^T\circ d_\theta \circ d_\theta \circ \pi_\theta$$

For more details, please the code in [model.py](). 

## Model training

Training the model entails searching for parameters $\theta_*$ maximizing the marginal likelihood $\theta\mapsto p_\theta(y):=\int p_\theta(x,y)dx$ (or, at least, for a local maximum thereof). To do so, we use PGA slightly modified to better cope with the high evaluation cost of the log-likelihood's, $\ell(\theta,x):=\log(p_\theta(x,y)$'s, gradients. In particular, in the $\theta$-update we replace $\nabla_{\theta} \ell(\theta,x)$ unbiased estimator thereof obtain by subsampling the training set:

\begin{align*}\nabla_{\theta} \ell(\theta,x)=\sum{m=1}^M \nabla_\theta\log(p_\theta(x^m,y^m))=M\left[\frac{1}{M}\sum{m=1}^M \nabla_\theta\log(p_\theta(x^m,y^m))\right]\approx M\left[\frac{1}{|\cal{B}|}\sum_{m\in\mathcal{B}}\nabla_\theta\log(p_\theta(x^m,y^m))\right]=\frac{M}{|\cal{B}|}\sum_{m\in\mathcal{B}}\nabla_\theta\log(p_\theta(x^m,y^m)),\end{align*}

where $\mathcal{B}$ denotes a random subset of $[M]:=\{1,\dots, M\}$ and $|\mathcal{B}|$ its cardinality. To mitigate the varying magnitudes among $\nabla_\theta\log(p_\theta(x^m,y^m))$'s entries and improve the learning, we use a modified version of the 'heuristic fix' discussed in Section 2.1 of manuscript: we rescale each entry by scalar only that this time we allow the scalars to vary with the iteration count choosing them like in RMSprop \[[2]()\]. In full, we update the parameter estimates $\theta_k$ using

\begin{align*}
    \theta_{k+1} &= \theta_k + \frac{h}{N}\sum_{n=1}^N \sum_{m\in\mathcal{\cal{B}_k}} h\lambda\Lambda_k\nabla_\theta \log p_{\theta_k}
(X^{n,m}_k, y^{m}) 
\end{align*}

where $(X^n)_{n=1}^N=((X^{n,m})_{m=1}^M)_{n=1}^N$ denotes the particle cloud at the $k^{th}$ iteration, $\mathcal{B}_k$ indexes the minibatch used in the $k^{th}$ iteration, $\Lambda_k$ a diagonal matrix containing the RMSprop step sizes, and $\lambda$ is a scalar we tune by hand to mitigate differences between the scales of log-likelihood's $\theta$ and $x$ gradients (it ranges between $0.01$ and $1$).

Because the dimensionality of the latent variables is $30$--$600$ times smaller than that of the parameters, the cost of the particle updates is $30$--$600$ smaller than that of the $\theta$ updates (without subsampling) and, so, we do not have to subsample the $x$ gradients. In particular, we update the particles just as in standard PGA. Given (1), these updates read 

\begin{align*}
X^{n,m}_{k+1}&=X^{n,m}_k + h\nabla_x \log p_{\theta_k}
(X^{n,m}_k, y^{m}) + \sqrt{2h} W^{n,m}_k \quad \forall m\in\[M\],\enskip n\in \[N\].
\end{align*}



Describe stopping criterion: early stop bla bla.


## Code links

## Image synthesis



## In-painting

In [None]:
# Declare dicitonary like object for storing config variables:
import argparse
args = argparse.Namespace()

args.seed = 1  # Seed for PRNs

# Data setttings
args.n_images = 1000  # M

# Training settings
args.n_epochs = 500  # K
args.n_batch = 128  # M_b
args.n_sampler_batch = 750
args.early_stopping = True  # Turn on early stopping

# Model Settings
args.x_dim = 64  # D_x
args.theta_opt = 'rmsprop'  # Lambda premultiplying matrix
args.likelihood_var = 0.15 ** 2  # \sigma^2

# EM Settings
args.theta_step_size = 1e-3  # h_\theta
args.q_step_size = 1e-4  # h_q
args.clip_grad = True
args.n_particles = 10  # N

# Import modules

First, we load the modules we need:

In [None]:
# Install missing modules
%%capture
!pip install torchtyping
!pip install functorch

In [None]:
# Import standard modules
import torch
import numpy as np
import sys

In [None]:
# Import custom modules
!rm -rf ParEM_VAE
!git clone https://pareem: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@github.com/ParticleEM/ParEM_VAE.git
sys.path.append("/content/ParEM_VAE/")
from parem.model import G
from parem.pga import PGA, optimisers

Cloning into 'ParEM_VAE'...
remote: Enumerating objects: 279, done.[K
remote: Counting objects: 100% (49/49), done.[K
remote: Compressing objects: 100% (31/31), done.[K
remote: Total 279 (delta 20), reused 44 (delta 18), pack-reused 230[K
Receiving objects: 100% (279/279), 41.77 KiB | 4.64 MiB/s, done.
Resolving deltas: 100% (128/128), done.


# Set paths

In [None]:
# Mounts drive to VM in colab.
from pathlib import Path
from google.colab import drive
drive.mount("/content/gdrive", force_remount=False)

# Path where dataset will be stored:
GDRIVE_DATASET_PATH = Path("/content/gdrive/MyDrive/data/vae")

# Path where checkpoints will be saved:
CHECKPOINT_DIR = Path("/content/gdrive/MyDrive/particle-em/svhn")
CHECKPOINT_DIR.mkdir(exist_ok=True, parents=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

Mounted at /content/gdrive


# Load dataset

In [None]:
#@title Load dataset
from parem.svhn import get_svhn

dataset = get_svhn(GDRIVE_DATASET_PATH, args.n_images)

  warn(f"Failed to load image Python extension: {e}")


Using downloaded and verified file: /content/gdrive/MyDrive/data/vae/train_32x32.mat


In [None]:
#@title Divvy up dataset in batches for training.

train = torch.utils.data.DataLoader(dataset, batch_size=args.n_batch, shuffle=True, pin_memory=True)
larger_batch_train = torch.utils.data.DataLoader(dataset, batch_size=args.n_sampler_batch, shuffle=True, pin_memory=True)

In [None]:
model = G(args.x_dim, sigma2=args.likelihood_var).to(DEVICE)
pga = PGA(model,
          args.n_images,
          larger_batch_train,
          device=DEVICE,
          theta_step_size=args.theta_step_size,
          q_step_size=args.q_step_size,
          n_particles=args.n_particles,
          clip_grad=args.clip_grad,
          theta_opt=args.theta_opt,
          )

In [None]:
# Import modules necessary for training loop
%%capture
!pip install wandb
import wandb
import pickle
from torchvision.utils import make_grid
import time

In [None]:
#@title Main training loop
to_range_0_1 = lambda x: (x + 1.) / 2.

wandb.login()
wandb.init(
    project="particle-em-svhm",
    config = vars(args),
)

wandb.watch(model, log="all", log_freq=10)

losses = []
for epoch in range(args.n_epochs):
  model.train()
  avg_loss = 0
  start = time.time()
  for imgs, idx in train:
      imgs = imgs.to(device=DEVICE)
      loss = pga.step(imgs, idx)
      avg_loss += loss
      print(".", end='')
  end = time.time()
  avg_loss = avg_loss / len(train)

  losses.append(avg_loss)
  if epoch > 2 and args.early_stopping:
    if epoch - np.argmin(losses) > 10:
      print("Early Stop")
      break;

  print(f"Epoch {epoch}: {end - start:2f}: Loss {avg_loss}")

  # Save model
  torch.save(model.state_dict(), CHECKPOINT_DIR / f"{epoch}_model")
  torch.save(pga.theta_opt.state_dict(), CHECKPOINT_DIR / f"{epoch}_opt")
  with open(CHECKPOINT_DIR / f"{epoch}_particles", 'wb') as f:
    pickle.dump(pga._particles, f)
  

  with torch.no_grad():
    model.eval()
    torch.random.manual_seed(1)
    sample = to_range_0_1(model(torch.randn(10, args.x_dim,1,1).to(DEVICE)))
    grid = make_grid(sample)
    samples = wandb.Image(grid)

    original_img = to_range_0_1(train.dataset[0][0].unsqueeze(0))
    particle_img = to_range_0_1(model(pga._particles[0, :10].to(DEVICE))).to(original_img.device)
    grid = make_grid(torch.concat([original_img, particle_img], dim=0))
    particles = wandb.Image(grid)

    mse_n_samples = 100
    mse_n_particles = 10
    original_img = to_range_0_1(dataset[:mse_n_samples][0].unsqueeze(1))
    particle_img = to_range_0_1(model(pga._particles[:mse_n_samples, :mse_n_particles].to(DEVICE))).to(original_img.device)
    assert original_img.shape == torch.Size([mse_n_samples, 1, 3, 32, 32])
    assert particle_img.shape == torch.Size([mse_n_samples, mse_n_particles, 3, 32, 32])
    mse = (((particle_img - original_img) ** 2).sum([-1, -2, -3]).mean()).item()
  wandb.log({'particles': particles,
             'samples': samples,
             "loss" : avg_loss,
             'mse': mse,
             'theta_step_size' : pga.theta_opt.param_groups[0]['lr'],
             })
wandb.finish()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 

··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjenninglim[0m. Use [1m`wandb login --relogin`[0m to force relogin


........Epoch 0: 4.206381: Loss 46709947.875
........Epoch 1: 3.946499: Loss 19369643.375
........Epoch 2: 3.942541: Loss 11710681.625
........Epoch 3: 3.935411: Loss 9691469.5625
........Epoch 4: 3.942672: Loss 7203934.9375
........Epoch 5: 3.921455: Loss 5638473.375
........Epoch 6: 3.942654: Loss 5491085.0625
........Epoch 7: 3.936682: Loss 4494481.84375
........Epoch 8: 3.943934: Loss 4375911.15625
........Epoch 9: 3.939325: Loss 4216165.15625
........Epoch 10: 3.919437: Loss 3672923.21875
........Epoch 11: 3.936026: Loss 3611881.6875
........Epoch 12: 3.936669: Loss 3467806.5625
........Epoch 13: 3.933230: Loss 3397499.40625
........Epoch 14: 3.941407: Loss 3179980.125
........Epoch 15: 3.920175: Loss 3127742.96875
........Epoch 16: 3.935094: Loss 3063833.71875
........Epoch 17: 3.937531: Loss 2834152.96875
........Epoch 18: 3.932847: Loss 2747621.65625
........Epoch 19: 3.982827: Loss 2754412.90625
........Epoch 20: 3.914003: Loss 2764356.25
........Epoch 21: 3.939615: Loss 28178

KeyboardInterrupt: ignored