<a href="https://colab.research.google.com/github/ParticleEM/ParEM_neural_latent_variable_model/blob/master/notebooks/MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Description

Introductory blabla

### Dataset description:

Dataset consists of $M$ images $y = (y^{m})_{m=1}^M$.

### Model description:

$$
p_\theta (x,y) = \prod_{m=1}^M p_\theta(x^{m}, y^{m})
$$
where
$$
p_\theta(x,y)= p_\theta(y|x)p(x)
$$
with
$$
p_\theta(y|x) = \mathcal{N}(y|\mu_\theta(x), \sigma^2 I),
$$
where $\mu_\theta(\cdot)$ is a neural network parameterised by $\theta$, and $p(x) = \mathcal{N}(x|0,I)$. 

The neural net consists of 

$$\mu_\theta =  \tanh\circ c_\theta\circ d_\theta \circ d_\theta \circ proj_\theta$$

where


*   $\phi$ is a GELU activation function.
*   $c_\theta$ is a transpose convolutional layer.
*   $proj_\theta=\phi \circ c_\theta \circ \phi\circ l_\theta$. Maps from $\mathbb{R}^{D_x}$ to $4\times 16\times 16$, where $D_x$ is dimension of latent variable.
*   $l_\theta$ is a linear layer.
*   $d_\theta=\phi \circ conv_\theta \circ \phi\circ conv_\theta + I$ (HAS A SKIP CONNECTION).
*   $conv_\theta$ is a convolutional layer.


### Algorithm description:

For $k=1,\dots,K$.
\begin{align*}
    \theta_{k+1} &= \theta_k + \frac{h}{N}\sum_{n=1}^N \sum_{m\in\mathcal{I}} \nabla_\theta \log p_{\theta_k}
(X^{n,m}_k, y^{m}) \\
X^{n,m}_{k+1}&=X^{n,m}_k + h\nabla_x \log p_{\theta_k}
(X^{n,m}_k, y^{m}) + \sqrt{2h} W^{n,m}_k \quad \forall m = 1, .., M, n= 1,..., N.
\end{align*}

where $\mathcal{I}$ is a random subset of $M_b$ images in $\mathcal{D}$.


Describe stopping criterion: early stop bla bla.


# Import modules

First, we load the modules we need:

In [None]:
# Declare dicitonary like object for storing config variables:
import argparse
args = argparse.Namespace()
args.seed = 1 # Seed for PRNs

# Data setttings
args.n_images = 10000 # M

# Training settings
args.n_epochs = 500 # K
args.n_batch = 128 # M_b
args.n_sampler_batch = 750
args.early_stopping = True # Turn on early stopping

# Model Settings
args.x_dim = 10 # D_x
args.theta_opt = 'rmsprop' # Lambda premultiplying matrix
args.likelihood_var = 0.3 ** 2 # \sigma^2

# EM Settings
args.theta_step_size = 1e-3 # h_\theta
args.q_step_size = 5e-5 # h_q
args.clip_grad = False
args.n_particles = 10 # N

# Synthesis settings
args.corrupt_std = 1

In [None]:
# Install missing modules
%%capture
!pip install torchtyping

In [None]:
# Import standard modules
import torch
import numpy as np
import sys
import matplotlib.pyplot as plt

In [None]:
# Import custom modules
!rm -rf ParEM_VAE
!git clone https://pareem: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@github.com/ParticleEM/ParEM_VAE.git
sys.path.append("/content/ParEM_VAE/")
from parem.model import G
from parem.pga import PGA, optimisers

