In [1]:
import itertools
import jax
from jax import jit, vmap
import torch
import jax.numpy as jnp
from functools import partial
import matplotlib.pyplot as plt
from math import prod
from jax.experimental import jet

In [2]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="0.95" 

In [3]:
torch.cuda.is_available()

True

In [4]:
jax.devices()

[cuda(id=0)]

In [5]:
key = jax.random.PRNGKey(2)

2024-06-12 14:06:17.741838: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [6]:
import equinox as eqx

In [7]:
from vae_jax import VAE

In [8]:
img_size = 256
batch_size = 2
latent_img_size = 32
z_dim = 256

In [9]:
vae, vae_states = eqx.nn.make_with_state(VAE)(img_size, latent_img_size, z_dim, key)
init_vae_params, vae_static = eqx.partition(vae, eqx.is_inexact_array)
# x = jnp.zeros((1, 3, 256, 256)) # we must have a batch_dim, the model can only be applied afetr vmap because of BN layer
# batch_model = jax.vmap(model, in_axes=(0, None, None, None), out_axes=(0, None), axis_name="batch")
# x, state = batch_model(x, state, key, True)

In [10]:
from utils import get_train_dataloader, get_test_dataloader

**Note:** if we were doing this gradient based search of beta without using a validation dataset for the outer loss, should we expect to have a beta estimated to be 0 as it would be optimal to overfit and kill the regularization term right ? -> this seems to be the case emprically

In [11]:
from torch.utils.data import Dataset, DataLoader
from datasets import LivestockTrainDataset
train_dataset = LivestockTrainDataset(
    img_size,
    fake_dataset_size=1500,
)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [12]:
val_dataset = LivestockTrainDataset(
    img_size,
    fake_dataset_size=1500,
    offset_idx=train_dataset.__len__()
)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [13]:
@partial(jit, static_argnums=(2, 6))
def loss(params, beta, static, states, x, key, train=True):
    """
    Parameters
    ----------
    params
        XXX
    static
        XXX
    states
        XXX
    x
        A batch of inputs
    key
        A JAX PRNGKey
    """
    vae = eqx.combine(params, static)
    if train:
        # make sure we are in training mode
        vae = eqx.nn.inference_mode(vae, value=False)
    else:
        vae = eqx.nn.inference_mode(vae)
    batched_vae = vmap(vae, in_axes=(0, None, None, None), out_axes=(0,  None, 0, 0), axis_name="batch")

    key, subkey = jax.random.split(key, 2)

    x_rec, states, mu, logvar = batched_vae(x, states, key, train)
    batched_elbo = vmap(vae.elbo, in_axes=(0, 0, 0, 0, None), out_axes=(0, 0, 0))

    elbo, rec_term, kld = batched_elbo(x_rec, x, mu, logvar, beta)

    elbo = jnp.mean(elbo) # avg over the batches
    rec_term = jnp.mean(rec_term)
    kld = jnp.mean(kld)

    x_rec = VAE.mean_from_lambda(x_rec)

    # elbo = jnp.array(0.)
    # rec_term = jnp.array(0.)
    # kld = jnp.array(0.)
    # x_rec = jnp.array(0.)
    # logvar = jnp.array(0.)
    # mu = jnp.array(0.)

    return -elbo, (x_rec, rec_term, kld, states, mu, logvar)

Test the loss on one mini-batch

In [14]:
mini_batch = next(iter(train_dataloader))
#loss_value, (_, _, _, vae_states, _, _) = loss(init_vae_params, 1., vae_static, vae_states, mini_batch[0].numpy(), key, train=True)

In [15]:
import copy
init_params = init_vae_params
init_states = vae_states
static = vae_static

In [16]:
vae = eqx.combine(init_params, static)

In [17]:
import optax

n_steps_inner = 30
n_steps_outer = 100

lr_inner = 1e-3
lr_outer = 1e-2

print_every = 1

init_beta = jnp.array(0.1)

optimizer_inner = optax.adam(lr_inner)
opt_state_inner = optimizer_inner.init(init_params)
optimizer_outer = optax.adam(lr_outer)
opt_state_outer = optimizer_outer.init(init_beta)

