In [16]:
import jax
import jax.numpy as jnp

# $$(a, b) \mapsto a + b$$

dim = 3:  $\implies jac = \left[\begin{array}{ccc|ccc} 
1 & 0 & 0 & 1 & 0 & 0 \\
0 & 1 & 0 & 0 & 1 & 0 \\
0 & 0 & 1 & 0 & 0 & 1 \\
\end{array}\right]$

In [17]:
from jaxtyping import Float, Array
from jax.numpy import array as jar

dim = 3


@jax.jit
def addition(a: Float[Array, "1"], b: Float[Array, "1"]) -> Float[Array, "1"]:
    return a + b


def add_jvp(
    primals: tuple[Float[Array, "1"], Float[Array, "1"]],
    tangents: tuple[Float[Array, "1"], Float[Array, "1"]],
) -> tuple[Float[Array, "1"], Float[Array, "1"]]:
    (a, b) = primals
    (da, db) = tangents

    primal = a + b
    tangent = da + db

    return primal, tangent


inp1, inp2 = jnp.zeros(dim), jnp.zeros(dim)
key = jax.random.PRNGKey(0)

for i in range(10):
    key, new = jax.random.split(key)
    arr = jax.random.normal(new, (dim * 2,))
    da, db = jnp.split(arr, [dim])
    custom = add_jvp((inp1, inp2), (da, db))
    truth = jax.jvp(addition, (inp1, inp2), (da, db))

    assert jnp.allclose(custom[1], truth[1]), (custom[1], truth[1])

addition = jax.custom_jvp(addition)
addition.defjvp(add_jvp);

# $$(a, b, \theta) \mapsto \theta a + b$$

dim = 3:  $\implies jac =  \left[\begin{array}{ccc|ccc|c} 
\theta & 0 & 0 & 1 & 0 & 0 & a_1 \\
0 & \theta & 0 & 0 & 1 & 0 & a_2 \\
0 & 0 & \theta & 0 & 0 & 1 & a_3 \\
\end{array}\right]$

In [18]:
dim = 3


@jax.jit
def addition_theta(
    a: Float[Array, "1"], b: Float[Array, "1"], theta: float
) -> Float[Array, "1"]:
    return theta * a + b


def add_jvp_theta(
    primals: tuple[Float[Array, "1"], Float[Array, "1"], float],
    tangents: tuple[Float[Array, "1"], Float[Array, "1"], float],
) -> tuple:
    (a, b, theta) = primals
    (da, db, dtheta) = tangents

    primal = theta * a + b
    tangent = theta * da + db + dtheta * a

    return primal, tangent


key = jax.random.PRNGKey(0)

for i in range(10):
    key, input_key, tangent_key = jax.random.split(key, num=3)

    arr = jax.random.normal(input_key, (dim * 2 + 1,))
    a, b, theta = jnp.split(arr, [dim, dim * 2])

    arr = jax.random.normal(tangent_key, (dim * 2 + 1,))
    da, db, dtheta = jnp.split(arr, [dim, dim * 2])

    custom = add_jvp_theta((a, b, theta), (da, db, dtheta))
    truth = jax.jvp(addition_theta, (a, b, theta), (da, db, dtheta))

    assert jnp.allclose(custom[1], truth[1]), (custom[1], truth[1])

addition_theta = jax.custom_jvp(addition_theta)
addition_theta.defjvp(add_jvp_theta)
jnp.column_stack(
    jax.jacfwd(addition_theta, argnums=(0, 1, 2))(
        jar([3.0, 1.0, 1.0]), jar([0.0, 0.0, 0.0]), 2.0
    )
)

Array([[2., 0., 0., 1., 0., 0., 3.],
       [0., 2., 0., 0., 1., 0., 1.],
       [0., 0., 2., 0., 0., 1., 1.]], dtype=float32)

# $$a \mapsto a^\top a$$

dim = 3:   $\implies jac = \left[\begin{array}{ccc} 
2a_1 & 2a_2 & 2a_3
\end{array}\right]$

In [19]:
dim = 3


@jax.jit
def inner(a: Float[Array, "1"]) -> Float[Array, "1"]:
    return a @ a


def inner_jvp(
    primals: tuple[Float[Array, "1"]], tangents: tuple[Float[Array, "1"]]
) -> tuple:
    (a,) = primals
    (da,) = tangents

    primal = a @ a
    tangent = 2 * a @ da

    return primal, tangent


key = jax.random.PRNGKey(0)

