In [1]:
import numpy as np
import torch
from torch import nn
from torch import optim
import time
types=torch.float32
torch.set_default_dtype(types)

In [2]:
class HJBEquation():
    """HJB equation in PNAS paper doi.org/10.1073/pnas.1718942115"""
    def __init__(self, eqn_config):
        self.dim=eqn_config["dim"]
        self.total_time=eqn_config["total_time"]
        self.Ndis=eqn_config["Ndis"]
        self.delta_t=self.total_time/self.Ndis
        self.sqrt_delta_t = np.sqrt(self.delta_t)
        self.x_init = torch.zeros(self.dim,requires_grad=False)
        self.sigma = np.sqrt(2.0)
        self.lambd = 1.0

    def sample(self, num_sample):
        dw_sample =torch.tensor( np.random.normal(size=[num_sample, self.dim, self.Ndis]) * self.sqrt_delta_t,requires_grad=False)
        x_sample = torch.zeros([num_sample, self.dim, self.Ndis + 1],requires_grad=False)
        x_sample[:, :, 0] = torch.ones([num_sample, self.dim]) * self.x_init
        for i in range(self.Ndis):
            x_sample[:, :, i + 1] = x_sample[:, :, i] + self.sigma * dw_sample[:, :, i]
        return dw_sample, x_sample

    def f_tf(self, t, x, y, z):
        return -self.lambd * torch.sum(torch.square(z), 1, keepdims=True) / 2

    def g_tf(self, t, x):
        return torch.log((1 + torch.sum(torch.square(x), 1, keepdims=True)) / 2)

In [3]:
class GlobalModel(nn.Module):
    def __init__(self, net_config, eqn):
        super(GlobalModel, self).__init__()
        self.net_config = net_config
        self.eqn = eqn
        self.y_0=nn.Parameter(torch.rand(1))
        self.z_0=nn.Parameter((torch.rand((1,self.eqn.dim))*0.2)-0.1)
        self.subnet = [FF_subnet(self.eqn,net_config) for _ in range(self.eqn.Ndis-1)]

    def forward(self, inputs):
        dw, x = inputs
        time_stamp = np.arange(0, self.eqn.Ndis) * self.eqn.delta_t
        all_one_vec = torch.ones((dw.shape[0], 1))
        y = all_one_vec * self.y_0
        z = torch.matmul(all_one_vec, self.z_0)

        for t in range(0, self.eqn.Ndis-1):
            y = y - self.eqn.delta_t * (
                self.eqn.f_tf(time_stamp[t], x[:, :, t], y, z)
            ) + torch.sum(z * dw[:, :, t], 1, keepdims=True)
            z = self.subnet[t](x[:, :, t + 1]) / self.eqn.dim
        # terminal time
        y = y - self.eqn.delta_t * self.eqn.f_tf(time_stamp[-1], x[:, :, -2], y, z) + \
            torch.sum(z * dw[:, :, -1], 1, keepdims=True)
        return y


class FF_subnet(nn.Module):
    def __init__(self, eqn,net_config):
        super(FF_subnet, self).__init__()
        self.dim = eqn.dim
        self.num_hiddens = net_config["num_hiddens"]
        self.net=nn.Sequential(
                            nn.BatchNorm1d(self.dim, eps=1e-06, momentum=0.99,affine=False,dtype=types),
                            nn.Linear(self.dim,self.dim+10,bias=False),
                            nn.BatchNorm1d(self.dim+10, eps=1e-06, momentum=0.99,affine=False,dtype=types),
                            nn.ReLU(),
                            nn.Linear(self.dim+10,self.dim+10,bias=False),
                            nn.BatchNorm1d(self.dim+10, eps=1e-06, momentum=0.99,affine=False,dtype=types),
                            nn.ReLU(),
                            nn.Linear(self.dim+10,self.dim,bias=False),
                            nn.BatchNorm1d(self.dim, eps=1e-06, momentum=0.99,affine=False,dtype=types)
                            )
    def forward(self,x):
        return self.net(x)

