# Introduction to Taylor-mode

In [32]:
import jax.numpy as jnp
import sympy as sp
# load folders from ../src/ as modules
import sys
sys.path.append('../src/') 
from taylor_mode import forward_derivatives

In [33]:
f = lambda x: jnp.sin(x)
val, derivs = forward_derivatives(f, jnp.array(0.0), order=2)
print("value", val)
print("first", derivs[1])
print("second", derivs[2])

value 0.0
first 1.0
second -0.0


In [None]:
from taylor_mode import taylor_series_coefficients
coeffs = taylor_series_coefficients(jnp.sin, jnp.array(0.0), order=3)
print(coeffs)

In [34]:
# Setup and Imports 
import jax, jax.numpy as jnp
from jax.experimental import jet
print("JAX version:", jax.__version__)
jax.config.update("jax_enable_x64", True)     # keep numerics stable

JAX version: 0.6.1


| Mode       | Primitive        | What it returns                                 | Typical use                                                       |
| ---------- | ---------------- | ----------------------------------------------- | ----------------------------------------------------------------- |
| Forward    | `jax.jvp`        | value & 1st-order directional derivative        | Jacobian-vector products                                          |
| Reverse    | `jax.vjp / grad` | value & pullbacks                               | gradients, back-prop                                              |
| **Taylor** | **`jet.jet`**    | *truncated Taylor polynomial* $(f_0,\dots,f_K)$ | *all* higher derivatives along one direction in a **single pass** |


Because it propagates truncated polynomials rather than single tangents/cotangents, Taylor mode avoids the  $K$  
K-fold tracing overhead of nested jacfwd/jacrev and uses far less memory on deep nets

In [35]:
# API Basics
def g(x):                    # any JAX-compatible function
    return jnp.sin(x) * jnp.exp(x)

def closed_d1(x):
    return ( jnp.cos(x) + jnp.sin(x) ) * jnp.exp(x)

x0      = (2.0,)             # primals
series1 = (1.0,)             # first-order offset  (v = 1) 
f0, (f1,) = jet.jet(g, x0, (series1,))
print("f(x0)     =", f0)
print("∂f/∂x|v=1 =", f1)
print("∂f/∂x|v=1 (closed form) =", closed_d1(x0[0]))


f(x0)     = 6.718849697428256
∂f/∂x|v=1 = 3.643917376788898
∂f/∂x|v=1 (closed form) = 3.6439173767888913


We can easily get the second order derivative by chagning series to: 

series1 = (1.0, 0.0)

In [36]:
series1 = (1.0, 0.0)  # first-order offset (v = 1, v' = 0)
f0, (f1, f2) = jet.jet(g, x0, (series1,))

def closed_d2(x):
    return 2 * jnp.exp(x) * jnp.cos(x)
print("f(x0)     =", f0)
print("∂f/∂x|v=1 =", f1)
print("∂²f/∂x²|v=1 =", f2)
print("∂²f/∂x²|v=1 (closed form) =", closed_d2(x0[0]))

f(x0)     = 6.718849697428256
∂f/∂x|v=1 = 3.643917376788898
∂²f/∂x²|v=1 = -6.149864641278714
∂²f/∂x²|v=1 (closed form) = -6.149864641278718


In [37]:
# General rule is to supply a tuple per argument containing v, 1/2 v, ...

In [None]:
# Second order example (Hessian Trace)  
def f(x):
    return jnp.sum(jnp.tanh(x)**2)

x0 = jnp.ones(5)

def second_directional(dir_vec):
    # one Taylor series: first-order = dir_vec, second-order = 0
    _, (_, f2) = jet.jet(f, (x0,), ((dir_vec, jnp.zeros_like(dir_vec)),))
    return f2          # shape (5,) ← directional 2nd-derivative vector

eye = jnp.eye(5)                      # 5 one-hot directions, each (5,)
f2_all = jax.vmap(second_directional)(eye)   # shape (5,5)
laplacian = jnp.sum(f2_all)          # ← this is the trace of the Hessian
print("Laplacian =", laplacian)

# Closed-form check
# one coordinate
x = sp.symbols('x')
f = sp.tanh(x)**2

# second derivative
f_xx = sp.diff(f, x, 2).simplify()

# print closed-form expression
# print("∂²/∂x² tanh(x)² =", sp.simplify(f_xx))
d = sp.symbols('d', integer=True, positive=True)
laplacian = d * f_xx.subs(x, 1)         # evaluate at x = 1
laplacian_simplified = sp.simplify(laplacian)
# print(laplacian_simplified)
import sympy as sp
lap_val = laplacian_simplified.subs(d, 5).evalf()
# print(lap_val)   # ≈ -3.1031334
print("∂²/∂x² tanh(x)² =", lap_val)



Laplacian = -3.108133403856503
∂²/∂x² tanh(x)² = -3.10813340385648


In [60]:
# Performance sanity check
import time, functools
from jax import jacfwd, jit, random

key = random.PRNGKey(0)
W = random.normal(key, (1024, 1024))

def big_fun(x):
    return jnp.tanh(W @ x).sum()

x0 = random.normal(key, (1024,))
# Hessian-vector product with jacfwd(jacrev)
t0 = time.time(); hvp = jacfwd(jax.grad(big_fun))(x0) ; print("nested time", time.time()-t0)
# Same with jet, degree=2
series = ((jnp.eye(1024), jnp.zeros((1024,1024))),)
t0 = time.time(); _, (_, f2) = jet.jet(big_fun, (x0,), series) ; print("jet time", time.time()-t0)


nested time 0.014540910720825195
jet time 0.0019953250885009766


In [61]:
def f_jax(x):
    return jnp.sum(jnp.tanh(x)**2)

batched_fun = jax.vmap(f_jax)

x_batch = jnp.linspace(-1, 1, 128)[:, None] * jnp.ones((128, 5))

# two coefficients (1-st and 2-nd order), both zero – **LIST**, not tuple
series = ([jnp.zeros_like(x_batch),   # v  (order-1)
           jnp.zeros_like(x_batch)],) # w  (order-2)

@jit
def batched_jet(xs, ser):
    vals, (f1, f2) = jet.jet(batched_fun, (xs,), ser)
    return vals, f1          # f1 is order-1 coefficient (“directional grad”)

vals, grads = batched_jet(x_batch, series)
print(vals.shape)   # (128,)
print(grads.shape)  # (128, 5)


(128,)
(128,)


In [51]:
# Taylor series check
# degree-3 jet of exp near 0
f0,(f1,f2,f3) = jet.jet(jnp.exp, (0.0,), ((1.0,0.5,1/6),))
print("exp(0) ≈", f0 + f1 + f2 + f3)          # → 2.718...


exp(0) ≈ 6.166666666666671


In [None]:
# Vector valued output
def F(x): return jnp.stack([jnp.sin(x), jnp.cos(x)])
prim, (g,) = jet.jet(F, (1.0,), ((1.0,),))
print("Jacobian-vector product:", g)           # shows shape (2,)


Jacobian-vector product: [ 0.54030231 -0.84147098]


In [53]:
# May 2025: collapsing Taylor mode propagates one summed K-jet instead of K separate jets, giving another 1.3-2× speed-up on high-order PDE operators