In [1]:
import jax
from jax import jit, vmap
import torch
import jax.numpy as jnp
from functools import partial
import matplotlib.pyplot as plt

In [2]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="0.95" 

In [3]:
torch.cuda.is_available()

True

In [4]:
jax.devices()

[cuda(id=0)]

In [5]:
key = jax.random.PRNGKey(2)

2024-06-03 17:49:47.102242: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [6]:
import equinox as eqx

In [7]:
from vae_jax import VAE

In [8]:
vae, vae_states = eqx.nn.make_with_state(VAE)(256, key)
init_vae_params, vae_static = eqx.partition(vae, eqx.is_inexact_array)
# x = jnp.zeros((1, 3, 256, 256)) # we must have a batch_dim, the model can only be applied afetr vmap because of BN layer
# batch_model = jax.vmap(model, in_axes=(0, None, None, None), out_axes=(0, None), axis_name="batch")
# x, state = batch_model(x, state, key, True)

In [9]:
from utils import get_train_dataloader, get_test_dataloader

In [10]:
img_size = 256
batch_size = 16

In [11]:
from torch.utils.data import Dataset, DataLoader
from datasets import LivestockTrainDataset
train_dataset = LivestockTrainDataset(
    img_size,
    fake_dataset_size=50000,
)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [12]:
@partial(jit, static_argnums=(1, 5, 6))
def loss(params, static, states, x, key, train=True, beta=1.):
    """
    Parameters
    ----------
    params
        XXX
    static
        XXX
    states
        XXX
    x
        A batch of inputs
    key
        A JAX PRNGKey
    """
    vae = eqx.combine(params, static)
    if train:
        # make sure we are in training mode
        vae = eqx.nn.inference_mode(vae, value=False)
    else:
        vae = eqx.nn.inference_mode(vae)
    batched_vae = vmap(vae, in_axes=(0, None, None, None), out_axes=(0, None, 0, 0), axis_name="batch")

    key, subkey = jax.random.split(key, 2)

    x_rec, state, mu, logvar = batched_vae(x, states, key, train)
    print(x_rec)
    batched_elbo = vmap(vae.elbo, in_axes=(0, 0, 0, 0, None), out_axes=(0, 0, 0))

    elbo, rec_term, kld = batched_elbo(x_rec, x, mu, logvar, beta)
    elbo = jnp.mean(elbo) # avg over the batches
    rec_term = jnp.mean(rec_term)
    kld = jnp.mean(kld)

    x_rec = VAE.mean_from_lambda(jnp.mean(x_rec, axis=1))

    return -elbo, (x_rec, rec_term, kld, state)

Test the loss on one mini-batch

In [13]:
mini_batch = next(iter(train_dataloader))
loss_value, (_, _, _, vae_states) = loss(init_vae_params, vae_static, vae_states, mini_batch[0].numpy(), key, train=True)

(50, 256, 32, 32)
Traced<ShapedArray(float32[16,50,3,256,256])>with<DynamicJaxprTrace(level=1/0)>


In [14]:
import copy
init_params = init_vae_params
init_states = vae_states
static = vae_static

In [15]:
vae = eqx.combine(init_params, static)

In [16]:
import optax

n_steps = 50
learning_rate = 1e-3
print_every = 1
beta = 1

optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(init_params)


In [17]:
def train(n_steps, params, static, states, train_loader, opt_state, key, print_every):

    def infinite_train_loader():
        while True:
            yield from train_loader

    def make_step(x, params, states, opt_state, key):
        key, subkey = jax.random.split(key)
        (minus_elbo, aux_loss), grads = jax.value_and_grad(loss, has_aux=True)(params, static, states, x, subkey, train=True, beta=beta)
        x_rec, rec_term, kld, states = aux_loss
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, states, opt_state, key, x_rec, (minus_elbo, rec_term, kld)

    elbo_list = []
    rec_term_list = []
    kld_list = []


    for step, (x, _) in zip(range(n_steps), infinite_train_loader()): 
        x = x.numpy()

        params, states, opt_state, key, x_rec, losses = make_step(x, params, states, opt_state, key)
        elbo_list.append(-losses[0])
        rec_term_list.append(losses[1])
        kld_list.append(losses[2])
        
        if (step % print_every) == 0 or (step == n_steps - 1):
            print(
                f"{step=}, elbo_loss={elbo_list[-1]}, rec_term={rec_term_list[-1]}, kld_term={kld_list[-1]}"
            )
        
            
    return params, states, (elbo_list, rec_term_list, kld_list), opt_state

In [18]:
key, subkey = jax.random.split(key)
final_params, final_states, loss_lists, opt_state = train(
    n_steps, init_params, static, init_states, train_dataloader, opt_state, key, print_every=print_every
)

(50, 256, 32, 32)
Traced<ShapedArray(float32[16,50,3,256,256])>with<DynamicJaxprTrace(level=3/0)>


