In [1]:
import numpy as np

# doing by hand

## case 1

In [26]:
def test_1(a=2,b=3,c=4, a_dot=1, b_dot=1, c_dot=1, g_bar=1):
    
    def output(a, b, c):
        d = b+c
        e= a*c
        f= d+e
        g = e/f
        return d,e,f,g
    
    def forward(a,b,c,d,e,f,g,a_dot, b_dot, c_dot):
        d_dot = b_dot + c_dot
        e_dot = a_dot*c + a*c_dot
        f_dot = d_dot + e_dot
        g_dot = 1.0/(f*f)*(e_dot*f - e*f_dot)
        return g_dot
    
    def backward(a,b,c,d,e,f,g,g_bar):
        f_bar = -e/(f*f)*g_bar
        e_bar = 1.0/f*g_bar
        d_bar = f_bar
        e_bar += f_bar
        a_bar = e_bar*c
        c_bar = e_bar*a
        b_bar = d_bar
        c_bar += d_bar
        return a_bar, b_bar, c_bar
    
    def validation(a_dot, b_dot, c_dot, g_dot, a_bar, b_bar, c_bar, g_bar):
        LHS = a_dot*a_bar + b_dot*b_bar + c_dot*c_bar
        RHS = g_dot*g_bar
        print(LHS,RHS,f'error:{abs(LHS-RHS)}')
    
    d,e,f,g = output(a,b,c)
    g_dot = forward(a,b,c,d,e,f,g,a_dot, b_dot, c_dot)
    a_bar, b_bar, c_bar = backward(a,b,c,d,e,f,g,g_bar)
    validation(a_dot, b_dot, c_dot, g_dot, a_bar, b_bar, c_bar, g_bar)
    

In [28]:
test_1(a=2,b=3,c=4, a_dot=1, b_dot=1, c_dot=1, g_bar=1)

0.11555555555555555 0.11555555555555555 error:0.0


## case 2

In [32]:
def test_2(a=2,b=3,c=4, a_dot=1, b_dot=1, c_dot=1, g_bar=1):
    
    def output(a, b, c):
        d = b+c
        e= a*c
        e_ = e
        f= d+e
        g = f*e
        e = np.exp(g)
        g += e
        return d,e,f,g, e_
    
    def forward(a,b,c,d,e,f,g,e_,a_dot, b_dot, c_dot):
        d_dot = b_dot + c_dot
        e_dot = a_dot*c + a*c_dot
        f_dot = d_dot + e_dot
        g_dot = f_dot*e_ + f*e_dot
        e_dot = e*g_dot
        g_dot += e_dot
        return g_dot
    
    def backward(a,b,c,d,e,f,g,e_,g_bar):
        e_bar = g_bar
        g_bar = g_bar
        g_bar += e_bar*e
        f_bar = g_bar*e_
        e_bar = g_bar*f
        d_bar = f_bar
        e_bar += f_bar
        a_bar = e_bar*c
        c_bar = e_bar*a
        b_bar = d_bar
        c_bar += d_bar
        return a_bar, b_bar, c_bar
    
    def validation(a_dot, b_dot, c_dot, g_dot, a_bar, b_bar, c_bar, g_bar):
        LHS = a_dot*a_bar + b_dot*b_bar + c_dot*c_bar
        RHS = g_dot*g_bar
        print(LHS,RHS,f'error:{abs(LHS-RHS)}')
    
    d,e,f,g,e_ = output(a,b,c)
    g_dot = forward(a,b,c,d,e,f,g,e_,a_dot, b_dot, c_dot)
    a_bar, b_bar, c_bar = backward(a,b,c,d,e,f,g,e_,g_bar)
    validation(a_dot, b_dot, c_dot, g_dot, a_bar, b_bar, c_bar, g_bar)
    

In [33]:
test_2(a=2,b=3,c=4, a_dot=1, b_dot=1, c_dot=1, g_bar=3)

6.025315658178581e+54 6.025315658178581e+54 error:0.0


## simple matrix case

