In [1]:
import numpy as np
import torch
from torch import nn
from torch import optim
import time
import torch.nn.functional as F
types=torch.float32
torch.set_default_dtype(types)

In [2]:
class HJBEquation():
    """HJB equation in PNAS paper doi.org/10.1073/pnas.1718942115"""
    def __init__(self, eqn_config):
        self.dim=eqn_config["dim"]
        self.total_time=eqn_config["total_time"]
        self.Ndis=eqn_config["Ndis"]
        self.delta_t=self.total_time/self.Ndis
        self.sqrt_delta_t = np.sqrt(self.delta_t)
        self.x_init = torch.zeros(self.dim,requires_grad=False)
        self.sigma = np.sqrt(2.0)
        self.lambd = 1.0

    def sample(self, num_sample):
        dw_sample =torch.tensor( np.random.normal(size=[num_sample, self.dim, self.Ndis]) * self.sqrt_delta_t,requires_grad=False)
        x_sample = torch.zeros([num_sample, self.dim, self.Ndis + 1],requires_grad=False)
        x_sample[:, :, 0] = torch.ones([num_sample, self.dim]) * self.x_init
        for i in range(self.Ndis):
            x_sample[:, :, i + 1] = x_sample[:, :, i] + self.sigma * dw_sample[:, :, i]
        return dw_sample, x_sample

    def f_tf(self, t, x, y, z):
        return -self.lambd * torch.sum(torch.square(z), 1, keepdims=True) / 2

    def g_tf(self, x):
        return torch.log((1 + torch.sum(torch.square(x), 1, keepdims=True)) / 2)

In [3]:
class GlobalModelDeepBSDE(nn.Module):
    def __init__(self, net_config, eqn):
        super(GlobalModelDeepBSDE, self).__init__()
        self.net_config = net_config
        self.eqn = eqn
        self.y_0=nn.Parameter(torch.rand(1))
        self.z_0=nn.Parameter((torch.rand((1,self.eqn.dim))*0.2)-0.1)
        self.subnet = [FF_subnet_DBSDE(self.eqn,net_config) for _ in range(self.eqn.Ndis-1)]

    def forward(self, inputs):
        dw, x = inputs
        time_stamp = np.arange(0, self.eqn.Ndis) * self.eqn.delta_t
        all_one_vec = torch.ones((dw.shape[0], 1))
        y = all_one_vec * self.y_0
        z = torch.matmul(all_one_vec, self.z_0)

        for t in range(0, self.eqn.Ndis-1):
            y = y - self.eqn.delta_t * (
                self.eqn.f_tf(time_stamp[t], x[:, :, t], y, z)
            ) + torch.sum(z * dw[:, :, t], 1, keepdims=True)
            z = self.subnet[t](x[:, :, t + 1]) / self.eqn.dim
        # terminal time
        y = y - self.eqn.delta_t * self.eqn.f_tf(time_stamp[-1], x[:, :, -2], y, z) + \
            torch.sum(z * dw[:, :, -1], 1, keepdims=True)
        return y


class FF_subnet_DBSDE(nn.Module):
    def __init__(self, eqn,net_config):
        super(FF_subnet_DBSDE, self).__init__()
        self.dim = eqn.dim
        self.num_hiddens = net_config["num_hiddens"]
        self.net=nn.Sequential(
                            nn.BatchNorm1d(self.dim, eps=1e-06, momentum=0.99,affine=False,dtype=types),
                            nn.Linear(self.dim,self.dim+10,bias=False),
                            nn.BatchNorm1d(self.dim+10, eps=1e-06, momentum=0.99,affine=False,dtype=types),
                            nn.ReLU(),
                            nn.Linear(self.dim+10,self.dim+10,bias=False),
                            nn.BatchNorm1d(self.dim+10, eps=1e-06, momentum=0.99,affine=False,dtype=types),
                            nn.ReLU(),
                            nn.Linear(self.dim+10,self.dim,bias=False),
                            nn.BatchNorm1d(self.dim, eps=1e-06, momentum=0.99,affine=False,dtype=types)
                            )
    def forward(self,x):
        return self.net(x)
    
