Skip to content

Conversation

@jstac
Copy link
Contributor

@jstac jstac commented Nov 23, 2025

Summary

This PR unifies the Coleman-Reffett operator K signature between the NumPy implementation (cake_eating_egm.md) and the JAX implementation (cake_eating_egm_jax.md).

Changes

  • Updated K operator to take (c_in, x_in, model) parameters and return (c_out, x_out) tuple
  • Updated solve_model_time_iter to take (c_init, x_init) and return (c, x)
  • Applied same signature changes to K_crra and solve_model_crra in the exercises section
  • Updated all function calls and variable names for consistency

Implementation Details

The efficient JAX implementation is fully preserved:

  • Vectorization using vmap for parallel computation over grid points
  • JIT compilation for optimal performance
  • Use of jax.lax.while_loop for iteration

Test Results

✅ Successfully tested with:

  • Maximum absolute deviation from analytical solution: 1.43e-06
  • Execution time: ~0.009 seconds
  • All CRRA exercise tests passing

The changes maintain backward compatibility with the mathematical formulation while ensuring consistency across both lecture implementations.

🤖 Generated with Claude Code

Update the Coleman-Reffett operator K and solver functions in the JAX
implementation to match the signature from the NumPy version:

- K now takes (c_in, x_in, model) and returns (c_out, x_out)
- solve_model_time_iter now takes (c_init, x_init) and returns (c, x)
- Applied same changes to K_crra and solve_model_crra in exercises

The efficient JAX implementation is fully preserved:
- Vectorization with vmap
- JIT compilation
- Use of jax.lax.while_loop

Tested successfully with maximum deviation of 1.43e-06 from analytical
solution and execution time of ~0.009 seconds.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
@jstac
Copy link
Contributor Author

jstac commented Nov 23, 2025

Additional Context

