In [1]:
import numpy as np
import torch
from torch import nn
from torch import optim
import time
types=torch.float32
torch.set_default_dtype(types)

In [214]:
class HJBEquation():
    """HJB equation in PNAS paper doi.org/10.1073/pnas.1718942115"""
    def __init__(self, eqn_config):
        self.dim=eqn_config["dim"]
        self.total_time=eqn_config["total_time"]
        self.Ndis=eqn_config["Ndis"]
        self.delta_t=self.total_time/self.Ndis
        self.sqrt_delta_t = np.sqrt(self.delta_t)
        self.x_init = torch.zeros(self.dim,requires_grad=False)
        self.sigma = np.sqrt(2.0)
        self.lambd = 1.0

    def sample(self, num_sample):
        dw_sample =torch.tensor( np.random.normal(size=[num_sample, self.dim, self.Ndis]) * self.sqrt_delta_t,requires_grad=False)
        x_sample = torch.zeros([num_sample, self.dim, self.Ndis + 1],requires_grad=False)
        x_sample[:, :, 0] = torch.ones([num_sample, self.dim]) * self.x_init
        for i in range(self.Ndis):
            x_sample[:, :, i + 1] = x_sample[:, :, i] + self.sigma * dw_sample[:, :, i]
        return dw_sample, x_sample

    def f_tf(self, t, x, y, z):
        return -self.lambd * torch.sum(torch.square(z), 1, keepdims=True) / 2

    def g_tf(self, x):
        return torch.log((1 + torch.sum(torch.square(x), 1, keepdims=True)) / 2)

In [3]:
class GlobalModelDeepBSDE(nn.Module):
    def __init__(self, net_config, eqn):
        super(GlobalModelDeepBSDE, self).__init__()
        self.net_config = net_config
        self.eqn = eqn
        self.y_0=nn.Parameter(torch.rand(1))
        self.z_0=nn.Parameter((torch.rand((1,self.eqn.dim))*0.2)-0.1)
        self.subnet = [FF_subnet_DBSDE(self.eqn,net_config) for _ in range(self.eqn.Ndis-1)]

    def forward(self, inputs):
        dw, x = inputs
        time_stamp = np.arange(0, self.eqn.Ndis) * self.eqn.delta_t
        all_one_vec = torch.ones((dw.shape[0], 1))
        y = all_one_vec * self.y_0
        z = torch.matmul(all_one_vec, self.z_0)

        for t in range(0, self.eqn.Ndis-1):
            y = y - self.eqn.delta_t * (
                self.eqn.f_tf(time_stamp[t], x[:, :, t], y, z)
            ) + torch.sum(z * dw[:, :, t], 1, keepdims=True)
            z = self.subnet[t](x[:, :, t + 1]) / self.eqn.dim
        # terminal time
        y = y - self.eqn.delta_t * self.eqn.f_tf(time_stamp[-1], x[:, :, -2], y, z) + \
            torch.sum(z * dw[:, :, -1], 1, keepdims=True)
        return y


class FF_subnet_DBSDE(nn.Module):
    def __init__(self, eqn,net_config):
        super(FF_subnet_DBSDE, self).__init__()
        self.dim = eqn.dim
        self.num_hiddens = net_config["num_hiddens"]
        self.net=nn.Sequential(
                            nn.BatchNorm1d(self.dim, eps=1e-06, momentum=0.99,affine=False,dtype=types),
                            nn.Linear(self.dim,self.dim+10,bias=False),
                            nn.BatchNorm1d(self.dim+10, eps=1e-06, momentum=0.99,affine=False,dtype=types),
                            nn.ReLU(),
                            nn.Linear(self.dim+10,self.dim+10,bias=False),
                            nn.BatchNorm1d(self.dim+10, eps=1e-06, momentum=0.99,affine=False,dtype=types),
                            nn.ReLU(),
                            nn.Linear(self.dim+10,self.dim,bias=False),
                            nn.BatchNorm1d(self.dim, eps=1e-06, momentum=0.99,affine=False,dtype=types)
                            )
    def forward(self,x):
        return self.net(x)
    