class BSDESolver(object):
    """The fully connected neural network model."""
    def __init__(self,eqn, net_config):
        self.net_config = net_config
        self.eqn = eqn

        #self.model = torch.compile(GlobalModel(net_config, eqn), mode="max-autotune")
        self.model = GlobalModelDeepBSDE(net_config, eqn)
        self.y_0 = self.model.y_0
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.01)
        #lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            #self.net_config.lr_boundaries, self.net_config.lr_values)
        #self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, epsilon=1e-8)

    def train(self):
        start_time = time.time()
        training_history = []
        valid_data = self.eqn.sample(256)

        # begin sgd iteration
        for step in range(2001):
            #print(step)
            inputs=self.eqn.sample(64)
            results=self.model(inputs)
            loss=self.loss_fn(inputs,results)
            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()
            
            if step % 200==0:
                loss = self.loss_fn(valid_data, self.model(valid_data)).detach().numpy()
                y_init = self.y_0.detach().numpy()[0]
                elapsed_time = time.time() - start_time
                training_history.append([step, loss, y_init, elapsed_time])
                print("Epoch ",step, " y_0 ",y_init," time ", elapsed_time," loss ", loss)

        return np.array(training_history)

    def loss_fn(self, inputs,results):
        DELTA_CLIP = 50.0
        dw, x = inputs
        y_terminal = self.model(inputs)
        delta = results - self.eqn.g_tf(x[:, :, -1])
        # use linear approximation outside the clipped range
        loss = torch.mean(torch.where(torch.abs(delta) < DELTA_CLIP, torch.square(delta),
                                       2 * DELTA_CLIP * torch.abs(delta) - DELTA_CLIP ** 2))

        return loss

In [4]:
class GlobalModelMergedDeepBSDE(nn.Module):
    def __init__(self, net_config, eqn):
        super(GlobalModelMergedDeepBSDE, self).__init__()
        self.net_config = net_config
        self.eqn = eqn
        self.dim=eqn.dim
        self.model=FF_subnet_Merged_DBSDE(self.eqn,net_config)
        self.y_0=nn.Parameter(torch.rand(1))

    def forward(self, inputs):
        dw, x = inputs
        time_stamp = np.arange(0, self.eqn.Ndis) * self.eqn.delta_t
        all_one_vec = torch.ones((dw.shape[0], 1))
        y = all_one_vec * self.y_0
        #z = torch.matmul(all_one_vec, self.z_0)

        for t in range(0, self.eqn.Ndis):
            inp=torch.hstack((time_stamp[t]*all_one_vec,x[:,:,t]))
            z = self.model(inp) / self.eqn.dim
            y = y - self.eqn.delta_t * (
                self.eqn.f_tf(time_stamp[t], x[:, :, t], y, z)
            ) + torch.sum(z * dw[:, :, t], 1, keepdims=True)
            
        return y

class FF_subnet_Merged_DBSDE(nn.Module):
    def __init__(self, eqn,net_config):
        super(FF_subnet_Merged_DBSDE, self).__init__()
        self.dim = eqn.dim
        self.net = nn.Sequential(
                            nn.Linear(self.dim+1,2*self.dim,bias=False),
                            nn.ELU(),
                            nn.Linear(2*self.dim,2*self.dim,bias=False),
                            nn.ELU(),
                            nn.Linear(2*self.dim,2*self.dim,bias=False),
                            nn.ELU(),
                            nn.Linear(2*self.dim,self.dim,bias=False)
                            )
    def forward(self,x):
        return self.net(x)
    
