In [None]:
pip install -q optax dm-haiku

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/371.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.4/371.0 kB[0m [31m4.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m371.0/371.0 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25h

Due Date: November 6th
# Part 1 (80 %)

The code below is similar to the Cake Eating problem code we implemented in class. The difference is that the consumption policy function is written as a simple sigle-layer neural network, with tanh activation.

We will interpret the size of the cake as being total wealth, and cake consumption as general consumption. The fraction of wealth not consumed today are the *savings* (line 51). The dynamics of wealth are described by line 54. That line is equivalent to assuming that your savings are invested in a risk-free savings account that pays 0 interest, and therefore has a gross return of 1, denoted by *R* (line 53).




In [None]:
import jax
import jax.numpy as jnp
import optax
import haiku as hk


γ = 2.
β = 0.95


def U(c):
    return c**(1 - γ) / (1 - γ)


optimizer = optax.adam
lr = 1e-3
T = 50


def nnet(x):
  X = jnp.column_stack([x])
  X = hk.Linear(32)(X)
  X = jnp.tanh(X)
  X = hk.Linear(1)(X)
  X = jnp.squeeze(X)
  return X


init, nnet = hk.without_apply_rng(hk.transform(nnet))
rng = jax.random.PRNGKey(0)
Θ = init(rng, jnp.array(1.))


opt_state = optimizer(lr).init(Θ)


def L(Θ):

  x = 1.
  G = 0.

  state = x
  inputs = jnp.arange(T)

  def core(state, inputs):
    t = inputs
    xt = state

    ct = jax.nn.sigmoid(nnet(Θ, xt) - 4.) * xt
    ut = U(ct)
    savings = xt - ct

    R = 1.
    x_tp1 = R * savings

    discounted_utility = β**t * ut
    return x_tp1, discounted_utility

  x, discounted_utility = jax.lax.scan(core, state, inputs)
  G = discounted_utility.sum()
  return -G


@jax.jit
def evaluation(Θ):
  return -L(Θ)


@jax.jit
def update_gradient_descent(Θ, opt_state):
  grad = jax.grad(L)(Θ)
  updates, opt_state = optimizer(lr).update(grad, opt_state)
  Θ = optax.apply_updates(Θ, updates)
  return Θ, opt_state


for iteration in range(100000):
  Θ, opt_state = update_gradient_descent(Θ, opt_state)

  if iteration % 1000 == 0:
    print('The Utility is',evaluation(Θ))
    # print('The proportion is', opt_state)

The Utility is -1301.506
The Utility is -869.2982
The Utility is -846.71466
The Utility is -839.00684
The Utility is -833.38855
The Utility is -830.0226
The Utility is -828.1885
The Utility is -826.8771
The Utility is -825.7527
The Utility is -824.7186
The Utility is -823.7797
The Utility is -822.95953
The Utility is -822.26624
The Utility is -821.70044
The Utility is -821.2106
The Utility is -820.8086
The Utility is -820.4761
The Utility is -820.17773
The Utility is -819.92346
The Utility is -819.6885
The Utility is -819.4805
The Utility is -819.30194
The Utility is -819.0996
The Utility is -818.935
The Utility is -818.78064
The Utility is -818.6075
The Utility is -818.45886
The Utility is -818.3211
The Utility is -818.1902
The Utility is -818.0672
The Utility is -817.9671
The Utility is -817.8436
The Utility is -817.7755
The Utility is -817.6581
The Utility is -817.5817
The Utility is -817.4724
The Utility is -817.3819
The Utility is -817.32
The Utility is -817.2313
The Utility is -8

Suppose now that your savings are fully invested in the stock market, so the evolution of wealth is now stochastic. The stock market gross return is modeled by the function below:

In [None]:
def stock_return(rng):
  μs = 0.06
  σs = 0.2
  ε = jax.random.normal(rng, ())
  log_return = μs + σs * ε
  return jnp.exp(log_return)

In [None]:
γ = 2.
β = 0.95


def U(c):
    return c**(1 - γ) / (1 - γ)


optimizer = optax.adam
lr = 1e-3
T = 50


def nnet(x):
  X = jnp.column_stack([x])
  X = hk.Linear(32)(X)
  X = jnp.tanh(X)
  X = hk.Linear(1)(X)
  X = jnp.squeeze(X)
  return X


init, nnet = hk.without_apply_rng(hk.transform(nnet))
rng = jax.random.PRNGKey(0)
Θ = init(rng, jnp.array(1.))


opt_state = optimizer(lr).init(Θ)


def L(Θ):

  x = 1.
  G = 0.

  state = x
  inputs = jnp.arange(T)

  def core(state, inputs):
    t = inputs
    xt = state

    ct = jax.nn.sigmoid(nnet(Θ, xt) - 4.) * xt
    ut = U(ct)
    savings = xt - ct

    R = stock_return(rng)
    x_tp1 = R * savings

    discounted_utility = β**t * ut
    return x_tp1, discounted_utility

  x, discounted_utility = jax.lax.scan(core, state, inputs)
  G = discounted_utility.sum()
  return -G


@jax.jit
def evaluation(Θ):
  return -L(Θ)


@jax.jit
def update_gradient_descent(Θ, opt_state):
  grad = jax.grad(L)(Θ)
  updates, opt_state = optimizer(lr).update(grad, opt_state)
  Θ = optax.apply_updates(Θ, updates)
  return Θ, opt_state


for iteration in range(100000):
  Θ, opt_state = update_gradient_descent(Θ, opt_state)

  if iteration % 1000 == 0:
    print(evaluation(Θ))

-918.53546
-603.8025
-589.3341
-585.246
-583.1464
-581.8572
-581.0594
-580.55164
-580.21466
-579.953
-579.72296
-579.49524
-579.26764
-579.0636
-578.91895
-578.6887
-578.5043
-578.35565
-578.21326
-578.06964
-577.96515
-577.843
-577.7491
-577.7085
-577.6126
-577.50977
-577.4608
-577.3866
-577.36285
-577.2839
-577.25745
-577.24945
-577.2052
-577.1188
-577.0862
-577.0678
-577.0211
-576.99084
-576.9856
-576.9595
-576.9342
-576.9092
-576.88696
-576.865
-576.8428
-576.8224
-576.8032
-576.78516
-576.7652
-576.74866
-576.73145
-576.716
-576.7004
-576.68494
-576.6708
-576.65674
-576.64294
-576.63074
-576.61896
-576.6061
-576.59454
-576.5834
-576.5721
-576.56177
-576.5514
-576.54205
-576.5316
-576.52216
-576.5133
-576.5047
-576.4957
-576.4874
-576.4801
-576.4719
-576.4646
-576.45605
-576.4493
-576.4429
-576.43616
-576.4296
-576.4227
-576.4164
-576.4103
-576.4043
-576.3988
-576.39246
-576.38696
-576.3812
-576.37616
-576.37115
-576.3656
-576.36096
-576.3557
-576.3507
-576.3457
-576.34155
-576.337

# Part 2 (20 %)
Suppose that instead of investing the wealth entirely in the stock market, you have the option to assign a fraction $\alpha$ of your wealth in the stock market, and the remaining is invested in a risk-free savings account that pays a 1.04 % gross return. Notice that $\alpha$ is bounded below by 0, and bounded above by 1.

Solve for the optimal consumption ($c$) and asset allocation ($\alpha$).

- Print the average sum of discounted rewards (utilities) using 1 million simulations.

 - Plot the average consumption-wealth ratio ($c / x)$ for each time period $t=0, 1, ..., 49$

 - Plot the average asset allocation in the risky asset($\alpha)$ for each time period $t=0, 1, ..., 49$

Hint: Starting from the code of the previous assignment, the modifications you have to implement are minimal. Namely:

- The output of the neural network now should be a 2d vector, corresponding to the consumption-wealth ratio (c / x) and $\alpha$, respectively


In [None]:
import jax
import jax.numpy as jnp

def scan_fn(carry, x):
    new_carry = carry * 2
    return new_carry, new_carry + x

init_state = jnp.array(1.0)
input_sequence = jnp.array([1.0, 2.0, 3.0, 4.0])

final_state, result_sequence = jax.lax.scan(scan_fn, init_state, input_sequence)

print("Final State:", final_state)
print("Result Sequence:", result_sequence)


Final State: 16.0
Result Sequence: [ 3.  6. 11. 20.]


In [None]:
import jax
import jax.numpy as jnp
import optax
import haiku as hk


γ = 2.
β = 0.95


def U(c):
    return c**(1 - γ) / (1 - γ)


optimizer = optax.adam
lr = 1e-3
T = 50
rf_return = 1.0104


def nnet(x):
  X = jnp.column_stack([x])
  X = hk.Linear(32)(X)
  X = jnp.tanh(X)
  X = hk.Linear(2)(X)
  # X = jnp.squeeze(X)
  return X


init, nnet = hk.without_apply_rng(hk.transform(nnet))
rng = jax.random.PRNGKey(0)
Θ = init(rng, jnp.array(1.))


opt_state = optimizer(lr).init(Θ)


def L(Θ):

  x = 1.
  G = 0.

  state = x
  inputs = jnp.arange(T)

  def core(state, inputs):
    t = inputs
    xt = state

    nn_output = nnet(Θ, xt)
    ct_x_ratio, alpha = jnp.split(nn_output, 2, axis = 1)
    alpha = jax.nn.sigmoid(alpha)
    ct = ct_x_ratio * xt
    ut = U(ct)
    savings = xt - ct

    R = stock_return(rng)*alpha + rf_return*(1-alpha)
    x_tp1 = R * savings

    discounted_utility = β**t * ut
    return x_tp1, discounted_utility

  x, discounted_utility = jax.lax.scan(core, state, inputs)
  G = discounted_utility.sum()
  return -G


@jax.jit
def evaluation(Θ):
  return -L(Θ)


@jax.jit
def update_gradient_descent(Θ, opt_state):
  grad = jax.grad(L)(Θ)
  updates, opt_state = optimizer(lr).update(grad, opt_state)
  Θ = optax.apply_updates(Θ, updates)
  return Θ, opt_state


for iteration in range(100000):
  Θ, opt_state = update_gradient_descent(Θ, opt_state)

  if iteration % 1000 == 0:
    print('The Utility is',evaluation(Θ))

# Simulation
avg_rewards = 0.0
avg_ct_x_ratios = jnp.zeros(T)
avg_alphas = jnp.zeros(T)

for _ in range(num_simulations):
    rng, _ = jax.random.split(rng)
    Θ, opt_state = update_gradient_descent(Θ, opt_state)
    avg_rewards += evaluation(Θ) / num_simulations

    state = 1.
    inputs = jnp.arange(T)
    (_, results), _ = jax.lax.scan(core, state, inputs)
    avg_ct_x_ratios += (results[:, 0] / inputs.size) / num_simulations
    avg_alphas += (results[:, 1] / inputs.size) / num_simulations

print("Average Sum of Discounted Rewards:", avg_rewards)

# Plot average consumption-wealth ratio (c / x) for each time period
import matplotlib.pyplot as plt
plt.plot(jnp.arange(T), avg_ct_x_ratios)
plt.xlabel("Time Period (t)")
plt.ylabel("Average Consumption-Wealth Ratio (c / x)")
plt.title("Average Consumption-Wealth Ratio Over Time")
plt.show()

# Plot average asset allocation 𝛼 for each time period
plt.plot(jnp.arange(T), avg_alphas)
plt.xlabel("Time Period (t)")
plt.ylabel("Average Asset Allocation (α)")
plt.title("Average Asset Allocation Over Time")
plt.show()

TypeError: ignored