for i in range(10):
    key, input_key = jax.random.split(key, num=2)

    arr = jax.random.normal(input_key, (dim * 2,))
    a, da = jnp.split(arr, [dim])

    custom = inner_jvp((a,), (da,))
    truth = jax.jvp(inner, (a,), (da,))

    assert jnp.allclose(custom[1], truth[1]), (custom[1], truth[1])

inner = jax.custom_jvp(inner)
inner.defjvp(inner_jvp)
jnp.column_stack(jax.jacfwd(inner, argnums=(0))(jar([3.0, 1.0, 1.0])))

Array([[6., 2., 2.]], dtype=float32)

# $$a \mapsto B a$$

dim = 3:   $\implies jac = \left[\begin{array}{ccc} 
B_{11} & B_{12} & B_{13} \\
B_{21} & B_{22} & B_{23} \\
B_{31} & B_{32} & B_{33} \\
\end{array}\right] = B$ ... makes sense..

In [None]:
dim = 3

B = jar(
    [
        [1.0, 2.0, 1.0],
        [1.0, 2.0, 2.0],
        [1.0, 2.0, 3.0],
    ]
)


@jax.jit
def inv_norm(a: Float[Array, "1"]) -> Float[Array, "1"]:
    return B @ a


def norm_jvp(
    primals: tuple[Float[Array, "1"]], tangents: tuple[Float[Array, "1"]]
) -> tuple:
    (a,) = primals
    (da,) = tangents

    primal = B @ a
    tangent = B @ da

    return primal, tangent


key = jax.random.PRNGKey(0)

for i in range(10):
    key, input_key = jax.random.split(key, num=2)

    arr = jax.random.normal(input_key, (dim * 2,))
    a, da = jnp.split(arr, [dim])

    custom = norm_jvp((a,), (da,))
    truth = jax.jvp(inv_norm, (a,), (da,))

    assert jnp.allclose(custom[1], truth[1]), (custom[1], truth[1])

inv_norm = jax.custom_jvp(inv_norm)
inv_norm.defjvp(norm_jvp)
jnp.column_stack(jax.jacfwd(inv_norm, argnums=(0))(jar([3.0, 1.0, 1.0])))

Array([[1., 1., 1.],
       [2., 2., 2.],
       [1., 2., 3.]], dtype=float32)

# $$f := (a, b) \mapsto a^\top b$$

dim = 3:   $\implies jac = \left[\begin{array}{ccc|ccc} 
b_1 & b_2 & b_3 & a_1 & a_2 & a_3 \\
\end{array}\right]$

$\partial f (a, b) (da, db) = b^\top da + a ^\top db$

In [21]:
from jaxtyping import Float, Array
from jax.numpy import array as jar

dim = 3


@jax.jit
def inner(a: Float[Array, "1"], b: Float[Array, "1"]) -> Float[Array, "1"]:
    return a @ b


def add_jvp(
    primals: tuple[Float[Array, "1"], Float[Array, "1"]],
    tangents: tuple[Float[Array, "1"], Float[Array, "1"]],
) -> tuple[Float[Array, "1"], Float[Array, "1"]]:
    (a, b) = primals
    (da, db) = tangents

    primal = a @ b
    tangent = b @ da + a @ db

    return primal, tangent


inp1, inp2 = jnp.zeros(dim), jnp.zeros(dim)
key = jax.random.PRNGKey(0)

for i in range(10):
    key, new = jax.random.split(key)
    arr = jax.random.normal(new, (dim * 2,))
    da, db = jnp.split(arr, [dim])
    custom = add_jvp((inp1, inp2), (da, db))
    truth = jax.jvp(inner, (inp1, inp2), (da, db))

    assert jnp.allclose(custom[1], truth[1]), (custom[1], truth[1])

inner = jax.custom_jvp(inner)
inner.defjvp(add_jvp);

# $$f := (B, a) \mapsto B a$$

$\partial f (B, a)(dB, da) = B da + dBa$

In [22]:
dim = 3


@jax.jit
def solve_system(B: Float[Array, "2"], a: Float[Array, "1"]) -> Float[Array, "1"]:
    return B @ a


def lin_map_jvp(
    primals: tuple[Float[Array, "1"]], tangents: tuple[Float[Array, "1"]]
) -> tuple:
    (B, a) = primals
    (dB, da) = tangents

    primal = B @ a

    tangent = dB @ a + B @ da

    return primal, tangent


key = jax.random.PRNGKey(0)

