In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as tgrad
from torch.autograd import Variable

import os
import time
import utils
import numpy as np
import pandas as pd

from tqdm import tqdm, trange
import matplotlib.pyplot as plt

import networks
import importlib

In [2]:
# seed = 1234
# torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
# np.random.seed(seed)
os.environ['KMP_DUPLICATE_LIB_OK']='True'
torch.set_default_dtype(torch.float32)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(torch.cuda.is_available())
print(device)

if device == 'cuda': 
    print(torch.cuda.get_device_name())

True
cuda


# Data Sampling
Here in our case, the system is European Call Option PDE and the physical information about the system consists of Boundary Value conditions, final Value conditions and the PDE itself.

In [3]:
K = 10
r = 0.035
sigma = 0.2
T = 1
S_range = [0, int(5*K)]
t_range = [0, T]
gs = lambda x: np.fmax(x-K, 0)
M = 100
N = 5000

# Build Neural Network

In [4]:
net = networks.FeedforwardNeuralNetwork(2, 50, 1, 8) #  Network initialization
net.cuda()

FeedforwardNeuralNetwork(
  (layers): ModuleList(
    (0): Linear(in_features=2, out_features=50, bias=True)
    (1-7): 7 x Linear(in_features=50, out_features=50, bias=True)
  )
  (output): Linear(in_features=50, out_features=1, bias=True)
  (relu): ReLU()
)

In [5]:
n_epochs = 5000
lossFunction = nn.MSELoss()
lr = 3e-5
optimizer = optim.Adam(net.parameters(), lr=lr)

x_f_s = torch.tensor(np.log(1.0)).float().to(device).requires_grad_(True)
x_label_s = torch.tensor(np.log(1.0)).float().to(device).requires_grad_(True)
x_data_s = torch.tensor(np.log(1.0)).float().to(device).requires_grad_(True)
w_lr = 3e-7
optimizer_adam_weight = optim.Adam([x_f_s] + [x_label_s] + [x_data_s], lr=w_lr)

In [6]:
# physical loss samples
samples = {"pde": 5000, "bc":500, "fc":500}

# sample data generated by finite difference method
X_train_tensor, y_train_tensor, X_test_tensor, y_test_tensor = utils.fdm_data(S_range[-1], T, M, N, "500000sample.csv", device)        

# Modelling


In [7]:
loss_hist = []
x_f_s_hist = []
x_label_s_hist = []
x_data_s_hist = []

start_time = time.time()
for epoch in range(n_epochs):
    
    bc_st_train, bc_v_train, n_st_train, n_v_train = \
    utils.trainingData(K, r, sigma, T, S_range[-1], S_range, t_range, gs, 
                       samples['bc'], 
                       samples['fc'], 
                       samples['pde'], 
                       RNG_key=123)
    
    # save training data points to tensor and send to device
    n_st_train = torch.from_numpy(n_st_train).float().requires_grad_().to(device)
    n_v_train = torch.from_numpy(n_v_train).float().to(device)
    
    bc_st_train = torch.from_numpy(bc_st_train).float().to(device)
    bc_v_train = torch.from_numpy(bc_v_train).float().to(device)   
    
    # pde residual loss
    y1_hat = net(n_st_train)
    grads = tgrad.grad(y1_hat, n_st_train, grad_outputs=torch.ones(y1_hat.shape).cuda(), 
                retain_graph=True, create_graph=True, only_inputs=True)[0]
    dVdt, dVdS = grads[:, 0].view(-1, 1), grads[:, 1].view(-1, 1)
    grads2nd = tgrad.grad(dVdS, n_st_train, grad_outputs=torch.ones(dVdS.shape).cuda(), 
                    create_graph=True, only_inputs=True, allow_unused=True)[0]
    S1 = n_st_train[:, 1].view(-1, 1)
    d2VdS2 = grads2nd[:, 1].view(-1, 1)
    pde_loss = lossFunction(-dVdt, 0.5*((sigma*S1)**2)*d2VdS2 + r*S1*dVdS - r*y1_hat)
    
    # boudary condition loss
    y2_hat = net(bc_st_train)
    bc_loss = lossFunction(bc_v_train, y2_hat)
    
    # sample training data loss
    y3_hat = net(X_train_tensor)
    data_loss = lossFunction(y_train_tensor, y3_hat)
    
    # Backpropagation and Update
    optimizer.zero_grad()
    combined_loss = torch.exp(-x_f_s.detach()) * pde_loss + torch.exp(-x_label_s.detach()) * bc_loss + torch.exp(-x_data_s.detach()) * data_loss + x_data_s + x_label_s + x_f_s
    combined_loss.backward()
    optimizer.step()
    
    # update the weight
    optimizer_adam_weight.zero_grad()
    loss = torch.exp(-x_f_s) * pde_loss.detach() + torch.exp(-x_label_s) * bc_loss.detach() + torch.exp(-x_data_s) * data_loss.detach() + x_data_s + x_label_s + x_f_s
    loss.backward()
    optimizer_adam_weight.step()
    
    # record the loss
    mse_loss = pde_loss + bc_loss + data_loss
    loss_hist.append(mse_loss.item())
    x_f_s_hist.append(torch.exp(-x_f_s).item())
    x_label_s_hist.append(torch.exp(-x_label_s).item())
    x_data_s_hist.append(torch.exp(-x_data_s).item())
    if epoch % 500 == 0:
        print(f'{epoch}/{n_epochs} PDE Loss: {pde_loss.item():.5f}, BC Loss: {bc_loss.item():.5f}, data loss: {data_loss.item():.5f}, total loss: {mse_loss.item():.5f}, minimum loss: {min(loss_hist):.5f}')
        print(f'the weight is {torch.exp(-x_f_s.detach()).item():.5f}, {torch.exp(-x_label_s.detach()).item():.5f}. {torch.exp(-x_data_s.detach()).item():.5f}, the parameter is {x_f_s.item():.5f}, {x_label_s.item():.5f}, {x_data_s.item():.5f}')
    pass
end_time = time.time()
print('run time:', end_time - start_time)

loss_weights_hist = pd.DataFrame({
        'PDE_Weight': x_f_s_hist,
        'BC_Weight': x_label_s_hist,
        'Data_Weight': x_data_s_hist
    })
loss_weights_hist.to_csv(f'weights/{w_lr}.csv', index=False)

0/5000 PDE Loss: 0.00000, BC Loss: 676.24292, data loss: 425.44766, total loss: 1101.69055, minimum loss: 1101.69055
the weight is 1.00000, 1.00000. 1.00000, the parameter is -0.00000, 0.00000, 0.00000
500/5000 PDE Loss: 0.01691, BC Loss: 146.57896, data loss: 79.39012, total loss: 225.98599, minimum loss: 225.98599
the weight is 1.00015, 0.99987. 0.99987, the parameter is -0.00015, 0.00013, 0.00013
1000/5000 PDE Loss: 1.04168, BC Loss: 3.06250, data loss: 6.72915, total loss: 10.83332, minimum loss: 10.83332
the weight is 1.00024, 0.99986. 0.99986, the parameter is -0.00024, 0.00014, 0.00014
1500/5000 PDE Loss: 0.72667, BC Loss: 0.12501, data loss: 0.39300, total loss: 1.24469, minimum loss: 1.24469
the weight is 1.00018, 0.99986. 0.99986, the parameter is -0.00018, 0.00014, 0.00014
2000/5000 PDE Loss: 0.00398, BC Loss: 0.00150, data loss: 0.00183, total loss: 0.00730, minimum loss: 0.00730
the weight is 1.00037, 0.99986. 0.99986, the parameter is -0.00037, 0.00014, 0.00014
2500/5000 

In [8]:
# Evaluate the model on the test set
net.eval()
with torch.no_grad():
    test_outputs = net(X_test_tensor)
    test_loss = lossFunction(test_outputs, y_test_tensor)
    print(f'Test Loss: {test_loss.item():.4f}')

Test Loss: 0.0016


In [9]:
for i in range(3000):
    if i % 500 == 0:
        print(loss_hist[i])

1101.6905517578125
225.98599243164062
10.833324432373047
1.2446893453598022
0.007300898432731628
0.0029041250236332417


