In [33]:
import equinox as eqx
import jax
import jax.numpy as jnp
from jax import random
import numpy as np

In [44]:
class HJBEquation():
    """HJB equation in PNAS paper doi.org/10.1073/pnas.1718942115"""
    def __init__(self, eqn_config):
        self.dim=eqn_config["dim"]
        self.total_time=eqn_config["total_time"]
        self.Ndis=eqn_config["Ndis"]
        self.delta_t=self.total_time/self.Ndis
        self.sqrt_delta_t = np.sqrt(self.delta_t)
        self.x_init = np.zeros(self.dim)
        self.sigma = np.sqrt(2.0)
        self.lambd = 1.0
        self.key=random.PRNGKey(758493)

    def sample(self, num_sample):
        self.key, subkey = random.split(self.key)
        dw_sample =np.random.normal(size=(num_sample, self.dim, self.Ndis)) * self.sqrt_delta_t
        x_sample = np.zeros([num_sample, self.dim, self.Ndis + 1])
        x_sample[:, :, 0] = np.ones([num_sample, self.dim]) * self.x_init
        for i in range(self.Ndis):
            x_sample[:, :, i + 1] = x_sample[:, :, i] + self.sigma * dw_sample[:, :, i]
        return dw_sample, x_sample

    def f_tf(self, t, x, y, z):
        return -self.lambd * jnp.sum(jnp.square(z), axis=1, keepdims=True) / 2

    def g_tf(self, t, x):
        return jnp.log((1 + jnp.sum(jnp.square(x), axis=1, keepdims=True)) / 2)
    
class GlobalModel(eqx.Module):
    def __init__(self, net_config, eqn):
        super(GlobalModel, self).__init__()
        self.net_config = net_config
        self.eqn = eqn
        self.y_0=nn.Parameter(torch.rand(1))
        self.z_0=nn.Parameter((torch.rand((1,self.eqn.dim))*0.2)-0.1)
        self.subnet = [FF_subnet(self.eqn,net_config) for _ in range(self.eqn.Ndis-1)]

    def forward(self, inputs):
        dw, x = inputs
        time_stamp = np.arange(0, self.eqn.Ndis) * self.eqn.delta_t
        all_one_vec = torch.ones((dw.shape[0], 1))
        y = all_one_vec * self.y_0
        z = torch.matmul(all_one_vec, self.z_0)

        for t in range(0, self.eqn.Ndis-1):
            y = y - self.eqn.delta_t * (
                self.eqn.f_tf(time_stamp[t], x[:, :, t], y, z)
            ) + torch.sum(z * dw[:, :, t], 1, keepdims=True)
            z = self.subnet[t](x[:, :, t + 1]) / self.eqn.dim
        # terminal time
        y = y - self.eqn.delta_t * self.eqn.f_tf(time_stamp[-1], x[:, :, -2], y, z) + \
            torch.sum(z * dw[:, :, -1], 1, keepdims=True)
        return y


class FF_subnet(eqx.Module):
    def __init__(self, eqn,net_config):
        super(FF_subnet, self).__init__()
        self.dim = eqn.dim
        self.num_hiddens = net_config["num_hiddens"]
        self.net=nn.Sequential(
                            nn.BatchNorm1d(self.dim, eps=1e-06, momentum=0.99,affine=False,dtype=types),
                            nn.Linear(self.dim,self.dim+10,bias=False),
                            nn.BatchNorm1d(self.dim+10, eps=1e-06, momentum=0.99,affine=False,dtype=types),
                            nn.ReLU(),
                            nn.Linear(self.dim+10,self.dim+10,bias=False),
                            nn.BatchNorm1d(self.dim+10, eps=1e-06, momentum=0.99,affine=False,dtype=types),
                            nn.ReLU(),
                            nn.Linear(self.dim+10,self.dim,bias=False),
                            nn.BatchNorm1d(self.dim, eps=1e-06, momentum=0.99,affine=False,dtype=types)
                            )
    def forward(self,x):
        return self.net(x)    
    
class BSDESolver(object):
    """The fully connected neural network model."""
    def __init__(self,eqn, net_config):
        self.net_config = net_config
        self.eqn = eqn

        #self.model = torch.compile(GlobalModel(net_config, eqn), mode="max-autotune")
        self.model = GlobalModel(net_config, eqn)
        self.y_0 = self.model.y_0
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.01)
        #lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            #self.net_config.lr_boundaries, self.net_config.lr_values)
        #self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, epsilon=1e-8)

    def train(self):
        start_time = time.time()
        training_history = []
        valid_data = self.eqn.sample(256)

        # begin sgd iteration
        for step in range(2001):
            #print(step)
            inputs=self.eqn.sample(64)
            results=self.model(inputs)
            loss=self.loss_fn(inputs,results)
            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()
            
            if step % 200==0:
                loss = self.loss_fn(valid_data, self.model(valid_data)).detach().numpy()
                y_init = self.y_0.detach().numpy()[0]
                elapsed_time = time.time() - start_time
                training_history.append([step, loss, y_init, elapsed_time])
                print("Epoch ",step, " y_0 ",y_init," time ", elapsed_time," loss ", loss)

        return np.array(training_history)

    def loss_fn(self, inputs,results):
        DELTA_CLIP = 50.0
        dw, x = inputs
        y_terminal = self.model(inputs)
        delta = results - self.eqn.g_tf(self.eqn.total_time, x[:, :, -1])
        # use linear approximation outside the clipped range
        loss = torch.mean(torch.where(torch.abs(delta) < DELTA_CLIP, torch.square(delta),
                                       2 * DELTA_CLIP * torch.abs(delta) - DELTA_CLIP ** 2))

        return loss


In [45]:
eqn=HJBEquation({"dim":100,"total_time":1.0,"Ndis":20})

In [46]:
eqn.sample(10)

(array([[[-0.13107446,  0.02267763, -0.03232467, ..., -0.00603461,
           0.05157448, -0.05713857],
         [ 0.17831186, -0.22858924,  0.03958036, ..., -0.19402191,
           0.23676424, -0.01694009],
         [-0.28806086, -0.25812933, -0.04439494, ..., -0.41843162,
           0.03761025,  0.34126305],
         ...,
         [-0.09962596,  0.02040972,  0.18814015, ..., -0.19727514,
           0.17905706, -0.05603662],
         [-0.06603878,  0.04707486,  0.31875176, ..., -0.06467424,
          -0.41317003, -0.2223989 ],
         [ 0.09717593,  0.00421439, -0.26804713, ...,  0.22567109,
           0.12604418,  0.30575494]],
 
        [[-0.00561743, -0.02884686,  0.20773038, ...,  0.12909941,
          -0.15114446,  0.16043468],
         [-0.13266738, -0.35348947, -0.08268627, ...,  0.18449425,
           0.36994297, -0.29058522],
         [-0.16980532, -0.21608765, -0.00421219, ...,  0.21589302,
           0.27888215,  0.41125244],
         ...,
         [ 0.25579357, -0.0596635

In [21]:
key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3)

In [23]:
key2

Array([3186719485, 3840466878], dtype=uint32)