-
-
Notifications
You must be signed in to change notification settings - Fork 55
Unify K operator signature in cake_eating_egm_jax with cake_eating_egm #732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Update the Coleman-Reffett operator K and solver functions in the JAX implementation to match the signature from the NumPy version: - K now takes (c_in, x_in, model) and returns (c_out, x_out) - solve_model_time_iter now takes (c_init, x_init) and returns (c, x) - Applied same changes to K_crra and solve_model_crra in exercises The efficient JAX implementation is fully preserved: - Vectorization with vmap - JIT compilation - Use of jax.lax.while_loop Tested successfully with maximum deviation of 1.43e-06 from analytical solution and execution time of ~0.009 seconds. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
Additional ContextThis change addresses a parameter signature inconsistency that was introduced in the recent refactoring (PR #730). By unifying the signatures, we ensure:
The JAX implementation benefits are retained:
All changes have been tested by converting to Python with jupytext and executing successfully. |
The α parameter is already defined with a default value (0.4) in the create_model function, so there's no need to set it as a global variable and pass it explicitly. Simplified: - model = create_model() instead of α = 0.4; model = create_model(α=α) - model_crra = create_model() in the CRRA exercise section Tested successfully with same results as before. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
The α parameter doesn't need to be passed explicitly to create_model since it already has a default value of 0.4. The α = 0.4 line is still needed for the lambda function closures (f and f_prime capture it). Changed: - create_model(u=u, f=f, α=α, ...) + create_model(u=u, f=f, ...) Tested successfully with same convergence behavior. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
Code CleanupAdded two commits to clean up redundant alpha parameter usage in both lectures: JAX version ():
NumPy version (
Both versions tested successfully with identical results to before. This makes the code cleaner and more consistent. |
|
📖 Netlify Preview Ready! Preview URL: https://pr-732--sunny-cactus-210e3e.netlify.app (7212a13) 📚 Changed Lecture Pages: cake_eating_egm_jax |
Changed from closure-based approach to explicit parameter passing: - f = lambda k, α: k**α (instead of f = lambda k: k**α with global α) - f_prime = lambda k, α: α * k**(α - 1) - Updated K operator to call f(s, α) and f_prime(s, α) This makes the NumPy version consistent with the JAX implementation and ensures the α stored in the Model is actually used in the K operator (previously it was unpacked but unused). Benefits: - Consistency between NumPy and JAX versions - Clearer function dependencies (α is an explicit parameter) - Actually uses model.α instead of relying on closure Tested successfully with same convergence behavior. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
Function Parameter RefactoringAdded another commit to improve consistency and fix an issue where Problem: Solution:
Benefits:
Tested successfully - same convergence behavior as before. |
|
📖 Netlify Preview Ready! Preview URL: https://pr-732--sunny-cactus-210e3e.netlify.app (2d2e99f) 📚 Changed Lecture Pages: cake_eating_egm_jax |
|
📖 Netlify Preview Ready! Preview URL: https://pr-732--sunny-cactus-210e3e.netlify.app (1ca20dd) 📚 Changed Lecture Pages: cake_eating_egm, cake_eating_egm_jax |
Changed all references from 'grid' to 's_grid' to match the NumPy implementation and clarify that this is the exogenous savings grid: - Model.grid → Model.s_grid - Updated comment: "state grid" → "exogenous savings grid" - Updated all variable names throughout (K, K_crra, initializations) - Also renamed loop variable from 'k' to 's' for consistency This makes the JAX version consistent with the NumPy version's naming conventions and makes it clearer that we're working with the exogenous grid for savings (not the endogenous grid for wealth x). Tested successfully with identical results. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
Naming Consistency: grid → s_gridAdded another commit to improve naming consistency between the two lectures: Changes:
Rationale:
Consistency achieved: Tested successfully with same results. |
Changed remaining instances where k was used instead of s: - Mathematical notation: x = k + σ(k) → x = s + σ(s) - Added missing inline comment in K_crra: x_i = s_i + c_i This completes the transition to using 's' for savings throughout, maintaining consistency with the exogenous savings grid terminology. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
|
📖 Netlify Preview Ready! Preview URL: https://pr-732--sunny-cactus-210e3e.netlify.app (4eac159) 📚 Changed Lecture Pages: cake_eating_egm, cake_eating_egm_jax |
|
📖 Netlify Preview Ready! Preview URL: https://pr-732--sunny-cactus-210e3e.netlify.app (d90822f) 📚 Changed Lecture Pages: cake_eating_egm, cake_eating_egm_jax |
|
📖 Netlify Preview Ready! Preview URL: https://pr-732--sunny-cactus-210e3e.netlify.app (1d2f54e) 📚 Changed Lecture Pages: cake_eating_egm, cake_eating_egm_jax |
Summary
This PR unifies the Coleman-Reffett operator
Ksignature between the NumPy implementation (cake_eating_egm.md) and the JAX implementation (cake_eating_egm_jax.md).Changes
Koperator to take(c_in, x_in, model)parameters and return(c_out, x_out)tuplesolve_model_time_iterto take(c_init, x_init)and return(c, x)K_crraandsolve_model_crrain the exercises sectionImplementation Details
The efficient JAX implementation is fully preserved:
vmapfor parallel computation over grid pointsjax.lax.while_loopfor iterationTest Results
✅ Successfully tested with:
1.43e-06~0.009 secondsThe changes maintain backward compatibility with the mathematical formulation while ensuring consistency across both lecture implementations.
🤖 Generated with Claude Code