In [10]:
'''
lr = 3e-5
lr = lr
1099.8822021484375
54.46702575683594
10.79588794708252
0.26179707050323486
0.06549832224845886
0.04295624420046806
v2
1091.264404296875
175.3550567626953
9.200620651245117
0.021437039598822594
0.011154992505908012
0.00615318538621068
v3
1096.0784912109375
145.78138732910156
11.065857887268066
1.195802927017212
0.007282810285687447
0.005661157425493002
v4
1099.22607421875
414.196533203125
9.993719100952148
1.4779069423675537
0.00975001323968172
0.0027645404916256666
v5
1101.3546142578125
117.58456420898438
10.586881637573242
0.4234362244606018
0.16712422668933868
0.08312012255191803

lr = lr*0.1
1108.269287109375
458.4380187988281
18.00096893310547
10.638126373291016
2.3813610076904297
0.04458016902208328
v2
1098.9173583984375
498.7457275390625
8.37392520904541
0.04848393797874451
0.01035694032907486
0.006585360039025545
v3
1106.7056884765625
199.62017822265625
20.819793701171875
11.254127502441406
2.2797794342041016
0.03742136433720589
v4
1107.873291015625
74.37163543701172
14.879474639892578
3.1905810832977295
0.045051589608192444
0.008426403626799583
v5
1110.307373046875
28.044471740722656
16.48096466064453
5.725974082946777
0.373512327671051
0.007787228096276522
v6
1093.041748046875
88.30891418457031
4.8466057777404785
0.030280468985438347
0.003480653278529644
0.003087927121669054

lr = lr*0.01
1108.57177734375
32.370243072509766
19.039159774780273
6.253387451171875
0.04844483360648155
0.009771297685801983
v2
1097.6781005859375
354.0700378417969
11.808877944946289
0.6618903875350952
0.017533782869577408
0.008471298031508923
v3
1092.56689453125
155.4998779296875
15.183956146240234
1.3004826307296753
0.13293592631816864
0.06324443966150284
v4
1097.779052734375
259.1787109375
4.381502151489258
0.03809420019388199
0.0030492269434034824
0.003581415396183729


lr = lr*0.001
1102.8289794921875
70.73458862304688
20.345287322998047
11.07504653930664
2.21193265914917
0.020034609362483025
v2
1100.103271484375
29.110830307006836
13.54207992553711
2.2702531814575195
0.03685794398188591
0.005187606438994408
v3
1101.7392578125
39.58943557739258
9.618980407714844
1.3002090454101562
0.0525912344455719
0.028041506186127663
v4
1106.5574951171875
54.07958984375
11.572875022888184
3.614226818084717
0.036611057817935944
0.007018315140157938


lr = lr*0.0001
1108.546875
64.16551208496094
19.342121124267578
11.909713745117188
2.2722582817077637
0.028394339606165886
v1
1097.0103759765625
641.002685546875
6.1454362869262695
0.028341539204120636
0.019869372248649597
0.015094837173819542
v2
1099.1807861328125
74.23635864257812
5.6405744552612305
0.03148069232702255
0.01238433551043272
0.008084086701273918

lr = lr*0.00001
1095.9866943359375
49.48216247558594
13.448659896850586
1.1453372240066528
0.0064399996772408485
0.004670318216085434
v1
1108.866943359375
298.6302795410156
20.8073787689209
11.087276458740234
2.8229360580444336
0.1352125108242035
v2
1089.82958984375
53.824317932128906
11.834549903869629
1.6474502086639404
0.027329374104738235
0.011542053893208504


lr = lr*0.0000001


lr = lr*0.000000001

'''

'\nlr = 3e-5\nlr = lr\n1099.8822021484375\n54.46702575683594\n10.79588794708252\n0.26179707050323486\n0.06549832224845886\n0.04295624420046806\nv2\n1091.264404296875\n175.3550567626953\n9.200620651245117\n0.021437039598822594\n0.011154992505908012\n0.00615318538621068\nv3\n1096.0784912109375\n145.78138732910156\n11.065857887268066\n1.195802927017212\n0.007282810285687447\n0.005661157425493002\nv4\n1099.22607421875\n414.196533203125\n9.993719100952148\n1.4779069423675537\n0.00975001323968172\n0.0027645404916256666\nv5\n1101.3546142578125\n117.58456420898438\n10.586881637573242\n0.4234362244606018\n0.16712422668933868\n0.08312012255191803\n\nlr = lr*0.1\n1108.269287109375\n458.4380187988281\n18.00096893310547\n10.638126373291016\n2.3813610076904297\n0.04458016902208328\nv2\n1098.9173583984375\n498.7457275390625\n8.37392520904541\n0.04848393797874451\n0.01035694032907486\n0.006585360039025545\nv3\n1106.7056884765625\n199.62017822265625\n20.819793701171875\n11.254127502441406\n2.2797794342