Simple Neural Network from scratch in pure python.

(I only use matplotlib to graph results out)

It's actually super simple and uses just one neuron.

Sum of Square Residuals
$$SSR = \sum_{i=1}^n(y_{actual,i} - y_{predicted,i})^2$$

In [248]:
import matplotlib.pyplot as plt
import math
import random

# Setup and Step 0
n_elems = 6

# actual values
fib_seq = [0,1]
for i in range(n_elems-len(fib_seq)):
  fib_seq.append(fib_seq[i] + fib_seq[i+1])

A = fib_seq

# activation function
def softplus(x):
  return math.log(1 + math.exp(x))

# setup for random
random.seed(11)
def rand():
  return random.gauss(0, 1)

# weights and biases
iw1 = rand() # input weight
ow1 = rand() # output weight
b1 = rand() # bias


def p_y(x):
    return softplus((x * iw1) + b1 )
            
def calc_P():
  return [  (p_y(x) * ow1) for x in range(n_elems) ]

# predicted values
P = calc_P()


# SSR (Sum of Square Residuals)
ssr_iw1 = []
ssr_b1 = []
ssr_ow1 = []
ssr_P = []

def SSR(actual, predicted):
    assert len(actual) == len(predicted), "Lists must be of same length"
    s = sum((lambda a,p: (a - p)**2)(a, p) for (a,p) in zip(actual, predicted))
    ssr_P.append(s)
    return s

# Derivatives
def dSSR_diw1(actual, predicted):
    # dSSR/diw1 = dSSR/dP * dP/dy1 * dy1/dx1 * dx1/diw1
    assert len(actual) == len(predicted), "Lists must be of same length"
    products = []
    for i, (a, p) in enumerate(zip(actual, predicted)):
        dSSP_dP = -2 * (a - p)
        dP_y1 = ow1
        dy1_dx1 = (math.exp(p_y(i)))/(1 + math.exp(p_y(i)))
        dx1_diw1 = i
        products.append(dSSP_dP * dP_y1 * dy1_dx1 * dx1_diw1)
        
    s = sum(products)
    ssr_iw1.append(s)
    print("dSSR_diw1= ", s)
    return s
                            


def dSSR_db1(actual, predicted):
    # dSSR/db1 = dSSR/dP * dP/dy1 * dy1/dx1 * dx1/db1
    assert len(actual) == len(predicted), "Lists must be of same length"
    products = []
    for i, (a, p) in enumerate(zip(actual, predicted)):
        dSSP_dP = -2 * (a - p)
        dP_y1 = ow1
        dy1_dx1 = (math.exp(p_y(i)))/(1 + math.exp(p_y(i)))
        dx1_db1 = 1
        products.append(dSSP_dP * dP_y1 * dy1_dx1 * dx1_db1)
        
    s = sum(products)
    ssr_b1.append(s)
    print("dSSR_db1= ", s)
    return s

def dSSR_dow1(actual, predicted):
    # dSSR/dow1 = dSSR/dP * dP/dow1
    assert len(actual) == len(predicted), "Lists must be of same length"
    products = []
    for i, (a, p) in enumerate(zip(actual, predicted)):
        dSSP_dP = -2 * (a - p)
        dP_ow1 = p_y(i)

        products.append(dSSP_dP * dP_ow1)
        
    s = sum(products)
    ssr_ow1.append(s)
    print("dSSR_dow1= ", s)
    return s


# Plotting
def plot_results(step_n, A, P):
    fig = plt.figure(figsize=(16.0, 8.0))
    fig.suptitle(f"Step {step_n}")
    gs = plt.GridSpec(2, 4, figure=fig)
    
    
    ax0 = fig.add_subplot(gs[0, :])

    ax0.set_title("actual vs predicted")
    ax0.plot(A, 'co')
    ax0.plot(P, 'm')
    ax0.legend(["actual", "predicted"])

    ax1 = fig.add_subplot(gs[1, 0])
    ax2 = fig.add_subplot(gs[1, 1])
    ax3 = fig.add_subplot(gs[1, 2])
    ax4 = fig.add_subplot(gs[1, 3])

    ax1.set_title("SSR P")
    ax1.plot(ssr_P, 'r.')
    
    ax2.set_title("d SSR P/ d iw1")
    ax2.plot(ssr_iw1, 'r.')
    
    ax3.set_title("d SSR P/ d b1")
    ax3.plot(ssr_b1, 'r.')
    
    ax4.set_title("d SSR P/ d ow1")
    ax4.plot(ssr_ow1, 'r.')
    print("iw1=", iw1, "b1=", b1, "ow1=", ow1)
    
    # there need to be 1 before 000, because Windows doesn't respect alphabetical order
    img_path = (f"./01_images/img_1{step_n:03}_01.jpg") 
    plt.savefig(img_path)
    plt.close()