class BSDESolver(object):
    """The fully connected neural network model."""
    def __init__(self,eqn, net_config):
        self.net_config = net_config
        self.eqn = eqn

        #self.model = torch.compile(GlobalModel(net_config, eqn), mode="max-autotune")
        self.model = GlobalModelDeepBSDE(net_config, eqn)
        self.y_0 = self.model.y_0
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.01)
        #lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            #self.net_config.lr_boundaries, self.net_config.lr_values)
        #self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, epsilon=1e-8)

    def train(self):
        start_time = time.time()
        training_history = []
        valid_data = self.eqn.sample(256)

        # begin sgd iteration
        for step in range(2001):
            #print(step)
            inputs=self.eqn.sample(64)
            results=self.model(inputs)
            loss=self.loss_fn(inputs,results)
            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()
            
            if step % 200==0:
                loss = self.loss_fn(valid_data, self.model(valid_data)).detach().numpy()
                y_init = self.y_0.detach().numpy()[0]
                elapsed_time = time.time() - start_time
                training_history.append([step, loss, y_init, elapsed_time])
                print("Epoch ",step, " y_0 ",y_init," time ", elapsed_time," loss ", loss)

        return np.array(training_history)

    def loss_fn(self, inputs,results):
        DELTA_CLIP = 50.0
        dw, x = inputs
        y_terminal = self.model(inputs)
        delta = results - self.eqn.g_tf(x[:, :, -1])
        # use linear approximation outside the clipped range
        loss = torch.mean(torch.where(torch.abs(delta) < DELTA_CLIP, torch.square(delta),
                                       2 * DELTA_CLIP * torch.abs(delta) - DELTA_CLIP ** 2))

        return loss

In [4]:
class GlobalModelMergedDeepBSDE(nn.Module):
    def __init__(self, net_config, eqn):
        super(GlobalModelMergedDeepBSDE, self).__init__()
        self.net_config = net_config
        self.eqn = eqn
        self.dim=eqn.dim
        self.model=FF_subnet_Merged_DBSDE(self.eqn,net_config)
        self.y_0=nn.Parameter(torch.rand(1))

    def forward(self, inputs):
        dw, x = inputs
        time_stamp = np.arange(0, self.eqn.Ndis) * self.eqn.delta_t
        all_one_vec = torch.ones((dw.shape[0], 1))
        y = all_one_vec * self.y_0
        #z = torch.matmul(all_one_vec, self.z_0)

        for t in range(0, self.eqn.Ndis):
            inp=torch.hstack((time_stamp[t]*all_one_vec,x[:,:,t]))
            z = self.model(inp) / self.eqn.dim
            y = y - self.eqn.delta_t * (
                self.eqn.f_tf(time_stamp[t], x[:, :, t], y, z)
            ) + torch.sum(z * dw[:, :, t], 1, keepdims=True)
            
        return y

class FF_subnet_Merged_DBSDE(nn.Module):
    def __init__(self, eqn,net_config):
        super(FF_subnet_Merged_DBSDE, self).__init__()
        self.dim = eqn.dim
        self.net = nn.Sequential(
                            nn.Linear(self.dim+1,2*self.dim,bias=False),
                            nn.ELU(),
                            nn.Linear(2*self.dim,2*self.dim,bias=False),
                            nn.ELU(),
                            nn.Linear(2*self.dim,2*self.dim,bias=False),
                            nn.ELU(),
                            nn.Linear(2*self.dim,self.dim,bias=False)
                            )
    def forward(self,x):
        return self.net(x)
    
