Learn JAX's powerful automatic differentiation (autodiff) system through a fun and relatable cake-baking analogy! π§βπ³
This repository contains a tutorial notebook (.ipynb
) that walks you through core autodiff concepts and practical JAX implementations. We start with the basics of getting gradients and progressively build up to advanced techniques like Jacobians, Hessians, complex number differentiation, and even defining your own custom "secret recipes" for derivatives.
- Intuitive Analogy: Uses cake baking to make potentially complex topics like Jacobians, Hessians, JVPs/VJPs, and custom rules easier to grasp.
- Comprehensive Coverage: Goes beyond
jax.grad
to cover a wide range of JAX's autodiff features. - Practical Code: Provides runnable code examples within a cohesive narrative.
- Target Audience: Suitable for those new to JAX autodiff or intermediate users looking for a deeper, more practical understanding.
- The Basics
Getting started with gradients (
jax.grad
), handling different input/output structures (PyTrees), and verifying results. Coversjax.value_and_grad
. - Advanced Baking
Dealing with multiple cake properties at once (Jacobians with
jax.jacfwd
/jax.jacrev
), understanding the rate at which improvements change (Hessians withjax.hessian
), efficiently calculating directional changes (JVPs/VJPs withjax.jvp
/jax.vjp
, and HVPs), and controlling gradient flow (jax.lax.stop_gradient
,jax.vmap
). - Exotic Flavors & Vibrations
Exploring how JAX differentiates functions involving complex numbers, understanding the difference between 'smooth' (holomorphic) and 'tricky' (non-holomorphic) cases, and using the right tools (
grad
withholomorphic=True
, Jacobians) accordingly. - Secret Family Recipes
Teaching JAX custom differentiation rules using
jax.custom_jvp
andjax.custom_vjp
to overcome limitations, such as fixing numerical instability, enforcing specific baking rules (like gradient clipping), or handling complex iterative processes (like dough maturation) that standard autodiff struggles with.
- Automatic Differentiation (Forward & Reverse Modes)
jax.grad
: Computing gradients of scalar functions.jax.value_and_grad
: Computing function value and gradient together efficiently.- Higher-Order Derivatives: Stacking
grad
for second, third, etc. derivatives. - PyTrees: Differentiating with respect to standard Python containers (dicts, lists, tuples).
jax.jacfwd
,jax.jacrev
: Computing full Jacobian matrices using forward and reverse modes.jax.jvp
,jax.vjp
: Understanding the core Jacobian-Vector Product and Vector-Jacobian Product primitives.jax.hessian
: Computing full Hessian matrices (second derivatives).- Hessian-Vector Products (HVPs): Efficiently computing the action of the Hessian on a vector (
H @ v
) without forming the fullH
. jax.lax.stop_gradient
: Preventing gradient flow through specific parts of a computation for algorithmic control.jax.vmap
: Combining withgrad
for efficient per-example or batched gradient calculations.- Complex Number Differentiation: Handling differentiation involving complex numbers, including the
holomorphic=True
argument forjax.grad
. jax.custom_jvp
: Defining custom forward-mode differentiation rules (and often getting reverse-mode automatically).jax.custom_vjp
: Defining custom reverse-mode differentiation rules for fine-grained control or handling complex operations.