SSR(A, P)
plot_results(0, A, P)


LR = 0.01 # Learning rate

iw1= -1.224072675713965 b1= 0.9949996276709397 ow1= 0.3775881983015979


In [249]:
for n in range(200):
    step_iw1 = (dSSR_diw1(A, P) * LR)
    step_b1 = (dSSR_db1(A, P) * LR)
    step_ow1 = (dSSR_dow1(A, P) * LR)

    iw1 -= step_iw1
    b1 -= step_b1
    ow1 -= step_ow1

    P = calc_P()

    SSR(A, P)
    plot_results(n+1, A, P)
    print("step", n+1)

dSSR_diw1=  -17.485638049162585
dSSR_db1=  -4.272199423890715
dSSR_dow1=  -0.4451145302847917
iw1= -1.0492162952223392 b1= 1.0377216219098468 ow1= 0.3820393436044458
step 1
dSSR_diw1=  -17.788035983703395
dSSR_db1=  -4.329203766742774
dSSR_dow1=  -1.0102408651427344
iw1= -0.8713359353853052 b1= 1.0810136595772746 ow1= 0.39214175225587317
step 2
dSSR_diw1=  -18.454932796846865
dSSR_db1=  -4.46190178152654
dSSR_dow1=  -1.9530629847808416
iw1= -0.6867866074168365 b1= 1.12563267739254 ow1= 0.41167238210368157
step 3
dSSR_diw1=  -19.79300857499461
dSSR_db1=  -4.727419294527547
dSSR_dow1=  -3.631810082854193
iw1= -0.4888565216668904 b1= 1.1729068703378156 ow1= 0.4479904829322235
step 4
dSSR_diw1=  -22.444813702593255
dSSR_db1=  -5.232809192968425
dSSR_dow1=  -6.774966822189565
iw1= -0.26440838464095784 b1= 1.2252349622674998 ow1= 0.5157401511541191
step 5
dSSR_diw1=  -27.455824067685818
dSSR_db1=  -6.084377194276533
dSSR_dow1=  -12.4668406347769
iw1= 0.010149856035900351 b1= 1.28607873421026

step 49
dSSR_diw1=  -13.593600809524405
dSSR_db1=  -2.298977650481809
dSSR_dow1=  -24.496455996008784
iw1= 1.1019242520258374 b1= 0.45026606476615144 ow1= 0.80868783159851
step 50
dSSR_diw1=  13.785773028563879
dSSR_db1=  6.159544356562502
dSSR_dow1=  24.366746158948036
iw1= 0.9640665217401987 b1= 0.3886706212005264 ow1= 0.5650203700090296
step 51
dSSR_diw1=  -13.944699074235157
dSSR_db1=  -2.4177494755090088
dSSR_dow1=  -24.966039922934748
iw1= 1.1035135124825501 b1= 0.4128481159556165 ow1= 0.8146807692383771
step 52
dSSR_diw1=  13.930755685012194
dSSR_db1=  6.16362667163407
dSSR_dow1=  24.14830428602503
iw1= 0.9642059556324282 b1= 0.3512118492392758 ow1= 0.5731977263781267
step 53
dSSR_diw1=  -13.93798893285523
dSSR_db1=  -2.4147880247825007
dSSR_dow1=  -24.43824777651939
iw1= 1.1035858449609806 b1= 0.3753597294871008 ow1= 0.8175802041433207
step 54
dSSR_diw1=  13.571242795724196
dSSR_db1=  6.016469066099701
dSSR_dow1=  23.219779516528334
iw1= 0.9678734170037386 b1= 0.315195038826103