class Merged_BSDE_Solver(object):
    """The fully connected neural network model."""
    def __init__(self,eqn, net_config):
        self.net_config = net_config
        self.eqn = eqn

        #self.model = torch.compile(GlobalModel(net_config, eqn), mode="max-autotune")
        self.model = GlobalModelMergedDeepBSDE(net_config, eqn)
        self.y_0 = self.model.y_0
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.01)
        #lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            #self.net_config.lr_boundaries, self.net_config.lr_values)
        #self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, epsilon=1e-8)

    def train(self):
        start_time = time.time()
        training_history = []
        valid_data = self.eqn.sample(256)

        # begin sgd iteration
        for step in range(2001):
            #print(step)
            inputs=self.eqn.sample(64)
            results=self.model(inputs)
            loss=self.loss_fn(inputs,results)
            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()
            
            if step % 200==0:
                loss = self.loss_fn(valid_data, self.model(valid_data)).detach().numpy()
                y_init = self.y_0.detach().numpy()[0]
                elapsed_time = time.time() - start_time
                training_history.append([step, loss, y_init, elapsed_time])
                print("Epoch ",step, " y_0 ",y_init," time ", elapsed_time," loss ", loss)

        return np.array(training_history)

    def loss_fn(self, inputs,results):
        DELTA_CLIP = 50.0
        dw, x = inputs
        y_terminal = self.model(inputs)
        delta = results - self.eqn.g_tf(x[:, :, -1])
        # use linear approximation outside the clipped range
        loss = torch.mean(torch.where(torch.abs(delta) < DELTA_CLIP, torch.square(delta),
                                       2 * DELTA_CLIP * torch.abs(delta) - DELTA_CLIP ** 2))

        return loss

In [5]:
class GlobalModelMergedResidualDeepBSDE(nn.Module):
    def __init__(self, net_config, eqn):
        super(GlobalModelMergedResidualDeepBSDE, self).__init__()
        self.net_config = net_config
        self.eqn = eqn
        self.dim=eqn.dim
        self.model=FF_subnet_Merged_Residual_DBSDE(self.eqn,net_config)
        self.y_0=nn.Parameter(torch.rand(1))

    def forward(self, inputs):
        dw, x = inputs
        time_stamp = np.arange(0, self.eqn.Ndis) * self.eqn.delta_t
        all_one_vec = torch.ones((dw.shape[0], 1))
        y = all_one_vec * self.y_0
        #z = torch.matmul(all_one_vec, self.z_0)

        for t in range(0, self.eqn.Ndis):
            inp=torch.hstack((time_stamp[t]*all_one_vec,y,self.eqn.g_tf(x[:,:,t]),x[:,:,t]))
            z = self.model(inp) / self.eqn.dim
            y = y - self.eqn.delta_t * (
                self.eqn.f_tf(time_stamp[t], x[:, :, t], y, z)
            ) + torch.sum(z * dw[:, :, t], 1, keepdims=True)
            
        return y

class FF_subnet_Merged_Residual_DBSDE(nn.Module):
    def __init__(self, eqn,net_config):
        super(FF_subnet_Merged_Residual_DBSDE, self).__init__()
        self.dim = eqn.dim
        self.l1=nn.Linear(self.dim+1+1+1,2*self.dim,bias=False)
        self.l2=nn.Linear(2*self.dim,2*self.dim,bias=False)
        self.l3=nn.Linear(2*self.dim,2*self.dim,bias=False)
        self.l4=nn.Linear(2*self.dim,self.dim,bias=False)

    def forward(self,x):
        out1=F.ELU(self.l1(x))
        out2=F.ELU(self.l3(F.ELU(self.l2(out1))))
        return self.l4(out1+out2)
    
