In [252]:
class Value:

    def __init__(self, data, parents=()):
        self.data = data
        self.parents = parents
        self.weight = 0.0

    def __add__(self, other):

        child = Value(self.data + other.data, parents=(self, other, '+'))
        
        def _backward():
            self.weight += 1.0 * child.weight
            other.weight += 1.0 * child.weight

            self._backward()
            other._backward()
        
        child._backward = _backward

        return child

    def __mul__(self, other):

        child = Value(self.data * other.data, parents=(self, other, '*'))
        
        def _backward():
            self.weight += other.data * child.weight
            other.weight += self.data * child.weight

            self._backward()
            other._backward()

        child._backward = _backward
        
        return child

    def _backward(self):
        pass
        #This will only be called in the last node. For the others, we call _backward. Thus we need topo sort to call it in the correct order (it won't be propagated backwards)
        #Implement a topological sort. Will work without topo sort, however topo sort ensures we do not recompute values (e.g. we don't recompute w if w*x0 and w*x1)

    def backward(self):
        self.weight = 1.0
        self._backward()

    def __repr__(self) -> str:
        if self.parents:
            return f"Val({self.data}, weight: {self.weight}, parents: ({self.parents[0].data} {self.parents[2]} {self.parents[1].data}))"
        else:
            return f'Val({self.data}, weight: {self.weight}, parents: None)'

In [253]:
a = Value(3.0)
b = Value(4.0)

c = a * b
print(c)
print('test')

Val(12.0, weight: 0.0, parents: (3.0 * 4.0))
test


In [254]:
d = c * Value(3.0)

In [255]:
#BUG: Probably because 'Value' has no parent. Thus. d = c * Value(3.0). So d = self:c, other:3.0. self.weight = other.data(3.0) * out.weight(-4.0)
e = d * Value(-4.0)


e._backward()

a.weight

-48.0

In [256]:
print(f'a: {a.parents}, {a.weight}')

print(f'b: {b.parents}, {b.weight}')

print(f'c: {c.parents}, {c.weight}')

print(f'd: {d.parents}, {d.weight}')

print(f'e: {e.parents}, {e.weight}')

a: (), -48.0
b: (), -36.0
c: (Val(3.0, weight: -48.0, parents: None), Val(4.0, weight: -36.0, parents: None), '*'), -12.0
d: (Val(12.0, weight: -12.0, parents: (3.0 * 4.0)), Val(3.0, weight: -48.0, parents: None), '*'), -4.0
e: (Val(36.0, weight: -4.0, parents: (12.0 * 3.0)), Val(-4.0, weight: 36.0, parents: None), '*'), 1.0
