<a href="https://colab.research.google.com/github/Frank-III/Fin553-ML-in-Finance/blob/main/Copy_of_553_2022_Project_3_Stochastic_Cake_Eating.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install -q optax dm-haiku

[K     |████████████████████████████████| 154 kB 29.3 MB/s 
[K     |████████████████████████████████| 352 kB 52.5 MB/s 
[K     |████████████████████████████████| 85 kB 4.3 MB/s 
[?25h

Due Date: November 30
# Problem statement

The code below is similar to the Cake Eating problem code we implemented in class. The differences are:
- Each time interval corresponds to one year (instead of one month)
- The consumption policy function is written as a simple sigle-layer neural network, with tanh activation (instead of the usual relu)

We will interpret the size of the cake as being total wealth, and cake consumption as general consumption. The fraction of wealth not consumed today are the *savings* (line 51). The dynamics of wealth are described by line 54. That line is equivalent to assuming that your savings are invested in a risk-free savings account that pays 0 interest, and therefore has a gross return of 1, denoted by *R* (line 53).




In [None]:
import jax
import jax.numpy as jnp
import optax
import haiku as hk


γ = 2.
β = 0.95

optimizer = optax.adam
lr = 1e-3
T = 50

In [None]:
def U(c):
    return c**(1 - γ) / (1 - γ)

optimizer = optax.adam
lr = 1e-3
T = 50


def nnet(x):
  X = jnp.column_stack([x])
  X = hk.Linear(32)(X)
  X = jnp.tanh(X)
  X = hk.Linear(1)(X)
  X = jnp.squeeze(X)
  return X


init, nnet = hk.without_apply_rng(hk.transform(nnet))
rng = jax.random.PRNGKey(0)
Θ = init(rng, jnp.array(1.))


opt_state = optimizer(lr).init(Θ)


def L(Θ):

  x = 1.
  G = 0.

  state = x
  inputs = jnp.arange(T)

  def core(state, inputs):
    t = inputs
    xt = state

    ct = jax.nn.sigmoid(nnet(Θ, xt) - 4.) * xt
    ut = U(ct)
    savings = xt - ct

    R = 1.
    x_tp1 = R * savings

    discounted_utility = β**t * ut
    return x_tp1, discounted_utility

  x, discounted_utility = jax.lax.scan(core, state, inputs)
  G = discounted_utility.sum()
  return -G


@jax.jit
def evaluation(Θ):
  return -L(Θ)


@jax.jit
def update_gradient_descent(Θ, opt_state):
  grad = jax.grad(L)(Θ)
  updates, opt_state = optimizer(lr).update(grad, opt_state)
  Θ = optax.apply_updates(Θ, updates)
  return Θ, opt_state




In [None]:
for iteration in range(1000000):
  Θ, opt_state = update_gradient_descent(Θ, opt_state)

  if iteration % 1000 == 0:
    print(evaluation(Θ))

-1301.5056
-869.2983
-846.71466
-839.00665
-833.3895
-830.0196


KeyboardInterrupt: ignored

Suppose now that your savings are fully invested in the stock market, so the evolution of wealth is now stochastic. The stock market gross return is modeled by the function below:

In [None]:
@jax.jit 
def stock_return(rng):
  μs = 0.06
  σs = 0.2  
  ε = jax.random.normal(rng, ())
  log_return = μs + σs * ε
  return jnp.exp(log_return)

In [None]:
#rng = jax.random.PRNGKey(0)
#jnp.mean([stock_return(key) for key in jax.random.split(rng, 1000000)])

Write a code to solve for the optimal consumption policy in this environment. 
What is the expceted sum of discounted rewards (value function) resulting from that policy? Use at least 1 million sample paths to estimate that number.

In [None]:
jax.random.normal(rng,(T,1000000))

(50, 1000000)

In [None]:
rng = jax.random.PRNGKey(0)
μs = 0.06
σs = 0.2  
#jax.random.normal(rng,(T,1000000)).mean(axis=1)
#jnp.exp(μs + σs * jax.random.normal(rng,(T,1000000))).mean(axis=1)
for _ in range(5):
  rng, _ = jax.random.split(rng)
  print(jnp.exp(μs + σs * jax.random.normal(rng,(T,1000000)).mean(axis=1))[:5])

[1.0618982 1.0619469 1.0619837 1.062034  1.0615014]
[1.0618906 1.06188   1.0615606 1.0618215 1.062045 ]
[1.0617937 1.062007  1.0619162 1.0616356 1.0621024]
[1.0620306 1.0619072 1.0615717 1.0618093 1.0617504]
[1.06171   1.0615904 1.0617591 1.0620694 1.061711 ]


In [None]:
rng = jax.random.PRNGKey(0)
print("Year 1 to 5")
for _ in range(5):
  rng, _ = jax.random.split(rng)
  print(jnp.exp(μs + σs * jax.random.normal(rng,(T,1000000))).mean(axis=1)[:5])

Year 1 to 5
[1.0833415 1.0834395 1.0834734 1.0834988 1.0829507]
[1.0833129 1.0833772 1.0830647 1.083262  1.0835037]
[1.0832926 1.0834638 1.0833498 1.0830674 1.0835745]
[1.0834701 1.0833898 1.0830349 1.0832455 1.0832254]
[1.083132  1.0830543 1.0832282 1.0835247 1.0832232]


In [None]:
rng = jax.random.PRNGKey(0)
for _ in range(5):
  rng, _ = jax.random.split(rng)
  print(jnp.exp(μs + σs * jax.random.normal(rng,(T,))[:5]))

[1.3033675 1.1019223 1.0690558 0.892919  1.1381866]
[1.3243431 1.1481453 1.4099396 1.8082697 1.3144269]
[1.0459208 1.0225538 1.1918788 1.1268678 1.0007964]
[0.9401438 1.0572538 1.2984145 0.8893047 0.7812677]
[0.7606866 1.146791  1.2302936 1.0184064 1.2396514]


In [None]:
def jax_mean_return(rng):
  def stock_return(x,rng):
    μs = 0.06
    σs = 0.2  
    ε = jax.random.normal(rng, ())
    log_return = μs + σs * ε
    return jnp.exp(log_return), jnp.exp(log_return)
  input = rng
  return jax.lax.scan(stock_return, init=0, xs= rng)

rng = jax.random.PRNGKey(0)
jax_mean_return(jax.random.split(rng, 1000000))[1].mean()

DeviceArray(1.0834519, dtype=float32)

In [None]:
import jax
import jax.numpy as jnp
import optax
import haiku as hk


γ = 2.
β = 0.95
n_samples = 1000000
batch_size = 10000
μs = 0.06
σs = 0.2  

def U(c):
    return c**(1 - γ) / (1 - γ)


optimizer = optax.adam
lr = 1e-3
T = 50


def nnet(x):
  X = jnp.column_stack([x])
  X = hk.Linear(32)(X)
  X = jnp.tanh(X)
  X = hk.Linear(1)(X)
  X = jnp.squeeze(X)
  return X



init, nnet = hk.without_apply_rng(hk.transform(nnet))
rng = jax.random.PRNGKey(0)
Θ = init(rng, jnp.array(1.))


opt_state = optimizer(lr).init(Θ)


def L(Θ, batch_size, rng):

  x = 1.
  G = 0.

  state = jnp.ones((batch_size,))
  stock_returns = jnp.exp(μs + σs * jax.random.normal(rng, (T, batch_size))) # shape (50, 200)
  inputs = jnp.arange(T), stock_returns 
  #print(inputs[0].shape, inputs[1].shape)
  
  def core(state, inputs):
    t, R = inputs # t-> scalar, R -> (200,)
    xt = state # array (200,)

    ct = jax.nn.sigmoid(nnet(Θ, xt) - 4.) * xt 
    ut = U(ct) # array (200, )
    savings = xt - ct 

    x_tp1 = R * savings

    discounted_utility = β**t * ut
    return x_tp1, discounted_utility
  
  x, discounted_utility = jax.lax.scan(core , state, inputs)
  G = discounted_utility.mean(axis=1).sum()
  return -G


@jax.jit
def evaluation(Θ):
  rng = jax.random.PRNGKey(0)
  return -L(Θ, n_samples, rng)


@jax.jit
def update_gradient_descent(rng, Θ, opt_state):
  rng, _ = jax.random.split(rng)
  grad = jax.grad(L)(Θ, batch_size,rng)
  updates, opt_state = optimizer(lr).update(grad, opt_state)
  Θ = optax.apply_updates(Θ, updates)
  return rng, Θ, opt_state


In [None]:
L(Θ, 200, rng)

DeviceArray(694.96045, dtype=float32)

In [None]:
rng = jax.random.PRNGKey(0)
for iteration in range(100000):
  rng, Θ, opt_state = update_gradient_descent(rng, Θ, opt_state)

  if iteration % 1000 == 0:
    print(evaluation(Θ))

-677.7496
-434.88882
-434.11334
-433.84802
-433.70032
-433.6125
-433.5578
-433.43768
-433.3656
-433.29797
-433.2447
-433.19153
-433.15735
-433.1344
-433.13126
-433.13556
-433.0511
-433.0437
-433.0252
-433.02295
-433.02264
-433.0065
-432.99207
-432.98975
-432.98056
-433.0334
-432.97656
-433.0379
-432.96753
-432.99396
-432.97968
-432.96393
-432.98825
-432.9608
-432.9598
-433.00934
-432.98218
-432.95618
-432.94287
-432.96173
-432.9615
-432.93796
-432.93393
-432.9399
-432.94553
-432.93622
-432.92776
-432.94223
-432.93408
-433.02484
-432.92572
-432.9438
-432.93103
-432.9436
-432.9215
-432.9455
-432.94867
-432.91907
-432.94354
-432.91785
-432.96182
-432.92892
-432.91602
-432.98203
-432.91827
-432.91284
-432.90497
-432.90845
-432.91418
-432.92038
-432.90778
-432.92188
-432.95868
-432.92438
-432.929
-432.9549
-432.8949
-432.91852
-432.90054
-432.89713
-432.90515
-432.89398
-432.89236
-432.89856
-432.891
-432.89386
-432.93103
-432.8956
-432.89545
-432.88116
-432.88235
-432.88498
-432.88098
-432