In [1]:
import numpy as np

In [2]:
relu = lambda x: x if x > 0 else 0
drelu = lambda x: x > 0
sigm = lambda x : 1.0 / (1.0 + np.exp(-a.value))
def dsigm(x):
    s = sigm(x)
    return s * (1 - s)

In [3]:
class Node:
    def __init__(self, node_type, g_id, childern, value = None):
        self.node_type = node_type
        self.value = value
        self.childern = childern
        self.grad = 0.0
        self.id = g_id
    
    def __repr__(self):
        return f"Node : {self.node_type} {self.id}"

In [4]:
class ComputationalGraph:
    def __init__(self):
        self.nodes = []
    
    def add_value_node(self, value):
        var = Node(node_type='var', g_id=len(self.nodes), childern=[], value=value)
        self.nodes.append(var)
        return var
    
    def add_binary_node(self, n_type, a, b):
        # n_type = 'add', 'sub', 'mul', 'div', 'pow', 'cost'
        node = Node(node_type=n_type, g_id=len(self.nodes), childern = [a, b])
        self.nodes.append(node)
        print(node.childern)
        return node

    def add_unary_node(self, n_type, a):
        # n_type = 'neg', 'relu', 'tanh', 'sigm'
        node = Node(node_type=n_type, g_id=len(self.nodes), childern = [a])
        self.nodes.append(node)
        return node
    
    def zero_grads(self):
        for node in self.nodes:
            node.grad = 0
    
    def forward(self):
        for node in self.nodes:
            print(node, node.childern)

            if node.node_type == 'var':
                continue
            elif node.node_type == 'add':
                a, b = node.children
                node.value = a.value + b.value
                continue
            elif node.node_type == 'sub':
                a, b = node.children
                node.value = a.value - b.value
                continue
            elif node.node_type == 'mul':
                a, b = node.children
                node.value = a.value * b.value
                continue
            elif node.node_type == 'div':
                a, b = node.children
                assert b.value != 0
                node.value = a.value / b.value
                continue
            elif node.node_type == 'pow':
                a, b = node.children
                node.value = a.value ** b.value
                continue
            elif node.node_type == 'neg':
                a, = node.children
                node.value = - a.value
                continue
            elif node.node_type == 'relu':
                a, = node.children
                node.value = relu(a.value)
                continue
            elif node.node_type == 'sigm':
                a, = node.children
                node.value = sigm(a.value)
                continue
            elif node.node_type == 'logsoftmax':
                a, = node.children
                shifted = a.value - np.max(a.value, axis=-1, keepdims=True)
                logsum = np.log(np.sum(np.exp(shifted), axis=-1, keepdims=True))
                node.value = shifted - logsum  # log(softmax(x))
                continue
            elif node.node_type == 'cross_entropy':
                logprob, true = node.children
                node.value = -np.sum(logprob.value * true.value)
                continue
            else:
                raise ValueError
                continue
    
    def init_grad(self, node):
        if node.grad is None:
            node.grad = np.zeros_like(node.value)
        
    def backward(self):
        reversed_nodes = reversed(self.nodes)

        for node in reversed_nodes:
            match node.node_type:
                case 'var':
                    continue
                case 'add':
                    a, b = node.children
                    init_grad(a)
                    init_grad(b)

                    a_grad, b_grad = node.grad, node.grad
                    a.grad += node.grad
                    b.grad += node.grad
                    continue
                case 'sub':
                    a, b = node.children
                    init_grad(a)
                    init_grad(b)

                    a_grad, b_grad = node.grad, node.grad
                    a.grad += n.grad
                    b.grad += n.grad
                    continue
                case 'mul':
                    a, b = node.children
                    init_grad(a)
                    init_grad(b)

                    a.grad += node.grad @ b.grad.T
                    b.grad += a.grad.T @ node.grad
                    continue
                case 'div':
                    a, b = node.children
                    init_grad(a)
                    init_grad(b)

                    assert b.value != 0
                    node.value = a.value / b.value
                    continue
                case 'pow':
                    a, b = node.children
                    init_grad(a)
                    init_grad(b)

                    node.value = a.value ** b.value
                    continue
                case 'neg':
                    a, = node.children
                    init_grad(a)
                    
                    a.grad += node.grad
                    continue
                case 'relu':
                    a, = node.children
                    init_grad(a)
                    
                    a.grad += drelu(a.value).T @ node.grad
                    continue
                case 'sigm':
                    a, = node.children
                    init_grad(a)
                    s = node.value
                    a.grad += node.grad * (s * (1.0 - s))
                    continue
                case 'logsoftmax':
                    x, = node.children
                    init_grad(x)
                    g = node.grad
                    logp = node.value
                    p = np.exp(logp)
                    s = np.sum(g * p, axis=-1, keepdims=True)
                    x.grad += g - p * s
                    continue
                case 'cross_entropy':
                    y_hat, true = node.children
                    init_grad(y_hat)
                    y_hat.grad += -true.value * node.grad
                    continue
                case _:
                    raise ValueError
                    continue


In [5]:
rng = np.random.RandomState(0)
x = rng.randn(1, 3)
y_idx = 1
y_onehot = np.zeros((1, 2))
y_onehot[0, y_idx] = 1

In [6]:
W1 = np.random.random(size = (x.shape[1], 4))
W2 = np.random.random(size = (4, 2))
b1 = np.random.random(size = 4)
b2 = np.random.random(size = 2)

In [7]:
g = ComputationalGraph()
nw1 = g.add_value_node(W1)
nw2 = g.add_value_node(W2)
nb1 = g.add_value_node(b1)
nb2 = g.add_value_node(b2)
nx = g.add_value_node(x)
ny = g.add_value_node(y_onehot[0])

In [8]:
# h = sigmoid(x @ W1 + b1)
# o = logsoftmax( h @ W2 + b2)
# cross_entropy(o, y)

In [9]:
h = g.add_unary_node('sigm', g.add_binary_node('add', g.add_binary_node('mul', nx, nw1), nb1))
o = g.add_unary_node('logsoftmax', g.add_binary_node('add', g.add_binary_node('mul', nx, nw2), nb2))
loss = g.add_binary_node('cross_entropy', o, ny)

[Node : var 4, Node : var 0]
[Node : mul 6, Node : var 2]
[Node : var 4, Node : var 1]
[Node : mul 9, Node : var 3]
[Node : logsoftmax 11, Node : var 5]


In [10]:
print(g.nodes[6].childern)

[Node : var 4, Node : var 0]


In [11]:
g.forward()

Node : var 0 []
Node : var 1 []
Node : var 2 []
Node : var 3 []
Node : var 4 []
Node : var 5 []
Node : mul 6 [Node : var 4, Node : var 0]


AttributeError: 'Node' object has no attribute 'children'

In [172]:
g.zero_grads()
loss.grad = 1.0   # dL/dL = 1, seed gradient

In [173]:
g.backward()

AttributeError: 'Node' object has no attribute 'children'