This change addresses a parameter signature inconsistency that was introduced in the recent refactoring (PR #730). By unifying the signatures, we ensure:

  1. Consistency: Both lectures now use the same function signature pattern
  2. Clarity: The parameter names (, , , ) clearly indicate consumption and endogenous grid values
  3. Maintainability: Future updates to one lecture can be more easily applied to the other

The JAX implementation benefits are retained:

  • Performance: JIT compilation and vectorization via
  • Clarity: Same mathematical structure as the NumPy version
  • Efficiency: Sub-10ms solve time with high accuracy

All changes have been tested by converting to Python with jupytext and executing successfully.

jstac and others added 2 commits November 24, 2025 04:33
The α parameter is already defined with a default value (0.4) in the
create_model function, so there's no need to set it as a global variable
and pass it explicitly.

Simplified:
- model = create_model() instead of α = 0.4; model = create_model(α=α)
- model_crra = create_model() in the CRRA exercise section

Tested successfully with same results as before.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
The α parameter doesn't need to be passed explicitly to create_model
since it already has a default value of 0.4. The α = 0.4 line is still
needed for the lambda function closures (f and f_prime capture it).

Changed:
- create_model(u=u, f=f, α=α, ...)
+ create_model(u=u, f=f, ...)

Tested successfully with same convergence behavior.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
@jstac
Copy link
Contributor Author

jstac commented Nov 23, 2025

Code Cleanup

Added two commits to clean up redundant alpha parameter usage in both lectures:

JAX version ():

  • Removed unnecessary global α = 0.4 variable
  • Since f and f_prime take α as an explicit parameter, no global is needed
  • Simplified to create_model() which uses the default α = 0.4

NumPy version (cake_eating_egm.md):

  • Removed redundant α=α from create_model() call
  • Kept α = 0.4 definition since the lambda closures f = lambda k: k**α capture it
  • Still uses default parameter value from create_model()

Both versions tested successfully with identical results to before. This makes the code cleaner and more consistent.

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-732--sunny-cactus-210e3e.netlify.app (7212a13)

📚 Changed Lecture Pages: cake_eating_egm_jax

Changed from closure-based approach to explicit parameter passing:
- f = lambda k, α: k**α (instead of f = lambda k: k**α with global α)
- f_prime = lambda k, α: α * k**(α - 1)
- Updated K operator to call f(s, α) and f_prime(s, α)

This makes the NumPy version consistent with the JAX implementation and
ensures the α stored in the Model is actually used in the K operator
(previously it was unpacked but unused).

Benefits:
- Consistency between NumPy and JAX versions
- Clearer function dependencies (α is an explicit parameter)
- Actually uses model.α instead of relying on closure

Tested successfully with same convergence behavior.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
@jstac
Copy link
Contributor Author

jstac commented Nov 23, 2025

Function Parameter Refactoring

Added another commit to improve consistency and fix an issue where model.α was stored but unused:

Problem:
In the NumPy version, the K operator unpacked α from the model but never used it. Instead, f and f_prime were defined as closures that captured the global α = 0.4.

Solution:
Refactored to match the JAX version's approach:

  • f = lambda k, α: k**α (takes α as explicit parameter)
  • f_prime = lambda k, α: α * k**(α - 1)
  • K operator now calls f(s, α) and f_prime(s, α)
  • Removed global α = 0.4 (no longer needed)

Benefits:

  1. Consistency: Both NumPy and JAX versions now use the same pattern
  2. Clarity: Function dependencies on α are explicit, not via closure
  3. Correctness: The α stored in the model is actually used
  4. Maintainability: Easier to understand and modify

Tested successfully - same convergence behavior as before.

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-732--sunny-cactus-210e3e.netlify.app (2d2e99f)

📚 Changed Lecture Pages: cake_eating_egm_jax

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-732--sunny-cactus-210e3e.netlify.app (1ca20dd)

📚 Changed Lecture Pages: cake_eating_egm, cake_eating_egm_jax

Changed all references from 'grid' to 's_grid' to match the NumPy
implementation and clarify that this is the exogenous savings grid:

- Model.grid → Model.s_grid
- Updated comment: "state grid" → "exogenous savings grid"
- Updated all variable names throughout (K, K_crra, initializations)
- Also renamed loop variable from 'k' to 's' for consistency

This makes the JAX version consistent with the NumPy version's naming
conventions and makes it clearer that we're working with the exogenous
grid for savings (not the endogenous grid for wealth x).

Tested successfully with identical results.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
@jstac
Copy link
Contributor Author

jstac commented Nov 23, 2025

Naming Consistency: grid → s_grid

Added another commit to improve naming consistency between the two lectures:

Changes:

  • Renamed grid to s_grid throughout the JAX version
  • Updated comment: "state grid" → "exogenous savings grid"
  • Renamed loop variable from k to s in K operators (savings, not capital)

Rationale:
The NumPy version clearly names this s_grid (savings grid) and comments it as the "exogenous savings grid". This distinguishes it from:

  • The endogenous wealth grid x = s + c
  • Making it clear we're iterating over savings values, not capital or wealth

Consistency achieved:
Both lectures now use identical naming conventions, making it easier for readers to understand the relationship between the exogenous grid (savings s) and the endogenous grid (wealth x).

Tested successfully with same results.

Changed remaining instances where k was used instead of s:
- Mathematical notation: x = k + σ(k) → x = s + σ(s)
- Added missing inline comment in K_crra: x_i = s_i + c_i

This completes the transition to using 's' for savings throughout,
maintaining consistency with the exogenous savings grid terminology.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-732--sunny-cactus-210e3e.netlify.app (4eac159)

📚 Changed Lecture Pages: cake_eating_egm, cake_eating_egm_jax

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-732--sunny-cactus-210e3e.netlify.app (d90822f)

📚 Changed Lecture Pages: cake_eating_egm, cake_eating_egm_jax

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-732--sunny-cactus-210e3e.netlify.app (1d2f54e)

📚 Changed Lecture Pages: cake_eating_egm, cake_eating_egm_jax

@jstac jstac merged commit ed8e89a into main Nov 23, 2025
1 check passed
@jstac jstac deleted the egm_jax branch November 23, 2025 20:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants