Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 85 additions & 60 deletions lectures/ifp_dl.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,10 @@ class Config(NamedTuple):
"""
seed: int = 1234 # Seed for network initialization
epochs: int = 400 # No of training epochs
path_length: int = 320 # Length of each consumption path
path_length: int = 200 # Length of each consumption path
layer_sizes: tuple = (1, 6, 6, 6, 1) # Network layer sizes
learning_rate: float = 0.001 # Constant learning rate
num_paths: int = 220 # Number of paths to average over
num_paths: int = 100 # Number of paths to average over
```

We use a class called `LayerParams` to store parameters representing a single
Expand All @@ -255,9 +255,13 @@ The following function initializes a single layer of the network using Le Cun
initialization.

```{code-cell} ipython3
def initialize_layer(in_dim, out_dim, key):
def initialize_layer(
in_dim: int, # Input dimension for the layer
out_dim: int, # Output dimension for the layer
key: jax.Array # Random key for initialization
):
"""
Initialize weights and biases for a single layer of a the network.
Initialize weights and biases for a single layer of the network.
Use LeCun initialization.

"""
Expand All @@ -271,10 +275,13 @@ The next function builds an entire network, as represented by its parameters, by
initializing layers and stacking them into a list.

```{code-cell} ipython3
def initialize_network(key, layer_sizes):
def initialize_network(
key: jax.Array, # Random key for initialization
layer_sizes: tuple # Layer sizes (input, hidden..., output)
):
"""
Build a network by initializing all of the parameters.
A network is a list of LayerParams instances, each
A network is a list of LayerParams instances, each
containing a weight-bias pair (W, b).

"""
Expand All @@ -299,10 +306,10 @@ Here's a function to train the network by gradient ascent, given a generic loss
function.

```{code-cell} ipython3
@partial(jax.jit, static_argnames=('config', 'loss_fn'))
def train_network(
config: Config, # Configuration object with training parameters
loss_fn: callable, # Loss function taking params and returning loss
print_interval: int = 100 # How often to print progress
config: Config, # Configuration with training parameters
loss_fn: callable # Loss function taking params, returning loss
):
"""
Train a neural network using policy gradient ascent.
Expand All @@ -323,29 +330,41 @@ def train_network(
)
opt_state = optimizer.init(params)

# Training loop
value_history = []
best_value = -jnp.inf
best_params = params
# Training loop state
def step(i, state):
params, opt_state, best_value, best_params, value_history = state

for i in range(config.epochs):
# Compute value and gradients at existing parameterization
loss, grads = jax.value_and_grad(loss_fn)(params)
lifetime_value = - loss
value_history.append(lifetime_value)
lifetime_value = -loss

# Update value history
value_history = value_history.at[i].set(lifetime_value)

# Track best parameters
if lifetime_value > best_value:
best_value = lifetime_value
best_params = params
is_best = lifetime_value > best_value
best_value = jnp.where(is_best, lifetime_value, best_value)
best_params = jax.tree.map(
lambda new, old: jnp.where(is_best, new, old),
params, best_params
)

# Update parameters using optimizer
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
if i % print_interval == 0:
print(f"Iteration {i}: Value = {lifetime_value:.4f}")

# Use best parameters instead of final
params = best_params
return params, value_history, best_value
return params, opt_state, best_value, best_params, value_history

# Run training loop
value_history = jnp.zeros(config.epochs)
initial_state = (params, opt_state, -jnp.inf, params, value_history)
final_state = jax.lax.fori_loop(
0, config.epochs, step, initial_state
)

# Extract results
_, _, best_value, best_params, value_history = final_state
return best_params, value_history, best_value
```


Expand Down Expand Up @@ -398,7 +417,10 @@ Now we provide a function that implements a consumption policy as a neural netwo
parameters of the network.

```{code-cell} ipython3
def forward(params, a):
def forward(
params: list, # Network parameters (LayerParams list)
a: float # Current asset level
):
"""
Evaluate neural network policy: maps a given asset level a to
consumption rate c/a by running a forward pass through the network.
Expand All @@ -424,7 +446,11 @@ network.

```{code-cell} ipython3
@partial(jax.jit, static_argnames=('path_length'))
def compute_lifetime_value(params, cake_eating_model, path_length):
def compute_lifetime_value(
params: list, # Network parameters
cake_eating_model: tuple, # Model parameters (γ, β, R)
path_length: int # Length of simulation path
):
"""
Compute the lifetime value of a path generated from
the policy embedded in params and the initial condition a_0 = 1.
Expand Down Expand Up @@ -503,8 +529,13 @@ config = Config(num_paths=1)
# Create a loss function that has params as the only argument
loss_fn = lambda params: loss_function(params, model, config.path_length)

