In [60]:
class Const:
    def __init__(self, val):
        self.val = val
        self.grad = 0
    
    def forward(self):
        return self.val
    
    def backward(self, prev_grad=1):
        pass
    
class Input:
    def __init__(self):
        self.val = None
        self.grad = None
    
    def feed(self, x):
        self.val = x
    
    def forward(self):
        if self.val is None:
            raise ValueError('The Input node needs to be initialized.')
        return self.val

    def backward(self, prev_grad):
        self.grad = prev_grad
    
class Add:
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.val = None
        self.grad = None
    
    def forward(self):
        self.val = self.x.forward() + self.y.forward() # Need to store the value for the backward pass
        return self.val
    
    def backward(self, prev_grad=1):
        self.grad = prev_grad
        self.x.backward(self.grad)
        self.y.backward(self.grad)

class Mult:
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.val = None
        self.grad = None
    
    def forward(self):
        self.val = self.x.forward() * self.y.forward()
        return self.val
    
    def backward(self, prev_grad=1):
        self.grad = prev_grad
        self.x.backward(self.grad * self.y.val)
        self.y.backward(self.grad * self.x.val)

In [61]:
# f(a, b, c) = (a+b)*c
# q = 1*a
# r = q + b
# f(a, b, c) = (q+b)*c = r*c

def f(a, b, c):
    return (a+b)*c

a, b, c = Input(), Input(), Input() # Will be fed later
q = Mult(Const(1), a) # Just to test a multiplication with a constant
r = Add(q, b)
output = Mult(r, c)

In [62]:
# f(-2, 5, -4) = -12
# https://youtu.be/d14TUNcbn1k?t=10m57s
a.feed(-2)
b.feed(5)
c.feed(-4)

python_res = f(a.forward(), b.forward(), c.forward())
graph_res = output.forward()

print(python_res, graph_res)

-12 -12


In [63]:
"""

--> a -> *1
          \
           +
          / \
     --> b   * -->
            /
      -->  c
"""

'\n\n--> a -> *1\n                     +\n          /      --> b   * -->\n            /\n      -->  c\n'

In [65]:
output.backward()
print(output.grad)
print(r.grad)
print(q.grad)
print(a.grad)
print(b.grad)
print(c.grad)

1
-4
-4
-4
-4
3
