In [1]:
import jax
from jax import jit, vmap
import torch
import jax.numpy as jnp
from functools import partial
import matplotlib.pyplot as plt

In [2]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="0.95" 

In [3]:
torch.cuda.is_available()

True

In [4]:
jax.devices()

[cuda(id=0)]

In [5]:
key = jax.random.PRNGKey(2)

2024-06-04 13:26:54.931660: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [6]:
import equinox as eqx

In [7]:
from vae_jax import VAE

In [8]:
vae, vae_states = eqx.nn.make_with_state(VAE)(256, key)
init_vae_params, vae_static = eqx.partition(vae, eqx.is_inexact_array)
# x = jnp.zeros((1, 3, 256, 256)) # we must have a batch_dim, the model can only be applied afetr vmap because of BN layer
# batch_model = jax.vmap(model, in_axes=(0, None, None, None), out_axes=(0, None), axis_name="batch")
# x, state = batch_model(x, state, key, True)

In [9]:
from utils import get_train_dataloader, get_test_dataloader

In [10]:
img_size = 256
batch_size = 16

In [11]:
from torch.utils.data import Dataset, DataLoader
from datasets import LivestockTrainDataset
train_dataset = LivestockTrainDataset(
    img_size,
    fake_dataset_size=50000,
)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [12]:
@partial(jit, static_argnums=(1, 6))
def loss(params, static, states, x, key, beta, train=True):
    """
    Parameters
    ----------
    params
        XXX
    static
        XXX
    states
        XXX
    x
        A batch of inputs
    key
        A JAX PRNGKey
    """
    vae = eqx.combine(params, static)
    if train:
        # make sure we are in training mode
        vae = eqx.nn.inference_mode(vae, value=False)
    else:
        vae = eqx.nn.inference_mode(vae)
    batched_vae = vmap(vae, in_axes=(0, None, None, None), out_axes=(0,  None, 0, 0), axis_name="batch")

    key, subkey = jax.random.split(key, 2)

    x_rec, states, mu, logvar = batched_vae(x, states, key, train)
    print(x_rec)
    batched_elbo = vmap(vae.elbo, in_axes=(0, 0, 0, 0, None), out_axes=(0, 0, 0))

    elbo, rec_term, kld = batched_elbo(x_rec, x, mu, logvar, beta)
    elbo = jnp.mean(elbo) # avg over the batches
    rec_term = jnp.mean(rec_term)
    kld = jnp.mean(kld)

    #x_rec = VAE.mean_from_lambda(x_rec)

    return -elbo, (x_rec, rec_term, kld, states)

Test the loss on one mini-batch

In [13]:
mini_batch = next(iter(train_dataloader))
loss_value, (_, _, _, vae_states) = loss(init_vae_params, vae_static, vae_states, mini_batch[0].numpy(), key, 1., train=True)

State(
  0x87f3c8=bool[],
  0x87f3e8=(f32[64], f32[64]),
  0x87f408=bool[],
  0x87f428=(f32[64], f32[64]),
  0x87f448=bool[],
  0x87f468=(f32[64], f32[64]),
  0x87f488=bool[],
  0x87f4a8=(f32[64], f32[64]),
  0x87f4c8=bool[],
  0x87f4e8=(f32[64], f32[64]),
  0x87f508=bool[],
  0x87f528=(f32[128], f32[128]),
  0x87f548=bool[],
  0x87f568=(f32[128], f32[128]),
  0x87f588=bool[],
  0x87f5a8=(f32[128], f32[128]),
  0x87f5c8=bool[],
  0x87f5e8=(f32[128], f32[128]),
  0x87f608=bool[],
  0x87f628=(f32[128], f32[128]),
  0x87f648=bool[],
  0x87f668=(f32[256], f32[256]),
  0x87f688=bool[],
  0x87f6a8=(f32[256], f32[256]),
  0x87f6c8=bool[],
  0x87f6e8=(f32[256], f32[256]),
  0x87f708=bool[],
  0x87f728=(f32[256], f32[256]),
  0x87f748=bool[],
  0x87f768=(f32[256], f32[256]),
  0x87f788=bool[],
  0x87f7a8=(f32[512], f32[512]),
  0x87f7c8=bool[],
  0x87f7e8=(f32[512], f32[512]),
  0x87f808=bool[],
  0x87f828=(f32[512], f32[512]),
  0x87f848=bool[],
  0x87f868=(f32[512], f32[512]),
  0x87f888=bool

2024-06-04 13:27:05.980566: W external/tsl/tsl/framework/bfc_allocator.cc:291] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.02GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-06-04 13:27:06.994681: W external/tsl/tsl/framework/bfc_allocator.cc:291] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.02GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-06-04 13:27:11.164617: W external/tsl/tsl/framework/bfc_allocator.cc:291] Allocator (GPU_0_bfc) ran out of memory trying to allocate 10.02GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-06-04 13:27:13.725578: E external/xla/xla/service/slow_operation_alarm.cc:65] T

