<a href="https://colab.research.google.com/github/Leonaxi/Machine-Learning-in-Finance/blob/main/L11_Least_squares_monte_carlo_with_neural_nets_ipynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
pip install optax dm-haiku

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting optax
  Downloading optax-0.1.3-py3-none-any.whl (145 kB)
[K     |████████████████████████████████| 145 kB 33.0 MB/s 
[?25hCollecting dm-haiku
  Downloading dm_haiku-0.0.8-py3-none-any.whl (350 kB)
[K     |████████████████████████████████| 350 kB 60.4 MB/s 
Collecting chex>=0.0.4
  Downloading chex-0.1.5-py3-none-any.whl (85 kB)
[K     |████████████████████████████████| 85 kB 4.8 MB/s 
Collecting jmp>=0.0.2
  Downloading jmp-0.0.2-py3-none-any.whl (16 kB)
Installing collected packages: jmp, chex, optax, dm-haiku
Successfully installed chex-0.1.5 dm-haiku-0.0.8 jmp-0.0.2 optax-0.1.3


In [3]:
import jax.numpy as jnp
import jax
import haiku as hk
import optax

optimizer = optax.adam
lr = 1e-4


Spot = jnp.array([38, 36, 35])   # stock price
σ = jnp.array([0.2, .25, .3])     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate

# change batch size(for training) and n(for evaluating)
# n=100000 runs too slow, so we choose a small batch size for training
# dont need to be that precise when train the model, that's the natual of stochastic descent
# but if we have to evaluate it, we should have a precise number n=100000
n = 100000  # Number of simualted paths   # evaluation batch size: larger
batch_size = 512  # training batch_size: smaller


m = 50      # number of exercise dates
T = 1       # maturity
order = 12   # Polynmial order
Δt = T / m  # interval between two exercise dates

n_stocks = 3

# simulates one step of the stock price evolution
def step(S, rng): # (sample, key): should specify key in future step function
    # rng= jax.random.PRNGKey(0)  # do not use it, because it would not be stochastic(we want stochastic), will give the wrong numbers
    ϵ = jax.random.normal(rng, S.shape) # key, shape: same dimension of S 生成符合正态分布的随机数以S的shape
    dB = jnp.sqrt(Δt) * ϵ
    S = S + r * S * Δt + σ * S * dB   # next instant will be this new S 遵循几何布朗运动的模拟
    return S, S


def payoff_put(S):  # option price
    return jnp.maximum(K - jnp.max(S, 1), 0.)

# normalize the input by subtract the mean and divide the sd or try a larger network --- more flexible network

def model(Si):
  out = (Si.reshape(-1,1) - 37.) / 5   # transform into metrix
  out = hk.Linear(64)(out) # Si:hk would expecting it as a mitrix: #of col determine #of fetures. But the way we write the code is a vector
  out = jax.nn.relu(out)

  out = hk.Linear(64)(out)
  out = jax.nn.relu(out)

  out = hk.Linear(1)(out)
  out = jnp.squeeze(out) # we want the result transfer from mitrix with one col to the vector
  return out
# squeeze: 从数组中删除单维度条目，即把shape=1的维度去掉，但对非但维度的维度不起作用


In [4]:
# when doing gradient decent, always initilize your input, tend to accelerate the gredient descent

In [5]:
# initilize the model

# initilize with hk.transform model; the network is deterministic, no stochastic need to provide seed: withoud seed
init, model = hk.without_apply_rng(hk.transform(model))
rng = jax.random.PRNGKey(0)  # create a pseudo-random number key given a integer seed
Θ = init(rng, jnp.ones([batch_size, n_stocks])) # n_stocks here =3

# we have n_stock=49 nnet works, each element has 49 copies
# we want to do every leaf for the python tree (1 nnet for 1 timeset for each regression)
def stack(Θ):
  return jnp.stack([Θ] * 49) 

# map a stack function over the theta tree, the output will be a pie tree, a data continer with same shape as theta
# map a multi-input function over pytree args to produce a new pytree
Θ = jax.tree_map(stack, Θ) 

# whenever you define a network parameter, you should initilize your optimizer state
opt_state = optimizer(lr).init(Θ)


In [6]:
# @jax.jit
# def update_gradient_descent(Θ, opt_state):

#   def L(Θ):
#     return compute_price(Θ)[1].sum()

#   grad = jax.grad(L)(Θ)
#   updates, opt_state = optimizer(lr).update(grad, opt_state)
#   Θ = optax.apply_updates(Θ, updates)
#   return Θ, opt_state


In [8]:
# LSMC algorithm

# compute_price: 
  # calculate the compute_price with large batchsize (n) when evaluating the model
  # with small batchsize when training the model: (batch size =512)
  # compute_price depends on desired batchsize
# rng:
  # if you always sample same data, will translate to lower performance model
  # each iteration gredient descent, use dif random seeds
def compute_price(Θ, batch_size, rng):  
    S = jnp.column_stack([jnp.ones(batch_size) * Spot[i] for i in range(3)]) # Spot* jnp.ones
    # rng = jax.random.PRNGKey(0)   # should random sampling dif data points each time: 每次512都不同
    rng_vector = jax.random.split(rng, m)   # take the key and split the key into m new keys

# jax.lax.scan():
  # implement efficiency of dessert flow (function that you apply(step), initial state of stock price, input)
  # initial: stock price change from step by step: initialized stock price and create a new stock price
  # input: whatever you need to provide for the step function: here is random seed, for each time differnt key  
    _, S = jax.lax.scan(step, S, rng_vector) 
    # 2 output: 
      # 1st: final stock price at exprision date (similate 50 times)
      # 2nd: whatever the term return the second output (自定义，见step,里面说的是S)

    discount = jnp.exp(-r * Δt)

    # Very last date
    value_if_exercise = payoff_put(S[-1])
    discounted_future_cashflows = value_if_exercise * discount

# state: the thing change from iteration to iteration
# input: things have to provide for this loop, depends on the index, i of the iteration
    def core(state, input):
        discounted_future_cashflows = state 
        Si, Θi = input
        Y = discounted_future_cashflows  

        #def model(Θi, Si):
        #  X= chebyshev_basis(scale(Si), order)
        #  return X @ Θi
        value_if_wait = model(Θi, Si) # write it as a function

        mse = jnp.mean((value_if_wait - discounted_future_cashflows)**2)
        value_if_exercise = payoff_put(Si)
        exercise = value_if_exercise >= value_if_wait  # make decision based on comparision
        discounted_future_cashflows = discount * jnp.where(
            exercise,
            value_if_exercise,
            discounted_future_cashflows)

        return discounted_future_cashflows, mse
 

    # Proceed recursively
    S = jnp.flip(S, 0)[1:] # delete the first row of metrix
    inputs = S, Θ
    discounted_future_cashflows, mse = jax.lax.scan(core, discounted_future_cashflows, inputs)

    return discounted_future_cashflows.mean(), mse

#print(compute_price(Θ, batch_size, rng))



    # 1st array: mean of discounted future casflows
    ## 3.95 should gradually increse from the initial value 3.95 to 4.47
    # if you use the wrong Θ (continuous value), you will have wrong comparision on if should exercise
    # follow the strategy with wrong continuous value, should produce a suboptimal value for the future cash flows
    ## Θ minimize mse： 
    # 2nd array: ([mse of 1st regression, mse of 2nd regression])




In [None]:
%timeit compute_price()

In [None]:
#loss function: training data, use small batch size = 500
def L(Θ, rng):
  mse = compute_price(Θ, batch_size, rng)[1]  # compute_price 里面第二个变量是[1]mse
  return mse.sum()

# apply gredient descent with opt_state
@jax.jit  # want to run fast
def gradient_descent_step(Θ, opt_state, rng):   # update gredient descent by using optimize_state
  rng, _ = jax.random.split(rng) # split differnet new seed and return the (rng) new seed, next time will use new random seed to run a code (no overfitting)
  grad = jax.grad(L)(Θ, rng)    # compute the gradient using jax.grad(loss function) to evaluate Θ
  updates, opt_state = optimizer(lr).update(grad, opt_state)  # compute updates, opt_state = optimizer(learning rate) + apply gradient and previous optimized state
  Θ = optax.apply_updates(Θ, updates)
  return Θ, opt_state, rng  # return new parameter(metrix), new optimal state

# gradient_descent_step(Θ, opt_state, rng)  # run to check if it works



In [None]:
# define evaluation function: use large batch size = n 
  # if you use right /theta means you have right decision on exrcising
  # 所以 we want to keep track discounted_future_cashflows.mean()--- can calculate the option price
  # 所以 we want to check how well dcf "[0]" is, but not mse "[1]" because it is not train and test result
@jax.jit
def evaluation(Θ):
  rng = jax.random.PRNGKey(0)  # random seed
  return compute_price(Θ, n, rng)[0] # return discounted_future_cashflows.mean()


# iteration
rng = jax.random.PRNGKey(0)
for iteration in range(10000):
  Θ, opt_state, rng = gradient_descent_step(Θ, opt_state, rng)

# to monitor process to check how well my code is performing, for every 1000 iteration, run the code to see how well it does
  if iteration % 100 == 0:
    metric = evaluation(Θ)  # θ defines how I compute the contiuns values, depends how I compute the prices
    print(metric)
  
  # numbers are going up from 3.+ to 4.47, it is what we expected