class Merged_Residual_BSDE_Solver(object):
    """The fully connected neural network model."""
    def __init__(self,eqn, net_config):
        self.net_config = net_config
        self.eqn = eqn

        #self.model = torch.compile(GlobalModel(net_config, eqn), mode="max-autotune")
        self.model = GlobalModelMergedDeepBSDE(net_config, eqn)
        self.y_0 = self.model.y_0
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.01)
        #lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            #self.net_config.lr_boundaries, self.net_config.lr_values)
        #self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, epsilon=1e-8)

    def train(self):
        start_time = time.time()
        training_history = []
        valid_data = self.eqn.sample(256)

        # begin sgd iteration
        for step in range(2001):
            #print(step)
            inputs=self.eqn.sample(64)
            results=self.model(inputs)
            loss=self.loss_fn(inputs,results)
            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()
            
            if step % 200==0:
                loss = self.loss_fn(valid_data, self.model(valid_data)).detach().numpy()
                y_init = self.y_0.detach().numpy()[0]
                elapsed_time = time.time() - start_time
                training_history.append([step, loss, y_init, elapsed_time])
                print("Epoch ",step, " y_0 ",y_init," time ", elapsed_time," loss ", loss)

        return np.array(training_history)

    def loss_fn(self, inputs,results):
        DELTA_CLIP = 50.0
        dw, x = inputs
        y_terminal = self.model(inputs)
        delta = results - self.eqn.g_tf(x[:, :, -1])
        # use linear approximation outside the clipped range
        loss = torch.mean(torch.where(torch.abs(delta) < DELTA_CLIP, torch.square(delta),
                                       2 * DELTA_CLIP * torch.abs(delta) - DELTA_CLIP ** 2))

        return loss

In [None]:
class FF_subnet_DenseNet_BSDE(nn.Module):
    def __init__(self, eqn,net_config):
        super(FF_subnet_Merged_Residual_DBSDE, self).__init__()
        self.dim = eqn.dim
        self.l1=nn.Linear(self.dim,self.dim+20,bias=False)
        self.l2=nn.Linear(2*self.dim+20,self.dim,bias=False)
        self.l3=nn.Linear(2*self.dim,self.dim,bias=False)
        self.l4=nn.Linear(2*self.dim,1,bias=False)

    def forward(self,x):
        y2=F.relu(self.l1(x))
        x2=torch.hstack((x,y2))
        y3=F.relu(self.l2(x2))
        x3=torch.hstack((x2,y3))
        y4=F.relu(self.l3(x3))
        x4=torch.hstack((x3,y4))
        return self.l4(x4)

In [6]:
class FF_net_Raissi(nn.Module):
    def __init__(self, eqn,net_config):
        super(FF_net_Raissi, self).__init__()
        self.dim = eqn.dim
        self.net = nn.Sequential(
                            nn.Linear(self.dim+1,2*self.dim,bias=True),
                            nn.ELU(),
                            nn.Linear(2*self.dim,2*self.dim,bias=True),
                            nn.ELU(),
                            nn.Linear(2*self.dim,2*self.dim,bias=True),
                            nn.ELU(),
                            nn.Linear(2*self.dim,2*self.dim,bias=True),
                            nn.ELU(),
                            nn.Linear(2*self.dim,1,bias=True),
                            )
    def forward(self,tx):
        return self.net(tx)
        #return self.net(torch.hstack((t*torch.ones((x.shape[0], 1)),x)))
    



