**RAN ON GOOGLE COLAB**

In [None]:
!pip install ml-collections

Collecting ml-collections
  Downloading ml_collections-1.1.0-py3-none-any.whl.metadata (22 kB)
Downloading ml_collections-1.1.0-py3-none-any.whl (76 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/76.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.7/76.7 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ml-collections
Successfully installed ml-collections-1.1.0


In [None]:
# ====== Imports ======
import os, json, time, pickle, shutil
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, value_and_grad, vmap
import flax
import flax.linen as nn
from flax.training import train_state
from flax.linen.initializers import glorot_normal, zeros, normal
import optax
import ml_collections
import orbax.checkpoint as ocp
jran = jax.random


# ====== NN definition ======
activation_fn = { "tanh": jnp.tanh, "sin": jnp.sin }

def _get_activation(s):
    if s in activation_fn:
        return activation_fn[s]
    raise NotImplementedError(f"Activation {s} not supported yet!")

def _weight_fact(init_fn, mean, stddev):
    def init(key, shape):
        key1, key2 = jran.split(key)
        w = init_fn(key1, shape)
        g = mean + normal(stddev)(key2, (shape[-1],))
        g = jnp.exp(g)
        v = w / g
        return g, v
    return init

class Dense(nn.Module):
    features: int
    kernel_init: callable = glorot_normal()
    bias_init: callable = zeros
    reparam: dict | None = None

    @nn.compact
    def __call__(self, x):
        if self.reparam is None:
            kernel = self.param("kernel", self.kernel_init, (x.shape[-1], self.features))
        elif self.reparam["type"] == "weight_fact":
            g, v = self.param(
                "kernel",
                _weight_fact(self.kernel_init, mean=self.reparam["mean"], stddev=self.reparam["stddev"]),
                (x.shape[-1], self.features),
            )
            kernel = g * v
        bias = self.param("bias", self.bias_init, (self.features,))
        return jnp.dot(x, kernel) + bias

class MLP(nn.Module):
    arch_name: str = "MLP"
    hidden_dim: tuple[int, ...] = (32, 16)
    out_dim: int = 1
    activation: str = "tanh"
    periodicity: dict | None = None
    fourier_emb: dict | None = None
    reparam: dict | None = None

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        for h in self.hidden_dim:
            x = Dense(features=h, reparam=self.reparam)(x)
            x = self.activation_fn(x)
        x = Dense(features=self.out_dim, reparam=self.reparam)(x)
        return x

def ann_gen(config):
    reparam = None
    if getattr(config, "ann_reparam", False) == "weight_fact":
        reparam = ml_collections.ConfigDict({"type": "weight_fact", "mean": 0.5, "stddev": 0.1})
    if config.ann_str == "MLP":
        return MLP(
            arch_name=config.ann_str,
            hidden_dim=tuple(config.ann_hidden_dim),
            out_dim=config.ann_out_dim,
            activation=config.ann_activation_str,
            periodicity=config.ann_periodicity,
            fourier_emb=config.ann_fourier_emb,
            reparam=reparam,
        )
    raise NotImplementedError(f"Unknown arch {config.ann_str}")


# ====== Derivatives ======
def derivatives_2d(fn2d, x):
    def f(z): return fn2d(z)[0]
    dx = vmap(grad(f), 0)(x)
    Gamma_tau = dx.T[0]
    Gamma_x   = dx.T[1]
    return Gamma_tau, Gamma_x

def derivatives_3d(fn3d, x):
    def f(z): return fn3d(z)[0]
    def f_dx(z): return grad(f)(z)[2]
    dx  = vmap(grad(f), 0)(x)
    d2x = vmap(grad(f_dx), 0)(x)
    Gamma_tau = dx.T[0]
    Gamma_X   = dx.T[1]
    Gamma_S   = dx.T[2]
    Gamma_SS  = d2x.T[2]
    return Gamma_tau, Gamma_X, Gamma_S, Gamma_SS


# ====== Data samplers ======
def get_data_2d(n_pde=10000, n_bc=500, n_term=100, T=5.0, x_range=(-10.,10.), seed=0):
    rng = np.random.default_rng(seed)
    X_pde = rng.uniform(x_range[0], x_range[1], size=(n_pde, 1))
    tau   = rng.uniform(0, T, size=(n_pde, 1))
    x_tau_pde = np.hstack([tau, X_pde])
    # x=0 boundary along tau
    tau_bc = rng.uniform(0, T, size=(n_bc, 1))
    x_bc   = np.zeros((n_bc, 1))
    x_bc_zero = np.hstack([tau_bc, x_bc])

    # terminal condition
    tau_term = np.zeros((n_term, 1))
    x_term = rng.uniform(x_range[0], x_range[1], size=(n_term, 1))
    x_term_cond = np.hstack([tau_term, x_term])

    return jnp.array(x_tau_pde), jnp.array(x_bc_zero), jnp.array(x_term_cond), T

def get_data_3d(n_pde=10000, n_bc=1000, n_term=100, T=5.0, x_range=(-10.,10.), s_range=(10,100), seed=0):
    rng = np.random.default_rng(seed)
    X_pde = rng.uniform(x_range[0], x_range[1], size=(n_pde, 1))
    S_pde = rng.uniform(s_range[0], s_range[1], size=(n_pde, 1))
    tau   = rng.uniform(0, T, size=(n_pde, 1))
    x_tau_pde = np.hstack([tau, X_pde, S_pde])
    # IC
    tau_bc = rng.uniform(0, T, size=(n_bc, 1))
    x_bc   = np.zeros((n_bc, 1))
    S_bc   = rng.uniform(s_range[0], s_range[1], size=(n_bc, 1))
    x_bc_zero = np.hstack([tau_bc, x_bc, S_bc])
    # Terminal 
    tau_term = np.zeros((n_term, 1))
    S_term   = rng.uniform(s_range[0], s_range[1], size=(n_term, 1))
    x_term   = rng.uniform(x_range[0], x_range[1], size=(n_term, 1))
    x_term_cond = np.hstack([tau_term, x_term, S_term])
    return jnp.array(x_tau_pde), jnp.array(x_bc_zero), jnp.array(x_term_cond), T


# ====== Error terms ======
def error_2d(config, fn2d, data):
    x_tau_pde, x_bc_zero, x_term_cond, T = data
    Gamma_tau, Gamma_X = derivatives_2d(fn2d, x_tau_pde)
    tau_vals = x_tau_pde[:,0]
    X_vals   = x_tau_pde[:,1]
    kappa = config.kappa
    # PDE residual
    e_pde = ( Gamma_tau - (kappa**2 * X_vals**2) + 0.25 * Gamma_X**2 )**2
    # IC
    e_bc_zero = fn2d(x_bc_zero).squeeze()**2
    # Symmetry
    z = x_tau_pde
    z_flip = z.at[:, 1].set(-z[:, 1])
    e_symm = (fn2d(z).squeeze() - fn2d(z_flip).squeeze())**2

    # Terminal condition
    x = x_term_cond[:, 1]                           
    target = (x**2) * 1e2                           
    pred   = fn2d(x_term_cond).squeeze()             
    e_term = (pred - target)**2                      
    metrics = {
        "e_pde": jnp.mean(e_pde),
        "e_bc_zero": jnp.mean(e_bc_zero),
        "e_symm": jnp.mean(e_symm),
        "e_term": jnp.mean(e_term),
    }
    return {"e_pde": e_pde, "e_bc_zero": e_bc_zero, "e_symm": e_symm, "e_term": e_term}, metrics

def error_3d(config, fn3d, data):
    x_tau_pde, x_bc_zero, x_term_cond, T = data
    Gamma_tau, Gamma_X, _, Gamma_SS = derivatives_3d(fn3d, x_tau_pde)
    tau_vals = x_tau_pde[:,0]
    X_vals   = x_tau_pde[:,1]
    S_vals   = x_tau_pde[:,2]

    kappa = config.kappa
    sigma = config.sigma
    lam   = config.lam

    # PDE residual
    e_pde = (
        Gamma_tau
        - 0.5 * sigma**2 * S_vals**2 * Gamma_SS
        - (kappa**2 * X_vals**2 + lam * S_vals * X_vals)
        + 0.25 * Gamma_X**2
    )**2
    # IC
    e_bc_zero = jnp.maximum(fn3d(x_bc_zero).squeeze() - 1e-6, 0.0)**2
    x = x_term_cond[:, 1]                           
    target = (x**2) * 1e2
    pred   = fn3d(x_term_cond).squeeze()
    e_term = (pred - target)**2
    z = x_tau_pde
    z_flip = z.at[:, 1].set(-z[:, 1]).at[:, 2].set(-z[:, 2])
    e_symm = (fn3d(z).squeeze() - fn3d(z_flip).squeeze())**2

    metrics = {
        "e_pde": jnp.mean(e_pde),
        "e_bc_zero": jnp.mean(e_bc_zero),
        "e_term": jnp.mean(e_term),
        "e_symm": jnp.mean(e_symm),
    }
    return {"e_pde": e_pde, "e_bc_zero": e_bc_zero, "e_term": e_term, "e_symm": e_symm}, metrics

def error(config, fn, data):
    d = data[0].shape[1]   # 2 or 3 
    if d == 2:
        return error_2d(config, fn, data)
    elif d == 3:
        return error_3d(config, fn, data)
    else:
        raise ValueError(f"Unexpected input dimension d={d}")



# ====== 2D→3D parameter inflation ======
def inflate_2d_to_3d_params(params_2d, params_3d_init, first_dense="Dense_0"):
    """
    Take trained 2D params (inputs [tau, x]) and freshly inited 3D params ([tau, x, S]).
    Return params_3d with an added zero row for S in the first Dense kernel.
    """
    p2 = flax.core.unfreeze(params_2d)
    p3 = flax.core.unfreeze(params_3d_init)
    for mod, wdict in p2["params"].items():
        for name, arr in wdict.items():
            if (mod, name) == (first_dense, "kernel"):
                continue
            if arr.shape == p3["params"][mod][name].shape:
                p3["params"][mod][name] = arr
    k2 = p2["params"][first_dense]["kernel"]   
    H  = k2.shape[1]
    k3 = jnp.concatenate([k2, jnp.zeros((1, H))], axis=0)
    p3["params"][first_dense]["kernel"] = k3
    p3["params"][first_dense]["bias"]   = p2["params"][first_dense]["bias"]
    return flax.core.freeze(p3)


# ====== Loss assembly ======
def adj(loss, lw, m):  
    return lw * jnp.mean(m * loss)

ALL_COMPONENTS = ['e_pde', 'e_bc_zero', 'e_term', 'e_symm']

def make_loss_lb():
    def loss_fn(config, fn, data, l_ws, params_sa):
        err, metrics = error(config, fn, data)
        components = [k for k in ALL_COMPONENTS if k in err]
        loss_terms = {}
        for k in components:
            w_k = l_ws.get(k, 1.0)
            m_k = params_sa.get(k, jnp.ones_like(err[k]))
            loss_terms[k] = adj(err[k], w_k, m_k)
        return loss_terms, metrics
    return loss_fn

loss_fn_lb = { "MLP": make_loss_lb(), "PINN": make_loss_lb() }

def make_loss_fn(components_key):
    def loss_fn(config, fn, data, l_ws, params_sa):
        loss_terms, metrics = loss_fn_lb[components_key](config, fn, data, l_ws, params_sa)
        total = jnp.sum(jnp.array(list(loss_terms.values())))
        return total, metrics
    return loss_fn

loss_fn = { "MLP": make_loss_fn("MLP"), "PINN": make_loss_fn("PINN") }

def init_params_sa(loss_str, data):
    d = data[0].shape[1]
    if d == 2:
        x_tau_pde, x_bc_zero, x_term_cond, _ = data
        return {
            "e_pde": jnp.ones(len(x_tau_pde)),
            "e_bc_zero": jnp.ones(len(x_bc_zero)),
            "e_symm": jnp.ones(len(x_tau_pde)),
            "e_term": jnp.ones(len(x_term_cond)),
        }
    elif d == 3:
        x_tau_pde, x_bc_zero, x_term_cond, _ = data
        return {
            "e_pde": jnp.ones(len(x_tau_pde)),
            "e_bc_zero": jnp.ones(len(x_bc_zero)),
            "e_symm": jnp.ones(len(x_tau_pde)),
            "e_term": jnp.ones(len(x_term_cond)),
        }

# Initial static weights 
init_l_ws = {
    "PINN": {'e_pde': 1.0, 'e_bc_zero': 0.1, 'e_symm': 0.5, 'e_term': 1.0}
}


# ====== Rebalancer ======
class Rebalancer:
    
    def __init__(
        self, keys,
        ema_beta=0.95,
        power=0.2,
        w_clip=(0.5, 2.0),
        ema_floor=1e-4,
        allow_upweight_small=False,
        freeze_tol=1e-5,
        freeze_patience=10,
        base_weights=None,        
        static_keys=None,         
        relative_to_init=True,
        eps=1e-12,
    ):
        self.keys = list(keys)
        self.ema_beta = float(ema_beta)
        self.power = float(power)
        self.w_min, self.w_max = map(float, w_clip)
        self.ema_floor = float(ema_floor)
        self.allow_upweight_small = bool(allow_upweight_small)
        self.freeze_tol = float(freeze_tol)
        self.freeze_patience = int(freeze_patience)
        self.base_weights = dict(base_weights or {})
        self.static_keys = set(static_keys or ())
        self.relative_to_init = bool(relative_to_init)
        self.eps = float(eps)

        self.ema = {k: None for k in self.keys}
        self.init_ema = {k: None for k in self.keys}
        self.below_ctr = {k: 0 for k in self.keys}
        self.frozen = {k: False for k in self.keys}

    def _smooth_update(self, old, new):
        if old is None:
            return new
        return self.ema_beta * old + (1.0 - self.ema_beta) * new

    def step(self, metrics, weights):
        for k in self.keys:
            if k not in metrics:
                continue
            val = float(metrics[k])
            val = max(val, self.ema_floor) 
            self.ema[k] = self._smooth_update(self.ema[k], val)
            if self.init_ema[k] is None:
                self.init_ema[k] = self.ema[k]

            if self.ema[k] < self.freeze_tol:
                self.below_ctr[k] += 1
            else:
                self.below_ctr[k] = 0
            if self.freeze_patience > 0 and self.below_ctr[k] >= self.freeze_patience:
                self.frozen[k] = True

        present = [
            k for k in self.keys
            if (k in metrics and self.ema[k] is not None and not self.frozen[k] and k not in self.static_keys)
        ]
        if not present:
            new_w = dict(weights)
            for k in self.keys:
                if self.frozen[k]:
                    new_w[k] = float(self.base_weights.get(k, new_w.get(k, 1.0)))
            return new_w

        def rel_ema(k):
            if not self.relative_to_init:
                return max(self.ema[k], self.ema_floor)
            denom = max(self.init_ema[k] or self.ema_floor, self.ema_floor)
            return max(self.ema[k] / denom, self.ema_floor)

        rels = {k: rel_ema(k) for k in present}
        gmean = float(np.exp(np.mean([np.log(max(rels[k], self.eps)) for k in present])))

        new_w = dict(weights)
        for k in self.keys:
            if self.frozen[k]:
                new_w[k] = float(self.base_weights.get(k, new_w.get(k, 1.0)))
                continue
            if k in self.static_keys or k not in rels:
                new_w[k] = float(new_w.get(k, 1.0))
                continue

            ratio = max(rels[k] / gmean, self.eps)
            if not self.allow_upweight_small and ratio < 1.0:
                mult = 1.0  
            else:
                mult = ratio ** (-self.power)

            new_w[k] = float(np.clip(new_w.get(k, 1.0) * mult, self.w_min, self.w_max))

        mean_w = float(np.mean([new_w[k] for k in present])) if present else 1.0
        if mean_w > 0:
            for k in present:
                new_w[k] = new_w[k] / mean_w
        return new_w


# ====== Training ======
def calibration(config, data, params_init=None, l_ws=None, params_sa=None):
    ann = ann_gen(config)
    ofunc = loss_fn[config.loss_str]

    if l_ws is None:
        l_ws = dict(init_l_ws[config.loss_str])
    if params_sa is None:
        params_sa = init_params_sa(config.loss_str, data)

        rebal_cfg = getattr(config, "rebal", None)
    if rebal_cfg is None:
        rebal_cfg = ml_collections.ConfigDict({})

    rebal = Rebalancer(
        keys=[k for k in ALL_COMPONENTS if k in (params_sa.keys())],
        ema_beta=getattr(rebal_cfg, "ema_beta", 0.98),
        power=getattr(rebal_cfg, "power", 0.2),
        w_clip=tuple(getattr(rebal_cfg, "w_clip", (0.3, 3.0))),
        ema_floor=getattr(rebal_cfg, "ema_floor", 1e-3),
        allow_upweight_small=getattr(rebal_cfg, "allow_upweight_small", False),
        freeze_tol=getattr(rebal_cfg, "freeze_tol", 1e-4),
        freeze_patience=getattr(rebal_cfg, "freeze_patience", 10),
        base_weights=dict(getattr(rebal_cfg, "base_weights", {"e_bc_zero": 0.1})),
        static_keys=set(getattr(rebal_cfg, "static_keys", ())),
        relative_to_init=getattr(rebal_cfg, "relative_to_init", True),
    )


    key = jran.PRNGKey(config.seed)
    _, key_init = jran.split(key, 2)
    dummy = jnp.ones((1, config.ann_in_dim))
    params0 = ann.init(key_init, dummy) if params_init is None else params_init

    state = train_state.TrainState.create(
        apply_fn=ann.apply,
        params=params0,
        tx=optax.adamw(learning_rate=5e-4)
    )

    @jit
    def train_step(state, data, l_ws, params_sa):
        def wrapped_loss(params):
            def fn(x): return ann.apply(params, x)
            return ofunc(config, fn, data, l_ws, params_sa)
        (loss, metrics), grads = value_and_grad(wrapped_loss, has_aux=True)(state.params)
        state = state.apply_gradients(grads=grads)
        return state, loss, metrics

    hist = []
    for epoch in range(config.num_epochs):
        state, loss, metrics = train_step(state, data, l_ws, params_sa)

        if (epoch >= rebal_cfg.start_epoch) and ((epoch - rebal_cfg.start_epoch) % rebal_cfg.every == 0):
            l_ws = rebal.step(metrics, l_ws)

        if epoch % 100 == 0:
            w_str = "{"+", ".join(f"{k}:{l_ws.get(k,1.0):.3g}" for k in sorted(l_ws.keys()))+"}"
            keys = ['e_pde','e_bc_zero','e_term','e_symm']
            msg  = " | ".join(f"{k}={float(metrics[k]):.6f}" for k in keys if k in metrics)
            print(f"Epoch {epoch}: total={float(loss):.6f} | {msg} | w={w_str}")

        hist.append((int(epoch), float(loss), {k: float(metrics[k]) for k in metrics}))

    return (lambda x: ann.apply(state.params, x)), hist, state.params, ann


# ====== Curriculum runner ======
def save_params_overwrite(path, params):
    path = os.path.abspath(path)
    if os.path.exists(path):
        shutil.rmtree(path)
    ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
    try:
        ckptr.save(path, params, force=True)  
    except TypeError:
        ckptr.save(path, params)            

def save_stage(run_dir, tag, *, params, config, history):
    """
    writes:
      - {run_dir}/{tag}_params/  
      - {run_dir}/{tag}_config.json
      - {run_dir}/{tag}_history.pkl
    """
    run_dir = os.path.abspath(run_dir)
    os.makedirs(run_dir, exist_ok=True)
    save_params_overwrite(os.path.join(run_dir, f"{tag}_params"), params)
    def _to_jsonable(obj):
        try:
            return obj.to_dict()
        except Exception:
            if isinstance(obj, (dict, list, str, int, float, bool)) or obj is None:
                return obj
            if hasattr(obj, "items"):
                return {k: _to_jsonable(v) for k, v in obj.items()}
            return str(obj)
    with open(os.path.join(run_dir, f"{tag}_config.json"), "w") as f:
        json.dump(_to_jsonable(config), f, indent=2)
    with open(os.path.join(run_dir, f"{tag}_history.pkl"), "wb") as f:
        pickle.dump(history, f)

def restore_params(run_dir, tag):
    ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
    return ckptr.restore(os.path.abspath(os.path.join(run_dir, f"{tag}_params")))

def run_curriculum(config2d, config3d, lam_star, workdir=None):
    # ---- Phase A: 2D (lambda=0) ----
    config2d.dim = 2
    config2d.ann_in_dim = 2
    config2d.lam = 0.0
    alphas = config3d.alphas
    workdir = os.path.abspath(workdir or getattr(config3d, "workdir", "runs/pinn_oe"))
    os.makedirs(workdir, exist_ok=True)

    data2d = get_data_2d(
        n_pde=config2d.pts_num, n_bc=config2d.n_bc_zero, n_term = config2d.n_term_zero,
        T=config2d.T, x_range=config2d.x_range, seed=config2d.seed
    )
    _, hist2d, params_2d, _ = calibration(config2d, data2d)

    # save 2D stage
    save_stage(workdir, "2d_lam0", params=params_2d, config=config2d, history=hist2d)

    if alphas is None:
        return params_2d, {"2d": hist2d}

    # ---- Phase B: inflate to 3D, then curriculum over lambda ----
    config3d.dim = 3
    config3d.ann_in_dim = 3
    ann3d = ann_gen(config3d)
    params_3d_init = ann3d.init(jran.PRNGKey(1), jnp.ones((1,3)))
    params = inflate_2d_to_3d_params(params_2d, params_3d_init)

    data3d = get_data_3d(
        n_pde=config3d.pts_num, n_bc=config3d.n_bc_zero, n_term=config3d.n_term_zero,
        T=config3d.T, x_range=config3d.x_range, s_range=config3d.s_range, seed=config3d.seed
    )

    all_hist = {"2d": hist2d}
    start = time.time()
    for alpha in alphas:
        lam_stage = float(alpha * lam_star)
        print(f"\n=== 3D stage: alpha={alpha}, lambda={lam_stage} ===")
        config3d.lam = lam_stage
        _, hist, params, _ = calibration(config3d, data3d, params_init=params)
        all_hist[f"alpha_{alpha}"] = hist
        tag = f"alpha_{alpha}".replace('.', 'p') 
        save_stage(workdir, tag, params=params, config=config3d, history=hist)

    print(f"Training completed in {time.time() - start:.2f} s total.")
    return params, all_hist


# ====== Model builder ======
def build_model_fn(config, params):
    ann = ann_gen(config)
    return jax.jit(lambda x: ann.apply(params, x)) 

T=5.0
# ====== Default configs ======
def default_configs():
    config2d = ml_collections.ConfigDict({
        "seed": 43,
        "loss_str": "PINN",
        "ann_str": "MLP",
        "ann_hidden_dim": (500, 500, 500),
        "ann_out_dim": 1,
        "ann_activation_str": "tanh",
        "ann_periodicity": None,
        "ann_fourier_emb": None,
        "ann_reparam": False,
        "dim": 2,
        "ann_in_dim": 2,
        "num_epochs": 30000,
        "pts_num": 30000,
        "x_range": (-10.0, 10.0),
        "T": T,
        "kappa": 0.1,
        "lam": 0.0,             # not used in 2D residual
        "n_bc_zero": 5000,    
        "n_term_zero": 5000,
        "rebal": ml_collections.ConfigDict({
            "start_epoch": 500,
            "every": 5000,
            "ema_beta": 0.95,
            "power": 0.3,
            "w_clip": (0.1, 2.0),
        }),
    })

    config3d = ml_collections.ConfigDict({
        "seed": 43,
        "loss_str": "PINN",
        "ann_str": "MLP",
        "ann_hidden_dim": (500, 500, 500),
        "ann_out_dim": 1,
        "ann_activation_str": "tanh",
        "ann_periodicity": None,
        "ann_fourier_emb": None,
        "ann_reparam": False,
        "dim": 3,
        "ann_in_dim": 3,
        "num_epochs": 5000,
        "pts_num": 30000,
        "x_range": (-10.0, 10.0),
        "s_range": (10.0, 100.0),
        "T": T,
        "kappa": 0.1,
        "sigma": 0.1,
        "lam": 0.1,             
        "n_bc_zero": 5000,      
        "n_term_zero": 5000,     
        "alphas": (0.25, 0.5, 0.75, 0.9, 1.0),
        "workdir": os.path.abspath("runs/pinn_oe"),
        "rebal": ml_collections.ConfigDict({
            "start_epoch": 500,
            "every": 1000,
            "ema_beta": 0.95,
            "power": 0.3,
            "w_clip": (0.1, 2.0),
        }),
    })
    return config2d, config3d

config2d, config3d = default_configs()
final_params, all_hist = run_curriculum(config2d, config3d, lam_star=config3d.lam, workdir=config3d.workdir)
model_PINN = build_model_fn(config3d, final_params)
print("Done. Checkpoints in:", config3d.workdir)


Epoch 0: total=20401458.000000 | e_pde=0.221659 | e_bc_zero=0.003863 | e_term=20401458.000000 | e_symm=0.559300 | w={e_bc_zero:0.1, e_pde:1, e_symm:0.5, e_term:1}
Epoch 100: total=20014684.000000 | e_pde=9.991587 | e_bc_zero=3084.725830 | e_term=20014366.000000 | e_symm=0.282889 | w={e_bc_zero:0.1, e_pde:1, e_symm:0.5, e_term:1}
Epoch 200: total=19824072.000000 | e_pde=93.262283 | e_bc_zero=4622.951660 | e_term=19823516.000000 | e_symm=1.464308 | w={e_bc_zero:0.1, e_pde:1, e_symm:0.5, e_term:1}
Epoch 300: total=19648372.000000 | e_pde=134.846771 | e_bc_zero=7940.327637 | e_term=19647444.000000 | e_symm=1.061868 | w={e_bc_zero:0.1, e_pde:1, e_symm:0.5, e_term:1}
Epoch 400: total=19478902.000000 | e_pde=267.266541 | e_bc_zero=9109.353516 | e_term=19477724.000000 | e_symm=0.908597 | w={e_bc_zero:0.1, e_pde:1, e_symm:0.5, e_term:1}
Epoch 500: total=19313382.000000 | e_pde=392.507080 | e_bc_zero=9318.698242 | e_term=19312058.000000 | e_symm=0.401422 | w={e_bc_zero:0.154, e_pde:1.54, e_symm:

In [None]:
!zip -r /content/runs.zip /content/runs


  adding: content/runs/ (stored 0%)
  adding: content/runs/pinn_oe/ (stored 0%)
  adding: content/runs/pinn_oe/alpha_1p0_params/ (stored 0%)
  adding: content/runs/pinn_oe/alpha_1p0_params/array_metadatas/ (stored 0%)
  adding: content/runs/pinn_oe/alpha_1p0_params/array_metadatas/process_0 (deflated 84%)
  adding: content/runs/pinn_oe/alpha_1p0_params/manifest.ocdbt (deflated 5%)
  adding: content/runs/pinn_oe/alpha_1p0_params/_CHECKPOINT_METADATA (deflated 39%)
  adding: content/runs/pinn_oe/alpha_1p0_params/_METADATA (deflated 87%)
  adding: content/runs/pinn_oe/alpha_1p0_params/ocdbt.process_0/ (stored 0%)
  adding: content/runs/pinn_oe/alpha_1p0_params/ocdbt.process_0/manifest.ocdbt (stored 0%)
  adding: content/runs/pinn_oe/alpha_1p0_params/ocdbt.process_0/d/ (stored 0%)
  adding: content/runs/pinn_oe/alpha_1p0_params/ocdbt.process_0/d/24038bb191ad7856d1e76a991ad35ae9 (stored 0%)
  adding: content/runs/pinn_oe/alpha_1p0_params/ocdbt.process_0/d/82b029e07d48961736d72355cfdd648b (s