2024-06-03 17:50:01.909979: W external/xla/xla/service/hlo_rematerialization.cc:2948] Can't reduce memory use below -18.49GiB (-19854978070 bytes) by rematerialization; only reduced to 24.63GiB (26451676305 bytes), down from 24.63GiB (26451676305 bytes) originally
2024-06-03 17:50:13.054486: W external/tsl/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.56GiB (rounded to 1677721600)requested by op 
2024-06-03 17:50:13.055550: W external/tsl/tsl/framework/bfc_allocator.cc:494] *******************************************************************************************_________
E0603 17:50:13.056539  155235 pjrt_stream_executor_client.cc:2826] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1677721600 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   15.65MiB
              constant allocation:         6B
        maybe_live_out allocation:   23.07GiB


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1677721600 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   15.65MiB
              constant allocation:         6B
        maybe_live_out allocation:   23.07GiB
     preallocated temp allocation:    18.8KiB
                 total allocation:   23.09GiB
Peak buffers:
	Buffer 1:
		Size: 1.56GiB
		XLA Label: fusion
		Shape: f32[50,16,32,128,128]
		==========================

	Buffer 2:
		Size: 1.56GiB
		Operator: op_name="jit(loss)/jit(main)/vmap(while)/body/dynamic_update_slice" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=116 deduplicated_name="loop_dynamic_update_slice_fusion.4"
		XLA Label: fusion
		Shape: f32[50,16,32,128,128]
		==========================

	Buffer 3:
		Size: 1.56GiB
		Operator: op_name="jit(loss)/jit(main)/vmap(while)/body/dynamic_update_slice" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=116
		XLA Label: fusion
		Shape: f32[50,16,32,128,128]
		==========================

	Buffer 4:
		Size: 1.56GiB
		Operator: op_name="jit(loss)/jit(main)/vmap(while)/body/dynamic_update_slice" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=116
		XLA Label: fusion
		Shape: f32[50,16,32,128,128]
		==========================

	Buffer 5:
		Size: 1.56GiB
		Operator: op_name="jit(loss)/jit(main)/vmap(while)/body/dynamic_update_slice" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=116 deduplicated_name="loop_dynamic_update_slice_fusion.4"
		XLA Label: fusion
		Shape: f32[50,16,32,128,128]
		==========================

	Buffer 6:
		Size: 800.00MiB
		XLA Label: fusion
		Shape: f32[50,16,64,64,64]
		==========================

	Buffer 7:
		Size: 800.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(while)/body/dynamic_update_slice" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=116
		XLA Label: fusion
		Shape: f32[50,16,256,32,32]
		==========================

	Buffer 8:
		Size: 800.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(while)/body/dynamic_update_slice" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=116 deduplicated_name="loop_dynamic_update_slice_fusion.12"
		XLA Label: fusion
		Shape: f32[50,16,64,64,64]
		==========================

	Buffer 9:
		Size: 800.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(while)/body/dynamic_update_slice" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=116
		XLA Label: fusion
		Shape: f32[50,16,64,64,64]
		==========================

	Buffer 10:
		Size: 800.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(while)/body/dynamic_update_slice" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=116
		XLA Label: fusion
		Shape: f32[50,16,64,64,64]
		==========================

	Buffer 11:
		Size: 800.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(while)/body/dynamic_update_slice" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=116 deduplicated_name="loop_dynamic_update_slice_fusion.12"
		XLA Label: fusion
		Shape: f32[50,16,64,64,64]
		==========================

	Buffer 12:
		Size: 600.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(while)/body/dynamic_update_slice" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=167
		XLA Label: fusion
		Shape: f32[50,16,3,256,256]
		==========================

	Buffer 13:
		Size: 600.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(while)/body/dynamic_update_slice" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=167
		XLA Label: fusion
		Shape: f32[50,16,3,256,256]
		==========================

	Buffer 14:
		Size: 600.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(while)/body/dynamic_update_slice" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=167
		XLA Label: fusion
		Shape: f32[50,16,3,256,256]
		==========================

	Buffer 15:
		Size: 600.00MiB
		Operator: op_name="jit(loss)/jit(main)/vmap(while)/body/dynamic_update_slice" source_file="/home/hugo/Documents/writings/vae_grf_isprs/vae_grf/vae_jax.py" source_line=116
		XLA Label: fusion
		Shape: f32[50,16,3,256,256]
		==========================



In [None]:
plt.plot(loss_lists[0], label="elbo")
plt.plot(loss_lists[1], label="rec_term")
plt.plot(loss_lists[2], label="kld")
plt.legend()
plt.show()

In [None]:
from datasets import LivestockTestDataset
test_dataset = LivestockTestDataset(
    img_size,
    fake_dataset_size=1024,
)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
x_test = next(iter(train_dataloader))
x_test =  x_test[0].numpy()
key, subkey = jax.random.split(key)
(_, aux_loss) = loss(final_params, static, final_states, x_test, subkey, train=False, beta=beta)
x_rec_test = aux_loss[0]
print(jnp.mean(x_rec_test[0,0,10:-10,10:-10]), jnp.mean(x_test[0,0,10:-10,10:-10]))
print(jnp.mean((x_rec_test[0,0,10:-10,10:-10]-x_test[0,0,10:-10,10:-10])**2))
print(x_rec_test[0, 0,10:-10,10:-10], x_test[0,0,10:-10,10:-10])

In [None]:
figure, axes = plt.subplots(2, 4)
axes[0, 0].imshow(jnp.moveaxis(x_test[0],0, 2))
axes[1, 0].imshow(jnp.moveaxis(x_rec_test[0], 0, 2))
plt.show()