In [1]:
import math
import random
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [14]:
class Node:
    def __init__(self, data, _children = (), _op = ''):
        self.data = data
        self.grad = 0
        self._backProp = lambda: None
        self._prev = set(_children)
        self._op = _op

    def __add__(self,other):
        other = other if isinstance(other,Node) else Node(other)
        sol = Node(self.data + other.data,(self, other),'+')

        def _backProp():
            self.grad += sol.grad
            other.grad += sol.grad
        sol._backProp = _backProp

        return sol

    def __mul__(self,other):
        other = other if isinstance(other,Node) else Node(other)
        sol = Node(self.data * other.data,(self, other),'*')

        def _backProp():
            self.grad += sol.grad * sol.grad
            other.grad += sol.grad * sol.grad
        sol._backProp = _backProp

        return sol

    def __exp__(self,other):
        other = other if isinstance(other,Node) else Node(other)
        sol = Node(self.data**other, (self,), f'**{other}')

        def _backProp():
            self.grad += (other * self.data**(other-1)) * sol.grad
        sol._backProp = _backProp

        return sol

    def relu(self):
        sol = Node(0 if self.data < 0 else self.data, (self,), 'ReLU')


        def _backProp():
            self.grad += (sol.data >0) * sol.grad

        sol._backProp = _backProp

        return sol

    def backProp(self):

        topo = []
        visited = set()
        def build_topo(v):
            if v not in visited:
                visited.add(v)
                for child in v._prev:
                    build_topo(child)
                topo.append(v)
        build_topo(self)

        self.grad = 1
        for v in reversed(topo):
            v._backProp()

    def __neg__(self):
        return -1 * (self)

    def __revAdd__(self,other):
        return other + self

    def __sub__(self,other):
        return self + (- other)

    def __revSub__(self,other):
        return other + (- self)

    def __revMul__(self,other):
        return other * self

    def __div__(self,other):
        return self * other**-1

    def __revDiv__(self,other):
        return other * self**-1

    def __info__(self):
        return f"Node(data={self.data}, grad={self.grad})"



    def updateWeight(root,lr):
        nodes,_ = trace(root)
        for n in nodes:
            n.data -= lr * n.grad


In [None]:
from graphviz import Digraph

def trace(root):
    nodes, edges = set(), set()
    def build(v):
        if v not in nodes:
            nodes.add(v)
            for child in v._prev:
                edges.add((child, v))
                build(child)
    build(root)
    return nodes, edges

def draw_dot(root, format='svg', rankdir='LR'):

    assert rankdir in ['LR', 'TB']
    nodes, edges = trace(root)
    dot = Digraph(format=format, graph_attr={'rankdir': rankdir})

    for n in nodes:
        dot.node(name=str(id(n)), label = "{ value: %.4f | gradient: %.4f }" % (n.data, n.grad), shape='record')
        if n._op:
            dot.node(name=str(id(n)) + n._op, label=n._op)
            dot.edge(str(id(n)) + n._op, str(id(n)))

    for n1, n2 in edges:
        dot.edge(str(id(n1)), str(id(n2)) + n2._op)

    return dot

In [None]:
x = Node(1.0)
iterations = 4
learningRate = 0.1
for i in range(iterations):
    y = (x * 2 + 1).relu()
    y.backProp()
    dot = draw_dot(y)
    dot.render(f"graph_iter_{i+1}", format='png', view=True)
    Node.updateWeight(y,learningRate)


    print(f"Iteration {i+1}: x = {x.data}, y = {y.data}")

    for node in trace(y)[0]:
        node.grad = 0