In [13]:
def test_3(A,B,A_dot,B_dot,C_bar):
    C = A*B
    
    #forward
    C_dot = A_dot*B + A*B_dot
    
    #adjoint
    A_bar = C_bar*B.T
    B_bar = A.T*C_bar
    
    #validation
    LHS = np.trace(A_bar.T*A_dot + B_bar.T*B_dot)
    RHS = np.trace(C_bar.T*C_dot)
    print(LHS,RHS,f'error:{abs(LHS-RHS)}')
    

In [14]:
A = np.random.randint(0, 3, size=(4, 4))
B = np.random.randint(1, 5, size=(4, 4))
A_dot = np.random.randn(4,4)
B_dot = np.random.randn(4,4)
C_bar = np.random.randn(4,4)
test_3(A,B,A_dot,B_dot,C_bar)

-1.5191489014569166 -1.5191489014569164 error:2.220446049250313e-16


# JAX

## testing with g as our object function

In [2]:
import jax
import jax.numpy as jnp

def output(a, b, c):
    d = b + c
    e = a * c
    f = d + e
    g = e / f
    return g  # Return only g because we're interested in gradients of g

# Suppose we want the gradient of g with respect to a
grad_g_wrt_a = jax.grad(output, argnums=0)

# Suppose we want the gradient of g with respect to b
grad_g_wrt_b = jax.grad(output, argnums=1)

# Suppose we want the gradient of g with respect to c
grad_g_wrt_c = jax.grad(output, argnums=2)

# To compute the gradients, we need to provide actual input values
a, b, c = 1.0, 2.0, 3.0  # Example values, these can be changed as needed

# Calculate the gradients
gradient_wrt_a = grad_g_wrt_a(a, b, c)
gradient_wrt_b = grad_g_wrt_b(a, b, c)
gradient_wrt_c = grad_g_wrt_c(a, b, c)

# Print the results
print('Gradient with respect to a:', gradient_wrt_a)
print('Gradient with respect to b:', gradient_wrt_b)
print('Gradient with respect to c:', gradient_wrt_c)


Gradient with respect to a: 0.234375
Gradient with respect to b: -0.046875
Gradient with respect to c: 0.03125


In [6]:
import jax.numpy as jnp

def output(a, b, c):
    d = b+c
    e= a*c
    f= d+e
    g = f*e
    e = jnp.exp(g)
    g += e
    return g
    

grad_g_wrt_a = jax.grad(output, argnums=0)
grad_g_wrt_b = jax.grad(output, argnums=1)
grad_g_wrt_c = jax.grad(output, argnums=2)

a, b, c = 1.0, 2.0, 3.0  

gradient_wrt_a = grad_g_wrt_a(a, b, c)
gradient_wrt_b = grad_g_wrt_b(a, b, c)
gradient_wrt_c = grad_g_wrt_c(a, b, c)

print('Gradient with respect to a:', gradient_wrt_a)
print('Gradient with respect to b:', gradient_wrt_b)
print('Gradient with respect to c:', gradient_wrt_c)

Gradient with respect to a: 874141000000.0
Gradient with respect to b: 79467364000.0
Gradient with respect to c: 370847700000.0


## jax version case 1

In [29]:
import jax
import jax.numpy as jnp
#test_1_JAX
# Define the function
def f(a, b, c):
    d = b + c
    e = a * c
    f = d + e
    g = e / f
    return g

# Forward mode AD to get the tangents (derivatives of the output w.r.t. an input perturbation)
a, b, c = 2.0, 3.0, 4.0  # Example input values
a_dot, b_dot, c_dot = 1.0, 1.0, 1.0  # Example perturbations

from jax import jacfwd
forward_jacobian = jacfwd(f, (0, 1, 2))
g_dot = forward_jacobian(a, b, c)
g_dot = (g_dot[0] * a_dot + g_dot[1] * b_dot + g_dot[2] * c_dot)

# Reverse mode AD to get the gradients (how the output changes w.r.t. each input)
g_bar = 1.0  # The seed for the reverse mode is typically set to 1
from jax import grad
a_bar = grad(f, 0)(a, b, c) * g_bar
b_bar = grad(f, 1)(a, b, c) * g_bar
c_bar = grad(f, 2)(a, b, c) * g_bar