Cloning into 'ParEM_VAE'...
remote: Enumerating objects: 291, done.[K
remote: Counting objects: 100% (61/61), done.[K
remote: Compressing objects: 100% (38/38), done.[K
remote: Total 291 (delta 26), reused 56 (delta 23), pack-reused 230[K
Receiving objects: 100% (291/291), 42.93 KiB | 4.77 MiB/s, done.
Resolving deltas: 100% (134/134), done.


# Set paths

In [None]:
# Mounts drive to VM in colab.
from pathlib import Path
from google.colab import drive
drive.mount("/content/gdrive", force_remount=False)

# Path where checkpoints will be saved:
CHECKPOINT_DIR = Path("/content/gdrive/MyDrive/particle-em/mnist")
CHECKPOINT_DIR.mkdir(exist_ok=True, parents=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


# Load dataset

In [None]:
#@title Load dataset
from parem.mnist import get_mnist

dataset = get_mnist('/content/mnist', args.n_images)



In [None]:
#@title Divvy up dataset in batches for training.

train = torch.utils.data.DataLoader(dataset, batch_size=args.n_batch, shuffle=True, pin_memory=True)
larger_batch_train = torch.utils.data.DataLoader(dataset, batch_size=args.n_sampler_batch, shuffle=True, pin_memory=True)

In [None]:
import torch
from torchtyping import TensorType
from torch.nn.utils import clip_grad_norm_

optimisers = {'sgd': torch.optim.SGD,
              'adagrad': torch.optim.Adagrad,
              'rmsprop': torch.optim.RMSprop,
              }


class PGA:
    def __init__(self,
                 model,
                 n_images: int,
                 dl,
                 q_step_size: float = 1e-2,
                 theta_step_size: float = 1e-2,
                 n_particles: int = 30,
                 device="cpu",
                 clip_grad=False,
                 theta_opt='sgd',):

        self.n_particles = n_particles
        self.model = model
        self.q_step_size = q_step_size
        self.device = device
        self.clip_grad = clip_grad
        self.dl = dl
        self.n_images = n_images

        # Initialize samples
        self._particles = model.init_x([n_images, n_particles],
                                       device=self.device)

        # Declare theta optimiser
        if type(theta_opt) == str:
            self.theta_opt = optimisers[theta_opt](model.parameters(),
                                                   lr=theta_step_size)
        elif isinstance(theta_opt, torch.optim.Optimizer):
            self.theta_opt = theta_opt

    def loss(self,
             images,#: TensorType["n_batch", "image_dimensions": ...],
             particles: TensorType["n_batch", "n_particles","x_dim"]
             ) -> TensorType[()]:
        """
        \frac{M}{N|images|}\sum_{n=1}^N\sum_{m in images}p_{\theta_k}(X_k^{n,m}, y^m)
        """
        log_p = self.model.log_p_v(images, particles)
        assert not log_p.isnan().any(), "log_p is nan."
        return - (1. / images.shape[0]) * log_p.mean()

    def step(self,
             img_batch,#: TensorType["n_batch", "image_dimensions":...],
             idx: TensorType["n_batch"]):

        # Compute theta gradients:
        self.model.train()  # ??
        self.theta_opt.zero_grad()  # Zero theta gradients
        self.model = self.model.requires_grad_(True)  # ??

        # Evaluate loss function:
        loss = self.loss(img_batch, self._particles[idx].to(img_batch.device))

        # Backpropagate theta gradients:
        loss.backward()

        # Clip theta gradients if clipping requested:
        if self.clip_grad:
            clip_grad_norm_(self.model.parameters(), 100)

        # Update particles batch by batch (s.t. device memory is not exceeded):
        self.model.eval()
        self.model = self.model.requires_grad_(False)
        for imgs, idx in self.dl:
            # Select particles to be updated in this iteration:
            sub_particles = (self._particles[idx].detach().clone()
                                 .to(img_batch.device).requires_grad_(True))
            # Send relevant images to device:
            imgs = imgs.to(img_batch.device)

            # Compute x gradients:
            log_p_v = self.model.log_p_v(imgs, sub_particles).sum()
            x_grad = torch.autograd.grad(log_p_v, sub_particles)[0]

            # Take a gradient step for this batch's particles:
            self._particles[idx] += (self.q_step_size
                                     * x_grad.to(self._particles.device))

        # Add noise to all particles:
        self._particles += ((2 * self.q_step_size) ** 0.5
                            * torch.randn_like(self._particles))

        # Update theta:
        self.theta_opt.step()

        # Return value of loss function:
        return loss.item()


In [None]:
model = G(args.x_dim, sigma2=args.likelihood_var, nc=1, use_bn=True).to(DEVICE)
pga = PGA(model,
          args.n_images,
          larger_batch_train,
          device='cpu',
          theta_step_size=args.theta_step_size,
          q_step_size=args.q_step_size,
          n_particles=args.n_particles,
          clip_grad=args.clip_grad,
          theta_opt=args.theta_opt,
          )

In [None]:
# Import modules necessary for training loop
%%capture
!pip install wandb
import wandb
import pickle
from torchvision.utils import make_grid
import time

In [None]:
#@title Plotting function
import torchvision.transforms.functional as F

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, dpi=400)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    return fig

In [None]:
#@title Main training loop
to_range_0_1 = lambda x: (x + 1.) / 2.

wandb.login()
wandb.init(
    project="particle-em-mnist",
    config = vars(args),
)

wandb.watch(model, log="all", log_freq=10)

losses = []
for epoch in range(args.n_epochs):
  model.train()
  avg_loss = 0
  start = time.time()
  for imgs, idx in train:
      imgs = imgs.to(device=DEVICE)
      loss = pga.step(imgs, idx)
      avg_loss += loss
      print(".", end='')
  end = time.time()
  avg_loss = avg_loss / len(train) #/ args.n_images
  losses.append(avg_loss)


  print(f"Epoch {epoch}: {end - start:2f}: Loss {avg_loss}")

  # Save model
  (CHECKPOINT_DIR / wandb.run.name / "model").mkdir(exist_ok=True, parents=True)
  torch.save(model.state_dict(), CHECKPOINT_DIR / wandb.run.name  / "model" / f"{epoch}_model")
  (CHECKPOINT_DIR / wandb.run.name / "particles").mkdir(exist_ok=True, parents=True)
  with open(CHECKPOINT_DIR / wandb.run.name / "particles" / f"{epoch}_particles", 'wb') as f:
    pickle.dump(pga._particles, f)
  

  with torch.no_grad():
    n_cols = 8
    n_rows = 8
    mean = torch.mean(pga._particles, [0, 1, 3, 4])
    cov = torch.cov(pga._particles.flatten(0,1).flatten(1, 3).transpose(0, 1))
    normal_approx = torch.distributions.multivariate_normal.MultivariateNormal(loc = mean, covariance_matrix=cov)
    z = normal_approx.sample(sample_shape=torch.Size([n_cols * n_rows])).unsqueeze(-1).unsqueeze(-1)
    samples = to_range_0_1(model(z.to(DEVICE)))
    grid = make_grid(samples)
    fig = show(grid)
    samples = wandb.Image(grid)
    (CHECKPOINT_DIR / wandb.run.name / "grid").mkdir(exist_ok=True, parents=True)
    plt.savefig(CHECKPOINT_DIR / wandb.run.name / "grid" / f"{epoch}_samples.png", bbox_inches='tight')
    plt.close(fig)

    model.eval()
    torch.random.manual_seed(1)
    original_img = to_range_0_1(train.dataset[0][0].unsqueeze(0))
    particle_img = to_range_0_1(model(pga._particles[0, :10].to(DEVICE))).to(original_img.device)
    grid = make_grid(torch.concat([original_img, particle_img], dim=0))
    particles = wandb.Image(grid)

    mse_n_samples = 100
    mse_n_particles = args.n_particles
    original_img = to_range_0_1(dataset[:mse_n_samples][0].unsqueeze(1))
    particle_img = to_range_0_1(model(pga._particles[:mse_n_samples, :mse_n_particles].contiguous().to(DEVICE))).to(original_img.device)
    assert original_img.shape == torch.Size([mse_n_samples, 1, 1, 32, 32])
    assert particle_img.shape == torch.Size([mse_n_samples, mse_n_particles, 1, 32, 32])
    mse = (((particle_img - original_img) ** 2).sum([-1, -2, -3]).mean()).item()

  if epoch % 5 == 0:
    n_missing_img = 10
    missing_imgs = dataset[:n_missing_img][0]
    init_x = torch.randn(n_missing_img, args.x_dim, 1, 1, requires_grad=True)
    opt = torch.optim.Adam([init_x], 1e-2)
    mse = torch.nn.MSELoss()
    missing_mask = torch.zeros_like(missing_imgs, dtype=torch.bool)

    for i in range(10, 22):
      for j in range(10, 22):
            missing_mask[..., i, j] = True

    for i in range(1000):
      opt.zero_grad()
      filled_imgs = model.forward(init_x.to(DEVICE)).to('cpu')
      loss = mse(filled_imgs[~missing_mask], missing_imgs[~missing_mask])
      loss.backward()
      opt.step()


    filled_imgs = to_range_0_1(filled_imgs).expand(-1, 3, -1, -1)
    missing_imgs = to_range_0_1(missing_imgs).expand(-1, 3, -1, -1)
    input = missing_imgs.detach().clone()
    input[missing_mask.expand(-1, 3, -1, -1)] = 0.2

    for i in range(n_missing_img):
      grid = make_grid(torch.concat([input[[i]], filled_imgs[[i]], missing_imgs[[i]]], dim=0))
      fig = show(grid)
      (CHECKPOINT_DIR / wandb.run.name / "impaint" / f"{epoch}").mkdir(exist_ok=True, parents=True)
      plt.savefig(CHECKPOINT_DIR / wandb.run.name / "impaint" / f"{epoch}" / f"{i}.png", bbox_inches='tight')
      plt.close(fig)

  if epoch > 2 and args.early_stopping:
    if epoch - np.argmin(losses) > 20:
      print("Early Stop")
      break;

    # import matplotlib.pyplot as plt
    # particles = pga._particles[:, :mse_n_particles].flatten(0,1).flatten(-3,-1).cpu()
    # plt.scatter(particles[:,0], particles[:,1])
    # plt.show()
  wandb.log({'particles': particles,
              'samples': samples,
              "loss" : avg_loss,
              'mse': mse,
              'theta_step_size' : pga.theta_opt.param_groups[0]['lr'],
              })
  plt.close('all')

[34m[1mwandb[0m: Currently logged in as: [33mjenninglim[0m. Use [1m`wandb login --relogin`[0m to force relogin


...............................................................................Epoch 0: 372.041398: Loss 1824.8615521781053
...............................................................................Epoch 1: 369.544169: Loss 1215.1942076864123
...............................................................................Epoch 2: 369.358910: Loss 1045.6545549223695
...............................................................................Epoch 3: 369.395399: Loss 886.3660278320312
...............................................................................Epoch 4: 369.397306: Loss 778.2756177684929
...............................................................................Epoch 5: 369.296647: Loss 692.9689045193829
...............................................................................Epoch 6: 369.483078: Loss 638.0797242756131
...............................................................................Epoch 7: 369.405639: Loss 590.8647793154173
.............

In [None]:
plt.close('all')