class Raissi_BSDE_Solver(object):
    """The fully connected neural network model."""
    def __init__(self,eqn, net_config):
        self.net_config = net_config
        self.eqn = eqn

        #self.model = torch.compile(GlobalModel(net_config, eqn), mode="max-autotune")
        self.net = FF_net_Raissi(eqn,net_config)
        #self.y_0 = nn.Parameter(torch.rand(1))
        self.times=np.arange(0, self.eqn.Ndis+1) * self.eqn.delta_t
        self.optimizer = optim.Adam(self.net.parameters(), lr=0.01)

    def train(self):
        start_time = time.time()
        training_history = []
        valid_data = self.eqn.sample(256)
        dw, x =valid_data

        # begin sgd iteration
        for step in range(2001):
            #print(step)
            inputs=self.eqn.sample(64)
            #results=self.model(inputs)
            loss=self.loss_total(inputs)
            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()
            
            if step % 200==0:
                loss = self.loss_total(valid_data).detach().numpy()
                tx=torch.hstack((self.times[0]*torch.ones((x.shape[0], 1)),x[:,:,0]))
                y_init = self.net(tx).detach().numpy()[0][0]
                elapsed_time = time.time() - start_time
                training_history.append([step, loss, y_init, elapsed_time])
                print("Epoch ",step, " y_0 ",y_init," time ", elapsed_time," loss ", loss)

        return np.array(training_history)
    def Dg_tf(self,X): # M x D
        Gt=self.eqn.g_tf(X)
        return torch.autograd.grad(Gt, X, create_graph=True,grad_outputs=torch.ones_like(Gt),allow_unused=True)[0]
    
    def loss_total(self, inputs):
        loss=0.0
        dw, x = inputs
        tx=torch.hstack((self.times[0]*torch.ones((x.shape[0], 1)),x[:,:,0]))
        tx.requires_grad_()
        Y0=self.net(tx)
        Z0 = torch.autograd.grad(Y0, tx, create_graph=True,grad_outputs=torch.ones_like(Y0),allow_unused=True)[0][:,1:]
        
        for t in range(self.eqn.Ndis):
            Y1_tilde = Y0 + self.eqn.f_tf(self.times[t],x[:,:,0],Y0,Z0)*self.eqn.delta_t + torch.sum(Z0 * dw[:, :, t], 1, keepdims=True)
            tx=torch.hstack((self.times[t+1]*torch.ones((x.shape[0], 1)),x[:,:,t+1]))
            tx.requires_grad_()
            Y1=self.net(tx)
            Z1=torch.autograd.grad(Y1, tx, create_graph=True,grad_outputs=torch.ones_like(Y1),allow_unused=True)[0][:,1:]
            loss+=torch.sum(torch.square(Y1_tilde-Y1))
            Y0=Y1
            Z0=Z1
        
        loss += torch.sum(torch.square(Y1 - self.eqn.g_tf(x[:,:,-1])))
        x.requires_grad_()
        loss += torch.sum(torch.square(Z1 - self.Dg_tf(x[:,:,-1])))
        return loss
                

In [7]:
eqn=HJBEquation({"dim":100,"total_time":1.0,"Ndis":20})

In [8]:
bsde_solver = BSDESolver(eqn,{"num_hiddens":2,"dtype":'float64'})
merged_bsde_solver = Merged_BSDE_Solver(eqn,{"num_hiddens":2,"dtype":'float64'})
merged_residual_bsde_solver = Merged_Residual_BSDE_Solver(eqn,{"num_hiddens":2,"dtype":'float64'})
raissi_bsde_solver=Raissi_BSDE_Solver(eqn,{"num_hiddens":2,"dtype":'float64'})

In [9]:
bsde_solver.train()

Epoch  0  y_0  0.38692307  time  0.09283852577209473  loss  17.6296045420346
Epoch  200  y_0  2.0784166  time  4.934387683868408  loss  3.9479276851932985
Epoch  400  y_0  2.88576  time  9.793327331542969  loss  2.0068472878605927
Epoch  600  y_0  3.5896077  time  14.63205337524414  loss  0.9711366747533632
Epoch  800  y_0  4.198273  time  19.45731210708618  loss  0.23364227338426266
Epoch  1000  y_0  4.502965  time  24.28793954849243  loss  0.04261526331890004
Epoch  1200  y_0  4.5827975  time  29.12455654144287  loss  0.03163510447388217
Epoch  1400  y_0  4.5947804  time  33.97587180137634  loss  0.03160657193441353
Epoch  1600  y_0  4.5963  time  38.83889031410217  loss  0.031927983016275176
Epoch  1800  y_0  4.594609  time  43.664732456207275  loss  0.03179026073366001
Epoch  2000  y_0  4.5957284  time  48.55494570732117  loss  0.030988513138864682


