In [17]:
!git clone -b IWELBO https://ghp_vrZ0h7xMpDhgmRaoktLtensorflowwUiFRqWACaj1dcqzL@github.com/albertaillet/vnca.git

Cloning into 'vnca'...
Password for 'https://ghp_vrZ0h7xMpDhgmRaoktLtensorflowwUiFRqWACaj1dcqzL@github.com': 

In [6]:
%%capture
%pip install --upgrade numpy equinox einops optax distrax wandb tensorflow tensorflow_probability jax jaxlib

In [7]:
import os

if 'TPU_NAME' in os.environ:
    import requests

    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1

    from jax.config import config

    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
    print('No TPU detected. Can be changed under "Runtime/Change runtime type".')

Registered TPU: grpc://10.0.0.2:8470


In [13]:
%cd /kaggle/working/vncays

[Errno 2] No such file or directory: '/kaggle/working/vnca'
/kaggle/working


In [10]:
# Imports
import equinox as eqx
import jax.numpy as np
from jax.random import PRNGKey, split, permutation
from jax import vmap, pmap, local_device_count, lax, local_devices, device_put_replicated, device_put, tree_map, devices, jit, nn
from jax.scipy.special import logsumexp
from einops import rearrange, repeat
from optax import adam, exponential_decay, clip_by_global_norm, chain
import matplotlib.pyplot as plt
from functools import partial
from tqdm import tqdm
from distrax import Normal, Bernoulli
from data.mnist import load_mnist_train_on_tpu, indicies_tpu_iterator, get_mnist

from models import BaselineVAE,DoublingVNCA

# Typing
from jax import Array
from equinox import Module
from typing import Optional, Any
from jax.random import PRNGKeyArray
from optax import GradientTransformation
from typing import Tuple

TARGET_SIZE = 28
MODEL_KEY = PRNGKey(0)
DATA_KEY = PRNGKey(1)

ModuleNotFoundError: No module named 'data'

In [None]:
@eqx.filter_value_and_grad
def loss_fn(model: Module, x: Array, key: PRNGKeyArray) -> float:
    keys = split(key, x.shape[0])
    recon_x, mean, logvar = vmap(model)(x, keys)
    kl_loss = np.sum(Normal(mean, np.exp(1 / 2 * logvar)).kl_divergence(Normal(0, 1)), axis=1)
    recon_loss = -np.sum(Bernoulli(logits=recon_x).log_prob(x), axis=(1, 2, 3))
    return np.mean(recon_loss + kl_loss)

@eqx.filter_value_and_grad
def iwelbo_loss(model: Module, x: Array, key: PRNGKeyArray, M: int = 1) -> float:
    keys = split(key, x.shape[0])
    recon_x, mean, logvar = vmap(model)(x, key=keys)
    kl = np.sum(Normal(mean, np.exp(1 / 2 * logvar)).kl_divergence(Normal(0, 1)), axis=1)
    like = np.sum(Bernoulli(logits=recon_x).log_prob(x), axis=(2, 3, 4))
    iw_loss = np.mean(logsumexp(-kl + like,axis=0)-np.log(M))
    return -iw_loss
    
    
@partial(pmap, axis_name='num_devices', static_broadcasted_argnums=(3,6), out_axes=(None, 0, 0))
def make_step(data: Array, index: Array, params, static, key: PRNGKeyArray, opt_state: tuple, optim: GradientTransformation) -> Tuple[float, Module, Any]:
    def step(carry, index):
        params, opt_state, key = carry
        x = data[index]
        key, subkey = split(key)

        model = eqx.combine(params, static)
        loss, grads = iwelbo_loss(model, x, subkey)
        loss = lax.pmean(loss, axis_name='num_devices')
        grads = lax.pmean(grads, axis_name='num_devices')

        updates, opt_state = optim.update(grads, opt_state)
        params = eqx.apply_updates(params, updates)
        return (params, opt_state, key), loss

    (params, opt_state, key), loss = lax.scan(step, (params, opt_state, key), index)
    return loss, params, opt_state

In [None]:
# Create model and define parameters
model = DoublingVNCA(key=MODEL_KEY)
n_tpus = local_device_count()
device = local_devices()
data = load_mnist_train_on_tpu(devices=device)

In [None]:
WANDB_MODE = "offline"
import wandb
wandb.init(project="vnca", entity="albertaillet",mode=WANDB_MODE)

wandb.config.model_type = model.__class__.__name__
wandb.config.batch_size = 32 // n_tpus
wandb.config.n_gradient_steps = 1000
wandb.config.n_tpus = n_tpus
wandb.config.lr = 4e-5
# wandb.config.lr_init_value = 3e-4 # when using exponential_decay
# wandb.config.lr_transition_steps = 100_000
# wandb.config.lr_decay_rate = 0.3
# wandb.config.lr_staircase = True
wandb.config.grad_norm_clip = 10.0
wandb.config.l = 250
wandb.config.model_key = MODEL_KEY
wandb.config.data_key = DATA_KEY


In [None]:
train_keys = split(DATA_KEY, wandb.config.n_gradient_steps * n_tpus)

train_keys = rearrange(train_keys, "(n t) k -> n t k", t=n_tpus, n=wandb.config.n_gradient_steps)

params, static = eqx.partition(model, eqx.is_array)

opt = chain(adam(wandb.config.lr), clip_by_global_norm(wandb.config.grad_norm_clip))
opt_state = opt.init(params)

params = device_put_replicated(params, device)
opt_state = device_put_replicated(opt_state, device)

pbar = tqdm(
    zip(indicies_tpu_iterator(n_tpus, wandb.config.batch_size, data.shape[1], wandb.config.n_gradient_steps, DATA_KEY, wandb.config.l), train_keys),
    total=wandb.config.n_gradient_steps,
)

for i, key in pbar:
    loss, params, opt_state = make_step(data, i, params, static, key, opt_state, opt)
    pbar.set_postfix({'loss': f"{np.mean(loss):.3}"})
    wandb.log({'loss': float(np.mean(loss))})

model = eqx.combine(tree_map(partial(np.mean, axis=0), params), static)

In [None]:
fig = data[0][201]
plt.imshow(nn.sigmoid(model(fig, DATA_KEY)[0][0]), cmap='gray')
plt.show()
fig = np.pad(fig, ((0, 0), (2, 2), (2, 2)))
plt.imshow(fig[0], cmap='gray')
plt.show()

In [None]:
test_dataset = MNIST(root='./data', train=False, download=True, transform=None)
test_dataset = np.array(np.float32(test_dataset.data / 255.0))
test_dataset = rearrange(test_dataset, 'n h w -> n 1 h w')

In [None]:
fig = test_dataset[400]
plt.imshow(nn.sigmoid(model(fig, DATA_KEY)[0][0]), cmap='gray')
plt.show()
plt.imshow(fig[0], cmap='gray')
plt.show()