In [14]:
import copy
init_params = init_vae_params
init_states = vae_states
static = vae_static

In [15]:
vae = eqx.combine(init_params, static)

In [16]:
import optax

n_steps = 100
learning_rate = 1e-3
print_every = 1
beta = jnp.array(0.00001)

optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(init_params)

transition_steps = 100
schedule_fn = optax.linear_schedule(init_value=0.1, end_value=10, transition_steps=transition_steps)
beta_scheduler = optax.scale_by_schedule(schedule_fn)
beta_sche_state = beta_scheduler.init(beta)

In [17]:
def train(n_steps, params, static, states, train_loader, opt_state, beta, key, print_every, beta_sche_state=None):

    def infinite_train_loader():
        while True:
            yield from train_loader

    def make_step(x, params, states, opt_state, beta, beta_sche_state, i, key):
        key, subkey = jax.random.split(key)
        if beta_sche_state is not None and i < transition_steps:
            updates, beta_sche_state = beta_scheduler.update(jnp.ones_like(beta), beta_sche_state, beta)
            beta = optax.apply_updates(beta, updates)
        (minus_elbo, aux_loss), grads = jax.value_and_grad(loss, has_aux=True)(params, static, states, x, subkey, train=True, beta=beta)
        x_rec, rec_term, kld, states = aux_loss
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, states, opt_state, key, x_rec, (minus_elbo, rec_term, kld), beta

    elbo_list = []
    rec_term_list = []
    kld_list = []

    beta_list = []


    for step, (x, _) in zip(range(n_steps), infinite_train_loader()): 
        x = x.numpy()

        params, states, opt_state, key, x_rec, losses, beta = make_step(x, params, states, opt_state, beta, beta_sche_state, step, key)
        elbo_list.append(-losses[0])
        rec_term_list.append(losses[1])
        kld_list.append(losses[2])
        beta_list.append(beta)
        
        if (step % print_every) == 0 or (step == n_steps - 1):
            print(
                f"{step=}, elbo_loss={elbo_list[-1]}, rec_term={rec_term_list[-1]}, kld_term={kld_list[-1]}, beta={beta_list[-1]}"
            )
        
            
    return params, states, (elbo_list, rec_term_list, kld_list), opt_state

In [18]:
key, subkey = jax.random.split(key)
final_params, final_states, loss_lists, opt_state = train(
    n_steps, init_params, static, init_states, train_dataloader, opt_state, beta, key, print_every=print_every, beta_sche_state=None
)

State(
  0x87f3c8=bool[],
  0x87f3e8=(f32[64], f32[64]),
  0x87f408=bool[],
  0x87f428=(f32[64], f32[64]),
  0x87f448=bool[],
  0x87f468=(f32[64], f32[64]),
  0x87f488=bool[],
  0x87f4a8=(f32[64], f32[64]),
  0x87f4c8=bool[],
  0x87f4e8=(f32[64], f32[64]),
  0x87f508=bool[],
  0x87f528=(f32[128], f32[128]),
  0x87f548=bool[],
  0x87f568=(f32[128], f32[128]),
  0x87f588=bool[],
  0x87f5a8=(f32[128], f32[128]),
  0x87f5c8=bool[],
  0x87f5e8=(f32[128], f32[128]),
  0x87f608=bool[],
  0x87f628=(f32[128], f32[128]),
  0x87f648=bool[],
  0x87f668=(f32[256], f32[256]),
  0x87f688=bool[],
  0x87f6a8=(f32[256], f32[256]),
  0x87f6c8=bool[],
  0x87f6e8=(f32[256], f32[256]),
  0x87f708=bool[],
  0x87f728=(f32[256], f32[256]),
  0x87f748=bool[],
  0x87f768=(f32[256], f32[256]),
  0x87f788=bool[],
  0x87f7a8=(f32[512], f32[512]),
  0x87f7c8=bool[],
  0x87f7e8=(f32[512], f32[512]),
  0x87f808=bool[],
  0x87f828=(f32[512], f32[512]),
  0x87f848=bool[],
  0x87f868=(f32[512], f32[512]),
  0x87f888=bool

2024-06-04 13:27:31.331088: W external/xla/xla/service/hlo_rematerialization.cc:2948] Can't reduce memory use below -1.35GiB (-1452923329 bytes) by rematerialization; only reduced to 6.43GiB (6904438181 bytes), down from 6.43GiB (6904438181 bytes) originally
2024-06-04 13:27:41.906927: W external/tsl/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 80.00MiB (rounded to 83886080)requested by op 
2024-06-04 13:27:41.907715: W external/tsl/tsl/framework/bfc_allocator.cc:494] ***************************************************************************************************x
E0604 13:27:41.908344   57459 pjrt_stream_executor_client.cc:2826] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 83886080 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   15.62MiB
              constant allocation:         1B
        maybe_live_out allocation:    5.03GiB
     prea

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 83886080 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   15.62MiB
              constant allocation:         1B
        maybe_live_out allocation:    5.03GiB
     preallocated temp allocation:    1.28GiB
  preallocated temp fragmentation:         0B (0.00%)
                 total allocation:    6.33GiB
              total fragmentation: 1003.42MiB (15.49%)