array([[0.00000000e+00, 1.76296045e+01, 3.86923075e-01, 9.28385258e-02],
       [2.00000000e+02, 3.94792769e+00, 2.07841659e+00, 4.93438768e+00],
       [4.00000000e+02, 2.00684729e+00, 2.88576007e+00, 9.79332733e+00],
       [6.00000000e+02, 9.71136675e-01, 3.58960772e+00, 1.46320534e+01],
       [8.00000000e+02, 2.33642273e-01, 4.19827318e+00, 1.94573121e+01],
       [1.00000000e+03, 4.26152633e-02, 4.50296497e+00, 2.42879395e+01],
       [1.20000000e+03, 3.16351045e-02, 4.58279753e+00, 2.91245565e+01],
       [1.40000000e+03, 3.16065719e-02, 4.59478045e+00, 3.39758718e+01],
       [1.60000000e+03, 3.19279830e-02, 4.59630013e+00, 3.88388903e+01],
       [1.80000000e+03, 3.17902607e-02, 4.59460878e+00, 4.36647325e+01],
       [2.00000000e+03, 3.09885131e-02, 4.59572840e+00, 4.85549457e+01]])

In [10]:
merged_bsde_solver.train()

Epoch  0  y_0  0.7678617  time  0.10190987586975098  loss  14.800367954367239
Epoch  200  y_0  1.7751629  time  4.937001943588257  loss  3.072468239566029
Epoch  400  y_0  2.6254697  time  9.718029499053955  loss  1.4265072958530047
Epoch  600  y_0  3.6082768  time  14.433924913406372  loss  0.9759215394144511
Epoch  800  y_0  4.4713254  time  19.25156307220459  loss  0.04651214364877901
Epoch  1000  y_0  4.6000304  time  23.997984647750854  loss  0.0207139657017685
Epoch  1200  y_0  4.598907  time  28.747323513031006  loss  0.02157977348341072
Epoch  1400  y_0  4.601288  time  33.48168683052063  loss  0.020587065153183225
Epoch  1600  y_0  4.6000586  time  38.23451256752014  loss  0.019999116820022256
Epoch  1800  y_0  4.595997  time  42.979350090026855  loss  0.01992012546604236
Epoch  2000  y_0  4.598587  time  47.728952169418335  loss  0.020817297450629654


array([[0.00000000e+00, 1.48003680e+01, 7.67861724e-01, 1.01909876e-01],
       [2.00000000e+02, 3.07246824e+00, 1.77516294e+00, 4.93700194e+00],
       [4.00000000e+02, 1.42650730e+00, 2.62546968e+00, 9.71802950e+00],
       [6.00000000e+02, 9.75921539e-01, 3.60827684e+00, 1.44339249e+01],
       [8.00000000e+02, 4.65121436e-02, 4.47132540e+00, 1.92515631e+01],
       [1.00000000e+03, 2.07139657e-02, 4.60003042e+00, 2.39979846e+01],
       [1.20000000e+03, 2.15797735e-02, 4.59890699e+00, 2.87473235e+01],
       [1.40000000e+03, 2.05870652e-02, 4.60128784e+00, 3.34816868e+01],
       [1.60000000e+03, 1.99991168e-02, 4.60005856e+00, 3.82345126e+01],
       [1.80000000e+03, 1.99201255e-02, 4.59599686e+00, 4.29793501e+01],
       [2.00000000e+03, 2.08172975e-02, 4.59858704e+00, 4.77289522e+01]])

In [11]:
raissi_bsde_solver.train()

Epoch  0  y_0  0.71274763  time  0.12448453903198242  loss  3307.733676285858
Epoch  200  y_0  4.517213  time  8.27889084815979  loss  10.49734853331923
Epoch  400  y_0  4.491789  time  16.60650324821472  loss  10.174380070709734
Epoch  600  y_0  4.5488873  time  24.91992950439453  loss  10.502842291889444
Epoch  800  y_0  4.463564  time  33.19879460334778  loss  9.140609841890804
Epoch  1000  y_0  4.4624968  time  41.50823616981506  loss  8.771660674796836
Epoch  1200  y_0  4.483387  time  49.781808376312256  loss  8.059260242680217
Epoch  1400  y_0  4.573529  time  58.0403196811676  loss  9.151293541494878
Epoch  1600  y_0  4.474177  time  66.32269954681396  loss  7.8776581443452605
Epoch  1800  y_0  4.514343  time  74.58523607254028  loss  6.945754072302962
Epoch  2000  y_0  4.4668703  time  82.96582818031311  loss  6.355837613860909


