In [1]:
import numpy as np

Function is:

$$ f(x_0, x_1) = \sin(x_0) (x_0 + x_1) $$

or broken down

$$ \begin{align}
z_0 &= x_0 \\
z_1 &= x_1 \\
z_2 &= \sin(z_0) \\
z_3 &= z_0 + z_1 \\
z_4 &= z_2 z_3 \\
\end{align} $$

Its symbolic derivative is:

$$ \nabla f(x_0, x_1) = \begin{bmatrix}
\cos(x_0) (x_0 + x_1) + \sin(x_0) \\
\sin(x_0)
\end{bmatrix} $$

In [2]:
def f(x_0, x_1):
    return np.sin(x_0) * (x_0 + x_1)

def f_grad(x_0, x_1):
    return np.array([
        np.cos(x_0) * (x_0 + x_1) + np.sin(x_0),
        np.sin(x_0),
    ])

In [3]:
compute_graph = [
    ("inp", (0,)),     # 0
    ("inp", (1,)),     # 1
    ("sin", (0,)),     # 2
    ("add", (0, 1)),   # 3
    ("mul", (2, 3)),   # 4
]

In [4]:
fn_library = {
    "inp": lambda x: x,
    "sin": lambda x: np.sin(x),
    "add": lambda x, y: x + y,
    "mul": lambda x, y: x * y,
}

In [5]:
def compute(graph, inputs):
    values = list(inputs)
    for operation, indices in graph:
        if operation == "inp":
            continue
        args = [values[index] for index in indices]
        result = fn_library[operation](*args)
        values.append(result)
    
    return values[-1]

In [6]:
SAMPLE_INPUT = (0.6, 1.4)

In [7]:
f(*SAMPLE_INPUT)

1.1292849467900707

In [8]:
compute(compute_graph, SAMPLE_INPUT)

1.1292849467900707

In [9]:
def inp_backprop_rule(x):
    z = x

    def inp_pullback(z_cotangent):
        x_cotangent = z_cotangent
        return (x_cotangent,)
    
    return z, inp_pullback

def sin_backprop_rule(x):
    z = np.sin(x)

    def sin_pullback(z_cotangent):
        x_cotangent = np.cos(x) * z_cotangent
        return (x_cotangent,)
    
    return z, sin_pullback

def add_backprop_rule(x, y):
    z = x + y

    def add_pullback(z_cotangent):
        x_cotangent = z_cotangent
        y_cotangent = z_cotangent

        return (x_cotangent, y_cotangent)
    
    return z, add_pullback

def mul_backprop_rule(x, y):
    z = x * y

    def mul_pullback(z_cotangent):
        x_cotangent = y * z_cotangent
        y_cotangent = x * z_cotangent
        return (x_cotangent, y_cotangent)
    
    return z, mul_pullback

In [10]:
backprop_library = {
    "inp": inp_backprop_rule,
    "sin": sin_backprop_rule,
    "add": add_backprop_rule,
    "mul": mul_backprop_rule,
}

In [11]:
def vjp(graph, inputs):
    values = list(inputs)
    pullback_stack = []

    # Forward pass
    for operation, indices in graph:
        if operation == "inp":
            continue
        args = [values[index] for index in indices]
        result, pullback_fn = backprop_library[operation](*args)
        values.append(result)
        pullback_stack.append((pullback_fn, indices))

    def pullback(output_cotangent):
        cotangent_values = np.zeros(len(values))
        cotangent_values[-1] = output_cotangent

        for i, (pullback_fn, indices) in enumerate(reversed(pullback_stack)):
            current_cotangent_value = cotangent_values[-1 - i]
            cotangent_args = pullback_fn(current_cotangent_value)
            for index, cotangent in zip(indices, cotangent_args):
                cotangent_values[index] += cotangent
        
        return cotangent_values[:len(inputs)]
    
    return values[-1], pullback
    

In [12]:
out, back_fn = vjp(compute_graph, SAMPLE_INPUT)

In [13]:
out

1.1292849467900707

In [15]:
back_fn(1.0)

array([2.2153137 , 0.56464247])

In [16]:
f_grad(*SAMPLE_INPUT)

array([2.2153137 , 0.56464247])

In [17]:
def value_and_grad(graph, inputs):
    out, back_fn = vjp(graph, inputs)
    grad = back_fn(1.0)
    return out, grad

In [18]:
value_and_grad(compute_graph, SAMPLE_INPUT)

(1.1292849467900707, array([2.2153137 , 0.56464247]))

In [19]:
f(*SAMPLE_INPUT), f_grad(*SAMPLE_INPUT)

(1.1292849467900707, array([2.2153137 , 0.56464247]))