class Merged_BSDE_Solver(object):
    """The fully connected neural network model."""
    def __init__(self,eqn, net_config):
        self.net_config = net_config
        self.eqn = eqn

        #self.model = torch.compile(GlobalModel(net_config, eqn), mode="max-autotune")
        self.model = GlobalModelMergedDeepBSDE(net_config, eqn)
        self.y_0 = self.model.y_0
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.01)
        #lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            #self.net_config.lr_boundaries, self.net_config.lr_values)
        #self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, epsilon=1e-8)

    def train(self):
        start_time = time.time()
        training_history = []
        valid_data = self.eqn.sample(256)

        # begin sgd iteration
        for step in range(2001):
            #print(step)
            inputs=self.eqn.sample(64)
            results=self.model(inputs)
            loss=self.loss_fn(inputs,results)
            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()
            
            if step % 200==0:
                loss = self.loss_fn(valid_data, self.model(valid_data)).detach().numpy()
                y_init = self.y_0.detach().numpy()[0]
                elapsed_time = time.time() - start_time
                training_history.append([step, loss, y_init, elapsed_time])
                print("Epoch ",step, " y_0 ",y_init," time ", elapsed_time," loss ", loss)

        return np.array(training_history)

    def loss_fn(self, inputs,results):
        DELTA_CLIP = 50.0
        dw, x = inputs
        y_terminal = self.model(inputs)
        delta = results - self.eqn.g_tf(x[:, :, -1])
        # use linear approximation outside the clipped range
        loss = torch.mean(torch.where(torch.abs(delta) < DELTA_CLIP, torch.square(delta),
                                       2 * DELTA_CLIP * torch.abs(delta) - DELTA_CLIP ** 2))

        return loss

In [200]:
np.arange(0,11)*0.1

array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ])

In [228]:
class FF_net_Raissi(nn.Module):
    def __init__(self, eqn,net_config):
        super(FF_net_Raissi, self).__init__()
        self.dim = eqn.dim
        self.net = nn.Sequential(
                            nn.Linear(self.dim+1,2*self.dim,bias=True),
                            nn.ELU(),
                            nn.Linear(2*self.dim,2*self.dim,bias=True),
                            nn.ELU(),
                            nn.Linear(2*self.dim,2*self.dim,bias=True),
                            nn.ELU(),
                            nn.Linear(2*self.dim,2*self.dim,bias=True),
                            nn.ELU(),
                            nn.Linear(2*self.dim,1,bias=True),
                            )
    def forward(self,tx):
        return self.net(tx)
        #return self.net(torch.hstack((t*torch.ones((x.shape[0], 1)),x)))
    



class Raissi_BSDE_Solver(object):
    """The fully connected neural network model."""
    def __init__(self,eqn, net_config):
        self.net_config = net_config
        self.eqn = eqn

        #self.model = torch.compile(GlobalModel(net_config, eqn), mode="max-autotune")
        self.net = FF_net_Raissi(eqn,net_config)
        #self.y_0 = nn.Parameter(torch.rand(1))
        self.times=np.arange(0, self.eqn.Ndis+1) * self.eqn.delta_t
        self.optimizer = optim.Adam(self.net.parameters(), lr=0.01)

    def train(self):
        start_time = time.time()
        training_history = []
        valid_data = self.eqn.sample(256)
        dw, x =valid_data

        # begin sgd iteration
        for step in range(2001):
            #print(step)
            inputs=self.eqn.sample(64)
            #results=self.model(inputs)
            loss=self.loss_total(inputs)
            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()
            
            if step % 200==0:
                loss = self.loss_total(valid_data).detach().numpy()
                tx=torch.hstack((self.times[0]*torch.ones((x.shape[0], 1)),x[:,:,0]))
                y_init = self.net(tx).detach().numpy()[0][0]
                elapsed_time = time.time() - start_time
                training_history.append([step, loss, y_init, elapsed_time])
                print("Epoch ",step, " y_0 ",y_init," time ", elapsed_time," loss ", loss)

        return np.array(training_history)
    def Dg_tf(self,X): # M x D
        Gt=self.eqn.g_tf(X)
        return torch.autograd.grad(Gt, X, create_graph=True,grad_outputs=torch.ones_like(Gt),allow_unused=True)[0]
    
    def loss_total(self, inputs):
        loss=0.0
        dw, x = inputs
        tx=torch.hstack((self.times[0]*torch.ones((x.shape[0], 1)),x[:,:,0]))
        tx.requires_grad_()
        Y0=self.net(tx)
        Z0 = torch.autograd.grad(Y0, tx, create_graph=True,grad_outputs=torch.ones_like(Y0),allow_unused=True)[0][:,1:]
        
        for t in range(self.eqn.Ndis):
            Y1_tilde = Y0 + self.eqn.f_tf(self.times[t],x[:,:,0],Y0,Z0)*self.eqn.delta_t + torch.sum(Z0 * dw[:, :, t], 1, keepdims=True)
            tx=torch.hstack((self.times[t+1]*torch.ones((x.shape[0], 1)),x[:,:,t+1]))
            tx.requires_grad_()
            Y1=self.net(tx)
            Z1=torch.autograd.grad(Y1, tx, create_graph=True,grad_outputs=torch.ones_like(Y1),allow_unused=True)[0][:,1:]
            loss+=torch.sum(torch.square(Y1_tilde-Y1))
            Y0=Y1
            Z0=Z1
        
        loss += torch.sum(torch.square(Y1 - self.eqn.g_tf(x[:,:,-1])))
        x.requires_grad_()
        loss += torch.sum(torch.square(Z1 - self.Dg_tf(x[:,:,-1])))
        return loss
                