for i in range(10):
    key, vec_key, mat_key = jax.random.split(key, num=3)

    arr = jax.random.normal(vec_key, (dim * 2,))
    a, da = jnp.split(arr, [dim])

    arr = jax.random.normal(vec_key, (dim * 2, dim))
    B, dB = jnp.split(arr, [dim])

    custom = lin_map_jvp(
        (B, a),
        (dB, da),
    )
    truth = jax.jvp(
        solve_system,
        (B, a),
        (dB, da),
    )

    assert jnp.allclose(custom[1], truth[1]), (custom[1], truth[1])

solve_system = jax.custom_jvp(solve_system)
solve_system.defjvp(lin_map_jvp);

# $$f := (B, a) \mapsto a^\top B$$

$\partial f (B, a)(dB, da) = da^\top B + a^\top dB$

In [23]:
dim = 3


@jax.jit
def solve_system(B: Float[Array, "2"], a: Float[Array, "1"]) -> Float[Array, "1"]:
    return a.T @ B


def lin_map_jvp(
    primals: tuple[Float[Array, "1"]], tangents: tuple[Float[Array, "1"]]
) -> tuple:
    (B, a) = primals
    (dB, da) = tangents

    primal = a.T @ B

    tangent = a.T @ dB + da.T @ B

    return primal, tangent


key = jax.random.PRNGKey(0)

for i in range(10):
    key, vec_key, mat_key = jax.random.split(key, num=3)

    arr = jax.random.normal(vec_key, (dim * 2,))
    a, da = jnp.split(arr, [dim])

    arr = jax.random.normal(vec_key, (dim * 2, dim))
    B, dB = jnp.split(arr, [dim])

    custom = lin_map_jvp(
        (B, a),
        (dB, da),
    )
    truth = jax.jvp(
        solve_system,
        (B, a),
        (dB, da),
    )

    assert jnp.allclose(custom[1], truth[1]), (custom[1], truth[1])

solve_system = jax.custom_jvp(solve_system)
solve_system.defjvp(lin_map_jvp);

# $$f := (B, a) \mapsto a^\top B a$$

$\partial_a f (B, a)(da) = da^\top(B^\top+B) a$

$\partial_B f (B, a)(dB) = \sum dB \circ (aa^\top)$

In [24]:
dim = 3


@jax.jit
def solve_system(B: Float[Array, "2"], a: Float[Array, "1"]) -> Float[Array, "1"]:
    return a @ (B @ a)


def lin_map_jvp(
    primals: tuple[Float[Array, "1"]], tangents: tuple[Float[Array, "1"]]
) -> tuple:
    (B, a) = primals
    (dB, da) = tangents

    primal = a @ (B @ a)

    tangent = da @ ((B.T + B) @ a) + jnp.sum(jnp.outer(a, a) * dB)

    return primal, tangent


key = jax.random.PRNGKey(0)

for i in range(10):
    key, vec_key, mat_key = jax.random.split(key, num=3)

    arr = jax.random.normal(vec_key, (dim * 2,))
    a, da = jnp.split(arr, [dim])

    arr = jax.random.normal(vec_key, (dim * 2, dim))
    B, dB = jnp.split(arr, [dim])

    # print(da @ ((B.T + B) @ a))
    # prim, tangent = jax.jvp(lin_map, primals=(B, a), tangents=(0 * dB, da))
    # print(tangent)

    # print(jnp.sum(dB * jnp.outer(a, a)))
    # prim, tangent = jax.jvp(lin_map, primals=(B, a), tangents=(dB, da * 0))
    # print(tangent)

    # break
    custom = lin_map_jvp(
        (B, a),
        (dB, da),
    )
    truth = jax.jvp(
        solve_system,
        (B, a),
        (dB, da),
    )
    assert jnp.allclose(custom[1], truth[1]), (custom[1], truth[1])

solve_system = jax.custom_jvp(solve_system)
solve_system.defjvp(lin_map_jvp);

# $$f := (A, B) \mapsto AB$$

$\partial f (B, a)(dB, da) = dAB  + AdB$

In [25]:
dim = 3


@jax.jit
def solve_system(A: Float[Array, "2"], B: Float[Array, "2"]) -> Float[Array, "1"]:
    return A @ B


def lin_map_jvp(primals, tangents) -> tuple:
    (A, B) = primals
    (dA, dB) = tangents

    primal = A @ B

    tangent = dA @ B + A @ dB

    return primal, tangent


key = jax.random.PRNGKey(0)

