In [None]:
#import libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F


#import models
from model_em import fbsde
from model_em import BSDEsolverEm   
from model_mil import BSDEsolverMil
from model_rk import BSDEsolverRK

In [None]:
#defining our equation
dim_x, dim_y, dim_d, dim_h, num_h, N, itr, batch_size = 1, 1, 1, 24, 3, 50, 3000, 256#batch size is number of paths used in each iteration

x_0, T = torch.ones(dim_x), 1

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# a, b = 1
def b(t, x, y):
    return (1-x).reshape(batch_size, dim_x)


def sigma(t, x):
    return torch.sqrt(torch.abs(x)).reshape(batch_size, dim_x, dim_d)


def f(t, x, y, z):
    return (-y*x).reshape(batch_size, dim_y)


def g(x):
    return torch.ones(batch_size, dim_y,device=device)

equation = fbsde(x_0, b, sigma, f, g, T,dim_x, dim_y, dim_d)

samples = 1 0

#Calculates exact y0 (~0.39647)
a, b, sigma = 1, 1, 1
gamma = np.sqrt(a**2 + 2* sigma**2)
num1 = 2*gamma* np.exp((gamma + a) / 2)
denom1 = (gamma - a) + ((gamma + a) * np.exp(gamma))
num2 = 2*(1 - np.exp(gamma))
denom2 = (gamma - a) + ((gamma + a) * np.exp(gamma))
y0_exact = (num1/denom1)**2 * np.exp(num2/denom2) 
print("Exact y0: ", y0_exact)

In [None]:
#training euler-maruyama model
for i in range(samples):
    bsde_solver_em = BSDEsolverEm(equation, dim_h, num_h)
    bsde_solver_em.train(batch_size, N, itr, log=True, test_num= i)

#training euler-maruyama model with 2N
for i in range(samples):
    bsde_solver_em = BSDEsolverEm(equation, dim_h, num_h)
    bsde_solver_em.train(batch_size, N=100, itr, log=True, test_num= i+100)    


#training milstein model
for i in range(samples):
    bsde_solver_mil = BSDEsolverMil(equation, dim_h, num_h)
    bsde_solver_mil.train(batch_size, N, itr, log=True, test_num=i)

#training runge-kutta model
for i in range(samples):
    bsde_solver_rk = BSDEsolverRK(equation, dim_h, num_h)
    bsde_solver_rk.train(batch_size, N, itr, log=True, test_num=i)


In [None]:
avg_loss = np.zeros(itr)
avg_y0 = np.zeros(itr)
for i in range(samples):
    loss_data = np.load(f'loss_data_EM{i}.npy')
    avg_loss += loss_data/(samples)
    plt.plot(loss_data, 'lightblue')
plt.yscale('log')
plt.plot(avg_loss, 'red')



plt.figure()
for i in range(samples):
    y0_data = np.load(f'y0_data_EM{i}.npy')
    y0_error = np.abs(y0_data - y0_exact)
    avg_y0 += y0_error/(samples)
    plt.plot(y0_error, 'lightblue')
plt.yscale('log')
plt.plot(avg_y0, 'red')

print('y0 error: ', avg_y0[-1])
print('loss: ', avg_loss[-1])
print('average computation time:  43.39s')

In [None]:
#euler-maruyama model with 2N steps
avg_loss_2N = np.zeros(itr)
avg_y0_2N = np.zeros(itr)
for i in range(samples):
    loss_data = np.load(f'loss_data_EM{i+100}.npy')
    avg_loss_2N += loss_data/(samples)


plt.figure()
for i in range(samples):
    y0_data = np.load(f'y0_data_EM{i+100}.npy')
    y0_error = np.abs(y0_data - y0_exact)
    avg_y0_2N += y0_error/(samples)

print('y0 error: ', avg_y0_2N[-1])
print('loss: ', avg_loss_2N[-1])
print('average computation time:  89.07s')

In [None]:
avg_loss_mil = np.zeros(itr)
avg_y0_mil = np.zeros(itr)
#analyzing Milstein results
for i in range(samples):
    loss_data = np.load(f'loss_data_Mil{i+1}.npy')
    avg_loss_mil += loss_data/(samples)
    plt.plot(loss_data, 'lightblue')

plt.plot(avg_loss_mil, 'red')
#plt.legend([i for i in range(samples)] + ['average'])
plt.yscale('log')


plt.figure()
for i in range(samples):
    y0_data = np.load(f'y0_data_Mil{i+1}.npy')
    y0_error = np.abs(y0_data - y0_exact)
    avg_y0_mil += y0_error/(samples)
    plt.plot(y0_error, 'lightblue')

plt.plot(avg_y0_mil, 'red')
plt.yscale('log')

print('Average computation time of 93.16 seconds')
print('final error', avg_y0_mil[-1])

In [None]:
avg_loss_rk = np.zeros(itr)
avg_y0_rk = np.zeros(itr)
#analyzing Milstein results
for i in range(samples):
    loss_data = np.load(f'loss_data_RK_fwd{i}.npy')
    avg_loss_rk += loss_data/(samples)
    plt.plot(loss_data, 'lightblue')

plt.plot(avg_loss_rk, 'red')
#plt.legend([i for i in range(samples)] + ['average'])
plt.yscale('log')


plt.figure()
for i in range(samples):
    y0_data = np.load(f'y0_data_RK_fwd{i}.npy')
    y0_error = np.abs(y0_data - y0_exact)
    avg_y0_rk += y0_error/(samples)
    plt.plot(y0_error, 'lightblue')

plt.plot(avg_y0_rk, 'red')
plt.yscale('log')

print('Average computation time of 93.16 seconds')
print('final error', avg_y0_rk[-1])

In [None]:
#plotting average loss and y0 error
plt.figure()
plt.title('Loss of network')
plt.plot(avg_loss)
plt.plot(avg_loss_2N)
plt.plot(avg_loss_rk)
#plt.plot(avg_loss_mil)
#plt.plot(avg_loss_RK)
plt.yscale('log')
plt.grid()
plt.legend(['Euler-Maruyama N=50','Euler-Maruyama N=100',"Stochastic Runge-Kutta", 'Milstein' ]) 

plt.figure()
plt.title('Error at Y0')
plt.plot(avg_y0)
plt.plot(avg_y0_2N)
plt.plot(avg_y0_rk)
#plt.plot(avg_y0_mil)

#plt.plot(avg_y0_RK)
plt.yscale('log')
plt.grid()
plt.legend(['Euler-Maruyama','Euler-Maruyama N=100', 'Stochastic Runge-Kutta', 'Milstein' ])