In [1]:
!git clone https://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@github.com/albertaillet/vnca.git -b file-restructure

Cloning into 'vnca'...
remote: Enumerating objects: 62, done.[K
remote: Counting objects: 100% (62/62), done.[K
remote: Compressing objects: 100% (37/37), done.[K
remote: Total 62 (delta 19), reused 58 (delta 16), pack-reused 0[K
Unpacking objects: 100% (62/62), done.


In [None]:
!pip install equinox einops optax

In [5]:
%load_ext autoreload
%autoreload 2
%cd vnca

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
[Errno 2] No such file or directory: 'vnca'
/content/vnca


In [9]:
# Imports
import equinox as eqx
import jax.numpy as np
from jax.random import PRNGKey, split
from einops import rearrange, repeat
from optax import adam, exponential_decay
import matplotlib.pyplot as plt

from models import BaselineVAE
from data.mnist import load_mnist

# typing
from jax import Array, vmap
from equinox import Module
from typing import Optional, Any
from jax.random import PRNGKeyArray
from optax import GradientTransformation
from typing import Tuple

TARGET_SIZE = 28
MODEL_KEY = PRNGKey(0)
DATA_KEY = PRNGKey(1)

In [18]:
# Define the neural nets
@eqx.filter_value_and_grad
def loss_fn(model: Module, x: Array, key: PRNGKeyArray) -> float:
    keys = split(key, len(x))
    recon_x, mean, logvar = vmap(model)(x, keys)
    recon_loss = np.mean(np.square(recon_x - x))
    kl_loss = -0.5 * np.mean(1 + logvar - np.square(mean) - np.exp(logvar))
    return recon_loss + 100*kl_loss


@eqx.filter_jit
def make_step(model: Module, x: Array, key: PRNGKeyArray, opt_state: tuple, optim: GradientTransformation) -> Tuple[float, Module, Any]:
    loss, grads = loss_fn(model, x, key)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [19]:
# Create model
vae = BaselineVAE(key=MODEL_KEY)

In [None]:
batch_size = 32
lr = exponential_decay(3e-5, 60, 0.1, staircase=True)
opt = adam(lr)
opt_state = opt.init(eqx.filter(vae, eqx.is_array))

train_data, test_data = load_mnist(batch_size=batch_size, key=DATA_KEY)

n_gradient_steps = 2000
steps = range(n_gradient_steps)
train_keys = split(DATA_KEY, n_gradient_steps)

for step, batch, key in zip(steps, train_data, train_keys):
    loss, vae, opt_state = make_step(vae, batch, key, opt_state, opt)
    print(step, loss, end='\r')

In [None]:
plt.imshow(vae.center()[0], cmap='gray')
plt.show()