In [10]:

from engine import Value
from nn import MLP
import random
import numpy as np
from graphviz import Digraph


In [11]:
def trace(root):
    # builds sets of nodes and edges in the graph
    nodes, edges = set(), set()
    def build(v):
        if v not in nodes:
            nodes.add(v)
            for c in v._children:
                edges.add((c,v))
                build(c)
    build(root)
    return nodes, edges
    
def draw_dot(root):
    dot = Digraph(format='svg',graph_attr={'rankdir':'LR'}) # left to right
    nodes, edges = trace(root)
    for n in nodes:
        uid = str(id(n))
        # for any value in the graph, create a rectangular ('record') node for it
        # print(n)
        # print(n.label)
        val_node_label = '' if not len(n.label) else f'{n.label}, '    
        val_node_label += f'v:{n.data:.2f}, ∂(L)/∂(v):{n.grad:.2f}'
        # print(val_node_label)
        dot.node(name=uid,label=val_node_label,shape='record')
        if n._op:
            # if the node is a result of some operation, create entering op node to it
            dot.node(name=uid+n._op, label=n._op) # create op node
            dot.edge(uid+n._op, uid) # edge from op symbol to op result
    
    for n1, n2 in edges:
        dot.edge(str(id(n1)),str(id(n2))+n2._op) # edge from previous value to next op symbol

    return dot


In [12]:
# random.seed(42)
nn = MLP(3,[4,4,1],'relu')


X_train = [
    [2.0, 3.0, -1.0],
    [3.0, -1.0, 0.5],
    [0.5, 1.0, 1.0],
    [1.0, 1.0, -1.0]
        
]
y_train = [1.0, -1.0, -1.0, 1.0] 

n_epoch = 1000
lr = 0.05
for ep in range(n_epoch):
    # loss_data = Value(0.0)
    loss = 0.0
    for xi,yi in zip(X_train,y_train):
        yi_pred = nn(xi)
        loss_i = (yi_pred - yi)**2
        loss = loss_i + loss
        # print(f'{yi_pred.data:.2f}', yi)
    print("loss=", loss.data)
    
    # loss = Value(loss_data, children=(y_pred,))
    for p in nn.parameters():
        p.grad = 0
    loss.backward() # get d(L)/d(p) for each parameter p in the network 
    for p in nn.parameters():
        p.data -= lr * p.grad # Update network params: p -= lr * d(L)/d(p)


loss= 2.9794744865893903
loss= 0.9451531802830089
loss= 2.7068860038673774
loss= 11.336819645267735
loss= 9.485250421048393
loss= 4.540941338464529
loss= 3.748582934762484
loss= 3.344644035403734
loss= 2.7242527277031576
loss= 2.0050675952706554
loss= 1.481615672793131
loss= 1.0771403571632663
loss= 0.7689053251961159
loss= 0.5232219664280264
loss= 0.34928652124836845
loss= 0.2314091265853346
loss= 0.15963500167011502
loss= 0.1103035707612063
loss= 0.08976586486317639
loss= 0.12395513368341778
loss= 0.2930494368085301
loss= 1.1326568685830924
loss= 2.0912036325734453
loss= 6.828665249727044
loss= 1.9176254884426545
loss= 0.6194316262126067
loss= 0.8172935615321865
loss= 0.7912066340990824
loss= 1.9507099549590372
loss= 0.9438873487547269
loss= 2.49159300096818
loss= 0.7023579998510625
loss= 1.6655396027938476
loss= 1.0181690248943591
loss= 2.759274003028871
loss= 0.5691654985131166
loss= 1.2214156423440041
loss= 1.0425978274370489
loss= 2.8947092171421747
loss= 0.5126160144127507
loss=

In [None]:
# draw_dot(loss)