Peak buffers:
	Buffer 1:
		Size: 1.28GiB
		Operator: op_name="jit(loss)/jit(main)/vmap(vmap(eqx.nn.ConvTranspose))/reshape[new_sizes=(160, 32, 128, 128) dimensions=None]" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=88
		XLA Label: fusion
		Shape: f32[160,32,259,259]
		==========================

	Buffer 2:
		Size: 320.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(vmap(jit(relu)))/max" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=87
		XLA Label: fusion
		Shape: f32[16,10,32,128,128]
		==========================

	Buffer 3:
		Size: 320.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(vmap(jit(relu)))/max" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=87
		XLA Label: fusion
		Shape: f32[160,32,128,128]
		==========================

	Buffer 4:
		Size: 320.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(vmap(jit(relu)))/max" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=87
		XLA Label: fusion
		Shape: f32[16,10,32,128,128]
		==========================

	Buffer 5:
		Size: 320.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(vmap(jit(relu)))/max" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=87
		XLA Label: fusion
		Shape: f32[16,10,32,128,128]
		==========================

	Buffer 6:
		Size: 320.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(vmap(eqx.nn.BatchNorm))/sub" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=86
		XLA Label: fusion
		Shape: f32[160,32,128,128]
		==========================

	Buffer 7:
		Size: 320.00MiB
		XLA Label: fusion
		Shape: f32[16,10,32,128,128]
		==========================

	Buffer 8:
		Size: 160.00MiB
		Operator: op_name="jit(loss)/jit(main)/add" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=98
		XLA Label: fusion
		Shape: f32[16,10,256,32,32]
		==========================

	Buffer 9:
		Size: 160.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(vmap(jit(relu)))/max" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=84
		XLA Label: fusion
		Shape: f32[16,10,64,64,64]
		==========================

	Buffer 10:
		Size: 160.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(vmap(jit(relu)))/max" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=84
		XLA Label: fusion
		Shape: f32[160,64,64,64]
		==========================

	Buffer 11:
		Size: 160.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(vmap(jit(relu)))/max" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=84
		XLA Label: fusion
		Shape: f32[16,10,64,64,64]
		==========================

	Buffer 12:
		Size: 160.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(vmap(jit(relu)))/max" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=84
		XLA Label: fusion
		Shape: f32[16,10,64,64,64]
		==========================

	Buffer 13:
		Size: 160.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(vmap(eqx.nn.BatchNorm))/sub" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=83
		XLA Label: fusion
		Shape: f32[160,64,64,64]
		==========================

	Buffer 14:
		Size: 160.00MiB
		XLA Label: fusion
		Shape: f32[16,10,64,64,64]
		==========================

	Buffer 15:
		Size: 120.00MiB
		Operator: op_name="jit(loss)/jit(main)/reduce_sum[axes=(1, 2, 3, 4)]" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=180
		XLA Label: fusion
		Shape: f32[16,10,3,256,256]
		==========================



In [None]:
plt.plot(loss_lists[0], label="elbo")
plt.plot(loss_lists[1], label="rec_term")
plt.plot(loss_lists[2], label="kld")
plt.legend()
plt.show()

In [None]:
from datasets import LivestockTestDataset
test_dataset = LivestockTestDataset(
    img_size,
    fake_dataset_size=1024,
)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
x_test = next(iter(test_dataloader))
x_test =  x_test[0].numpy()
key, subkey = jax.random.split(key)
(_, aux_loss) = loss(final_params, static, final_states, x_test, subkey, train=False, beta=beta)
x_rec_test = aux_loss[0]

In [None]:
figure, axes = plt.subplots(2, 4)
axes[0, 0].imshow(jnp.moveaxis(x_test[0],0, 2))
axes[1, 0].imshow(jnp.moveaxis(x_rec_test[0], 0, 2))
plt.show()