for i in range(10):
    key, mat1_key, vec_key = jax.random.split(key, num=3)

    arr = jax.random.normal(mat1_key, (dim * 2, dim))
    A, dA = jnp.split(arr, [dim])

    arr = jax.random.normal(vec_key, (dim * 2, dim))
    B, dB = jnp.split(arr, [dim])

    custom = lin_map_jvp(
        (A, B),
        (dA, dB),
    )
    truth = jax.jvp(
        solve_system,
        (A, B),
        (dA, dB),
    )
    assert jnp.allclose(custom[1], truth[1]), (custom[1], truth[1])

solve_system = jax.custom_jvp(solve_system)
solve_system.defjvp(lin_map_jvp);

# $$f := (A, b) \mapsto A^{-1}b = x$$

In [26]:
dim = 3


@jax.jit
def solve_system(A: Float[Array, "2"], b: Float[Array, "1"]) -> Float[Array, "1"]:
    return jax.scipy.linalg.inv(A) @ b


def solve_system_jvp(primals, tangents) -> tuple:
    (A, b) = primals
    (dA, db) = tangents

    A_inv = jax.scipy.linalg.inv(A)

    primal = A_inv @ b

    tangent = -A_inv @ dA @ primal + A_inv @ db

    return primal, tangent


key = jax.random.PRNGKey(0)

for i in range(10):
    key, mat1_key, vec_key = jax.random.split(key, num=3)

    arr = jax.random.normal(mat1_key, (dim * 2, dim))
    A, dA = jnp.split(arr, [dim])

    arr = jax.random.normal(vec_key, (dim * 2,))
    b, db = jnp.split(arr, [dim])

    custom = solve_system_jvp(
        (A, b),
        (dA, db),
    )
    truth = jax.jvp(
        solve_system,
        (A, b),
        (dA, db),
    )
    assert jnp.allclose(custom[1], truth[1]), (custom[1], truth[1])

solve_system = jax.custom_jvp(solve_system)
solve_system.defjvp(solve_system_jvp);

# $$f := (A) \mapsto A^{-1}$$

In [27]:
dim = 3


@jax.jit
def solve_system(A: Float[Array, "2"]) -> Float[Array, "2"]:
    return jax.scipy.linalg.inv(A)


def inv_jvp(primals, tangents) -> tuple:
    (A,) = primals
    (dA,) = tangents

    A_inv = jax.scipy.linalg.inv(A)

    primal = A_inv

    tangent = -A_inv @ dA @ A_inv

    return primal, tangent


key = jax.random.PRNGKey(0)

for i in range(10):
    key, mat1_key = jax.random.split(key, num=2)

    arr = jax.random.normal(mat1_key, (dim * 2, dim))
    A, dA = jnp.split(arr, [dim])

    custom = inv_jvp(
        (A,),
        (dA,),
    )
    truth = jax.jvp(
        solve_system,
        (A,),
        (dA,),
    )
    assert jnp.allclose(custom[1], truth[1]), (custom[1], truth[1])

solve_system = jax.custom_jvp(solve_system)
solve_system.defjvp(inv_jvp);

# $$f(A) = L$$
###  $$A = LL^\top$$
###  $$I_> \circ L = 0$$

In [28]:
dim = 3


@jax.jit
def chol(A: Float[Array, "2"]) -> Float[Array, "2"]:
    return jax.scipy.linalg.cholesky(A, lower=True)


def chol_jvp(primals, tangents) -> tuple:
    (A,) = primals
    (dA,) = tangents

    L = jax.scipy.linalg.cholesky(A, lower=True)

    left_term = jax.scipy.linalg.solve_triangular(L, dA, lower=True)
    term = jax.scipy.linalg.solve_triangular(L, left_term.T, lower=True).T
    dL = L @ ((jnp.tril(term, k=-1)) + 0.5 * jnp.diag(jnp.diag(term)))

    return L, dL


for i in range(10):
    key, mat1_key = jax.random.split(key, num=2)

    arr = jax.random.normal(mat1_key, (dim * 2, dim))
    A, dA = jnp.split(arr, [dim])
    A = A.T @ A
    dA = dA.T @ dA
    custom = chol_jvp(
        (A,),
        (dA,),
    )
    truth = jax.jvp(
        chol,
        (A,),
        (dA,),
    )
    assert jnp.allclose(custom[0], truth[0]), (custom[0], truth[0])
    assert jnp.allclose(custom[1], truth[1]), (
        custom[1],
        truth[1],
        jnp.linalg.matrix_norm(custom[1] - truth[1]),
    )

chol = jax.custom_jvp(chol)
chol.defjvp(chol_jvp);

# $$f(A^{m \times n}) = (Q^{m \times n},R^{n \times n})$$
###  $$A = QR$$
###  $$I_< \circ R = 0$$
###  $$Q^\top Q = I$$

In [29]:
dim = 10


