In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [None]:

class Model2(nn.Module):
    def __init__(self, equation, dim_h, N):
        super(Model2, self).__init__()
        self.equation= equation
        self.N = N
        # specifying parameters of NN
        l = torch.rand(equation.dim_y, device=device)
        self.y_0 = nn.Parameter(equation.guess_y_0[0]*l+ equation.guess_y_0[1]*(1-l))
        self.linear1 = [nn.Linear(equation.dim_x, dim_h) for _ in range(N)]
        self.linear2 = [nn.Linear(dim_h, dim_h) for _ in range(N)]
        self.bn = [nn.BatchNorm1d(dim_h) for _ in range(N)]
        self.linear4 = [nn.Linear(dim_h, equation.dim_y*equation.dim_d) for _ in range(N)]
    
    def get_z(self,x,i):
        x = F.relu(self.linear1[i](x))
        x = self.bn[i](F.relu(self.linear2[i](x)))
        return self.linear4[i](x).reshape(-1, self.equation.dim_y, self.equation.dim_d)
        

    def forward(self,batch_size):
        dt = self.equation.T / self.N
        x = self.equation.x_0+torch.zeros(batch_size,self.equation.dim_x,device=device)
        y = self.y_0+torch.zeros(batch_size,self.equation.dim_y,device=device)

        for i in range(self.N):
            t = dt*i
            z = self.get_z(x,i)

            dW = torch.randn(batch_size, self.equation.dim_d, 1, device=device) * np.sqrt(dt)
            x = x+self.equation.b(t, x, y)*dt+torch.matmul( self.equation.sigma(t, x), dW).reshape(-1, self.equation.dim_x)
            y = y-self.equation.f(t, x, y, z)*dt + torch.matmul(z, dW).reshape(-1, self.equation.dim_y)
        return x, y


class BSDEsolver2():
    def __init__(self, equation, dim_h,N):
        self.model = Model2(equation,dim_h,N).to(device)
        self.equation = equation

    def train(self, batch_size, itr, log):
        criterion = torch.nn.MSELoss().to(device)
        optimizer = torch.optim.Adam(self.model.parameters())
        loss_data, y0_data = [], []

        for i in range(itr):
            x, y = self.model(batch_size)
            loss = criterion(self.equation.g(x), y)
            loss_data.append(float(loss))
            y0_data.append(float(self.model.y_0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if log and i%int(itr/20) == 0:
                print(f"loss: {float(loss):7.2f} y0: {float(self.model.y_0):7.2f} done: {i/itr*100:5.2f}% Iteration: {i}")
        return loss_data, y0_data
            

In [None]:

mu = 0.07
sig = 0.3
r = 0.02
K = 110
dim_x, dim_y, dim_d, dim_h, N, itr, batch_size = 1, 1, 1, 3, 10, 3000, 1000
guess_y_0 = [14,14]

x_0, T = 100*torch.ones(dim_x), 1


def b(t, x, y):
    return mu*x.reshape(batch_size, dim_x)


def sigma(t, x):
    return sig*x.reshape(batch_size, dim_x, dim_d)

#f_american doesnt work
def f_american(t, x, y, z):
    return (-r*y + r*K* (y<=g(x))).reshape(batch_size, dim_y)

def f_europian(t, x, y, z):
    return (-r*y ).reshape(batch_size, dim_y)


def g(x):
    return torch.max(K-x, torch.zeros(batch_size, dim_y, device=device)) 


equation = fbsde(x_0, b, sigma, f_europian, g, T,dim_x, dim_y, dim_d, guess_y_0)

bsde_solver = BSDEsolver2(equation, dim_h,N)

In [None]:

import matplotlib.pyplot as plt

torch.manual_seed(44)
loss, y0=bsde_solver.train(batch_size, itr, log=True)

fig, axs = plt.subplots(1,2)
axs[0].plot(loss)
axs[0].set_title('Loss')

axs[1].plot(y0)
axs[1].set_title('y0')

plt.tight_layout()
plt.show()