step 97
dSSR_diw1=  -7.157123091160589
dSSR_db1=  -0.787949930500339
dSSR_dow1=  -10.669811665473286
iw1= 1.2057088741523443 b1= -0.27634929698122274 ow1= 0.8006435832556299
step 98
dSSR_diw1=  6.57307516209411
dSSR_db1=  3.1213401224847646
dSSR_dow1=  10.269100824718665
iw1= 1.1399781225314032 b1= -0.3075626982060704 ow1= 0.6979525750084432
step 99
dSSR_diw1=  -6.892375875043028
dSSR_db1=  -0.7294792656829615
dSSR_dow1=  -10.229771032120503
iw1= 1.2089018812818335 b1= -0.30026790554924077 ow1= 0.8002502853296483
step 100
dSSR_diw1=  6.316762006706686
dSSR_db1=  3.0226456547951077
dSSR_dow1=  9.828008170311273
iw1= 1.1457342612147667 b1= -0.33049436209719185 ow1= 0.7019702036265355
step 101
dSSR_diw1=  -6.622340884110338
dSSR_db1=  -0.6694795347502358
dSSR_dow1=  -9.785445895143432
iw1= 1.21195767005587 b1= -0.3237995667496895 ow1= 0.7998246625779698
step 102
dSSR_diw1=  6.055533789572976
dSSR_db1=  2.9234717841836626
dSSR_dow1=  9.38370248534816
iw1= 1.1514023321601403 b1= -0.35303428

step 145
dSSR_diw1=  -1.1841343438781058
dSSR_db1=  0.5272943928787179
dSSR_dow1=  -1.6500730722908061
iw1= 1.249457530752249 b1= -0.759191507359115 ow1= 0.8011976642442471
step 146
dSSR_diw1=  0.9815388030729943
dSSR_db1=  1.1132527178938112
dSSR_dow1=  1.3103517022873952
iw1= 1.2396421427215192 b1= -0.7703240345380531 ow1= 0.7880941472213732
step 147
dSSR_diw1=  -1.0341144231433574
dSSR_db1=  0.5563092794503364
dSSR_dow1=  -1.4497839971817204
iw1= 1.2499832869529528 b1= -0.7758871273325565 ow1= 0.8025919871931904
step 148
dSSR_diw1=  0.8519973707640469
dSSR_db1=  1.0652311751193988
dSSR_dow1=  1.1132815158312441
iw1= 1.2414633132453123 b1= -0.7865394390837505 ow1= 0.7914591720348779
step 149
dSSR_diw1=  -0.8968719506273857
dSSR_db1=  0.5819971200887069
dSSR_dow1=  -1.2686486709960993
iw1= 1.250432032751586 b1= -0.7923594102846375 ow1= 0.8041456587448389
step 150
dSSR_diw1=  0.7347710001464289
dSSR_db1=  1.020945813167053
dSSR_dow1=  0.9351301457006809
iw1= 1.2430843227501218 b1= -0.8

step 192
dSSR_diw1=  0.04876354696295504
dSSR_db1=  0.6277636044603169
dSSR_dow1=  -0.1677550336210034
iw1= 1.244598029323248 b1= -1.0999526623437437 ow1= 0.8646574308901753
step 193
dSSR_diw1=  0.03706763485917097
dSSR_db1=  0.6205235042129408
dSSR_dow1=  -0.18345274983372928
iw1= 1.2442273529746561 b1= -1.106157897385873 ow1= 0.8664919583885127
step 194
dSSR_diw1=  0.049242003485225894
dSSR_db1=  0.6196493085990815
dSSR_dow1=  -0.17131111136633326
iw1= 1.243734932939804 b1= -1.1123543904718638 ow1= 0.868205069502176
step 195
dSSR_diw1=  0.040807865195809256
dSSR_db1=  0.6133299610569863
dSSR_dow1=  -0.1829983219348339
iw1= 1.243326854287846 b1= -1.1184876900824336 ow1= 0.8700350527215244
step 196
dSSR_diw1=  0.05002172544643191
dSSR_db1=  0.611718170440233
dSSR_dow1=  -0.17424716811115015
iw1= 1.2428266370333816 b1= -1.124604871786836 ow1= 0.8717775244026359
step 197
dSSR_diw1=  0.04402717894508967
dSSR_db1=  0.6060980969450979
dSSR_dow1=  -0.18293155424384722
iw1= 1.2423863652439306