@jax.jit
def qr_solve(A: Float[Array, "2"]) -> tuple[Float[Array, "2"], Float[Array, "2"]]:
    return jax.scipy.linalg.qr(A, mode="economic")


def qr_solve_jvp(primals, tangents) -> tuple:
    (A,) = primals
    (dA,) = tangents

    primal = jax.scipy.linalg.qr(A, mode="economic")
    Q, R = primal

    L1 = jnp.tril(jax.scipy.linalg.solve(R.T, (Q.T @ dA).T).T, k=-1)
    dR = Q.T @ dA - (L1 - L1.T) @ R
    dQ = jax.scipy.linalg.solve(R.T, (dA - Q @ dR).T).T

    tangent = (dQ, jnp.triu(dR))

    return primal, tangent


for i in range(10):
    key, mat1_key, vec_key = jax.random.split(key, num=3)

    arr = jax.random.normal(mat1_key, (dim * 2, dim - 1))
    A, dA = jnp.split(arr, [dim])

    custom = qr_solve_jvp(
        (A,),
        (dA,),
    )
    truth = jax.jvp(
        qr_solve,
        (A,),
        (dA,),
    )
    assert jnp.allclose(custom[0][0], truth[0][0], atol=1e-05), (
        "\n" + jnp.array_str(custom[0][0]) + "\n" + jnp.array_str(truth[0][0])
    )
    assert jnp.allclose(custom[0][1], truth[0][1], atol=1e-05), (
        "\n" + jnp.array_str(custom[0][1]) + "\n" + jnp.array_str(truth[0][1])
    )
    assert jnp.allclose(custom[1][0], truth[1][0], atol=1e-05), (
        "\n" + jnp.array_str(custom[1][0]) + "\n" + jnp.array_str(truth[1][0])
    )
    assert jnp.allclose(custom[1][1], truth[1][1], atol=1e-05), (
        "\n" + jnp.array_str(custom[1][1]) + "\n" + jnp.array_str(truth[1][1])
    )

qr_solve = jax.custom_jvp(qr_solve)
qr_solve.defjvp(qr_solve_jvp);

### $$f(a) = ||v||$$

In [62]:
dim = 3


@jax.jit
def norm(a: Float[Array, "1"]) -> Float:
    return jnp.linalg.norm(a)


def norm_jvp(
    primals: tuple[Float[Array, "1"]], tangents: tuple[Float[Array, "1"]]
) -> tuple:
    (a,) = primals
    (da,) = tangents

    primal = jnp.linalg.norm(a)
    tangent = a @ da / primal

    return primal, tangent


key = jax.random.PRNGKey(0)

for i in range(10):
    key, input_key = jax.random.split(key, num=2)

    arr = jax.random.normal(input_key, (dim * 2,))
    a, da = jnp.split(arr, [dim])

    custom = norm_jvp((a,), (da,))
    truth = jax.jvp(norm, (a,), (da,))

    assert jnp.allclose(custom[1], truth[1]), (custom[1], truth[1])

norm = jax.custom_jvp(norm)
norm.defjvp(norm_jvp);

### $$f(a) = 1/||v||$$

In [64]:
dim = 3


@jax.jit
def inv_norm(a: Float[Array, "1"]) -> Float:
    return 1 / jnp.linalg.norm(a)


def inv_norm_jvp(
    primals: tuple[Float[Array, "1"]], tangents: tuple[Float[Array, "1"]]
) -> tuple:
    (a,) = primals
    (da,) = tangents

    primal = jnp.linalg.norm(a)
    tangent = -a @ da / (primal * a @ a)

    return primal, tangent


key = jax.random.PRNGKey(0)

for i in range(10):
    key, input_key = jax.random.split(key, num=2)

    arr = jax.random.normal(input_key, (dim * 2,))
    a, da = jnp.split(arr, [dim])

    custom = inv_norm_jvp((a,), (da,))
    truth = jax.jvp(inv_norm, (a,), (da,))

    assert jnp.allclose(custom[1], truth[1]), (custom[1], truth[1])

inv_norm = jax.custom_jvp(inv_norm)
inv_norm.defjvp(inv_norm_jvp);

In [None]:
key = jax.random.PRNGKey(0)

In [61]:
numkey, key = jax.random.split(key, num=2)
v = jar([1, 0, 0])

h = 0.000001
dv = jar([3, 0, 0])

print((jnp.linalg.norm(v + h * dv) - jnp.linalg.norm(v)) / h)

v @ dv / jnp.linalg.norm(v)

2.9802322


Array(3., dtype=float32)