# Hedging as a BSDE

# Deep BSDE
Hedging option can be formulated as a BSDE problem:


This code is copy pasted and modified from:
https://github.com/YifanJiang233/Deep_BSDE_solver/tree/master.

Do I have to add a MIT license?

Note that the paper 

[1] E, W., Han, J., and Jentzen, A. Deep learning-based numerical methods for high-dimensional parabolic partial differential equations and backward stochastic differential equations, Communications in Mathematics and Statistics, 5, 349–380 (2017).

is referenced but no subnetworks are used but a network that also takes in time. 
The actual implementation is for 

https://arxiv.org/pdf/2101.01869.pdf

[2] Jiang,Y., Li, J. Convergence of the deep bsde method for fbsdes with non-lipschitz coefficients.

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from parameterfree import COCOB
from dataclasses import dataclass
from typing import Callable,List

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

@dataclass
class fbsde_parameters:
    S0: torch.Tensor
    mu: float
    sigma: float
    f: Callable
    g: Callable
    T: float
    dim_x: int
    dim_y: int
    dim_d: int
    guess_sol: List[int]

    def __post_init__(self):
        self.S0 = self.S0.to(device)


class Model(nn.Module):
    def __init__(self, fbsde_parameters, dim_h):
        super(Model, self).__init__()
        self.equation= fbsde_parameters

        # specifying parameters of NN
        l = torch.rand(fbsde_parameters.dim_y, device=device)
        self.y_0 = nn.Parameter(fbsde_parameters.guess_sol[0]*l+ fbsde_parameters.guess_sol[1]*(1-l))
        self.linear1 = nn.Linear(fbsde_parameters.dim_x+1, dim_h) # dim_x + 1  the extra 1 for time
        self.linear2 = nn.Linear(dim_h, dim_h)
        self.linear3 = nn.Linear(dim_h, dim_h)
        self.linear4 = nn.Linear(dim_h, fbsde_parameters.dim_y*fbsde_parameters.dim_d)
        self.bn1 =nn.BatchNorm1d(dim_h)
        self.bn2 =nn.BatchNorm1d(dim_h)
    
    def get_z(self,x,t):
        output = torch.cat((x, t*torch.ones(x.size()[0], 1,device=device)), 1)
        output = F.gelu(self.linear1(output))
        output = self.bn1(F.gelu(self.linear2(output)))
        output = self.bn2(F.gelu(self.linear3(output)))
        return self.linear4(output).reshape(-1, self.equation.dim_y, self.equation.dim_d)
        

    def forward(self,batch_size, N):
        dt = self.equation.T / N
        x = self.equation.S0+torch.zeros(batch_size,self.equation.dim_x,device=device)
        y = self.y_0+torch.zeros(batch_size,self.equation.dim_y,device=device)

        for i in range(N):
            t = dt*i
            z = self.get_z(x,t)

            dW = torch.randn(batch_size, self.equation.dim_d, 1, device=device) * np.sqrt(dt)
            x = x+self.equation.mu(t, x, y)*dt+torch.matmul( self.equation.sigma(t, x), dW).reshape(-1, self.equation.dim_x)
            y = y-self.equation.f(t, x, y, z)*dt + torch.matmul(z, dW).reshape(-1, self.equation.dim_y)
        return x, y



class BSDEsolver():
    def __init__(self, equation, model):
        self.model = model 
        self.equation = equation

    def train(self, batch_size, N, itr, log):
        loss_fun = torch.nn.MSELoss().to(device)
        # optimizer = torch.optim.Adam(self.model.parameters())
        optimizer = COCOB(self.model.parameters())
        loss_data, y0_data = [], []

        for i in range(itr):
            x, y = self.model(batch_size,N)
            loss = loss_fun(self.equation.g(x), y)
            loss_data.append(float(loss))
            y0_data.append(float(self.model.y_0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if log  and i%int(itr/20) == 0:
                print(f"loss: {float(loss):7.2f} y0: {float(self.model.y_0):7.2f} done: {i/itr*100:5.2f}% Iteration: {i}")
        return loss_data, y0_data

In [2]:
import matplotlib.pyplot as plt
from parameters import parameters_default

par = parameters_default()

dim_x, dim_y, dim_d, dim_h = 1, 1, 1, 11

guess_sol = [0,27]

S0 = par.S0*torch.ones(dim_x)

def mu(t, x, y): return par.mu*x

def sigma(t, x): return par.sigma*x.reshape(batch_size, dim_x, dim_d)

def f_europian(t, x, y, z): return (-par.r*y ).reshape(batch_size, dim_y)

def g(x): return torch.max(par.K-x, torch.zeros(batch_size, dim_y, device=device)) 

fbsde_pars = fbsde_parameters(S0, mu, sigma, f_europian, g, par.T,dim_x, dim_y, dim_d, guess_sol)
model = Model(fbsde_pars, dim_h)
bsde_solver = BSDEsolver(fbsde_pars, model)

torch.manual_seed(46)

time_discretizations = [2,10,40]
iterations = [100,100,100]
batch_sizes = [1000,1000,1000]
plotting = False

for N,it,batch_size in zip(time_discretizations,iterations,batch_sizes): 
    print(f"Time discretization: {N} Batch size: {batch_size} Iterations: {it}")
    loss, y0=bsde_solver.train(batch_size, N,it, log=True)
    print(60*"=")
    if plotting:
        fig, axs = plt.subplots(1,2)
        axs[0].plot(loss)
        axs[0].set_title('Loss')

        axs[1].plot(y0)
        axs[1].set_title('y0')

        plt.tight_layout()
        plt.show()

Time discretization: 2 Batch size: 1000 Iterations: 100
loss:  310.19 y0:    9.48 done:  0.00% Iteration: 0
loss:  329.69 y0:    9.53 done:  5.00% Iteration: 5
loss:  306.79 y0:    9.61 done: 10.00% Iteration: 10
loss:  334.30 y0:    9.81 done: 15.00% Iteration: 15
loss:  273.51 y0:   10.28 done: 20.00% Iteration: 20
loss:  144.78 y0:   11.21 done: 25.00% Iteration: 25
loss:   55.34 y0:   12.33 done: 30.00% Iteration: 30
loss:   55.27 y0:   13.43 done: 35.00% Iteration: 35
loss:   54.35 y0:   13.68 done: 40.00% Iteration: 40
loss:   48.96 y0:   13.76 done: 45.00% Iteration: 45
loss:   50.08 y0:   13.80 done: 50.00% Iteration: 50
loss:   53.10 y0:   13.87 done: 55.00% Iteration: 55
loss:   47.25 y0:   13.79 done: 60.00% Iteration: 60
loss:   48.25 y0:   14.03 done: 65.00% Iteration: 65
loss:   60.10 y0:   13.87 done: 70.00% Iteration: 70
loss:   51.81 y0:   13.66 done: 75.00% Iteration: 75
loss:   54.06 y0:   13.72 done: 80.00% Iteration: 80
loss:   49.30 y0:   13.66 done: 85.00% Iterat