In [4]:
import os
import jax
import jax.numpy as jnp
import numpy as np

In [2]:
def encode(encoder_coeff, encoder_exp, y):
    """
    Encode from observation to reduced state, with batch dimension last.
    """
    if y.ndim == 1:
        y = jnp.expand_dims(y, -1)
        single_input = True
    else:
        single_input = False

    n, n_coeff = encoder_coeff.shape
    results = []

    # Loop over each dimension to compute the reduced state
    for dim in range(n):
        polynomial = 0
        for j in range(n_coeff):
            exponents = jnp.expand_dims(encoder_exp[j, :], axis=-1)
            term = encoder_coeff[dim, j] * jnp.prod(y ** exponents, axis=0, keepdims=True)
            polynomial += term
        results.append(polynomial)

    x = jnp.concatenate(results, axis=0)

    if single_input:
        return x.squeeze(-1)
    return x

In [10]:
def decode(decoder_coeff, decoder_exp, x):
    """
    Decode from reduced state to observation, with batch dimension last.
    """
    if x.ndim == 1:
        x = jnp.expand_dims(x, -1)
        single_input = True
    else:
        single_input = False

    p, p_coeff = decoder_coeff.shape
    results = []

    # Loop over each dimension to compute the observation
    for obs_dim in range(p):
        polynomial = 0
        for j in range(p_coeff):
            exponents = jnp.expand_dims(decoder_exp[j, :], axis=-1)
            term = decoder_coeff[obs_dim, j] * jnp.prod(x ** exponents, axis=0, keepdims=True)
            polynomial += term
        results.append(polynomial)

    y = jnp.concatenate(results, axis=0)

    if single_input:
        return y.squeeze(-1)
    return y

In [12]:
def reduced_dynamics(dynamics_coeff, dynamics_exp, x):
    """
    Evaluate the continuous-time dynamics of the reduced system, with batch dimension last.
    """
    if x.ndim == 1:
        x = jnp.expand_dims(x, -1)
        single_input = True
    else:
        single_input = False

    n, n_coeff = dynamics_coeff.shape
    results = []

    # Loop over each dimension to compute the derivative
    for dim in range(n):
        polynomial = 0
        for j in range(n_coeff):
            exponents = jnp.expand_dims(dynamics_exp[j, :], axis=-1)
            term = dynamics_coeff[dim, j] * jnp.prod(x ** exponents, axis=0, keepdims=True)
            polynomial += term
        results.append(polynomial)
    
    x_dot = jnp.concatenate(results, axis=0)

    if single_input:
        return x_dot.squeeze(-1)
    return x_dot

In [7]:
model_name = 'origin_ssmr_200g'
data_dir = os.getenv('TRUNK_DATA', '/home/trunk/Documents/trunk-stack/stack/main/data')
npz_filepath = os.path.join(data_dir, f'models/ssmr/{model_name}.npz')
data = np.load(npz_filepath)
encoder_coeff, encoder_exp = data['Vfinal'], data['exps_V']

In [9]:
y = jnp.array([0.1, 0.2, -0.6, 0.2, 0.5, -0.2, 0.5, 0.8, -0.2, 0.1, 0.05, -0.1])
xi = encode(encoder_coeff, encoder_exp, y)

Array([ 0.4902135 ,  0.74074817,  0.5681536 , -0.34081072, -0.00973913],      dtype=float32)

In [11]:
decoder_coeff, decoder_exp = data['M'], data['exps']
decode(decoder_coeff, decoder_exp, xi)

Array([ 2.4772573e-01,  2.9616034e-01, -5.0710428e-01,  1.6720746e-01,
        3.4380329e-01, -3.1249958e-01,  1.3570189e-01,  4.1823369e-01,
       -2.7795431e-01,  3.5379633e-01,  5.1612693e-01, -3.8000019e-04],      dtype=float32)

In [13]:
dynamics_coeff, dynamics_exp = data['R'], data['exps_r']
reduced_dynamics(dynamics_coeff, dynamics_exp, xi)

Array([-30.684515  ,  -1.8484488 ,  -2.1302652 ,   0.74905264,
        -0.5363502 ], dtype=float32)

In [15]:
B_r = data['B_red']
B_r.T @ xi

Array([-0.14280866,  0.03240603, -0.25879735,  0.22471327,  0.1076438 ,
        0.09430611], dtype=float32)