In [2]:
import os

# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
os.environ["JAX_PLATFORMS"] = "cpu"

from datasets import load_dataset

import grain

import jax 
from jax import numpy as jnp

from flax import nnx

from gensbi.models.autoencoders import AutoEncoder1D, AutoEncoderParams, vae_loss_fn
from gensbi.models.autoencoders.commons import Loss

from gensbi.recipes import VAE1DPipeline

import optax

from tqdm import tqdm

In [3]:
repo_name = "aurelio-amerio/SBI-benchmarks"

task_name = "gravitational_waves"

# dataset = load_dataset(repo_name, task_name).with_format("numpy")
dataset = load_dataset(repo_name, task_name, cache_dir="/data/users/.cache").with_format("numpy")


In [4]:
df_train = dataset["train"]
df_val = dataset["validation"]

In [5]:
df_train.shape

(89488, 2)

In [6]:
def get_obs(batch):
    return jnp.array(batch["xs"],dtype=jnp.bfloat16)

In [7]:
train_dataset = (
            grain.MapDataset.source(df_train)
            .shuffle(42)
            .repeat()
            .to_iter_dataset()
            .batch(128)
            .map(get_obs)
        )

val_dataset = (
            grain.MapDataset.source(df_val)
            .shuffle(42)
            .repeat()
            .to_iter_dataset()
            .batch(128)
            .map(get_obs)
        )

In [8]:
train_iter = iter(train_dataset)
val_iter = iter(val_dataset)

In [9]:
next(train_iter).shape, next(val_iter).shape

((128, 8192, 2), (128, 8192, 2))

In [10]:
ae_params = AutoEncoderParams(
            resolution=8192,
            in_channels=2,
            ch=32,
            out_ch=2,
            ch_mult=[1, # 4096
                     2, # 2048
                     4, # 1024
                     8, # 512
                     16, # 256
                     16, # 128
                    #  16, # 64
                    #  16, # 32
                    #  16, # 16
                    #  16, # 8
                    #  16 # 4
                     ], 
            num_res_blocks=1,
            z_channels=128,
            scale_factor=0.3611,
            shift_factor=0.1159,
            rngs=nnx.Rngs(42),
            param_dtype=jnp.bfloat16,
        )

In [11]:
pipeline = VAE1DPipeline(
            train_dataset=train_dataset,
            val_dataset=val_dataset,
            params = ae_params,
)

In [None]:
# losses = nnx.state(pipeline.model, nnx.Intermediate, "kl_loss")
losses = nnx.state(pipeline.model, Loss)
losses

(State({}), State({}))

In [13]:
batch = next(val_iter)[:10]

In [14]:
pipeline.model.train(update_KL=True)

In [15]:
pipeline.model(batch)

Array([[[-0.271484, 0.00445557],
        [-0.136719, -0.263672],
        [-0.052002, -0.322266],
        ...,
        [0.0869141, 0.0537109],
        [0.200195, 0.00340271],
        [0.0668945, -0.00123596]],

       [[-0.298828, 0.275391],
        [-0.182617, 0.0159912],
        [0.0603027, -0.263672],
        ...,
        [0.206055, -0.0195312],
        [0.245117, -0.0483398],
        [0.108887, 0.0654297]],

       [[-0.263672, 0.0032196],
        [-0.0549316, -0.0693359],
        [0.045166, -0.164062],
        ...,
        [-0.00552368, 0.0805664],
        [0.188477, 0.124512],
        [0.0839844, 0.0130615]],

       ...,

       [[0.111816, -0.0266113],
        [0.269531, -0.185547],
        [0.455078, 0.0258789],
        ...,
        [0.195312, 0.057373],
        [0.130859, 0.15625],
        [0.235352, 0.102051]],

       [[0.00141144, 0.0249023],
        [-0.136719, -0.0878906],
        [-0.186523, -0.333984],
        ...,
        [0.0932617, -0.0976562],
        [0.431641, -0.

In [16]:
# losses = nnx.state(pipeline.model, nnx.Intermediate, "kl_loss")
losses = nnx.state(pipeline.model, Loss)
losses

State({
  'reg': {
    'kl_loss': Loss( # 1 (2 B)
      value=Array(0.296875, dtype=bfloat16)
    )
  }
})

In [17]:
sum(jax.tree_util.tree_leaves(losses), 0.0)

Array(0.296875, dtype=bfloat16)

In [18]:
pipeline.train(nnx.Rngs(0), 2, save_model=True)

100%|██████████| 2/2 [00:46<00:00, 23.22s/it]


Saved model to checkpoint


([], [])

In [22]:
pipeline.restore_model()





Restored model from checkpoint


In [20]:
pipeline.model(batch)

Array([[[-0.02699093,  0.14120202],
        [-0.14797217,  0.0440407 ],
        [-0.13716231,  0.06781854],
        ...,
        [-0.16043693,  0.02135202],
        [-0.14379624,  0.00975631],
        [-0.02119375,  0.21261992]],

       [[-0.22071975,  0.10075741],
        [-0.33332667,  0.1215784 ],
        [-0.29402235,  0.11830261],
        ...,
        [-0.32479218,  0.0025952 ],
        [-0.32971165, -0.02891255],
        [-0.06240482,  0.3217828 ]],

       [[-0.07455317,  0.24592632],
        [-0.24471566,  0.12325878],
        [-0.25449184,  0.15765694],
        ...,
        [-0.21531966,  0.11880031],
        [-0.22289711,  0.05020047],
        [-0.00130046,  0.3634961 ]],

       ...,

       [[ 0.07507195,  0.23252553],
        [ 0.04847581,  0.14074957],
        [ 0.05756085,  0.14732203],
        ...,
        [ 0.00859745,  0.12317373],
        [-0.02417461,  0.08864112],
        [ 0.07095455,  0.22526038]],

       [[ 0.05196337,  0.17241219],
        [-0.02159689,  0.09

In [23]:
# losses = nnx.state(pipeline.model, nnx.Intermediate, "kl_loss")
losses = nnx.state(pipeline.model, Loss)
losses

State({
  'reg': {
    'kl_loss': Loss( # 1 (2 B)
      value=Array(0.511719, dtype=bfloat16)
    )
  }
})