In [45]:
%matplotlib inline
import networkx as nx
import math
import random
from collections import defaultdict
from matplotlib import pyplot as plt

In [51]:
class Node():
    def __init__(self):
        self.value = 0
    
class Sum(Node):
    def __call__(self, values):
        return sum(values)

    def derivative(self, value, activation):
        return 1
    
class Multiply(Node):
    def __call__(self, values):
        result = 1
        for val in values:
            result *= val
        return result
    
    def derivative(self, value, activation):
        return activation / value
    
class Input(Node):
    def __init__(self, input_value):
        self.value = input_value
    
    def __call__(self, values=None):
        return self.value
    
class Parameter(Node):
    def __init__(self, initial_value):
        self.value = initial_value
    
    def __call__(self, values=None):
        return 
    
class Loss(Node):
    def __init__(self, expected_value):
        self.expected = expected_value
    
    def __call__(self, values):
        return (sum(values) - self.expected)**2
    
    def derivative(self, value, activation):
        return 2 * activation

Create the architecture

In [95]:
G = nx.DiGraph()

G.add_node('x0', kind=Input(5), value=5)
G.add_node('x1', kind=Input(7), value=7)
G.add_node('w0', kind=Parameter(0.1), value=0.1)
G.add_node('w1', kind=Parameter(0.1), value=0.1)

G.add_node('mult0', kind=Multiply())
G.add_edge('x0', 'mult0')
G.add_edge('w0', 'mult0')

G.add_node('mult1', kind=Multiply())
G.add_edge('x1', 'mult1')
G.add_edge('w1', 'mult1')

G.add_node('sum', kind=Sum())
G.add_edge('mult0', 'sum')
G.add_edge('mult1', 'sum')

G.add_node('loss', kind=Loss(17))
G.add_edge('sum', 'loss')

In [96]:
for i in range(5000):
    # Forward pass
    for v in nx.topological_sort(G):
        inputs = [G.node[u]['value'] for u in G.predecessors(v)]
        if inputs:
            G.node[v]['value'] = G.node[v]['kind'](inputs)

    # Backward pass
    for n in G.nodes():
        G.node[n]['grad'] = 0
    G.node['loss']['grad'] = 1

    for v in nx.topological_sort(G, reverse=True):
        for u in G.predecessors(v):
            derivative_u_wrt_v = G.node[v]['kind'].derivative(G.node[u]['value'], G.node[v]['value'])
            derivative_v_wrt_F = G.node[v]['grad']
            G.node[u]['grad'] += derivative_u_wrt_v * derivative_v_wrt_F

    print(G.node['sum']['value'])
    
    # Update parameters
    for n in G.nodes():
        if n.startswith('w'):
           G.node[n]['value'] += 0.0001 * G.node[n]['grad']


1.2000000000000002
4.894672
7.063448696616243
8.52472746332695
9.58781508297604
10.40093426458957
11.045439759577537
11.57020021689847
12.0065445570301
12.3755765964918
12.692078915352532
12.966740039611476
13.207494391050961
13.420364253200827
13.61000837598131
13.78009061550288
13.933534298879596
14.072701834942583
14.199524138240381
14.315595581014396
14.422244781867505
14.520588146943721
14.611570897372445
14.695998882330965
14.774563515354222
14.847861515052672
14.916410675916865
14.980662574094078
15.041012883961109
15.097809815669045
15.151361062630047
15.201939558257301
15.24978827426927
15.29512424232557
15.338141942292339
15.379016170911065
15.417904481808748
15.454949269993225
15.490279560015983
15.524012545958186
15.556254922632622
15.587104040389276
15.616648910281004
15.64497108179492
15.672145411658665
15.698240739213798
15.723320481374827
15.747443158155328
15.770662858057696
15.793029651224368
15.814589957082157
15.835386872235944
15.855460463548878
15.874848030656274
