In [3]:
import os

# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
os.environ["JAX_PLATFORMS"] = "cpu"

from datasets import load_dataset

import grain

import jax 
from jax import numpy as jnp

from flax import nnx

from gensbi.models.autoencoders import AutoEncoder1D, AutoEncoderParams, vae_loss_fn

import optax

In [4]:
repo_name = "aurelio-amerio/SBI-benchmarks"

task_name = "gravitational_waves"

dataset = load_dataset(repo_name, task_name).with_format("numpy")

Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]

In [5]:
df_train = dataset["train"]

In [6]:
def get_obs(batch):
    return batch["xs"]

In [7]:
dataset_grain = (
            grain.MapDataset.source(df_train)
            .shuffle(42)
            .repeat()
            .to_iter_dataset()
            .batch(128)
            .map(get_obs)
        )

In [8]:
dset_iter = iter(dataset_grain)
batch=next(dset_iter)

In [9]:
ae_params = AutoEncoderParams(
            resolution=8192,
            in_channels=2,
            ch=128,
            out_ch=2,
            ch_mult=[1, 2, 2, 2, 4],
            num_res_blocks=2,
            z_channels=32,
            scale_factor=0.3611,
            shift_factor=0.1159,
            rngs=nnx.Rngs(42),
            param_dtype=jnp.float32,
        )

In [10]:
ae_model = AutoEncoder1D(ae_params)

In [12]:
@nnx.jit
def train_step(model: AutoEncoder1D, optimizer: nnx.Optimizer, x: jax.Array):

  loss, grads = nnx.value_and_grad(vae_loss_fn)(model, x)
  optimizer.update(model, grads)

  return loss

In [13]:
optimizer = nnx.Optimizer(ae_model, optax.adamw(1e-3), wrt=nnx.Param)



In [None]:
train_step(ae_model, optimizer, next(dset_iter))