In [264]:
import numpy as np
from typing import List
from collections.abc import Callable
import operator as op
import math
from collections import deque, defaultdict

In [248]:
class Operator:
    def __init__(self, op: Callable):
        self.op = op
    
    def eval(self, *args: List[np.number]):
        return self.op(*args)

class Node:
    def __init__(self, op: Callable = None, partials: Callable = None, input_nodes: List['Node'] = None):
        self.N = len(input_nodes) if input_nodes else 1
        self.input_nodes = [] if input_nodes is None else input_nodes
        self.op = Operator(op)
        self.partials = partials
        self.input_cache = np.zeros(self.N)
        
        self.reset_gradient()
        self.name: str
    
    def push(self):
        for i, node in enumerate(self.input_nodes):
            self.input_cache[i] = node.push()
        return self.op.eval(*self.input_cache)

    def reset_gradient(self):
        self.gradient = np.zeros((self.N))

In [249]:
OPS = {'ADD': op.add,
       'MULT': op.mul,
       'SUB': op.sub,
       'DIV': np.divide,
       'EXP': math.exp,
       'POW': math.pow}

class InputNode(Node):
    def __init__(self, name: str = ""):
        super().__init__(partials = lambda: np.array([1]))
        self.name = f'Input_{name}'
    
    def push(self):
        return self.op.eval()

class AddNode(Node):
    def __init__(self, *args: List[Node]):
        super().__init__(op = OPS['ADD'], partials = lambda: np.array([1, 1]), input_nodes = args)
        self.name = 'Add'

class MultNode(Node):
    def __init__(self, *args: List[Node]):
        super().__init__(op = OPS['MULT'], partials = lambda: self.input_cache[::-1], input_nodes = args)
        self.name = 'Mult'

class SubNode(Node):
    def __init__(self, *args: List[Node]):
        super().__init__(op = OPS['SUB'], partials = lambda: np.array([1, -1]), input_nodes = args)
        self.name = 'Sub'

class DivNode(Node):
    def __init__(self, *args: List[Node]):
        super().__init__(op = OPS['DIV'], partials = lambda: [1 / self.input_cache[1], -self.input_cache[0] / math.pow(self.input_cache[1], 2)], input_nodes = args)
        self.name = 'Div'

class ExpNode(Node):
    def __init__(self, *args: List[Node]):
        super().__init__(op = OPS['EXP'], partials = lambda: [math.exp(self.input_cache[0])], input_nodes = args)
        self.name = 'Exp'
    
class PowNode(Node):
    def __init__(self, *args: List[Node]):
        super().__init__(op = OPS['POW'], partials = lambda: [math.pow(self.input_cache[0], self.input_cache[1] - 1) if self.input_cache[1] != 0 else math.log(self.input_cache[0]), math.pow(*self.input_cache) * math.log(self.input_cache[0]),], input_nodes = args)
        self.name = 'Pow'

In [259]:
class Function:
    def __init__(self, input_nodes: List[InputNode], output_node: Node):
        self.input_nodes = input_nodes
        self.output_node = output_node
        self.outputs = defaultdict(set)
        self.inputs = defaultdict(set)

        def dfs(cur: Node):
            for node in cur.input_nodes:
                self.inputs[cur].add(node)
                self.outputs[node].add(cur)
                dfs(node)
        
        dfs(output_node)
    
    # forward pass
    def eval(self, *args: List[np.number]):
        if len(args) != len(self.input_nodes):
            raise Exception('Length of args mismatch')
        for arg, input_node in zip(args, self.input_nodes):
            input_node.op = Operator(lambda a=arg: a)
        return self.output_node.push()
    
    # backpropagate
    def backward(self):
        dq = deque([(self.output_node, 1.0)])
        visits = defaultdict(int)

        while dq:
            cur, partial = dq.popleft()
            if visits[cur] == 0:
                cur.reset_gradient()
            visits[cur] += 1
            for i in range(cur.N):
                cur.gradient[i] += cur.partials()[i] * partial
            if visits[cur] >= len(self.outputs[cur]) and cur.input_nodes:
                for i in range(cur.N):
                    dq.append((cur.input_nodes[i], cur.gradient[i]))
    
        for input_node in self.input_nodes:
            print(f'{input_node.name}: {input_node.gradient[0]}')

In [268]:
input_A = InputNode('A')
input_B = InputNode('B')
input_C = InputNode('C')
input_D = InputNode('D')

# d - (e^((a+b)*c) / b)
add = AddNode(input_A, input_B)
mult = MultNode(add, input_C)
exp = ExpNode(mult)
div = DivNode(exp, input_B)
sub = SubNode(input_D, div)
pow = PowNode(input_C, input_A)
add2 = AddNode(sub, mult)
mult2 = MultNode(add2, sub)
div2 = DivNode(mult2, exp)
exp2 = ExpNode(pow)
sub2 = SubNode(exp2, div2)

func = Function([input_A, input_B, input_C, input_D], sub2)
print(func.eval(1, 2, 3, 4))

-1997.1918622804708


In [269]:
func.backward()

Input_A: -6009.596521604817
Input_B: -4054.274192775904
Input_C: -6055.709637746562
Input_D: 0.9979020333305264


In [141]:
class Tensor:
    def __init__(self, value: np.array):
        self.value = value

class Layer:
    def __init__(self, input_size, output_size):
        self.W = Tensor(np.random.rand(input_size, output_size))
        self.B = Tensor(np.zeros((input_size, output_size)))

class ActivationLayer:
    def __init__(self, input_size):
        pass

class Model:
    def __init__():
        pass