**RAN ON GOOGLE COLAB**

In [2]:
!pip install ml-collections

Collecting ml-collections
  Downloading ml_collections-1.1.0-py3-none-any.whl.metadata (22 kB)
Downloading ml_collections-1.1.0-py3-none-any.whl (76 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/76.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.7/76.7 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ml-collections
Successfully installed ml-collections-1.1.0


In [7]:
# ====== Imports ======
import os, json, time, pickle, shutil
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, value_and_grad, vmap, random as jran
import flax
import flax.linen as nn
from flax.training import train_state
from flax.linen.initializers import glorot_normal, zeros, normal
import optax
import ml_collections
import orbax.checkpoint as ocp
jran = jax.random

# ====== NN definition ======
activation_fn = { "tanh": jnp.tanh, "sin": jnp.sin }

def _get_activation(s):
    if s in activation_fn:
        return activation_fn[s]
    raise NotImplementedError(f"Activation {s} not supported yet!")

def _weight_fact(init_fn, mean, stddev):
    def init(key, shape):
        key1, key2 = jran.split(key)
        w = init_fn(key1, shape)
        g = mean + normal(stddev)(key2, (shape[-1],))
        g = jnp.exp(g)
        v = w / g
        return g, v
    return init

class Dense(nn.Module):
    features: int
    kernel_init: callable = glorot_normal()
    bias_init: callable = zeros
    reparam: dict | None = None

    @nn.compact
    def __call__(self, x):
        if self.reparam is None:
            kernel = self.param("kernel", self.kernel_init, (x.shape[-1], self.features))
        elif self.reparam["type"] == "weight_fact":
            g, v = self.param(
                "kernel",
                _weight_fact(self.kernel_init, mean=self.reparam["mean"], stddev=self.reparam["stddev"]),
                (x.shape[-1], self.features),
            )
            kernel = g * v
        else:
            raise ValueError(f"Unknown reparam type: {self.reparam}")
        bias = self.param("bias", self.bias_init, (self.features,))
        return jnp.dot(x, kernel) + bias

class MLP(nn.Module):
    arch_name: str = "MLP"
    hidden_dim: tuple[int, ...] = (32, 16)
    out_dim: int = 1
    activation: str = "tanh"
    periodicity: dict | None = None
    fourier_emb: dict | None = None
    reparam: dict | None = None

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        for h in self.hidden_dim:
            x = Dense(features=h, reparam=self.reparam)(x)
            x = self.activation_fn(x)
        x = Dense(features=self.out_dim, reparam=self.reparam)(x)
        return x

def ann_gen(config):
    reparam = None
    if getattr(config, "ann_reparam", False) == "weight_fact":
        reparam = ml_collections.ConfigDict({"type": "weight_fact", "mean": 0.5, "stddev": 0.1})
    if config.ann_str == "MLP":
        return MLP(
            arch_name=config.ann_str,
            hidden_dim=tuple(config.ann_hidden_dim),
            out_dim=config.ann_out_dim,
            activation=config.ann_activation_str,
            periodicity=config.ann_periodicity,
            fourier_emb=config.ann_fourier_emb,
            reparam=reparam,
        )
    raise NotImplementedError(f"Unknown arch {config.ann_str}")

# ====== Derivatives ======
def derivatives_2d(fn2d, x):
    def f(z): return jnp.squeeze(fn2d(z))
    dx = vmap(grad(f), 0)(x)
    Gamma_tau = dx[:, 0]
    Gamma_X   = dx[:, 1]
    return Gamma_tau, Gamma_X

def derivatives_3d(fn3d, x):
    def f(z): return jnp.squeeze(fn3d(z))
    def f_dx(z): return grad(f)(z)[2]
    dx  = vmap(grad(f), 0)(x)
    d2x = vmap(grad(f_dx), 0)(x)
    Gamma_tau = dx[:, 0]
    Gamma_X   = dx[:, 1]
    Gamma_S   = dx[:, 2]
    Gamma_SS  = d2x[:, 2]
    return Gamma_tau, Gamma_X, Gamma_S, Gamma_SS

# ====== Data samplers ======
def get_data_2d(n_pde=10000, n_bc=500, n_term=100, T=5.0, x_range=(-10.,10.), seed=0):
    rng = np.random.default_rng(seed)
    X_pde = rng.uniform(x_range[0], x_range[1], size=(n_pde, 1))
    tau   = rng.uniform(0, T, size=(n_pde, 1))
    x_tau_pde = np.hstack([tau, X_pde])
    # x=0 boundary along tau
    tau_bc = rng.uniform(0, T, size=(n_bc, 1))
    x_bc   = np.zeros((n_bc, 1))
    x_bc_zero = np.hstack([tau_bc, x_bc])
    # terminal condition samples at tau=0
    tau_term = np.zeros((n_term, 1))
    x_term = rng.uniform(x_range[0], x_range[1], size=(n_term, 1))
    x_term_cond = np.hstack([tau_term, x_term])
    return jnp.array(x_tau_pde), jnp.array(x_bc_zero), jnp.array(x_term_cond), T

def get_data_3d(n_pde=10000, n_bc=1000, n_term=100, T=5.0, x_range=(-10.,10.), s_range=(10,100), seed=0):
    rng = np.random.default_rng(seed)
    X_pde = rng.uniform(x_range[0], x_range[1], size=(n_pde, 1))
    S_pde = rng.uniform(s_range[0], s_range[1], size=(n_pde, 1))
    tau   = rng.uniform(0, T, size=(n_pde, 1))
    x_tau_pde = np.hstack([tau, X_pde, S_pde])
    # IC
    tau_bc = rng.uniform(0, T, size=(n_bc, 1))
    x_bc   = np.zeros((n_bc, 1))
    S_bc   = rng.uniform(s_range[0], s_range[1], size=(n_bc, 1))
    x_bc_zero = np.hstack([tau_bc, x_bc, S_bc])
    # Terminal
    tau_term = np.zeros((n_term, 1))
    S_term   = rng.uniform(s_range[0], s_range[1], size=(n_term, 1))
    x_term   = rng.uniform(x_range[0], x_range[1], size=(n_term, 1))
    x_term_cond = np.hstack([tau_term, x_term, S_term])
    return jnp.array(x_tau_pde), jnp.array(x_bc_zero), jnp.array(x_term_cond), T

# ====== Error terms ======
def error_2d(config, fn2d, data):
    x_tau_pde, x_bc_zero, x_term_cond, T = data
    Gamma_tau, Gamma_X = derivatives_2d(fn2d, x_tau_pde)
    X_vals = x_tau_pde[:,1]
    kappa = float(config.kappa)

    # PDE
    e_pde = (Gamma_tau - (kappa**2 * X_vals**2) + 0.25 * Gamma_X**2)**2

    # IC
    e_bc_zero = fn2d(x_bc_zero).squeeze()**2

    # Symmetry 
    z = x_tau_pde
    z_flip = z.at[:, 1].set(-z[:, 1])
    e_symm = (jnp.squeeze(fn2d(z)) - jnp.squeeze(fn2d(z_flip)))**2

    # Terminal
    x = x_term_cond[:, 1]
    target = (x**2) * 1e2
    pred   = jnp.squeeze(fn2d(x_term_cond))
    e_term = (pred - target)**2

    metrics = {
        "e_pde": jnp.mean(e_pde),
        "e_bc_zero": jnp.mean(e_bc_zero),
        "e_symm": jnp.mean(e_symm),
        "e_term": jnp.mean(e_term),
    }
    return {"e_pde": e_pde, "e_bc_zero": e_bc_zero, "e_symm": e_symm, "e_term": e_term}, metrics

def error_3d(config, fn3d, data):
    x_tau_pde, x_bc_zero, x_term_cond, T = data
    Gamma_tau, Gamma_X, _, Gamma_SS = derivatives_3d(fn3d, x_tau_pde)
    X_vals = x_tau_pde[:,1]
    S_vals = x_tau_pde[:,2]

    kappa = float(config.kappa)
    sigma = float(config.sigma)
    lam   = float(config.lam)

    # PDE
    e_pde = (
        Gamma_tau
        - 0.5 * sigma**2 * S_vals**2 * Gamma_SS
        - (kappa**2 * X_vals**2 + lam * S_vals * X_vals)
        + 0.25 * Gamma_X**2
    )**2

    # IC
    e_bc_zero = jnp.maximum(jnp.squeeze(fn3d(x_bc_zero)) - 1e-6, 0.0)**2

    # Terminal
    x = x_term_cond[:, 1]
    target = (x**2) * 1e2
    pred   = jnp.squeeze(fn3d(x_term_cond))
    e_term = (pred - target)**2

    # Symmetry
    z = x_tau_pde
    z_flip = z.at[:, 1].set(-z[:, 1]).at[:, 2].set(-z[:, 2])
    e_symm = (jnp.squeeze(fn3d(z)) - jnp.squeeze(fn3d(z_flip)))**2

    metrics = {
        "e_pde": jnp.mean(e_pde),
        "e_bc_zero": jnp.mean(e_bc_zero),
        "e_term": jnp.mean(e_term),
        "e_symm": jnp.mean(e_symm),
    }
    return {"e_pde": e_pde, "e_bc_zero": e_bc_zero, "e_term": e_term, "e_symm": e_symm}, metrics

def error(config, fn, data):
    d = data[0].shape[1]   # 2 or 3
    if d == 2:
        return error_2d(config, fn, data)
    elif d == 3:
        return error_3d(config, fn, data)
    else:
        raise ValueError(f"Unexpected input dimension d={d}")

# ====== Loss assembly ======
def adj(loss, lw, m):  
    return lw * jnp.mean(m * loss)

ALL_COMPONENTS = ['e_pde', 'e_bc_zero', 'e_term', 'e_symm']

def make_loss_lb():
    def loss_fn(config, fn, data, l_ws, params_sa):
        err, metrics = error(config, fn, data)
        loss_terms = {}
        for k in ALL_COMPONENTS:
            if k in err:
                w_k = l_ws.get(k, 1.0)
                m_k = params_sa.get(k, jnp.ones_like(err[k]))
                loss_terms[k] = adj(err[k], w_k, m_k)
        return loss_terms, metrics
    return loss_fn

loss_fn_lb = { "MLP": make_loss_lb(), "PINN": make_loss_lb() }

def make_loss_fn(components_key):
    def loss_fn(config, fn, data, l_ws, params_sa):
        loss_terms, metrics = loss_fn_lb[components_key](config, fn, data, l_ws, params_sa)
        total = jnp.sum(jnp.array(list(loss_terms.values())))
        return total, metrics
    return loss_fn

loss_fn = { "MLP": make_loss_fn("MLP"), "PINN": make_loss_fn("PINN") }

def init_params_sa(loss_str, data):
    x_tau_pde, x_bc_zero, x_term_cond, _ = data
    return {
        "e_pde": jnp.ones(len(x_tau_pde)),
        "e_bc_zero": jnp.ones(len(x_bc_zero)),
        "e_symm": jnp.ones(len(x_tau_pde)),
        "e_term": jnp.ones(len(x_term_cond)),
    }

# Initial static weights
init_l_ws = {
    "PINN": {'e_pde': 1.0, 'e_bc_zero': 0.1, 'e_symm': 0.5, 'e_term': 1.0}
}

# ====== Rebalancer ======
class Rebalancer:
    def __init__(
        self, keys,
        ema_beta=0.95,
        power=0.2,
        w_clip=(0.5, 2.0),
        ema_floor=1e-4,
        allow_upweight_small=False,
        freeze_tol=1e-5,
        freeze_patience=10,
        base_weights=None,
        static_keys=None,
        relative_to_init=True,
        eps=1e-12,
    ):
        self.keys = list(keys)
        self.ema_beta = float(ema_beta)
        self.power = float(power)
        self.w_min, self.w_max = map(float, w_clip)
        self.ema_floor = float(ema_floor)
        self.allow_upweight_small = bool(allow_upweight_small)
        self.freeze_tol = float(freeze_tol)
        self.freeze_patience = int(freeze_patience)
        self.base_weights = dict(base_weights or {})
        self.static_keys = set(static_keys or ())
        self.relative_to_init = bool(relative_to_init)
        self.eps = float(eps)
        self.ema = {k: None for k in self.keys}
        self.init_ema = {k: None for k in self.keys}
        self.below_ctr = {k: 0 for k in self.keys}
        self.frozen = {k: False for k in self.keys}

    def _smooth_update(self, old, new):
        if old is None:
            return new
        return self.ema_beta * old + (1.0 - self.ema_beta) * new

    def step(self, metrics, weights):
        for k in self.keys:
            if k not in metrics:
                continue
            val = float(metrics[k])
            val = max(val, self.ema_floor)
            self.ema[k] = self._smooth_update(self.ema[k], val)
            if self.init_ema[k] is None:
                self.init_ema[k] = self.ema[k]
            if self.ema[k] < self.freeze_tol:
                self.below_ctr[k] += 1
            else:
                self.below_ctr[k] = 0
            if self.freeze_patience > 0 and self.below_ctr[k] >= self.freeze_patience:
                self.frozen[k] = True

        present = [k for k in self.keys
                   if (k in metrics and self.ema[k] is not None and not self.frozen[k] and k not in self.static_keys)]
        if not present:
            new_w = dict(weights)
            for k in self.keys:
                if self.frozen[k]:
                    new_w[k] = float(self.base_weights.get(k, new_w.get(k, 1.0)))
            return new_w

        def rel_ema(k):
            if not self.relative_to_init:
                return max(self.ema[k], self.ema_floor)
            denom = max(self.init_ema[k] or self.ema_floor, self.ema_floor)
            return max(self.ema[k] / denom, self.ema_floor)

        rels = {k: rel_ema(k) for k in present}
        gmean = float(np.exp(np.mean([np.log(max(rels[k], self.eps)) for k in present])))

        new_w = dict(weights)
        for k in self.keys:
            if self.frozen[k]:
                new_w[k] = float(self.base_weights.get(k, new_w.get(k, 1.0)))
                continue
            if k in self.static_keys or k not in rels:
                new_w[k] = float(new_w.get(k, 1.0))
                continue
            ratio = max(rels[k] / gmean, self.eps)
            mult = 1.0 if (not self.allow_upweight_small and ratio < 1.0) else ratio ** (-self.power)
            new_w[k] = float(np.clip(new_w.get(k, 1.0) * mult, self.w_min, self.w_max))

        mean_w = float(np.mean([new_w[k] for k in present])) if present else 1.0
        if mean_w > 0:
            for k in present:
                new_w[k] = new_w[k] / mean_w
        return new_w

# ====== Training ======
def calibration(config, data, params_init=None, l_ws=None, params_sa=None):
    ann = ann_gen(config)
    ofunc = loss_fn[config.loss_str]

    if l_ws is None:
        l_ws = dict(init_l_ws[config.loss_str])
    if params_sa is None:
        params_sa = init_params_sa(config.loss_str, data)

    # FIXED: rebal_cfg indentation + safe defaults
    rebal_cfg = getattr(config, "rebal", None)
    if rebal_cfg is None:
        rebal_cfg = ml_collections.ConfigDict({})

    rebal = Rebalancer(
        keys=[k for k in ALL_COMPONENTS if k in (params_sa.keys())],
        ema_beta=getattr(rebal_cfg, "ema_beta", 0.98),
        power=getattr(rebal_cfg, "power", 0.2),
        w_clip=tuple(getattr(rebal_cfg, "w_clip", (0.3, 3.0))),
        ema_floor=getattr(rebal_cfg, "ema_floor", 1e-3),
        allow_upweight_small=getattr(rebal_cfg, "allow_upweight_small", False),
        freeze_tol=getattr(rebal_cfg, "freeze_tol", 1e-4),
        freeze_patience=getattr(rebal_cfg, "freeze_patience", 10),
        base_weights=dict(getattr(rebal_cfg, "base_weights", {"e_bc_zero": 0.1})),
        static_keys=set(getattr(rebal_cfg, "static_keys", ())),
        relative_to_init=getattr(rebal_cfg, "relative_to_init", True),
    )

    key = jran.PRNGKey(config.seed)
    _, key_init = jran.split(key, 2)
    dummy = jnp.ones((1, config.ann_in_dim))
    params0 = ann.init(key_init, dummy) if params_init is None else params_init

    state = train_state.TrainState.create(
        apply_fn=ann.apply,
        params=params0,
        tx=optax.adamw(learning_rate=5e-4)
    )

    @jit
    def train_step(state, data, l_ws, params_sa):
        def wrapped_loss(params):
            def fn(x): return ann.apply(params, x)
            return ofunc(config, fn, data, l_ws, params_sa)
        (loss, metrics), grads = value_and_grad(wrapped_loss, has_aux=True)(state.params)
        state = state.apply_gradients(grads=grads)
        return state, loss, metrics

    hist = []
    start_epoch = getattr(rebal_cfg, "start_epoch", 0)
    every       = getattr(rebal_cfg, "every", 1000)

    for epoch in range(config.num_epochs):
        state, loss, metrics = train_step(state, data, l_ws, params_sa)
        if (epoch >= start_epoch) and ((epoch - start_epoch) % every == 0):
            l_ws = rebal.step(metrics, l_ws)
        if epoch % 100 == 0:
            w_str = "{"+", ".join(f"{k}:{l_ws.get(k,1.0):.3g}" for k in sorted(l_ws.keys()))+"}"
            keys = ['e_pde','e_bc_zero','e_term','e_symm']
            msg  = " | ".join(f"{k}={float(metrics[k]):.6f}" for k in keys if k in metrics)
            print(f"Epoch {epoch}: total={float(loss):.6f} | {msg} | w={w_str}")
        hist.append((int(epoch), float(loss), {k: float(metrics[k]) for k in metrics}))

    return (lambda x: ann.apply(state.params, x)), hist, state.params, ann

# ====== Saving helpers ======
def save_params_overwrite(path, params):
    path = os.path.abspath(path)
    if os.path.exists(path):
        shutil.rmtree(path)
    ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
    try:
        ckptr.save(path, params, force=True)
    except TypeError:
        ckptr.save(path, params)

def _to_jsonable(obj):
    if isinstance(obj, ml_collections.ConfigDict):
        return {k: _to_jsonable(v) for k, v in obj.items()}
    if isinstance(obj, dict):
        return {k: _to_jsonable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_to_jsonable(v) for v in obj]
    if isinstance(obj, (str, int, float, bool)) or obj is None:
        return obj
    return str(obj)

def save_stage(run_dir, tag, *, params, config, history):
    run_dir = os.path.abspath(run_dir)
    os.makedirs(run_dir, exist_ok=True)
    save_params_overwrite(os.path.join(run_dir, f"{tag}_params"), params)
    with open(os.path.join(run_dir, f"{tag}_config.json"), "w") as f:
        json.dump(_to_jsonable(config), f, indent=2)
    with open(os.path.join(run_dir, f"{tag}_history.pkl"), "wb") as f:
        pickle.dump(history, f)

def restore_params(run_dir, tag):
    ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
    return ckptr.restore(os.path.abspath(os.path.join(run_dir, f"{tag}_params")))

# ====== Single-path runner  ======
def run_experiment(config, workdir=None):
    wd = os.path.abspath(workdir or getattr(config, "workdir", "runs/pinn_oe"))
    os.makedirs(wd, exist_ok=True)

    if float(config.lam) == 0.0:
        # ---- 2D baseline ----
        config.dim = 2
        config.ann_in_dim = 2
        data = get_data_2d(
            n_pde=getattr(config, "pts_num", 30000),
            n_bc =getattr(config, "n_bc_2d", getattr(config, "n_bc", getattr(config, "n_term", 2000))),
            n_term=getattr(config, "n_term_2d", getattr(config, "n_term", 500)),
            T=config.T, x_range=config.x_range, seed=config.seed
        )
        _, hist, params, _ = calibration(config, data)
        tag = "2d_lam0"
    else:
        # ---- 3D full PDE ----
        config.dim = 3
        config.ann_in_dim = 3
        data = get_data_3d(
            n_pde=getattr(config, "pts_num", 30000),
            n_bc =getattr(config, "n_bc_zero", 2000),
            n_term=getattr(config, "n_term_zero", 200),
            T=config.T, x_range=config.x_range, s_range=config.s_range, seed=config.seed
        )
        _, hist, params, _ = calibration(config, data)
        tag = f"3d_lam{float(config.lam)}".replace(".", "p")

    save_stage(wd, tag, params=params, config=config, history=hist)
    return params, {tag: hist}

# ====== Model builder ======
def build_model_fn(config, params):
    ann = ann_gen(config)
    return jax.jit(lambda x: ann.apply(params, x))

# ====== Default config & run ======
T = 5.0
config = ml_collections.ConfigDict({
    "seed": 43,
    "loss_str": "PINN",
    "ann_str": "MLP",
    "ann_hidden_dim": (500, 500, 500),
    "ann_out_dim": 1,
    "ann_activation_str": "tanh",
    "ann_periodicity": None,
    "ann_fourier_emb": None,
    "ann_reparam": False,
    "num_epochs": 55000,
    "pts_num": 30000,
    "x_range": (-10.0, 10.0),
    "s_range": (10.0, 100.0),    # used only if lam != 0
    "T": T,
    "kappa": 0.1,
    "sigma": 0.1,                # used only if lam != 0
    "lam": 0.1,                  # 0 -> 2D; nonzero -> 3D
    "n_bc": 5000,            
    "n_term": 5000,           
    "workdir": os.path.abspath("runs/pinn_oe"),
    "rebal": ml_collections.ConfigDict({
        "start_epoch": 500,
        "every": 5000,
        "ema_beta": 0.9,
        "power": 0.3,
        "w_clip": (0.1, 2.0),
    }),
})

final_params, all_hist = run_experiment(config, workdir=config.workdir)
model_PINN = build_model_fn(config, final_params)
print("Done. Checkpoints in:", config.workdir)


Epoch 0: total=20568564.000000 | e_pde=1233.092896 | e_bc_zero=0.766778 | e_term=20567328.000000 | e_symm=2.800812 | w={e_bc_zero:0.1, e_pde:1, e_symm:0.5, e_term:1}
Epoch 100: total=20248744.000000 | e_pde=1232.901611 | e_bc_zero=2579.071533 | e_term=20242106.000000 | e_symm=10295.749023 | w={e_bc_zero:0.1, e_pde:1, e_symm:0.5, e_term:1}
Epoch 200: total=20093874.000000 | e_pde=1232.901855 | e_bc_zero=5761.765137 | e_term=20080570.000000 | e_symm=22986.324219 | w={e_bc_zero:0.1, e_pde:1, e_symm:0.5, e_term:1}
Epoch 300: total=19945880.000000 | e_pde=1232.901855 | e_bc_zero=10105.311523 | e_term=19923486.000000 | e_symm=40301.117188 | w={e_bc_zero:0.1, e_pde:1, e_symm:0.5, e_term:1}
Epoch 400: total=19803792.000000 | e_pde=1232.901978 | e_bc_zero=15568.184570 | e_term=19769964.000000 | e_symm=62074.515625 | w={e_bc_zero:0.1, e_pde:1, e_symm:0.5, e_term:1}
Epoch 500: total=19667174.000000 | e_pde=1232.901855 | e_bc_zero=22109.556641 | e_term=19619658.000000 | e_symm=88143.546875 | w={e_

In [8]:
!zip -r /content/runs.zip /content/runs


  adding: content/runs/ (stored 0%)
  adding: content/runs/pinn_oe/ (stored 0%)
  adding: content/runs/pinn_oe/3d_lam0p1_history.pkl (deflated 62%)
  adding: content/runs/pinn_oe/3d_lam0p1_config.json (deflated 53%)
  adding: content/runs/pinn_oe/3d_lam0p1_params/ (stored 0%)
  adding: content/runs/pinn_oe/3d_lam0p1_params/array_metadatas/ (stored 0%)
  adding: content/runs/pinn_oe/3d_lam0p1_params/array_metadatas/process_0 (deflated 84%)
  adding: content/runs/pinn_oe/3d_lam0p1_params/manifest.ocdbt (deflated 7%)
  adding: content/runs/pinn_oe/3d_lam0p1_params/_CHECKPOINT_METADATA (deflated 39%)
  adding: content/runs/pinn_oe/3d_lam0p1_params/_METADATA (deflated 87%)
  adding: content/runs/pinn_oe/3d_lam0p1_params/ocdbt.process_0/ (stored 0%)
  adding: content/runs/pinn_oe/3d_lam0p1_params/ocdbt.process_0/manifest.ocdbt (stored 0%)
  adding: content/runs/pinn_oe/3d_lam0p1_params/ocdbt.process_0/d/ (stored 0%)
  adding: content/runs/pinn_oe/3d_lam0p1_params/ocdbt.process_0/d/8e13c54d53