In [1]:
import equinox as eqx
import jax
import jax.numpy as jnp

In [2]:
class Equation(object):
    """Base class for defining PDE related function."""
    def __init__(self, eqn_config):
        self.dim = eqn_config['dim']
        self.total_time = eqn_config['total_time']
        self.Ndis = eqn_config['Ndis']
        self.delta_t = self.total_time / self.Ndis
        self.sqrt_delta_t = jnp.sqrt(self.delta_t)
        self.y_init = None

    def sample(self, num_sample):
        """Sample forward SDE."""
        raise NotImplementedError

    def f_tf(self, t, x, y, z):
        """Generator function in the PDE."""
        raise NotImplementedError

    def g_tf(self, t, x):
        """Terminal condition of the PDE."""
        raise NotImplementedError

In [15]:
class HJBLQ(Equation):
    """HJB equation in PNAS paper doi.org/10.1073/pnas.1718942115"""
    def __init__(self, eqn_config):
        super(HJBLQ, self).__init__(eqn_config)
        self.x_init = jnp.zeros(self.dim)
        self.sigma = jnp.sqrt(2.0)
        self.lambd = 1.0

    def sample(self, num_sample):
        key = jax.random.PRNGKey(23)
        dw_sample = jax.random.normal(key,shape=[num_sample, self.dim, self.Ndis]) * self.sqrt_delta_t
        x_sample = jnp.zeros([num_sample, self.dim, self.Ndis + 1])
        x_sample.at[:, :, 0].set(jnp.ones([num_sample, self.dim]) * self.x_init)
        for i in range(self.Ndis):
            x_sample=x_sample.at[:, :, i + 1].set(x_sample[:, :, i] + self.sigma * dw_sample[:, :, i])
        return dw_sample, x_sample

    def f_tf(self, t, x, y, z):
        return -self.lambd * jnp.sum(z**2) / 2

    def g_tf(self, t, x):
        return jnp.log((1 + jnp.sum(x**2) / 2))

In [None]:
class FeedForwardSubNet(eqx.Module):
    layers:  list
    
    def __init__(self, key,eqn_config,config_subnet):
        dim = eqn_config["dim"]
        num_hiddens = config_subnet.num_hiddens
        self.layers = [
            eqx.experimental.BatchNorm(input_size=4, axis_name="batch",momentum=0.99,eps=1e-6)
            
            
        ]
        self.bn_layers = [
            tf.keras.layers.BatchNormalization(
                momentum=0.99,
                epsilon=1e-6,
                beta_initializer=tf.random_normal_initializer(0.0, stddev=0.1),
                gamma_initializer=tf.random_uniform_initializer(0.1, 0.5)
            )
            for _ in range(len(num_hiddens) + 2)]
        self.dense_layers = [tf.keras.layers.Dense(num_hiddens[i],
                                                   use_bias=False,
                                                   activation=None)
                             for i in range(len(num_hiddens))]
        # final output should be gradient of size dim
        self.dense_layers.append(tf.keras.layers.Dense(dim, activation=None))

    def call(self, x, training):
        """structure: bn -> (dense -> bn -> relu) * len(num_hiddens) -> dense -> bn"""
        x = self.bn_layers[0](x, training)
        for i in range(len(self.dense_layers) - 1):
            x = self.dense_layers[i](x)
            x = self.bn_layers[i+1](x, training)
            x = tf.nn.relu(x)
        x = self.dense_layers[-1](x)
        x = self.bn_layers[-1](x, training)
        return x

In [16]:
eqn_config={"dim":1,"total_time":1.0,"Ndis":20}
eqn=HJBLQ(eqn_config)

In [21]:
key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3)

In [23]:
key2

Array([3186719485, 3840466878], dtype=uint32)