In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt

In [126]:
r=0.05
d=2
s_0=90
delta=0.1
sigma=0.2
rho=0
K=100
T=3
N=9
lr=5e-4
delta_t=T/N
sqrt_delta_t=np.sqrt(delta_t)
batch_size=2048
M=10000

In [127]:
def sample(batch_size):
    S=np.ones(shape=(batch_size,d,N+1))*s_0
    for i in range(batch_size):
        for j in range(N):
            brown=np.random.normal(size=(d,))*sqrt_delta_t
            S[i,:,j+1]=S[i,:,j]*np.exp((r-delta-0.5*sigma**2)*delta_t+sigma*brown)
    return S

In [128]:
def g_tf(x):
    return torch.nn.functional.relu(torch.max(x,dim=1)[0]-K)

In [129]:
class one_time_net(torch.nn.Module):
    def __init__(self,d):
        super(one_time_net,self).__init__()
        self.n_neuron=[d,d+40,d+40,1]
        self.norm=torch.nn.BatchNorm1d(self.n_neuron[0])
        #self.layer1=self._one_layer(self.n_neuron[0],self.n_neuron[1],torch.nn.ReLU())
        #self.layer2=self._one_layer(self.n_neuron[1],self.n_neuron[2],torch.nn.ReLU())
        #self.layer3=self._one_layer(self.n_neuron[2],self.n_neuron[3],None)
        self.layers=torch.nn.ModuleList([self._one_layer(self.n_neuron[0],self.n_neuron[1],torch.nn.ReLU()),self._one_layer(self.n_neuron[1],self.n_neuron[2],torch.nn.ReLU()),
                                        self._one_layer(self.n_neuron[2],self.n_neuron[3],torch.nn.Sigmoid())])
    def _one_layer(self,input_dim,output_dim,activation_fn=torch.nn.ReLU()):
        one_layer=torch.nn.Sequential()
        one_layer.add_module('Linear',torch.nn.Linear(input_dim,output_dim))
        one_layer.add_module('Norm',torch.nn.BatchNorm1d(output_dim))
        if activation_fn != None:
            one_layer.add_module('activation',activation_fn)
        return one_layer
    def forward(self,x):
        norm=self.norm(x)
        l1=self.layers[0](x)
        l2=self.layers[1](l1)
        out=self.layers[2](l2)
        return out      

In [130]:
stopping_rule={}
for i in range(1,N):
    stopping_rule['t{}'.format(i)]=one_time_net(d+1)
optimizers={}
for i in range(1,N):
    optimizers['t{}'.format(i)]=torch.optim.Adam(stopping_rule['t{}'.format(i)].parameters(),lr=lr)
Y0=torch.nn.Parameter(torch.tensor(10,dtype=torch.float32))
Yoptimizer=torch.optim.Adam(list([Y0]),lr=lr)
for i in list(stopping_rule):
    for name,para in stopping_rule[i].named_parameters():
        if 'Linear' in name and 'weight' in name:
            torch.nn.init.xavier_normal_(para)

In [131]:
for i in range(M):
    S=sample(batch_size)
    S=torch.tensor(S,dtype=torch.float32,requires_grad=False)
    g=g_tf(S)
    S=torch.cat([S,torch.unsqueeze(g,dim=1)],dim=1)
    continue_value=S[:,-1,N]*torch.exp(-1*r*torch.tensor(N).double()*delta_t)
    for k in reversed(range(1,N)):
        stopping_value=S[:,-1,k]*torch.exp(-1*r*torch.tensor(k).double()*delta_t)
        F=stopping_rule['t{}'.format(k)](S[:,:,k])
        loss=-1*(stopping_value*F.squeeze()+continue_value*(1-F.squeeze())).mean()
        optimizer=optimizers['t{}'.format(k)]
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        newF=stopping_rule['t{}'.format(k)](S[:,:,k])
        index=(newF>0.5).reshape((-1,))
        continue_value[index]=stopping_value[index]
    loss=(Y0-continue_value.mean()).pow(2)
    if (i+1)%500==0:
        print('Episode {} loss:{},Y0:{}'.format(i+1,loss.item(),Y0.item()))
        #print(loss.item())
        #print(Y0.item())
    Yoptimizer.zero_grad()
    loss.backward()
    Yoptimizer.step()

Episode 500 loss:13.911758422851562,Y0:9.782146453857422
Episode 1000 loss:7.793548107147217,Y0:9.582137107849121
Episode 1500 loss:5.795093059539795,Y0:9.39547061920166
Episode 2000 loss:3.217606782913208,Y0:9.218242645263672
Episode 2500 loss:2.5977768898010254,Y0:9.049223899841309
Episode 3000 loss:1.1588634252548218,Y0:8.887639045715332


In [132]:
S=sample(4096*100)

In [133]:
S=torch.tensor(S,dtype=torch.float32,requires_grad=False)
g=g_tf(S)
S=torch.cat([S,torch.unsqueeze(g,dim=1)],dim=1)
continue_value=S[:,-1,N]*torch.exp(-1*r*torch.tensor(N).double()*delta_t)
for k in reversed(range(1,N)):
    stopping_value=S[:,-1,k]*torch.exp(-1*r*torch.tensor(k).double()*delta_t)
    newF=stopping_rule['t{}'.format(k)](S[:,:,k])
    index=(newF>0.5).reshape((-1,))
    continue_value[index]=stopping_value[index]

In [134]:
print(continue_value.mean().item())

7.718098163604736