# Validate that the dot products match
lhs = a_dot * a_bar + b_dot * b_bar + c_dot * c_bar
rhs = g_dot * g_bar

print("LHS:", lhs)
print("RHS:", rhs)
assert jnp.isclose(lhs, rhs), "Validation failed: LHS and RHS do not match!"


LHS: 0.11555557
RHS: 0.11555557


## jax version of case 2

In [7]:
#test_2_JAX
# Define the function
def f(a, b, c):
    d = b+c
    e= a*c
    e_ = e
    f= d+e
    g = f*e
    e = jnp.exp(g)
    g += e
    return g

# Forward mode AD to get the tangents (derivatives of the output w.r.t. an input perturbation)
a, b, c = 1.0, 2.0, 3.0  # Example input values
a_dot, b_dot, c_dot = 1.0, 1.0, 1.0  # Example perturbations

from jax import jacfwd
forward_jacobian = jacfwd(f, (0, 1, 2))
g_dot = forward_jacobian(a, b, c)
g_dot = (g_dot[0] * a_dot + g_dot[1] * b_dot + g_dot[2] * c_dot)

# Reverse mode AD to get the gradients (how the output changes w.r.t. each input)
g_bar = 1.0  # The seed for the reverse mode is typically set to 1
from jax import grad
a_bar = grad(f, 0)(a, b, c) * g_bar
b_bar = grad(f, 1)(a, b, c) * g_bar
c_bar = grad(f, 2)(a, b, c) * g_bar

# Validate that the dot products match
lhs = a_dot * a_bar + b_dot * b_bar + c_dot * c_bar
rhs = g_dot * g_bar

print("LHS:", lhs)
print("RHS:", rhs)
assert jnp.isclose(lhs, rhs), "Validation failed: LHS and RHS do not match!"


LHS: 1324456100000.0
RHS: 1324456100000.0


## jax in matrix

In [8]:
import jax.numpy as jnp
from jax import grad

def matrix_multiply(A, B):
    return jnp.dot(A, B)

A = jnp.array([[1.0, 2.0], [3.0, 4.0]])
B = jnp.array([[5.0, 6.0], [7.0, 8.0]])

C = matrix_multiply(A, B)

def loss_function(A, B):
    C = matrix_multiply(A, B)
    return jnp.sum(C)

grad_loss_wrt_A = grad(loss_function, argnums=0)(A, B)
grad_loss_wrt_B = grad(loss_function, argnums=1)(A, B)

print("C:")
print(C)

print("Gradient with respect to A:")
print(grad_loss_wrt_A)

print("Gradient with respect to B:")
print(grad_loss_wrt_B)

C:
[[19. 22.]
 [43. 50.]]
Gradient with respect to A:
[[11. 15.]
 [11. 15.]]
Gradient with respect to B:
[[4. 4.]
 [6. 6.]]


## jax version of case 3

In [22]:
import jax
import jax.numpy as jnp

# Initialize random matrices for A and B
key = jax.random.PRNGKey(0)
A = jax.random.normal(key, (3, 3))
B = jax.random.normal(key, (3, 3))

# Define the matrix multiplication function
def matmul(A, B):
    return A @ B

# Perform reverse mode AD to get A_bar and B_bar
C_bar = jax.random.normal(key, (3, 3))
_, vjp_fun = jax.vjp(matmul, A, B)  # vjp_fun now holds the gradient functions
A_bar, B_bar = vjp_fun(C_bar)  # Call it with C_bar to get the gradients

# Perform forward mode AD to get A_dot and B_dot
A_dot = jax.random.normal(key, (3, 3))
B_dot = jax.random.normal(key, (3, 3))
_, C_dot = jax.jvp(matmul, (A, B), (A_dot, B_dot))

# Validate the trace relationship
lhs = jnp.trace(A_bar.T @ A_dot) + jnp.trace(B_bar.T @ B_dot)
rhs = jnp.trace(C_bar.T @ C_dot)

print('LHS:', lhs)
print('RHS:', rhs)
assert jnp.isclose(lhs, rhs), "Validation failed: LHS and RHS do not match!"


LHS: 1.9607668
RHS: 1.9607666
