In [2]:
#import libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

from model_em import fbsde
from model_em import BSDEsolverEm 
from model_mil import BSDEsolverMil


In [3]:
#defining equation parameters
dim_x, dim_y, dim_d, N = 100, 1, 100, 50
dim_h, num_h, itr, batch_size = 110, 32, 3000, 64  #batch size is number of paths used in each iteration

x_0, T = 100*torch.ones(dim_x), 1

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

R = 0.02
gamma_h = 0.2
gamma_l = 0.02
v_h = 47
v_l = 65
delta = 2/3
slope = (gamma_h - gamma_l) / (v_h - v_l)
mu = 0.02
sig = 0.2

def Q(y):
    Q_val = torch.zeros_like(y)
    Q_val += (y < v_h) * gamma_h
    Q_val += (y > v_l) * gamma_l
    between = (y >= v_h) & (y <= v_l)
    Q_val += between * (slope * (y - v_h) + gamma_h)
    return Q_val

def b(t, x, y):
    return (mu * x).reshape(batch_size, dim_x)


def sigma(t, x):
    return torch.diag_embed(sig*x).reshape(batch_size, dim_x, dim_d)


def f(t, x, y, z):
    q = Q(y)
    return ((-(1 - delta) * q - R) * y).reshape(batch_size, dim_y)

def g(x):
    return (torch.min(x, 1)[0]).reshape(batch_size, dim_y)


equation = fbsde(x_0, b, sigma, f, g, T,dim_x,dim_y,dim_d)

samples = 10

In [None]:
for i in range(samples):
    #training em model, N=50
    bsde_solver_em = BSDEsolverEm(equation, dim_h, num_h)
    bsde_solver_em.train(batch_size, N, itr, log=True, test_num= i)

    #testing em model, N=100
    bsde_solver_em = BSDEsolverEm(equation, dim_h, num_h)
    bsde_solver_em.train(batch_size, N=100, itr, log=True, test_num= i)

    #training milstein model
    bsde_solver_mil = BSDEsolverMil(equation, dim_h, num_h)
    bsde_solver_mil.train(batch_size, N, itr, log = True, test_num= i)


In [None]:
mean_loss = np.zeros(itr)
mean_y0 = np.zeros(itr)
mean_y0_error = np.zeros(itr)
loss_std_mat = []
y0_std_mat = []

#exact y0 computed in paper
#y0 = 57.300
y0 = 58.113

for i in range(samples):
    loss_data = np.load(f'loss_data_EM{i}.npy')
    y0_data = np.load(f'y0_data_EM{i}.npy')

    mean_loss += loss_data/samples #adds loss
    mean_y0 += y0_data/samples #adds y0 data

    y0_error = np.abs(y0 - y0_data)
    mean_y0_error += y0_error/samples #adds error

    #append all to matrix
    loss_std_mat.append(loss_data)
    y0_std_mat.append(y0_error)

#create std vector
loss_std = np.zeros(itr)
y0_std = np.zeros(itr)
for i in range(itr):
    loss_std[i] = np.std([loss_std_mat[j][i] for j in range(samples)])
    y0_std[i] = np.std([y0_std_mat[j][i] for j in range(samples)])


plt.figure()
plt.plot(mean_y0)
plt.fill_between(range(itr), mean_y0 - y0_std, mean_y0 + y0_std, alpha=0.2)
plt.xlabel('Iterations')
plt.ylabel('$y_0$ approximation')
plt.grid()
#plt.yscale('log')


plt.figure()
plt.plot(mean_y0_error)
plt.ylim(10e-2,4)
plt.xlabel('Iterations')
plt.ylabel('$y_0$ error')
plt.yticks()
plt.yscale('log')
plt.grid()

print("y0: ", mean_y0[-1])
print("y0 error: ", mean_y0_error[-1])
print('Relative error: ', mean_y0_error[-1]/y0 * 100, '%') 
print("network loss: ", mean_loss[-1])
print('Average computation time of 440.35s')

