<a href="https://colab.research.google.com/github/ParticleEM/ParEM_neural_latent_variable_model/blob/master/notebooks/MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import modules

In [1]:
# Install missing modules
%%capture
!pip install torchtyping

In [2]:
# Import standard modules
import torch
import numpy as np
import sys
import matplotlib.pyplot as plt
import argparse
#from pathlib import Path
from google.colab import drive

In [3]:
# Import custom modules
!rm -rf ParEM_neural_latent_variable_model
!git clone https://pareem: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@github.com/ParticleEM/ParEM_neural_latent_variable_model.git
sys.path.append("/content/ParEM_neural_latent_variable_model/")
from parem.model import G
from parem.pga import PGA
from parem.dataset_loaders import get_mnist

Cloning into 'ParEM_neural_latent_variable_model'...
remote: Enumerating objects: 74, done.[K
remote: Counting objects: 100% (74/74), done.[K
remote: Compressing objects: 100% (66/66), done.[K
remote: Total 74 (delta 27), reused 28 (delta 7), pack-reused 0[K
Unpacking objects: 100% (74/74), done.


# Set config variables

In [4]:
# Declare dicitonary-like object for storing config variables:
args = argparse.Namespace()

# Data setttings
args.n_images = 1000  # M: training set size 

# Training settings
args.n_epochs = 500 # K: total number of iterations
args.n_batch = 128 # M_b: batch size for theta updates
args.seed = 1 # Seed for PRNG
# Device on which to carry out computations:
args.device = "cuda" if torch.cuda.is_available() else "cpu"

# Model Settings
args.x_dim = 10  # d_x: dimension of latent space
args.likelihood_var = 0.3 ** 2  # sigma^2

# PGA Settings
args.h = 5e-5 # h: step size 
args.lambd = 1e-3 / (args.h * args.n_images)  # lambda
args.n_particles = 10 # N

# Load dataset

In [5]:
drive.mount("/content/gdrive", force_remount=False) # Mount drive to VM in colab
dataset = get_mnist('/content/mnist', args.n_images)  # Load dataset

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).




# Define and train model

In [None]:
# Define model:
model = G(args.x_dim, sigma2=args.likelihood_var, nc=1).to(args.device)

# Define training algorithm:
pga = PGA(model, dataset, args.h, args.lambd, args.n_particles)

# Split dataset into batches for training:
training_batches = torch.utils.data.DataLoader(dataset, batch_size=args.n_batch, 
                                               shuffle=True, pin_memory=True)

# Train:
losses = []
for epoch in range(args.n_epochs):
  # model.train()
  avg_loss = 0
  for imgs, idx in training_batches:
      imgs = imgs.to(device=args.device)
      loss = pga.step(imgs, idx)
      avg_loss += loss
      print(".", end='')
  avg_loss = avg_loss / len(training_batches) 
  losses.append(avg_loss)
  print(f"Epoch {epoch}: Loss {avg_loss}")

In [None]:
# Import modules necessary for training loop
%%capture
import pickle
from torchvision.utils import make_grid
import time

In [None]:

to_range_0_1 = lambda x: (x + 1.) / 2.


  with torch.no_grad():
    n_cols = 8
    n_rows = 8
    mean = torch.mean(pga._particles, [0, 1, 3, 4])
    cov = torch.cov(pga._particles.flatten(0,1).flatten(1, 3).transpose(0, 1))
    normal_approx = torch.distributions.multivariate_normal.MultivariateNormal(loc = mean, covariance_matrix=cov)
    z = normal_approx.sample(sample_shape=torch.Size([n_cols * n_rows])).unsqueeze(-1).unsqueeze(-1)
    samples = to_range_0_1(model(z.to(DEVICE)))
    grid = make_grid(samples)
    fig = show(grid)
    samples = wandb.Image(grid)
    (CHECKPOINT_DIR / wandb.run.name / "grid").mkdir(exist_ok=True, parents=True)
    plt.savefig(CHECKPOINT_DIR / wandb.run.name / "grid" / f"{epoch}_samples.png", bbox_inches='tight')
    plt.close(fig)

    model.eval()
    torch.random.manual_seed(1)
    original_img = to_range_0_1(train.dataset[0][0].unsqueeze(0))
    particle_img = to_range_0_1(model(pga._particles[0, :10].to(DEVICE))).to(original_img.device)
    grid = make_grid(torch.concat([original_img, particle_img], dim=0))
    particles = wandb.Image(grid)

    mse_n_samples = 100
    mse_n_particles = args.n_particles
    original_img = to_range_0_1(dataset[:mse_n_samples][0].unsqueeze(1))
    particle_img = to_range_0_1(model(pga._particles[:mse_n_samples, :mse_n_particles].contiguous().to(DEVICE))).to(original_img.device)
    assert original_img.shape == torch.Size([mse_n_samples, 1, 1, 32, 32])
    assert particle_img.shape == torch.Size([mse_n_samples, mse_n_particles, 1, 32, 32])
    mse = (((particle_img - original_img) ** 2).sum([-1, -2, -3]).mean()).item()

  if epoch % 5 == 0:
    n_missing_img = 10
    missing_imgs = dataset[:n_missing_img][0]
    init_x = torch.randn(n_missing_img, args.x_dim, 1, 1, requires_grad=True)
    opt = torch.optim.Adam([init_x], 1e-2)
    mse = torch.nn.MSELoss()
    missing_mask = torch.zeros_like(missing_imgs, dtype=torch.bool)

    for i in range(10, 22):
      for j in range(10, 22):
            missing_mask[..., i, j] = True

    for i in range(1000):
      opt.zero_grad()
      filled_imgs = model.forward(init_x.to(DEVICE)).to('cpu')
      loss = mse(filled_imgs[~missing_mask], missing_imgs[~missing_mask])
      loss.backward()
      opt.step()


    filled_imgs = to_range_0_1(filled_imgs).expand(-1, 3, -1, -1)
    missing_imgs = to_range_0_1(missing_imgs).expand(-1, 3, -1, -1)
    input = missing_imgs.detach().clone()
    input[missing_mask.expand(-1, 3, -1, -1)] = 0.2

    for i in range(n_missing_img):
      grid = make_grid(torch.concat([input[[i]], filled_imgs[[i]], missing_imgs[[i]]], dim=0))
      fig = show(grid)
      (CHECKPOINT_DIR / wandb.run.name / "impaint" / f"{epoch}").mkdir(exist_ok=True, parents=True)
      plt.savefig(CHECKPOINT_DIR / wandb.run.name / "impaint" / f"{epoch}" / f"{i}.png", bbox_inches='tight')
      plt.close(fig)

  if epoch > 2 and args.early_stopping:
    if epoch - np.argmin(losses) > 20:
      print("Early Stop")
      break;

    # import matplotlib.pyplot as plt
    # particles = pga._particles[:, :mse_n_particles].flatten(0,1).flatten(-3,-1).cpu()
    # plt.scatter(particles[:,0], particles[:,1])
    # plt.show()
  plt.close('all')

In [None]:
#@title Plotting function
import torchvision.transforms.functional as F

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, dpi=400)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    return fig

In [None]:
plt.close('all')