# Warmup to trigger JIT compilation
print("Warming up JIT compilation...")
_ = train_network(config, loss_fn)

start_time = time.time()
params, value_history, best_value = train_network(config, loss_fn)
best_value.block_until_ready()
elapsed = time.time() - start_time

print(f"\nBest value: {best_value:.4f}")
Expand Down Expand Up @@ -663,20 +694,16 @@ We approximate this expectation using Monte Carlo.
Here is the EGM operator $K$ for the IID case:

```{code-cell} ipython3
def K(c_in, a_in, ifp, s_grid, n_shocks=50):
def K(
c_in: jnp.ndarray, # Current consumption policy on endogenous grid
a_in: jnp.ndarray, # Current endogenous asset grid
ifp: IFP, # IFP model instance
s_grid: jnp.ndarray, # Exogenous savings grid
n_shocks: int = 50 # Number of points for Monte Carlo integration
):
"""
The Euler equation operator for the IFP model with IID shocks using EGM.

Args:
c_in: Current consumption policy on endogenous grid
a_in: Current endogenous asset grid
ifp: IFP model instance
s_grid: Exogenous savings grid
n_shocks: Number of points for Monte Carlo integration

Returns:
c_out: Updated consumption policy
a_out: Updated endogenous asset grid
"""
R, β, γ, z_mean, z_std, z_samples = ifp
y_samples = jnp.exp(z_samples)
Expand Down Expand Up @@ -714,20 +741,16 @@ def K(c_in, a_in, ifp, s_grid, n_shocks=50):
Here's the solver using time iteration:

```{code-cell} ipython3
def solve_model(ifp, s_grid, n_shocks=50, tol=1e-5, max_iter=1000):
def solve_model(
ifp: IFP, # IFP model instance
s_grid: jnp.ndarray, # Exogenous savings grid
n_shocks: int = 50, # Number of income shock realizations
tol: float = 1e-5, # Convergence tolerance
max_iter: int = 1000 # Maximum iterations
):
"""
Solve the IID model using time iteration with EGM.

Args:
ifp: IFP model instance
s_grid: Exogenous savings grid
n_shocks: Number of income shock realizations for integration
tol: Convergence tolerance
max_iter: Maximum iterations

Returns:
c_out: Optimal consumption policy on endogenous grid
a_out: Endogenous asset grid
"""
# Initialize with consumption = assets (consume everything)
a_init = s_grid.copy()
Expand Down Expand Up @@ -793,20 +816,17 @@ The key is to simulate paths with IID normal income shocks.

```{code-cell} ipython3
@partial(jax.jit, static_argnames=('path_length', 'num_paths'))
def compute_lifetime_value_ifp(params, ifp, path_length, num_paths, key):
def compute_lifetime_value_ifp(
params: list, # Neural network parameters
ifp: IFP, # IFP model instance
path_length: int, # Length of each simulated path
num_paths: int, # Number of paths to simulate for averaging
key: jax.Array # JAX random key for generating income shocks
):
"""
Compute expected lifetime value by averaging over multiple
Compute expected lifetime value by averaging over multiple
simulated paths.

Args:
params: Neural network parameters
ifp: IFP model instance
path_length: Length of each simulated path
num_paths: Number of paths to simulate for averaging
key: JAX random key for generating income shocks

Returns:
Average lifetime value across all simulated paths
"""
R, β, γ, z_mean, z_std, z_samples = ifp

Expand Down Expand Up @@ -866,10 +886,15 @@ ifp_loss_fn = lambda params: loss_function_ifp(
params, ifp, config.path_length, config.num_paths, key
)

# Warmup to trigger JIT compilation
print("Warming up JIT compilation...")
_ = train_network(config, ifp_loss_fn)

start_time = time.time()
ifp_params, ifp_value_history, best_ifp_value = train_network(
config, ifp_loss_fn, print_interval=50
config, ifp_loss_fn
)
best_ifp_value.block_until_ready()
elapsed = time.time() - start_time

print(f"\nBest value: {best_ifp_value:.4f}")
Expand Down
Loading