-
-
Notifications
You must be signed in to change notification settings - Fork 55
Add JAX implementation to ifp_advanced lecture #705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Renamed "Implementation" section to "Numba Implementation" - Added new "JAX Implementation" section before "Exercises" - Implemented IFP_JAX as NamedTuple for JAX JIT compatibility - Created global utility functions (u_prime, u_prime_inv, R, Y) - Added create_ifp_jax() factory function - Implemented K_jax Coleman-Reffett operator with JAX - Added solve_model_time_iter_jax solver - Included comparison section showing Numba vs JAX solutions - Configured JAX for 64-bit precision - Fixed import conflicts between numba.jit and jax.jit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
Implementation HighlightsWhy NamedTuple?The JAX implementation uses
Code Architecture# Traditional class approach (doesn't work with JAX JIT)
class IFP_JAX:
def __init__(self, ...):
self.γ = γ
self.P = jnp.array(P) # Array attribute prevents hashing
# Our solution: NamedTuple + factory pattern
class IFP_JAX(NamedTuple):
γ: float
P: jnp.ndarray
def create_ifp_jax(...): # Factory function handles construction
return IFP_JAX(γ=γ, P=jnp.array(P), ...)Performance ConsiderationsThe JAX implementation:
Example Usage ComparisonNumba: ifp = IFP()
a_star, σ_star = solve_model_time_iter(ifp, a_init, σ_init)JAX: ifp_jax = create_ifp_jax()
a_star_jax, σ_star_jax = solve_model_time_iter_jax(ifp_jax, a_init_jax, σ_init_jax)The API remains similar, making it easy for students to compare both approaches. |
|
📖 Netlify Preview Ready! Preview URL: https://pr-705--sunny-cactus-210e3e.netlify.app (a532c78) 📚 Changed Lecture Pages: ifp_advanced |
- Add bridging text connecting mathematical equations to code implementation - Add detailed code walkthrough for Coleman-Reffett operator - Add explanation of solver function and convergence - Add economic interpretation of default parameters - Expand interpretation of consumption policy results - Fix grammatical errors (comma splice, missing period) - Rename variables for clarity: a_in→ae_vals, σ_in→c_vals, a_out→ae_out, σ_out→c_out 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
Improved Code ExplanationsThis commit enhances the pedagogical quality of the lecture by adding clearer explanations that connect the mathematical theory to the code implementation: Key Improvements
All additions follow the one-sentence-per-paragraph format used throughout the lecture. |
|
📖 Netlify Preview Ready! Preview URL: https://pr-705--sunny-cactus-210e3e.netlify.app (7c0a00b) 📚 Changed Lecture Pages: ifp_advanced |
Summary
This PR adds a comprehensive JAX implementation to the Income Fluctuation Problem II lecture (
ifp_advanced.md), providing an alternative high-performance implementation alongside the existing Numba version.Changes Made
1. Section Restructuring
2. JAX Implementation Details
Core Components:
IFP_JAX: Implemented as aNamedTuple(instead of a regular class) to ensure compatibility with JAX's JIT compilation and hashability requirementsu_prime(c, γ): Marginal utilityu_prime_inv(c, γ): Inverse marginal utilityR(z, ζ, a_r, b_r): Gross return on assetsY(z, η, a_y, b_y): Labor incomecreate_ifp_jax(): Factory function to construct IFP_JAX instances with parameter validationK_jax(): Coleman-Reffett operator using JAX's JIT compilation and vectorizationsolve_model_time_iter_jax(): Time iteration solver for JAX3. Numerical Validation
Added comprehensive comparison section showing:
4. Technical Improvements
jax.config.update("jax_enable_x64", True))jax.jitasjax_jitto avoid overwritingnumba.jitTest Results
Both implementations successfully converge:
Numerical comparison:
The solutions are essentially identical, with minor differences arising from:
Benefits
The JAX implementation offers:
Validation
✅ Script runs successfully with exit code 0
✅ All plots generate correctly
✅ Both implementations produce consistent results
✅ No breaking changes to existing Numba implementation
🤖 Generated with Claude Code