In [215]:
eqn=HJBEquation({"dim":100,"total_time":1.0,"Ndis":20})

In [208]:
bsde_solver = BSDESolver(eqn,{"num_hiddens":2,"dtype":'float64'})
merged_bsde_solver = BSDESolver(eqn,{"num_hiddens":2,"dtype":'float64'})
raissi_bsde_solver=Raissi_BSDE_Solver(eqn,{"num_hiddens":2,"dtype":'float64'})

In [8]:
bsde_solver.train()

Epoch  0  y_0  0.024651418  time  0.11657214164733887  loss  21.008406567426228
Epoch  200  y_0  1.7480297  time  5.101586818695068  loss  5.965802361533653
Epoch  400  y_0  2.5713747  time  9.989607572555542  loss  3.1400265639773015
Epoch  600  y_0  3.2447877  time  14.885701179504395  loss  1.8466851338470927
Epoch  800  y_0  3.9214158  time  19.769885778427124  loss  0.6652847482735919
Epoch  1000  y_0  4.3777547  time  24.661163568496704  loss  0.11554710254698819
Epoch  1200  y_0  4.5520897  time  29.545178174972534  loss  0.03571667198614602
Epoch  1400  y_0  4.5878286  time  34.431739807128906  loss  0.032207409658111974
Epoch  1600  y_0  4.5954113  time  39.327043771743774  loss  0.031009219598391556
Epoch  1800  y_0  4.598763  time  44.22826647758484  loss  0.03173958647942544
Epoch  2000  y_0  4.5953074  time  49.11518955230713  loss  0.03190324346831917


array([[0.00000000e+00, 2.10084066e+01, 2.46514175e-02, 1.16572142e-01],
       [2.00000000e+02, 5.96580236e+00, 1.74802971e+00, 5.10158682e+00],
       [4.00000000e+02, 3.14002656e+00, 2.57137465e+00, 9.98960757e+00],
       [6.00000000e+02, 1.84668513e+00, 3.24478769e+00, 1.48857012e+01],
       [8.00000000e+02, 6.65284748e-01, 3.92141581e+00, 1.97698858e+01],
       [1.00000000e+03, 1.15547103e-01, 4.37775469e+00, 2.46611636e+01],
       [1.20000000e+03, 3.57166720e-02, 4.55208969e+00, 2.95451782e+01],
       [1.40000000e+03, 3.22074097e-02, 4.58782864e+00, 3.44317398e+01],
       [1.60000000e+03, 3.10092196e-02, 4.59541130e+00, 3.93270438e+01],
       [1.80000000e+03, 3.17395865e-02, 4.59876299e+00, 4.42282665e+01],
       [2.00000000e+03, 3.19032435e-02, 4.59530735e+00, 4.91151896e+01]])

In [9]:
merged_bsde_solver.train()

Epoch  0  y_0  0.8912061  time  0.07808279991149902  loss  13.67742695001863
Epoch  200  y_0  2.5708623  time  5.006568908691406  loss  3.4231703089464363
Epoch  400  y_0  3.4342017  time  9.888447761535645  loss  1.495930806863312
Epoch  600  y_0  4.0999155  time  14.765660762786865  loss  0.3966359336852514
Epoch  800  y_0  4.471634  time  19.64422583580017  loss  0.05698976950765269
Epoch  1000  y_0  4.577501  time  24.534709453582764  loss  0.03384459338410853
Epoch  1200  y_0  4.593001  time  29.413618326187134  loss  0.03387129571859156
Epoch  1400  y_0  4.5921683  time  34.290764808654785  loss  0.03251961704018899
Epoch  1600  y_0  4.5948195  time  39.1920850276947  loss  0.03286078295005468
Epoch  1800  y_0  4.5956793  time  44.14352083206177  loss  0.03386297165276013
Epoch  2000  y_0  4.5951767  time  49.03537082672119  loss  0.03344524612148178


array([[0.00000000e+00, 1.36774270e+01, 8.91206086e-01, 7.80827999e-02],
       [2.00000000e+02, 3.42317031e+00, 2.57086229e+00, 5.00656891e+00],
       [4.00000000e+02, 1.49593081e+00, 3.43420172e+00, 9.88844776e+00],
       [6.00000000e+02, 3.96635934e-01, 4.09991550e+00, 1.47656608e+01],
       [8.00000000e+02, 5.69897695e-02, 4.47163391e+00, 1.96442258e+01],
       [1.00000000e+03, 3.38445934e-02, 4.57750082e+00, 2.45347095e+01],
       [1.20000000e+03, 3.38712957e-02, 4.59300089e+00, 2.94136183e+01],
       [1.40000000e+03, 3.25196170e-02, 4.59216833e+00, 3.42907648e+01],
       [1.60000000e+03, 3.28607830e-02, 4.59481955e+00, 3.91920850e+01],
       [1.80000000e+03, 3.38629717e-02, 4.59567928e+00, 4.41435208e+01],
       [2.00000000e+03, 3.34452461e-02, 4.59517670e+00, 4.90353708e+01]])

In [230]:
raissi_bsde_solver.train()

Epoch  0  y_0  [0.8659583]  time  0.1250009536743164  loss  2377.671400258971
Epoch  200  y_0  [4.51918]  time  8.195310831069946  loss  10.066446997864508
Epoch  400  y_0  [4.5304646]  time  16.30252981185913  loss  9.850023431799535
Epoch  600  y_0  [4.5362096]  time  24.36612558364868  loss  9.330826614109046
Epoch  800  y_0  [4.5330725]  time  32.407026052474976  loss  9.489533338991754
Epoch  1000  y_0  [4.5201335]  time  40.50288701057434  loss  8.900748463315185
Epoch  1200  y_0  [4.5890684]  time  48.672719955444336  loss  10.100365392985488
Epoch  1400  y_0  [4.5701017]  time  56.72680640220642  loss  8.745620604769396
Epoch  1600  y_0  [4.5049753]  time  64.78528261184692  loss  8.401678908261598
Epoch  1800  y_0  [4.5326095]  time  72.94533801078796  loss  8.456892131958792
Epoch  2000  y_0  [4.5011926]  time  81.32745790481567  loss  8.357344548002208


  return np.array(training_history)


array([[0, array(2377.67140026), array([0.8659583], dtype=float32),
        0.1250009536743164],
       [200, array(10.066447), array([4.51918], dtype=float32),
        8.195310831069946],
       [400, array(9.85002343), array([4.5304646], dtype=float32),
        16.30252981185913],
       [600, array(9.33082661), array([4.5362096], dtype=float32),
        24.36612558364868],
       [800, array(9.48953334), array([4.5330725], dtype=float32),
        32.407026052474976],
       [1000, array(8.90074846), array([4.5201335], dtype=float32),
        40.50288701057434],
       [1200, array(10.10036539), array([4.5890684], dtype=float32),
        48.672719955444336],
       [1400, array(8.7456206), array([4.5701017], dtype=float32),
        56.72680640220642],
       [1600, array(8.40167891), array([4.5049753], dtype=float32),
        64.78528261184692],
       [1800, array(8.45689213), array([4.5326095], dtype=float32),
        72.94533801078796],
       [2000, array(8.35734455), array([4.50