In [None]:
mean_loss_2 = np.zeros(itr)
mean_y0_2 = np.zeros(itr)
mean_y0_error_2 = np.zeros(itr)
loss_std_mat_2 = []
y0_std_mat_2 = []

#exact y0 computed in paper
#y0 = 57.300
y0 = 58.113

for i in range(samples):
    loss_data = np.load(f'loss_data_EM{i+100}.npy')
    y0_data = np.load(f'y0_data_EM{i+100}.npy')

    mean_loss_2 += loss_data/samples #adds loss
    mean_y0_2 += y0_data/samples #adds y0 data

    y0_error = np.abs(y0 - y0_data)
    mean_y0_error_2 += y0_error/samples #adds error

    #append all to matrix
    loss_std_mat_2.append(loss_data)
    y0_std_mat_2.append(y0_error)

#create std vector
loss_std_2 = np.zeros(itr)
y0_std_2 = np.zeros(itr)
for i in range(itr):
    loss_std_2[i] = np.std([loss_std_mat_2[j][i] for j in range(samples)])
    y0_std_2[i] = np.std([y0_std_mat_2[j][i] for j in range(samples)])


plt.figure()
plt.plot(mean_y0_2)
plt.fill_between(range(itr), mean_y0_2 - y0_std, mean_y0_2 + y0_std, alpha=0.2)
plt.xlabel('Iterations')
plt.ylabel('$y_0$ approximation')
plt.grid()
#plt.yscale('log')

plt.figure()
plt.plot(mean_y0_error_2)
plt.xlabel('Iterations')
plt.ylabel('$y_0$ error')
plt.yticks()
plt.yscale('log')
plt.grid()

print("y0: ", mean_y0_2[-1])
print("y0 error: ", mean_y0_error_2[-1])
print('Relative error: ', mean_y0_error_2[-1]/y0 * 100, '%') 
print("network loss: ", mean_loss_2[-1])
print('Average computation time of 1541.6s')

In [None]:
mean_loss_mil = np.zeros(itr)
mean_y0_mil = np.zeros(itr)
mean_y0_error_mil = np.zeros(itr)


mil_loss_std_mat = []
mil_y0_std_mat = []


y0 = 58.113

for i in range(samples):
    loss_data = np.load(f'loss_data_Mil{i}.npy')
    y0_data = np.load(f'y0_data_Mil{i}.npy')

    mean_loss_mil += loss_data/samples #adds loss
    mean_y0_mil += y0_data/samples #adds y0 data

    y0_error = np.abs(y0 - y0_data)
    mean_y0_error_mil += y0_error/samples #adds error

    #append all to matrix
    mil_loss_std_mat.append(loss_data)
    mil_y0_std_mat.append(y0_error)

#create std vector
loss_std_mil = np.zeros(itr)
y0_std_mil = np.zeros(itr)
for i in range(itr):
    loss_std_mil[i] = np.std([mil_loss_std_mat[j][i] for j in range(samples)])
    y0_std_mil[i] = np.std([mil_y0_std_mat[j][i] for j in range(samples)])


plt.figure()
plt.plot(mean_y0_mil,'g')
plt.fill_between(range(itr), mean_y0_mil - y0_std, mean_y0_mil + y0_std, alpha=0.2 )
plt.xlabel('Iterations')
plt.ylabel('$y_0$ approximation')
plt.grid()

plt.figure()
plt.plot(mean_y0_error_mil)
plt.xlabel('Iterations')
plt.ylabel('$y_0$ error')
plt.yticks()
plt.yscale('log')
plt.grid()


print("y0: ", mean_y0_mil[-1])
print("y0 error: ", mean_y0_error_mil[-1])
print('Relative error: ', mean_y0_error_mil[-1]/y0 * 100, '%') 
print("network loss: ", mean_loss_mil[-1])
print('Average computation time of ---s')


In [None]:
plt.plot(mean_y0_error, label='Euler-Maruyama, $N = 50$')
plt.plot(mean_y0_error_2, label='Euler-Maruyama, $N = 100$')
plt.plot(mean_y0_error_mil, label='Milstein, $N = 50$')
plt.legend()
plt.grid()
plt.yscale('log')