In [18]:
def train_inner(params, beta, static, states, train_data, opt_state, key, print_every):
    """
    must not necesarily be jittable since we a priori only need to get grad
    but we provide a jittable train_inner for speed
    """

    def make_step(carry, x):
        params, beta, states, opt_state, key = carry
        key, subkey = jax.random.split(key)
        (minus_elbo, aux_loss), grads = jax.value_and_grad(loss, argnums=0, has_aux=True)(params, beta, static, states, x, subkey, train=True)
        _, rec_term, kld, states, _, _ = aux_loss
        updates, opt_state = optimizer_inner.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return (params, beta, states, opt_state, key), jnp.array([minus_elbo, rec_term, kld])

    carry_init = (
        params, beta, states, opt_state, key
    )
    carry, losses = jax.lax.scan(
        make_step,
        carry_init,
        train_data
    )
    params = carry[0]
    states = carry[2]
    opt_state = carry[3]
    return params, beta, states, (losses[:, 0], losses[:, 1], losses[:, 2]), opt_state

In [19]:
@partial(jax.custom_jvp, nondiff_argnums=(0, 2, 3, 4, 5, 6)) # we will not differentiate wrt something else than beta
def train_inner_(params, beta, static, states, inner_train_data, opt_state_inner, key):
    return train_inner(params, beta, static, states, inner_train_data, opt_state_inner, key, print_every)
    
@train_inner_.defjvp
def train_inner_jvp(params0, static, states, x_data, opt_state_inner, key, primals, tangents):
    """
    Note that nondiff arg placed at the start of the signature of the corresponding JVP rule
    """
    print("here")
    beta0, = primals
    v, = tangents
    key, subkey = jax.random.split(key)
    # # compute x_0(beta_0) (for a given beta_0). The gives the couple (x0, beta_0) around which we are authorized to use th IFT formula
    # params_0_beta, _, _, (_, _, _), _ = train_inner_(
    #     params, beta, static, states, inner_train_data, opt_state_inner, subkey
    # )
    # subkeys not considered below since train=False
    Ax = lambda x:-jax.jacfwd(lambda params_:jax.grad(loss, argnums=0, has_aux=True)(params_, beta0, static, states, x_data, subkey, train=False)[0])(params0) @ x
    b = jax.jacfwd(lambda beta_:jax.grad(loss, argnums=0, has_aux=True)(params0, beta_, static, states, x_data, subkey, train=False)[0])(beta0) # second diff wrt lambda
    #b = jax.jacfwd(lambda beta_:jax.grad(loss, argnums=0, has_aux=True)(params0, beta_, static, states, x_data, subkey, train=False))(b'0')# second diff wrt lambda

    return params0, jax.scipy.sparse.linalg.cg(Ax, b) * v

