**RAN ON GOOGLE COLAB**

In [None]:
!pip install ml-collections

Collecting ml-collections
  Downloading ml_collections-1.1.0-py3-none-any.whl.metadata (22 kB)
Downloading ml_collections-1.1.0-py3-none-any.whl (76 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.7/76.7 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ml-collections
Successfully installed ml-collections-1.1.0


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import grad, vmap
from flax import linen as nn
import optax
import matplotlib.pyplot as plt
import numpy as np
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
from jax.nn.initializers import glorot_normal, normal, zeros
from jax import jit, random as jran
import ml_collections
from flax.training import train_state
import pickle
import time
from jax import jit, value_and_grad
import pandas as pd
import seaborn as sns
import orbax.checkpoint as ocp
import os, json, time, pickle, shutil



In [None]:
import jax.numpy as jnp

dtype=jnp.float32
dim = 20
shift_factor= 0.1
# Base 4×4 prototype (Ntogramatzidis, 2003)
A_base = jnp.array([
    [-8.95,  -6.45,   0.00,   0.00],
    [ 2.15,  -0.35,   0.00,   0.00],
    [-10.89, -40.94, -16.10, -7.95],
    [ 8.17,  28.87,   7.07,  -0.20]
], dtype=dtype)

B_base = jnp.array([
    [1., 0.],
    [0., 0.],
    [0., 1.],
    [1., 1.]
], dtype=dtype)

Q_base = jnp.array([
    [ 5.,  4., 13., 16.],
    [ 4.,  5., 11., 14.],
    [13., 11., 34., 42.],
    [16., 14., 42., 52.]
], dtype=dtype)

R_base = jnp.eye(2, dtype=dtype)

S_base = jnp.array([
    [1., 2.],
    [2., 1.],
    [3., 5.],
    [4., 6.]
], dtype=dtype)

# Small utility: block-diagonal assembly 
def blkdiag(*blocks):
    n = len(blocks)
    return jnp.block([
        [blocks[i] if i == j else jnp.zeros((blocks[i].shape[0], blocks[j].shape[1]), dtype=dtype)
            for j in range(n)]
        for i in range(n)
    ])

full = dim // 4         # number of full 4×4 blocks
rem  = dim % 4          # size of the final truncated block 
n_blocks = full + (1 if rem > 0 else 0)

A_blocks, B_blocks, Q_blocks, R_blocks, S_blocks = [], [], [], [], []

for j in range(n_blocks):
    # block size r: 4 for full blocks; 'rem' for the last (if rem > 0)
    r = 4 if (j < full) else (rem if rem > 0 else 4)

    # slice top-left sub-blocks of proper size
    A_j = A_base[:r, :r]
    B_j = B_base[:r, :]          # r×2
    Q_j = Q_base[:r, :r]
    S_j = S_base[:r, :]          # r×2
    R_j = R_base                 # 2×2

    # vary dynamics slightly per block with a diagonal shift
    shift = shift_factor * (j + 1)
    A_j = A_j - shift * jnp.eye(r, dtype=dtype)

    A_blocks.append(A_j)
    B_blocks.append(B_j)
    Q_blocks.append(Q_j)
    R_blocks.append(R_j)
    S_blocks.append(S_j)

# assemble big block-diagonal matrices
A = blkdiag(*A_blocks)                      # (dim, dim)
B = blkdiag(*B_blocks)                      # (dim, 2*n_blocks)
Q = blkdiag(*Q_blocks)                      # (dim, dim)
R = blkdiag(*R_blocks)                      # (2*n_blocks, 2*n_blocks)
S = blkdiag(*S_blocks)                      # (dim, 2*n_blocks)

#  shape checks 
print(A.shape == (dim, dim), f"A shape {A.shape} != {(dim, dim)}")
print(Q.shape == (dim, dim), f"Q shape {Q.shape} != {(dim, dim)}")
print(B.shape[0] == dim and S.shape[0] == dim, "B/S row dims must match 'dim'")
print(R.shape[0] == R.shape[1] == 2 * n_blocks, "R must be 2×2 per block")


Shapes -> A: (20, 20) B: (20, 10) Q: (20, 20) R: (10, 10) S: (20, 10)
R SPD?: True
Popov (Q - S R^{-1} S^T) >= 0 ? True
x0 in [-4,4] and integer? True int32


In [None]:
"""Step 1: NN definition (see Kentaro DC_PINNs_IVS_calibration.ipynb)"""
activation_fn = {
    "tanh": jnp.tanh,
    "sin": jnp.sin,
}

def _get_activation(str):
    if str in activation_fn:
        return activation_fn[str]

    else:
        raise NotImplementedError(f"Activation {str} not supported yet!")

def _weight_fact(init_fn, mean, stddev):
    def init(key, shape):
        key1, key2 = jran.split(key)
        w = init_fn(key1, shape)
        g = mean + normal(stddev)(key2, (shape[-1],))
        g = jnp.exp(g)
        v = w / g
        return g, v

    return init

class Dense(nn.Module):
    features: int
    kernel_init: Callable = glorot_normal()
    bias_init: Callable = zeros
    reparam: Union[None, Dict] = None

    @nn.compact
    def __call__(self, x):
        if self.reparam is None:
            kernel = self.param(
                "kernel", self.kernel_init, (x.shape[-1], self.features)
            )

        elif self.reparam["type"] == "weight_fact":
            g, v = self.param(
                "kernel",
                _weight_fact(
                    self.kernel_init,
                    mean=self.reparam["mean"],
                    stddev=self.reparam["stddev"],
                ),
                (x.shape[-1], self.features),
            )
            kernel = g * v
        bias = self.param("bias", self.bias_init, (self.features,))
        y = jnp.dot(x, kernel) + bias

        return y

class MLP(nn.Module):
    arch_name: Optional[str]="MLP"
    hidden_dim: Tuple[int]=(32, 16)
    out_dim: int=1
    activation: str="tanh"
    periodicity: Union[None, Dict]=None
    fourier_emb: Union[None, Dict]=None
    reparam: Union[None, Dict]=None

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        for i in range(len(self.hidden_dim)):
            x = Dense(features=self.hidden_dim[i], reparam=self.reparam)(x)
            x = self.activation_fn(x)
        x = Dense(features=self.out_dim, reparam=self.reparam)(x)
        return x



def ann_gen(config):
    ann = None
    reparam = None
    if config.ann_reparam=="weight_fact":
        reparam = ml_collections.ConfigDict({"type": "weight_fact", "mean": 0.5, "stddev": 0.1})

    if config.ann_str == "MLP":
        ann = MLP(arch_name=config.ann_str,
                  hidden_dim=config.ann_hidden_dim,
                  out_dim=config.ann_out_dim,
                  activation=config.ann_activation_str,
                  periodicity=config.ann_periodicity,
                  fourier_emb=config.ann_fourier_emb,
                  reparam=reparam)

    return ann


optimizer = "Adam"
beta1 = 0.9
beta2 = 0.999
eps = 1e-8
learning_rate = 1e-3
decay_rate = 0.9
decay_steps = 2000
grad_accum_steps = 0

lr = optax.exponential_decay(
        init_value=learning_rate,
        transition_steps=decay_steps,
        decay_rate=decay_rate,
    )
tx = optax.adam(
    learning_rate=lr, b1=beta1, b2=beta2, eps=eps
)


#####################################################################################
#####################################################################################


def derivatives(fn, x):
    def f(z):
        y = fn(z)             
        return y.squeeze()    
    J = jax.vmap(jax.jacrev(f))(x)   
    return J[:, 0], J[:, 1:]


def quad_form(v, M):
    """Compute v^T M v for each sample in a batch.
       v: (batch, d) M: (d, d) -> (batch,)"""
    return jnp.einsum('bi,ij,bj->b', v, M, v)

def bilinear(u, M, x):
    """Compute u^T M x per sample."""
    return jnp.einsum('bi,ij,bj->b', u, M, x)

def u_star_batch(fn, t_batch, X_batch, A, B, R, S):
    Z = jnp.hstack([t_batch, X_batch]) # (B, 1+d)

    _, V_x = derivatives(fn, Z)

    VB = V_x @ B # (B,m)
    XS = X_batch @ S # (B,m)
    W  = VB + 2.0 * XS # (B,m)

    u = -0.5 * jnp.linalg.solve(R, W.T).T # (B,m)
    return u

def quad_form_solve(v, M):
    """
    Return diag(v @ M^{-1} @ v^T) for a batch.
    v : (batch, m)
    M : (m, m)
    """
    y = jnp.linalg.solve(M, v.T).T # (batch, m) 
    return jnp.einsum('bi,bi->b', v, y)

def single_path_penalty(fn, *, X0, T, Ndt, A, B, R, S):
    dt = T / Ndt # (P,1)
    X0 = jnp.atleast_2d(jnp.asarray(X0, dtype=jnp.float32))

    P, d = X0.shape

    t0 = jnp.zeros((P, 1), dtype=X0.dtype)

    def step(carry, _):
        X, t = carry
        # X:(P,4) t:(P,1)
        tau = T - t
        u = u_star_batch(fn, tau, X, A, B, R, S) # (P,2)
        # A x + B u
        Ax   = jnp.einsum('bi,ji->bj', X, A) # (P,4)
        Bu   = jnp.einsum('bi,ji->bj', u, B) # (P,4)  
        Xdot = Ax + Bu
        X_next = X + dt * Xdot
        t_next = t + dt
        return (X_next, t_next), None

    (X_T, _), _ = jax.lax.scan(step, (X0, t0), None, length=int(Ndt))

    r = jnp.linalg.norm(X_T, axis=1) # (P,)
    loss_per_path = jnp.where(r <= 1.0, r, r**2)
    return jnp.mean(loss_per_path)


def multi_traj_penalty(fn, *, X0, T_list, A, B, R, S, Ndt=50, weights=None):
    T_list = jnp.atleast_1d(jnp.array(T_list))

    penalties = jax.vmap(lambda Ti: single_path_penalty(
        fn, X0=X0, T=Ti, Ndt=Ndt, A=A, B=B, R=R, S=S))(T_list)

    if weights is not None:
        w = jnp.array(weights)
        w = w / jnp.sum(w)
        return jnp.sum(w * penalties)
    return jnp.mean(penalties)

def error(fn, data, Q, A, B, R, S):
    x_tau_pde, x_term_zero, T, X0_batch = data
    tau = x_tau_pde[:, :1]                 
    X   = x_tau_pde[:, 1:]
    Z   = jnp.hstack([tau, X])

    V_tau, V_x = derivatives(fn, Z)        
    term_AX = bilinear(V_x, A, X)
    term_Q  = quad_form(X, Q)
    W       = V_x @ B + 2.0 * (X @ S)
    term_R  = 0.25 * quad_form_solve(W, R)

    # PDE residual
    e_pde = (V_tau + term_AX + term_Q - term_R) ** 2  # τ-form HJB

    # Terminal condition
    e_term_zero = jnp.squeeze(fn(x_term_zero))**2

    # Trajectory penalty
    e_traj = multi_traj_penalty(fn, X0=X0_batch, T_list=jnp.array([T]),
                                A=A, B=B, R=R, S=S, Ndt=50)
    return {"e_pde": e_pde, "e_term_zero": e_term_zero, "e_traj": e_traj}, {
        "e_pde": jnp.mean(e_pde), "e_term_zero": jnp.mean(e_term_zero), "e_traj": e_traj
    }

#####################################################################################
#####################################################################################


def adj(loss, lw, m):  
    return lw * jnp.mean(m * loss)

ALL_COMPONENTS = ['e_pde','e_term_zero', 'e_traj']

def make_loss_lb():
    def loss_fn(config, fn, data, l_ws, params_sa):
        err, metrics = error(fn, data, Q, A, B, R, S)
        loss_terms = {}
        for k in ALL_COMPONENTS:
            if k in err:
                w_k = l_ws.get(k, 1.0)
                m_k = params_sa.get(k, jnp.ones_like(err[k]))
                loss_terms[k] = adj(err[k], w_k, m_k)
        return loss_terms, metrics
    return loss_fn

loss_fn_lb = { "MLP": make_loss_lb(), "PINN": make_loss_lb() }

def make_loss_fn(components_key):
    def loss_fn(config, fn, data, l_ws, params_sa):
        loss_terms, metrics = loss_fn_lb[components_key](config, fn, data, l_ws, params_sa)
        total = jnp.sum(jnp.array(list(loss_terms.values())))
        return total, metrics
    return loss_fn

loss_fn = { "MLP": make_loss_fn("MLP"), "PINN": make_loss_fn("PINN") }

def init_params_sa(loss_str, data):
    x_tau_pde, x_term_zero, _, _ = data

    return {
        "MLP": {'e_pde': jnp.ones(len(x_tau_pde))},
        "PINN": {
            'e_pde': jnp.ones(len(x_tau_pde)),
            'e_term_zero': jnp.ones(len(x_term_zero)),
            'e_traj': jnp.ones(1)

        }
    }[loss_str]

# Initial static weights
init_l_ws = {
    "PINN": {
        'e_pde': 1.,
        'e_term_zero': 0.1,
        'e_traj': 1.,
    }
}


#####################################################################################
#####################################################################################

def sample_X0_near(key, x0, n_samples, radius=1, low=-2, high=2):
    """
    Sample integer vectors near x0 within [-radius, radius] per coordinate,
    clipped to [low, high].

    x0:       (d,) test vector
    n_samples number of extra X0 to generate
    radius:   max integer change per coordinate
    low/high: value bounds for each coordinate
    """
    x0 = jnp.asarray(x0)
    d  = x0.shape[0]
    deltas = jax.random.randint(key, (n_samples, d), -radius, radius+1)
    return jnp.clip(x0 + deltas, low, high).astype(jnp.float32)

def get_data(n_pde=1000, n_term_zero=50, T=1.0, x_ranges=[(-4, 4)]*4, seed=0):
    rng = np.random.default_rng(seed)
    d   = len(x_ranges)                    

    # PDE collocation points
    tau_pde = rng.uniform(1e-3, T, size=(n_pde, 1))          # τ in (0,T]
    x_pde   = np.hstack([rng.uniform(lo, hi, (n_pde, 1)) for lo, hi in x_ranges])
    x_tau_pde = np.hstack([tau_pde, x_pde])

    # terminal points
    x_term_zero = np.hstack([np.zeros((n_term_zero, 1)), np.zeros((n_term_zero, d))])

    x0_test = jnp.array([
        1, -2,  0,  2,
        -1,  2, -2,  2,
        0, -2,  1, -1,
        2,  0,  2, -2,
        2, -2,  1,  0
        ], dtype=jnp.float32)


    key = jax.random.PRNGKey(seed)
    # Sample initial points near x0_test
    X0_batch = sample_X0_near(key, x0_test, n_samples=100, radius=1, low=-2, high=2)
    return jnp.array(x_tau_pde), jnp.array(x_term_zero), T, X0_batch



#####################################################################################
#####################################################################################

# Rebalancer 
class Rebalancer:
    def __init__(
        self, keys,
        ema_beta=0.95,
        power=0.2,
        w_clip=(0.5, 2.0),
        ema_floor=1e-4,
        allow_upweight_small=False,
        freeze_tol=1e-5,
        freeze_patience=10,
        base_weights=None,
        static_keys=None,
        relative_to_init=True,
        eps=1e-12,
    ):
        self.keys = list(keys)
        self.ema_beta = float(ema_beta)
        self.power = float(power)
        self.w_min, self.w_max = map(float, w_clip)
        self.ema_floor = float(ema_floor)
        self.allow_upweight_small = bool(allow_upweight_small)
        self.freeze_tol = float(freeze_tol)
        self.freeze_patience = int(freeze_patience)
        self.base_weights = dict(base_weights or {})
        self.static_keys = set(static_keys or ())
        self.relative_to_init = bool(relative_to_init)
        self.eps = float(eps)
        self.ema = {k: None for k in self.keys}
        self.init_ema = {k: None for k in self.keys}
        self.below_ctr = {k: 0 for k in self.keys}
        self.frozen = {k: False for k in self.keys}

    def _smooth_update(self, old, new):
        if old is None:
            return new
        return self.ema_beta * old + (1.0 - self.ema_beta) * new

    def step(self, metrics, weights):
        for k in self.keys:
            if k not in metrics:
                continue
            val = float(metrics[k])
            val = max(val, self.ema_floor)
            self.ema[k] = self._smooth_update(self.ema[k], val)
            if self.init_ema[k] is None:
                self.init_ema[k] = self.ema[k]
            if self.ema[k] < self.freeze_tol:
                self.below_ctr[k] += 1
            else:
                self.below_ctr[k] = 0
            if self.freeze_patience > 0 and self.below_ctr[k] >= self.freeze_patience:
                self.frozen[k] = True

        present = [k for k in self.keys
                   if (k in metrics and self.ema[k] is not None and not self.frozen[k] and k not in self.static_keys)]
        if not present:
            new_w = dict(weights)
            for k in self.keys:
                if self.frozen[k]:
                    new_w[k] = float(self.base_weights.get(k, new_w.get(k, 1.0)))
            return new_w

        def rel_ema(k):
            if not self.relative_to_init:
                return max(self.ema[k], self.ema_floor)
            denom = max(self.init_ema[k] or self.ema_floor, self.ema_floor)
            return max(self.ema[k] / denom, self.ema_floor)

        rels = {k: rel_ema(k) for k in present}
        gmean = float(np.exp(np.mean([np.log(max(rels[k], self.eps)) for k in present])))

        new_w = dict(weights)
        for k in self.keys:
            if self.frozen[k]:
                new_w[k] = float(self.base_weights.get(k, new_w.get(k, 1.0)))
                continue
            if k in self.static_keys or k not in rels:
                new_w[k] = float(new_w.get(k, 1.0))
                continue
            ratio = max(rels[k] / gmean, self.eps)
            mult = 1.0 if (not self.allow_upweight_small and ratio < 1.0) else ratio ** (-self.power)
            new_w[k] = float(np.clip(new_w.get(k, 1.0) * mult, self.w_min, self.w_max))

        mean_w = float(np.mean([new_w[k] for k in present])) if present else 1.0
        if mean_w > 0:
            for k in present:
                new_w[k] = new_w[k] / mean_w
        return new_w


#####################################################################################
#####################################################################################

# ====== Training ======
def calibration(config, data, params_init=None, l_ws=None, params_sa=None):
    ann = ann_gen(config)
    ofunc = loss_fn[config.loss_str]

    if l_ws is None:
        l_ws = dict(init_l_ws[config.loss_str])
    if params_sa is None:
        params_sa = init_params_sa(config.loss_str, data)

    rebal_cfg = getattr(config, "rebal", None)
    if rebal_cfg is None:
        rebal_cfg = ml_collections.ConfigDict({})

    rebal = Rebalancer(
        keys=[k for k in ALL_COMPONENTS if k in (params_sa.keys())],
        ema_beta=getattr(rebal_cfg, "ema_beta", 0.98),
        power=getattr(rebal_cfg, "power", 0.2),
        w_clip=tuple(getattr(rebal_cfg, "w_clip", (0.3, 3.0))),
        ema_floor=getattr(rebal_cfg, "ema_floor", 1e-3),
        allow_upweight_small=getattr(rebal_cfg, "allow_upweight_small", False),
        freeze_tol=getattr(rebal_cfg, "freeze_tol", 1e-4),
        freeze_patience=getattr(rebal_cfg, "freeze_patience", 10),
        static_keys=set(getattr(rebal_cfg, "static_keys", ())),
        relative_to_init=getattr(rebal_cfg, "relative_to_init", True),
    )

    key = jran.PRNGKey(config.seed)
    _, key_init = jran.split(key, 2)
    dummy = jnp.ones((1, config.ann_in_dim))
    params0 = ann.init(key_init, dummy) if params_init is None else params_init

    state = train_state.TrainState.create(
        apply_fn=ann.apply,
        params=params0,
        tx=optax.adamw(learning_rate=1e-4)
    )

    @jit
    def train_step(state, data, l_ws, params_sa):
        def wrapped_loss(params):
            def fn(x): return ann.apply(params, x)
            return ofunc(config, fn, data, l_ws, params_sa)
        (loss, metrics), grads = value_and_grad(wrapped_loss, has_aux=True)(state.params)
        state = state.apply_gradients(grads=grads)
        return state, loss, metrics

    hist = []
    start_epoch = getattr(rebal_cfg, "start_epoch", 0)
    every = getattr(rebal_cfg, "every", 1000)

    for epoch in range(config.num_epochs):
        state, loss, metrics = train_step(state, data, l_ws, params_sa)
        if (epoch >= start_epoch) and ((epoch - start_epoch) % every == 0):
            l_ws = rebal.step(metrics, l_ws)
        if epoch % 100 == 0:
            w_str = "{"+", ".join(f"{k}:{l_ws.get(k,1.0):.3g}" for k in sorted(l_ws.keys()))+"}"
            keys = ['e_pde','e_term_zero', 'e_traj']
            msg  = " | ".join(f"{k}={float(metrics[k]):.6f}" for k in keys if k in metrics)
            print(f"Epoch {epoch}: total={float(loss):.6f} | {msg} | w={w_str}")
        hist.append((int(epoch), float(loss), {k: float(metrics[k]) for k in metrics}))

    return (lambda x: ann.apply(state.params, x)), hist, state.params, ann


# saving helpers
def save_params_overwrite(path, params):
    path = os.path.abspath(path)
    if os.path.exists(path):
        shutil.rmtree(path)
    ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
    try:
        ckptr.save(path, params, force=True)
    except TypeError:
        ckptr.save(path, params)

def _to_jsonable(obj):
    if isinstance(obj, ml_collections.ConfigDict):
        return {k: _to_jsonable(v) for k, v in obj.items()}
    if isinstance(obj, dict):
        return {k: _to_jsonable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_to_jsonable(v) for v in obj]
    if isinstance(obj, (str, int, float, bool)) or obj is None:
        return obj
    return str(obj)

def save_stage(run_dir, tag, *, params, config, history):
    run_dir = os.path.abspath(run_dir)
    os.makedirs(run_dir, exist_ok=True)
    save_params_overwrite(os.path.join(run_dir, f"{tag}_params"), params)
    with open(os.path.join(run_dir, f"{tag}_config.json"), "w") as f:
        json.dump(_to_jsonable(config), f, indent=2)
    with open(os.path.join(run_dir, f"{tag}_history.pkl"), "wb") as f:
        pickle.dump(history, f)

def restore_params(run_dir, tag):
    ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
    return ckptr.restore(os.path.abspath(os.path.join(run_dir, f"{tag}_params")))



def run_experiment(config, workdir=None):
    wd = os.path.abspath(workdir or getattr(config, "workdir", "runs/pinn_LQR"))
    os.makedirs(wd, exist_ok=True)
    data = get_data(
        n_pde=config.pts_num,
        n_term_zero=config.n_term_zero,
        T=config.T,
        x_ranges=config.x_ranges,
        seed=config.seed,
    )

    _, hist, params, _ = calibration(config, data)
    tag = f"MT-PINN_LQR"

    save_stage(wd, tag, params=params, config=config, history=hist)
    return params, {tag: hist}



#####################################################################################
#####################################################################################

# Config
config_PINN = ml_collections.ConfigDict({
    "pts_num": 30000,
    "n_term_zero": 1000,
    "x_ranges": [(-2, 2)] * 20,  
    "T": 1.0,
    "seed": 42,
    "ann_str": "MLP",
    "loss_str": "PINN",
    "num_epochs": 50000,  
    "ann_in_dim": 21,  
    "ann_out_dim": 1,
    "ann_activation_str": "tanh",
    "ann_periodicity": None,
    "ann_fourier_emb": None,
    "ann_reparam": False,
    "ann_hidden_dim": (128, 128, 128),
    "workdir": os.path.abspath("runs/pinn_LQR"),
    "rebal": ml_collections.ConfigDict({
        "start_epoch": 500,
        "every": 5000,
        "ema_beta": 0.9,
        "power": 0.3,
        "w_clip": (0.1, 5.0),
    }),

})

def build_model_fn(config, params):
    ann = ann_gen(config)
    return jax.jit(lambda x: ann.apply(params, x))

final_params, all_hist = run_experiment(config_PINN, workdir=config_PINN.workdir)
model_PINN = build_model_fn(config_PINN, final_params)

Epoch 0: total=154.313370 | e_pde=153.679535 | e_term_zero=0.000000 | e_traj=0.633828 | w={e_pde:1, e_term_zero:0.1, e_traj:1}
Epoch 100: total=12.104628 | e_pde=11.405900 | e_term_zero=0.002344 | e_traj=0.698493 | w={e_pde:1, e_term_zero:0.1, e_traj:1}
Epoch 200: total=6.148886 | e_pde=5.443377 | e_term_zero=0.000187 | e_traj=0.705491 | w={e_pde:1, e_term_zero:0.1, e_traj:1}
Epoch 300: total=3.965752 | e_pde=3.257679 | e_term_zero=0.000001 | e_traj=0.708073 | w={e_pde:1, e_term_zero:0.1, e_traj:1}
Epoch 400: total=2.961757 | e_pde=2.252255 | e_term_zero=0.000047 | e_traj=0.709497 | w={e_pde:1, e_term_zero:0.1, e_traj:1}
Epoch 500: total=2.417915 | e_pde=1.707290 | e_term_zero=0.000074 | e_traj=0.710618 | w={e_pde:1.43, e_term_zero:0.143, e_traj:1.43}
Epoch 600: total=2.838758 | e_pde=1.275808 | e_term_zero=0.000055 | e_traj=0.711317 | w={e_pde:1.43, e_term_zero:0.143, e_traj:1.43}
Epoch 700: total=2.468810 | e_pde=1.016337 | e_term_zero=0.000030 | e_traj=0.711827 | w={e_pde:1.43, e_te

In [None]:
!zip -r /content/runs.zip /content/runs


  adding: content/runs/ (stored 0%)
  adding: content/runs/pinn_LQR/ (stored 0%)
  adding: content/runs/pinn_LQR/MT-PINN_LQR_history.pkl (deflated 62%)
  adding: content/runs/pinn_LQR/MT-PINN_LQR_config.json (deflated 74%)
  adding: content/runs/pinn_LQR/MT-PINN_LQR_params/ (stored 0%)
  adding: content/runs/pinn_LQR/MT-PINN_LQR_params/d/ (stored 0%)
  adding: content/runs/pinn_LQR/MT-PINN_LQR_params/d/0484d54260de310bf7a848ce13260504 (stored 0%)
  adding: content/runs/pinn_LQR/MT-PINN_LQR_params/_CHECKPOINT_METADATA (deflated 38%)
  adding: content/runs/pinn_LQR/MT-PINN_LQR_params/ocdbt.process_0/ (stored 0%)
  adding: content/runs/pinn_LQR/MT-PINN_LQR_params/ocdbt.process_0/d/ (stored 0%)
  adding: content/runs/pinn_LQR/MT-PINN_LQR_params/ocdbt.process_0/d/32273d0bbdaca8e19e0aad517180846b (stored 0%)
  adding: content/runs/pinn_LQR/MT-PINN_LQR_params/ocdbt.process_0/d/e87bc091ee1128e4d4fee3ce35b653ce (stored 0%)
  adding: content/runs/pinn_LQR/MT-PINN_LQR_params/ocdbt.process_0/d/017