In [17]:
import os

# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
os.environ["JAX_PLATFORMS"] = "cuda"

from datasets import load_dataset

import grain

import jax 
from jax import numpy as jnp

from flax import nnx

from gensbi.models.autoencoders import AutoEncoder1D, AutoEncoderParams, vae_loss_fn

import optax

from tqdm import tqdm

In [2]:
repo_name = "aurelio-amerio/SBI-benchmarks"

task_name = "gravitational_waves"

# dataset = load_dataset(repo_name, task_name).with_format("numpy")
dataset = load_dataset(repo_name, task_name, cache_dir="/data/users/.cache").with_format("numpy")


In [3]:
df_train = dataset["train"]

In [4]:
def get_obs(batch):
    return jnp.array(batch["xs"],dtype=jnp.bfloat16)

In [5]:
dataset_grain = (
            grain.MapDataset.source(df_train)
            .shuffle(42)
            .repeat()
            .to_iter_dataset()
            .batch(128)
            .map(get_obs)
        )

In [6]:
dset_iter = iter(dataset_grain)
batch=next(dset_iter)

In [7]:
%timeit next(dset_iter)

219 ms ± 22.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
ae_params = AutoEncoderParams(
            resolution=8192,
            in_channels=2,
            ch=32,
            out_ch=2,
            ch_mult=[1, # 4096
                     2, # 2048
                     4, # 1024
                     8, # 512
                     16, # 256
                     16, # 128
                    #  16, # 64
                    #  16, # 32
                    #  16, # 16
                    #  16, # 8
                    #  16 # 4
                     ], 
            num_res_blocks=1,
            z_channels=128,
            scale_factor=0.3611,
            shift_factor=0.1159,
            rngs=nnx.Rngs(42),
            param_dtype=jnp.bfloat16,
        )

In [9]:
ae_model = AutoEncoder1D(ae_params)

In [10]:
@nnx.jit
def train_step(model: AutoEncoder1D, optimizer: nnx.Optimizer, x: jax.Array):

  loss, grads = nnx.value_and_grad(vae_loss_fn)(model, x)
  optimizer.update(model, grads)

  return loss

In [None]:
optimizer = nnx.Optimizer(ae_model, optax.adamw(1e-3), wrt=nnx.Param)



In [19]:
pbar = tqdm(range(10_000))
l_train = None

for j in pbar:

    batch=next(dset_iter)

    loss = train_step(
        ae_model, optimizer, batch
    )

    if j == 0:
        l_train = loss
    else:
        l_train = 0.9 * l_train + 0.1 * loss

    if j > 0 and j % 10 == 0:
        pbar.set_postfix(
            loss=f"{l_train:.4f}",
        )

  3%|▎         | 317/10000 [03:59<2:01:45,  1.33it/s, loss=-1388.4187]


KeyboardInterrupt: 