In [96]:
def train_outer(n_steps_outer, n_steps_inner, params, beta, static, states, train_loader, val_loader, opt_state_outer, opt_state_inner, key, print_every):
    """
    """
    
    def infinite_val_loader():
        while True:
            yield from val_loader

    def make_step(outer_val_batch, params, beta0, states, inner_train_data, opt_state_outer, opt_state_inner, key):
        key, subkey = jax.random.split(key)

        # this call to train_inner_ gives a x_0, theta_0 (the resulting couple (params, beta)
        # this couple is a root of F(=dELBO(params, beta)/dparams) (for a fixed theta_0 (beta)),
        # using a SGD that starts at a x_init (params)
        params0, _, states, (elbo_list, rec_term_list, kld_list), opt_state_inner = train_inner_(
            params, beta0, static, states, inner_train_data, opt_state_inner, subkey
        )
        # ... around this x_0, theta_0 we know that we have the right to compute dx*(theta)/dtheta and we have done so
        # in the jvp defined in the previous cells

        # key, subkey = jax.random.split(key)
        # grads_inner = jax.jacfwd(train_inner_, argnums=1)(params, beta, static, states, x, opt_state_inner, subkey)


        # Reconstruct the gradient here ?
        # Ax = lambda x:-(jax.jacfwd(
        #         lambda params_:jax.grad(
        #             loss, argnums=0, has_aux=True
        #         )(params_, beta0, static, states, outer_val_batch, subkey, train=False)[0]
        #     )(params0) @ x)
        # def mul_(x, y):
        #     print(x)
        #     print(y)
        #     return x @ y
        # Ax  = lambda x:jax.tree.map(mul_, #lambda x_, y:x_ @ y,
        #         x,
        #         jax.jacfwd(
        #             lambda params_:jax.grad(
        #                 loss, argnums=0, has_aux=True
        #             )(params_, beta0, static, states, outer_val_batch, subkey, train=False)[0]
        #         )(params0),
        #         is_leaf=lambda x_: eqx.is_inexact_array(x_))
        # Ax = 

        def Ax(x):
            # trying jax.jet with https://github.com/google/jax/discussions/9598

            loss_ = lambda p: loss(p, beta0, static, states, outer_val_batch, subkey, train=False)[0]
            def loss_wrapper(*true_params, tree):
                """To be able to differentiate the certain elements of the pytree
                https://github.com/google/jax/discussions/12765

                true_params is a list of the arrays with respect to which we want to differentiate inside the eqx.Module
                defined by params0

                Note that all that goes after *args are keyword only arguments
                """
                p = jax.tree.unflatten(tree, true_params)
                # p = jax.tree_util.tree_map_with_path(
                #     lambda kp, p: dict_path_true_params[kp], # kp will nec be in dict_path_true_params because of the following is_leaf function
                #     params0,
                #     is_leaf=eqx.is_inexact_array
                # )
                return loss_(p)

            # params0_fl = jax.tree.map(
            #     lambda x_: x_.flatten(),
            #     params0,
            #     is_leaf=eqx.is_inexact_array
            # )
            true_params, tree = jax.tree_util.tree_flatten(params0) # in the children non static content (param_fl) only the jnp arrays go
            # all None and integer constant go to tree so a tree_flatten suffices to get the true params
            # we really optimize upon
            

            # true_params_identity_filled = jax.tree.map(
            #     lambda x_:jnp.eye(x_.shape[0]),
            #     true_params,
            # )
            # print(true_params)
            # jet.jet(loss_, (true_params,), ((true_params_identity_filled,)))[1][0]
            # fs
            # # jet.jet(fun, (x,), ((v, jnp.zeros_like(x)),))[1][1]
            # #     return jnp.sum(hvv(jnp.eye(x.shape[0], dtype=x.dtype)))
            
            # jet.jet(loss_, (params0,), )[1][1]

            # # dL_dtheta = lambda params_:jax.grad(
            # #         loss, argnums=0, has_aux=True
            # #     )(params_, beta0, static, states, outer_val_batch, subkey, train=False)[0]
            
            
            pytree_grad = [
                jax.jacfwd(
                    lambda *params_:jax.grad(loss_wrapper, argnums=i)(*params_, tree=tree),
                    argnums=i
                )(*true_params)
                for i in range(len(true_params))
            ]
            fs
            pytree_grad = jax.hessian(
                loss, argnums=0, has_aux=True
            )(params0, beta0, static, states, outer_val_batch, subkey, train=False)[0]

            # for each inexact field (x_) of the VAE params we only compute the @ (and let a non None value)
            # with the array (y_) at the same position of the nested VAE params located (at the x_ inexact field of the outer VAE params)
            # this can be seen as retrieving the diagonal of the Hessian matrix
            nested_vaes = jax.tree_util.tree_map_with_path(
                lambda key_path_x_, x_, y: jax.tree_util.tree_map_with_path(
                    lambda key_path_y_, y_: (y_.reshape((prod(x_.shape), prod(x_.shape))) @ x_.flatten()).reshape(x_.shape) if key_path_x_ == key_path_y_ else None,
                    y,
                    is_leaf=eqx.is_inexact_array
                ),
                x,
                pytree_grad,
                is_leaf=eqx.is_inexact_array
            )
                
            outer_vae = jax.tree_util.tree_map_with_path(
                lambda key_path_x_, x_, y: jax.tree.leaves(y, is_leaf=eqx.is_inexact_array)[0],
                x,
                nested_vaes,
                is_leaf=eqx.is_inexact_array
            )
            #outer_vae = jax.tree.leaves(nested_vaes, is_leaf=eqx.is_inexact_array)
            #print(outer_vae)
                
            #print(outer_vae)
            return outer_vae
                
            
        b = jax.jacfwd(lambda beta_:jax.grad(loss, argnums=0, has_aux=True)(params0, beta_, static, states, outer_val_batch, subkey, train=False)[0])(beta0)
        #b = jax.tree.leaves(b, is_leaf=eqx.is_inexact_array)
        #print("B", b)
        grads_inner = jax.scipy.sparse.linalg.cg(Ax, b)
        print(grads_inner.shape)


        key, subkey = jax.random.split(key)
        # Now we want to compute the ELBO for the outer loss and take gradient wrt to theta here.
        # We have defined a rule to backpropagate through beta i.e.
        # through train_inner (dx*(theta)/dtheta) in the previous cells
        (minus_elbo, aux_loss), grads_outer = jax.value_and_grad(loss, argnums=1, has_aux=True)(params, beta, static, states, outer_val_batch, subkey, train=False)
        x_rec, rec_term, kld, _, _, _ = aux_loss
        print(grads_outer.shape)
        grads = grads_outer @ grads_inner
        print(grads.shape)
        updates, opt_state_outer = optimizer_outer.update(grads, opt_state_outer, beta)
        beta = optax.apply_updates(beta, updates)
        return params, beta, states, opt_state_outer, opt_state_inner, key, x_rec, (elbo_list, rec_term_list, kld_list)

    elbo_list = []
    rec_term_list = []
    kld_list = []
    beta_list = []


    for step, (x, _) in zip(range(n_steps_outer), infinite_val_loader()): 
        x = x.numpy()
        inner_train_data = jnp.asarray(
            list(map(lambda x:x[0].numpy(), list(itertools.islice(train_loader, n_steps_inner)))) # get the next n_steps_inner elements from train_loader
        )
        params, beta, states, opt_state_outer, opt_state_inner, key, x_rec, losses = make_step(
            x, params, beta, states, inner_train_data, opt_state_outer, opt_state_inner, key
        )
        elbo_list.extend(-losses[0])
        rec_term_list.extend(losses[1])
        kld_list.extend(losses[2])
        # elbo_list.append(-losses[0])
        # rec_term_list.append(losses[1])
        # kld_list.append(losses[2])
        beta_list.append(jnp.full((n_steps_inner,), beta))
        
        if (step % print_every) == 0 or (step == n_steps - 1):
            print(
                f"{step=}, elbo_loss={elbo_list[-1]}, rec_term={rec_term_list[-1]}, kld_term={kld_list[-1]}, beta={beta_list[-1][0]}"
            )
            
    return params, beta, states, (elbo_list, rec_term_list, kld_list), (opt_state_outer, opt_state_inner)