In [4]:
class BSDESolver(object):
    """The fully connected neural network model."""
    def __init__(self,eqn, net_config):
        self.net_config = net_config
        self.eqn = eqn

        self.model = torch.compile(GlobalModel(net_config, eqn), mode="max-autotune")
        #self.model = GlobalModel(net_config, eqn)
        self.y_0 = self.model.y_0
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.01)
        #lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            #self.net_config.lr_boundaries, self.net_config.lr_values)
        #self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, epsilon=1e-8)

    def train(self):
        start_time = time.time()
        training_history = []
        valid_data = self.eqn.sample(256)

        # begin sgd iteration
        for step in range(2001):
            #print(step)
            inputs=self.eqn.sample(64)
            results=self.model(inputs)
            loss=self.loss_fn(inputs,results)
            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()
            
            if step % 200==0:
                loss = self.loss_fn(valid_data, self.model(valid_data)).detach().numpy()
                y_init = self.y_0.detach().numpy()[0]
                elapsed_time = time.time() - start_time
                training_history.append([step, loss, y_init, elapsed_time])
                print("Epoch ",step, " y_0 ",y_init," time ", elapsed_time," loss ", loss)

        return np.array(training_history)

    def loss_fn(self, inputs,results):
        DELTA_CLIP = 50.0
        dw, x = inputs
        y_terminal = self.model(inputs)
        delta = results - self.eqn.g_tf(self.eqn.total_time, x[:, :, -1])
        # use linear approximation outside the clipped range
        loss = torch.mean(torch.where(torch.abs(delta) < DELTA_CLIP, torch.square(delta),
                                       2 * DELTA_CLIP * torch.abs(delta) - DELTA_CLIP ** 2))

        return loss


In [5]:
eqn=HJBEquation({"dim":100,"total_time":1.0,"Ndis":20})

In [6]:
bsde_solver = BSDESolver(eqn,{"num_hiddens":2,"dtype":'float64'})

In [7]:
bsde_solver.train()

Epoch  0  y_0  0.5370314  time  32.45413875579834  loss  16.475947029590582
Epoch  200  y_0  2.2282743  time  35.9521369934082  loss  4.223927288148565
Epoch  400  y_0  3.0626101  time  39.43185591697693  loss  2.2454835320105024
Epoch  600  y_0  3.7620978  time  42.91301870346069  loss  0.9255054097937361
Epoch  800  y_0  4.309346  time  46.612215518951416  loss  0.1655064708540866
Epoch  1000  y_0  4.5362577  time  50.06212258338928  loss  0.034993309278345756
Epoch  1200  y_0  4.587101  time  53.530511140823364  loss  0.030660184217406756
Epoch  1400  y_0  4.5949526  time  57.07760787010193  loss  0.030481080378026434
Epoch  1600  y_0  4.5955005  time  60.54098987579346  loss  0.030733773640204964
Epoch  1800  y_0  4.5934854  time  64.00586438179016  loss  0.029663833406598176
Epoch  2000  y_0  4.5949993  time  67.45974493026733  loss  0.030241390493593784


array([[0.00000000e+00, 1.64759470e+01, 5.37031412e-01, 3.24541388e+01],
       [2.00000000e+02, 4.22392729e+00, 2.22827435e+00, 3.59521370e+01],
       [4.00000000e+02, 2.24548353e+00, 3.06261015e+00, 3.94318559e+01],
       [6.00000000e+02, 9.25505410e-01, 3.76209784e+00, 4.29130187e+01],
       [8.00000000e+02, 1.65506471e-01, 4.30934620e+00, 4.66122155e+01],
       [1.00000000e+03, 3.49933093e-02, 4.53625774e+00, 5.00621226e+01],
       [1.20000000e+03, 3.06601842e-02, 4.58710098e+00, 5.35305111e+01],
       [1.40000000e+03, 3.04810804e-02, 4.59495258e+00, 5.70776079e+01],
       [1.60000000e+03, 3.07337736e-02, 4.59550047e+00, 6.05409899e+01],
       [1.80000000e+03, 2.96638334e-02, 4.59348536e+00, 6.40058644e+01],
       [2.00000000e+03, 3.02413905e-02, 4.59499931e+00, 6.74597449e+01]])

In [8]:
bsde_solver.y_0

Parameter containing:
tensor([4.5950], requires_grad=True)