In [55]:
class VarNode:

    def __init__(self, var_name, value, children = [], parents = []):
        self.var_name = var_name
        self.value = value 
        self.children = children 
        self.parents = parents 
        self.gradient = 1

class OperationNode:

    def __init__(self, operation, value, children = [], parents = []):
        self.operation = operation
        self.value = value 
        self.children = children 
        self.parents = parents
        self.gradient = 1

In [56]:
# test graph for the function l = (x * y) + (y * z), x = 2, y = 3, z = 4
x = VarNode('x', 2, [], [])
y = VarNode('y', 3, [], [])
z = VarNode('z', 4, [], [])

mult1 = OperationNode('*', 1, [x, y], [])
mult2 = OperationNode('*', 1, [y, z], [])

x.parents.append(mult1)
y.parents.append(mult1)
y.parents.append(mult2)
z.parents.append(mult2)

add1 = OperationNode('+', float('inf'), [mult1, mult2], [])
mult1.parents.append(add1)
mult2.parents.append(add1)

loss = VarNode('l', float('inf'), [add1], [])
add1.parents.append(loss)

graph = [x, y, z, mult1, mult2, add1, loss]

print(len(x.parents))

1


In [57]:
# do forward pass
pq = []
for node in graph:
    pq.append(node)

for node in pq:
    # check if all dependencies have been calculated 
    all_dependencies_calculated = True 
    for child in node.children:
        if child.value == float('inf'):
            all_dependencies_calculated = False 
    
    if all_dependencies_calculated:
        if isinstance(node, OperationNode):
            if node.operation == '*':
                res = 1
                for child in node.children:
                    res *= child.value
                node.value = res
            elif node.operation == '+':
                res = 0
                for child in node.children:
                    res += child.value
                node.value = res
            elif node.operation == "-":
                res = 0
                for child in node.children:
                    res -= child.value
                node.value = res
        else:
            if node.children: # must be the final loss node
                node.value = node.children[0].value
        



In [58]:
for node in graph:
    node_name = None
    if isinstance(node, VarNode):
        node_name = node.var_name
    else:
        node_name = node.operation
    print(f"{node_name} value = {node.value}")

x value = 2
y value = 3
z value = 4
* value = 6
* value = 12
+ value = 18
l value = 18


In [59]:
# implement backward pass
pq = [add1]

visited = set()
while pq:
    node = pq.pop(0)
    visited.add(node)
    # analyze the parent(s) of the node
    if loss in node.parents:
        node.gradient = loss.gradient
        for child in node.children:
            if child not in pq and child not in visited:
                pq.append(child)
    else:
        mult_grad, add_grad = 0, 0
        for parent in node.parents:
            if isinstance(parent, OperationNode):
                if parent.operation == "*":
                    for child in parent.children:
                        if child != node:
                            mult_grad += child.value 
                elif parent.operation == "+" or parent.operation == "-":
                    add_grad += 1
        if mult_grad != 1 and add_grad != 0:
            node.gradient = mult_grad + add_grad
        elif mult_grad == 1 and add_grad != 0:
            node.gradient = add_grad
        elif mult_grad != 1 and add_grad == 0:
            node.gradient = mult_grad

        for child in node.children:
            if child not in pq and child not in visited:
                pq.append(child)


    

In [60]:
for node in graph:
    node_name = None
    if isinstance(node, VarNode):
        node_name = node.var_name
    else:
        node_name = node.operation
    print(f"{node_name} gradient = {node.gradient}")

x gradient = 3
y gradient = 6
z gradient = 3
* gradient = 1
* gradient = 1
+ gradient = 1
l gradient = 1