In [97]:
jax.tree.map(lambda x, y: jax.tree.map(lambda y_:x+y_, y, is_leaf=lambda x:(isinstance(x, float))),
             (0., 0., 0.),
             ((1., None), (2., None), (3., None, 4)),
             is_leaf=lambda x:(isinstance(x, float)))# and isinstance(y, tuple)))

((1.0, None), (2.0, None), (3.0, None, 4.0))

In [98]:
key, subkey = jax.random.split(key)
final_params, final_beta, final_states, loss_lists, _ = train_outer(
    n_steps_outer, n_steps_inner, init_params, init_beta, static, init_states, train_dataloader, val_dataloader, opt_state_outer, opt_state_inner, key, print_every
)

NameError: name 'fs' is not defined

In [None]:
plt.plot(loss_lists[0], label="elbo")
plt.plot(loss_lists[1], label="rec_term")
plt.plot(loss_lists[2], label="kld")
plt.legend()
plt.show()

In [None]:
from datasets import LivestockTestDataset
test_dataset = LivestockTestDataset(
    img_size,
    fake_dataset_size=1024,
)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
x_test = next(iter(test_dataloader))
x_test =  x_test[0].numpy()

In [None]:
key, subkey = jax.random.split(key)
(_, aux_loss) = loss(final_params, static, final_states, x_test, subkey, train=False, beta=beta)
x_rec_test = aux_loss[0]

vae_mu = aux_loss[-1]
mad = jnp.mean(jnp.abs(vae_mu - jnp.mean(vae_mu, axis=(1), keepdims=True)), axis=(1)) # mean on latent dims for all batches

In [None]:
figure, axes = plt.subplots(2, 4)
axes[0, 0].imshow(jnp.moveaxis(x_test[1],0, 2))
axes[1, 0].imshow(jnp.moveaxis(x_rec_test[1], 0, 2))
axes[0, 1].imshow(mad[1])
plt.show()