array([[0.00000000e+00, 3.30773368e+03, 7.12747633e-01, 1.24484539e-01],
       [2.00000000e+02, 1.04973485e+01, 4.51721287e+00, 8.27889085e+00],
       [4.00000000e+02, 1.01743801e+01, 4.49178886e+00, 1.66065032e+01],
       [6.00000000e+02, 1.05028423e+01, 4.54888725e+00, 2.49199295e+01],
       [8.00000000e+02, 9.14060984e+00, 4.46356392e+00, 3.31987946e+01],
       [1.00000000e+03, 8.77166067e+00, 4.46249676e+00, 4.15082362e+01],
       [1.20000000e+03, 8.05926024e+00, 4.48338699e+00, 4.97818084e+01],
       [1.40000000e+03, 9.15129354e+00, 4.57352877e+00, 5.80403197e+01],
       [1.60000000e+03, 7.87765814e+00, 4.47417688e+00, 6.63226995e+01],
       [1.80000000e+03, 6.94575407e+00, 4.51434278e+00, 7.45852361e+01],
       [2.00000000e+03, 6.35583761e+00, 4.46687031e+00, 8.29658282e+01]])

In [12]:
merged_residual_bsde_solver.train()

Epoch  0  y_0  0.97598165  time  0.10168290138244629  loss  13.162296418591458
Epoch  200  y_0  1.9735289  time  5.021441698074341  loss  3.1716427480950564
Epoch  400  y_0  2.8692765  time  9.858968496322632  loss  1.5456351927991623
Epoch  600  y_0  3.8270943  time  14.847797870635986  loss  0.47027912564070457
Epoch  800  y_0  4.544493  time  19.682626724243164  loss  0.02609388216400159
Epoch  1000  y_0  4.599294  time  24.50575304031372  loss  0.020081667465717507
Epoch  1200  y_0  4.59985  time  29.317641496658325  loss  0.019959196827357393
Epoch  1400  y_0  4.599803  time  34.13115668296814  loss  0.019768003305034847
Epoch  1600  y_0  4.602144  time  38.947619915008545  loss  0.019978553125954184
Epoch  1800  y_0  4.5970664  time  43.76622200012207  loss  0.022010927543323677
Epoch  2000  y_0  4.5960317  time  48.58127045631409  loss  0.020750965863285058


array([[0.00000000e+00, 1.31622964e+01, 9.75981653e-01, 1.01682901e-01],
       [2.00000000e+02, 3.17164275e+00, 1.97352886e+00, 5.02144170e+00],
       [4.00000000e+02, 1.54563519e+00, 2.86927652e+00, 9.85896850e+00],
       [6.00000000e+02, 4.70279126e-01, 3.82709432e+00, 1.48477979e+01],
       [8.00000000e+02, 2.60938822e-02, 4.54449320e+00, 1.96826267e+01],
       [1.00000000e+03, 2.00816675e-02, 4.59929419e+00, 2.45057530e+01],
       [1.20000000e+03, 1.99591968e-02, 4.59985018e+00, 2.93176415e+01],
       [1.40000000e+03, 1.97680033e-02, 4.59980297e+00, 3.41311567e+01],
       [1.60000000e+03, 1.99785531e-02, 4.60214376e+00, 3.89476199e+01],
       [1.80000000e+03, 2.20109275e-02, 4.59706640e+00, 4.37662220e+01],
       [2.00000000e+03, 2.07509659e-02, 4.59603167e+00, 4.85812705e+01]])