Simple Neural Network from scratch in pure python.

(I only use matplotlib to graph results out)

It's actually quite simple and uses just two neurons.

Sum of Square Residuals
$$SSR = \sum_{i=1}^n(y_{actual,i} - y_{predicted,i})^2$$

In [124]:
import matplotlib.pyplot as plt
import math
import random

# Setup and Step 0
n_elems = 12

X = [ (n+1)/(n_elems/2) for n in range(n_elems) ]
print(X)

# actual values
"""
fib_seq = [0,1]
for i in range((n_elems-3)//2):
  fib_seq.append(fib_seq[i] + fib_seq[i+1])


A =  [-2*n for n in fib_seq] + [ i for i in reversed([-3*n for n in fib_seq]) ]
A = [ (n+2/fib_seq[-1])/n_elems for n in A]
print(len(A))
"""

A = [ softplus(math.sin(math.pi*x)*2) for x in X]

# activation function
def softplus(x):
  return math.log(1 + math.exp(x))

# setup for random
random.seed(13)
def rand():
  return random.gauss(0, 3)

# weights and biases
iw1 = rand() # input weight
ow1 = rand() # output weight
b1 = rand() # bias

iw2 = rand() # input weight
ow2 = rand() # output weight
b2 = rand() # bias



def p_y1(x):
    return softplus((x * iw1) + b1 )

def p_y2(x):
    return softplus((x * iw2) + b2 )
            
def calc_N1():
  return [  (p_y1(x) * ow1) for x in X ]

def calc_N2():
  return [  (p_y2(x) * ow2) for x in X ]

def calc_P():
    return [(n1+n2) for (n1,n2) in zip(calc_N1(), calc_N2())]

# predicted values
N1 = calc_N1()
N2 = calc_N2()
P = calc_P()
print("P", len(P))

# SSR (Sum of Square Residuals)
ssr_iw1 = []
ssr_b1 = []
ssr_ow1 = []

ssr_iw2 = []
ssr_b2 = []
ssr_ow2 = []

ssr_P = []

def SSR(actual, predicted):
    assert len(actual) == len(predicted), "Lists must be of same length"
    s = sum((lambda a,p: (a - p)**2)(a, p) for (a,p) in zip(actual, predicted))
    ssr_P.append(s)
    return s

# Derivatives
def dSSR_diw1(actual, predicted):
    # dSSR/diw1 = dSSR/dP * dP/dy1 * dy1/dx1 * dx1/diw1
    assert len(actual) == len(predicted), "Lists must be of same length"
    products = []
    for i, (a, p) in enumerate(zip(actual, predicted)):
        dSSP_dP = -2 * (a - p)
        dP_y1 = ow1
        dy1_dx1 = (math.exp(p_y1(i)))/(1 + math.exp(p_y1(i)))
        dx1_diw1 = i
        products.append(dSSP_dP * dP_y1 * dy1_dx1 * dx1_diw1)
        
    s = sum(products)
    ssr_iw1.append(s)
    return s
                            
def dSSR_db1(actual, predicted):
    # dSSR/db1 = dSSR/dP * dP/dy1 * dy1/dx1 * dx1/db1
    assert len(actual) == len(predicted), "Lists must be of same length"
    products = []
    for i, (a, p) in enumerate(zip(actual, predicted)):
        dSSP_dP = -2 * (a - p)
        dP_y1 = ow1
        dy1_dx1 = (math.exp(p_y1(i)))/(1 + math.exp(p_y1(i)))
        dx1_db1 = 1
        products.append(dSSP_dP * dP_y1 * dy1_dx1 * dx1_db1)
        
    s = sum(products)
    ssr_b1.append(s)
    return s

def dSSR_dow1(actual, predicted):
    # dSSR/dow1 = dSSR/dP * dP/dow1
    assert len(actual) == len(predicted), "Lists must be of same length"
    products = []
    for i, (a, p) in enumerate(zip(actual, predicted)):
        dSSP_dP = -2 * (a - p)
        dP_ow1 = p_y1(i)

        products.append(dSSP_dP * dP_ow1)
        
    s = sum(products)
    ssr_ow1.append(2)
    return s



def dSSR_diw2(actual, predicted):
    assert len(actual) == len(predicted), "Lists must be of same length"
    products = []
    for i, (a, p) in enumerate(zip(actual, predicted)):
        dSSP_dP = -2 * (a - p)
        dP_y2 = ow2
        dy2_dx2 = (math.exp(p_y2(i)))/(1 + math.exp(p_y2(i)))
        dx2_diw2 = i
        products.append(dSSP_dP * dP_y2 * dy2_dx2 * dx2_diw2)
        
    s = sum(products)
    ssr_iw2.append(s)
    return s
                            
def dSSR_db2(actual, predicted):
    assert len(actual) == len(predicted), "Lists must be of same length"
    products = []
    for i, (a, p) in enumerate(zip(actual, predicted)):
        dSSP_dP = -2 * (a - p)
        dP_y2 = ow2
        dy2_dx2 = (math.exp(p_y2(i)))/(1 + math.exp(p_y2(i)))
        dx2_db2 = 1
        products.append(dSSP_dP * dP_y2 * dy2_dx2 * dx2_db2)
        
    s = sum(products)
    ssr_b1.append(s)
    return s

def dSSR_dow2(actual, predicted):
    assert len(actual) == len(predicted), "Lists must be of same length"
    products = []
    for i, (a, p) in enumerate(zip(actual, predicted)):
        dSSP_dP = -2 * (a - p)
        dP_ow2 = p_y2(i)

        products.append(dSSP_dP * dP_ow2)
        
    s = sum(products)
    ssr_ow1.append(s)
    return s


# Plotting
def plot_results(step_n, A, P):
    fig = plt.figure(figsize=(16.0, 10.0))
    fig.suptitle(f"Step {step_n}")
    gs = plt.GridSpec(4, 4, figure=fig)
    
    
    ax0 = fig.add_subplot(gs[0:2, :])

    ax0.set_title("actual vs predicted")
    ax0.set_ylim([-3, 3])
    ax0.plot(X,A, 'ko')
    ax0.plot(X,N1, 'c')
    ax0.plot(X,N2, 'y')
    ax0.plot(X,P, 'm')
    ax0.legend(["actual", "neuron 1", "neuron 2", "predicted"])

    ax11 = fig.add_subplot(gs[2, 0])
    ax12 = fig.add_subplot(gs[2, 1])
    ax13 = fig.add_subplot(gs[2, 2])
    ax14 = fig.add_subplot(gs[2, 3])

    ax11.set_title("SSR P")
    ax11.plot(ssr_P, 'm')
    
    ax12.set_title("d SSR P/ d iw1")
    ax12.plot(ssr_iw1, 'c')
    
    ax13.set_title("d SSR P/ d b1")
    ax13.plot(ssr_b1, 'c')
    
    ax14.set_title("d SSR P/ d ow1")
    ax14.plot(ssr_ow1, 'c')
    
    ax22 = fig.add_subplot(gs[3, 1])
    ax23 = fig.add_subplot(gs[3, 2])
    ax24 = fig.add_subplot(gs[3, 3])

    ax22.set_title("d SSR P/ d iw2")
    ax22.plot(ssr_iw1, 'y')
    
    ax23.set_title("d SSR P/ d b2")
    ax23.plot(ssr_b1, 'y')
    
    ax24.set_title("d SSR P/ d ow2")
    ax24.plot(ssr_ow1, 'y')
    
    plt.tight_layout()
    
    print("iw1=", iw1, "b1=", b1, "ow1=", ow1)
    print("iw2=", iw2, "b2=", b2, "ow2=", ow2)

    
    # there need to be 1 before 000, because Windows doesn't respect alphabetical order
    img_path = (f"./02_images/img_1{step_n:03}_02.jpg") 
    plt.savefig(img_path)
    plt.close()


SSR(A, P)
plot_results(0, A, P)



LR = 0.01 # Learning rate

[0.16666666666666666, 0.3333333333333333, 0.5, 0.6666666666666666, 0.8333333333333334, 1.0, 1.1666666666666667, 1.3333333333333333, 1.5, 1.6666666666666667, 1.8333333333333333, 2.0]
P 12
iw1= -0.25805695865858497 b1= -2.348949561217586 ow1= 4.554277398725461
iw2= -5.343323947499308 b2= 1.9972768560328862 ow2= 0.8535303422715148


In [125]:
for n in range(600):
    step_iw1 = (dSSR_diw1(A, P) * LR)
    step_b1 = (dSSR_db1(A, P) * LR)
    step_ow1 = (dSSR_dow1(A, P) * LR)

    step_iw2 = (dSSR_diw2(A, P) * LR)
    step_b2 = (dSSR_db2(A, P) * LR)
    step_ow2 = (dSSR_dow2(A, P) * LR)
    
    iw1 -= step_iw1
    b1 -= step_b1
    ow1 -= step_ow1
    
    iw2 -= step_iw2
    b2 -= step_b2
    ow2 -= step_ow2

    N1 = calc_N1()
    N2 = calc_N2()
    P = calc_P()

    SSR(A, P)
    plot_results(n+1, A, P)
    print("step", n+1)

iw1= 0.4422409387767306 b1= -2.132424610422213 ow1= 4.558431545882631
iw2= -5.214291138179935 b2= 2.0352553530459074 ow2= 0.8424751404156855
step 1
iw1= -1.917870976777778 b1= -2.309042019633724 ow1= 4.424934454233359
iw2= -5.447276175756526 b2= 2.0202933642464536 ow2= 0.8237450456885723
step 2
iw1= -0.5089058933391493 b1= -1.9774564880556331 ow1= 4.4250324145196664
iw2= -5.184945013303251 b2= 2.081186965236706 ow2= 0.8181494071892682
step 3
iw1= 0.15624192052652808 b1= -1.7920193532527724 ow1= 4.427400593939347
iw2= -5.0633408888310765 b2= 2.1123335032162336 ow2= 0.8000936881058743
step 4
iw1= -1.0697762230537804 b1= -1.8745750729994501 ow1= 4.404695358143198
iw2= -5.23632643829245 b2= 2.098063514586431 ow2= 0.7738766014323127
step 5
iw1= -0.02100562442794618 b1= -1.628803434794854 ow1= 4.405148364021993
iw2= -5.052419137322825 b2= 2.139006612459056 ow2= 0.7583145073191826
step 6
iw1= -0.8099869101505204 b1= -1.6707176569429478 ow1= 4.402983133979208
iw2= -5.178926827112327 b2= 2.1283

step 55
iw1= -1.2456103565672776 b1= 0.062755698254091 ow1= 4.278821774880714
iw2= -5.232854198661973 b2= 2.187197953707656 ow2= -0.4559256322867022
step 56
iw1= -1.8188592836009339 b1= -0.03905600442958003 ow1= 4.270568028867682
iw2= -5.171541844957216 b2= 2.1995499265084386 ow2= -0.48551120571374556
step 57
iw1= -1.2713051041272303 b1= 0.09644082711110509 ow1= 4.2698579519547755
iw2= -5.233513866263542 b2= 2.1848608757859163 ow2= -0.4948931665480467
step 58
iw1= -1.8375195126376558 b1= -0.005461321824169238 ow1= 4.2614151931141
iw2= -5.167649421267305 b2= 2.198256635666224 ow2= -0.5241598747076756
step 59
iw1= -1.2961685420169782 b1= 0.12899840601212875 ow1= 4.260800678957393
iw2= -5.233935514153726 b2= 2.1824555920262885 ow2= -0.5330318657827684
step 60
iw1= -1.8554288847112963 b1= 0.02706562641335672 ow1= 4.252180431466455
iw2= -5.163727150800901 b2= 2.196865487588844 ow2= -0.5619810809904378
step 61
iw1= -1.3202489601761698 b1= 0.16048820225616564 ow1= 4.251664934738097
iw2= -5.23

step 110
iw1= -2.132566381365977 b1= 0.6074164361391371 ow1= 4.024036816343129
iw2= -5.07262990765133 b2= 2.1419598737370564 ow2= -1.2881050081115633
step 111
iw1= -1.7551982990214752 b1= 0.7099352754814348 ow1= 4.025870922024633
iw2= -5.19262655356278 b2= 2.109921112797544 ow2= -1.289058857528482
step 112
iw1= -2.1389515501394216 b1= 0.624209658814772 ow1= 4.015753454362169
iw2= -5.069497495223765 b2= 2.139334186818431 ow2= -1.3101371359461655
step 113
iw1= -1.7681179887334457 b1= 0.7252625170447216 ow1= 4.017646457869533
iw2= -5.189666578384464 b2= 2.107144869437675 ow2= -1.3109612032905076
step 114
iw1= -2.145101768427267 b1= 0.6406546305583593 ow1= 4.007567482265465
iw2= -5.066401025046227 b2= 2.1366996386806067 ow2= -1.3317448542934023
step 115
iw1= -1.7808002510896994 b1= 0.7402299805536956 ow1= 4.009515607529296
iw2= -5.186631133908776 b2= 2.1043909058033146 ow2= -1.3324490247369696
step 116
iw1= -2.1510284439355454 b1= 0.6567622700915361 ow1= 3.999480570926055
iw2= -5.063338687

step 165
iw1= -2.0358112274695768 b1= 1.0268740123210864 ow1= 3.8386425246403477
iw2= -5.092215492743129 b2= 2.0449601377963798 ow2= -1.760654135699389
step 166
iw1= -2.253860455604256 b1= 0.9739612157486509 ow1= 3.8308196444203513
iw2= -4.991802317076878 b2= 2.0706825056068707 ow2= -1.7743929524141846
step 167
iw1= -2.043918583992247 b1= 1.0357642454014648 ow1= 3.833076472255209
iw2= -5.087962881646338 b2= 2.042973794473088 ow2= -1.774281119399623
step 168
iw1= -2.257011283731094 b1= 0.9839621820173197 ow1= 3.8253600680094695
iw2= -4.988915826004248 b2= 2.0683891993786916 ow2= -1.7877897974885464
step 169
iw1= -2.0518841897939954 b1= 1.0445115964798375 ow1= 3.8276002567294793
iw2= -5.083695235797018 b2= 2.041014669523227 ow2= -1.7876935859146474
step 170
iw1= -2.2601353471385375 b1= 0.9937994484821845 ow1= 3.8199898351815884
iw2= -4.986013822788936 b2= 2.0661210478498355 ow2= -1.8009763010190762
step 171
iw1= -2.059709558829435 b1= 1.0531206991875843 ow1= 3.822212634699104
iw2= -5.079

step 220
iw1= -2.336153759300753 b1= 1.1956157720643785 ow1= 3.7106165045077533
iw2= -4.906842164193193 b2= 2.0181436440049207 ow2= -2.0750283524547153
step 221
iw1= -2.213988517241003 b1= 1.2338637204531189 ow1= 3.712538912439923
iw2= -4.97399292909624 b2= 1.9977631833751779 ow2= -2.0751412002578737
step 222
iw1= -2.339268185955213 b1= 1.202221876236105 ow1= 3.707079233194432
iw2= -4.903385908414483 b2= 2.0165771757593056 ow2= -2.0841861068976923
step 223
iw1= -2.218669718376469 b1= 1.2400349317434116 ow1= 3.7090073279342555
iw2= -4.970020061112972 b2= 1.996327598839004 ow2= -2.084284091558725
step 224
iw1= -2.3423960069655845 b1= 1.2087337951644543 ow1= 3.7035953543588938
iw2= -4.899906825801239 b2= 2.015037256340774 ow2= -2.093231519622609
step 225
iw1= -2.223247824444422 b1= 1.2461439342273912 ow1= 3.7055312861836467
iw2= -4.966076806241707 b2= 1.9949042352906037 ow2= -2.093312134611856
step 226
iw1= -2.3455380136069013 b1= 1.2151531595011815 ow1= 3.7001637667894816
iw2= -4.8964048

step 275
iw1= -2.3053591070318364 b1= 1.3834580838076216 ow1= 3.6338942192812644
iw2= -4.880797270277239 b2= 1.9614721105918378 ow2= -2.287444894007268
step 276
iw1= -2.4328326459351293 b1= 1.3489697844283033 ow1= 3.628267914533958
iw2= -4.800010346304308 b2= 1.9841187180067377 ow2= -2.295746859272338
step 277
iw1= -2.3072942756558206 b1= 1.3885284315125188 ow1= 3.6315302029884426
iw2= -4.87813032861221 b2= 1.9601184127808196 ow2= -2.29413316315727
step 278
iw1= -2.436955316740129 b1= 1.353332234159341 ow1= 3.6258177779536807
iw2= -4.795683884242623 b2= 1.9832810301339452 ow2= -2.3024987922367885
step 279
iw1= -2.3091123476001156 b1= 1.3935823679068204 ow1= 3.629195675544812
iw2= -4.875544972128454 b2= 1.958751947185508 ow2= -2.3007496534063354
step 280
iw1= -2.441159957745801 b1= 1.3576193597124588 ow1= 3.623388764409301
iw2= -4.791306771960707 b2= 1.9824699088333575 ow2= -2.3091875844530216
step 281
iw1= -2.310810947843927 b1= 1.3986218373544261 ow1= 3.626889395327029
iw2= -4.8730450

step 330
iw1= -2.58116329932025 b1= 1.4401374820313058 ow1= 3.5630664470673445
iw2= -4.666679813466096 b2= 1.9681050717691568 ow2= -2.4584296748601995
step 331
iw1= -2.3151138503831787 b1= 1.5224918063172748 ow1= 3.572415174191909
iw2= -4.848507569860318 b2= 1.91232805655523 ow2= -2.4499010474180163
step 332
iw1= -2.587358120715745 b1= 1.442674405303359 ow1= 3.560373360184724
iw2= -4.662025762679723 b2= 1.9674212042671715 ow2= -2.4637148706677943
step 333
iw1= -2.3146348142456405 b1= 1.5270908451464666 ow1= 3.570012585449436
iw2= -4.848988353825764 b2= 1.91005175698432 ow2= -2.454860901845066
step 334
iw1= -2.5933711408879594 b1= 1.445201968273699 ow1= 3.5576444720822256
iw2= -4.657581417536363 b2= 1.9666811858714737 ow2= -2.4689440072523987
step 335
iw1= -2.3142640169903106 b1= 1.531593859658556 ow1= 3.567564681206958
iw2= -4.849503445203438 b2= 1.907771884377568 ow2= -2.459776331364632
step 336
iw1= -2.599158636607974 b1= 1.447728677372165 ow1= 3.55488160597817
iw2= -4.6533768012647 

step 385
iw1= -2.3562673456654597 b1= 1.60188579312523 ow1= 3.503787927153877
iw2= -4.839003851564941 b2= 1.869893817059688 ow2= -2.5700432305346173
step 386
iw1= -2.6352604854456887 b1= 1.5174538524205843 ow1= 3.4906799139752347
iw2= -4.634804749219402 b2= 1.9318148189163615 ow2= -2.583979512573909
step 387
iw1= -2.358940318627062 b1= 1.603603953362512 ow1= 3.501806174481008
iw2= -4.837452395340545 b2= 1.8692439445401952 ow2= -2.573861565464741
step 388
iw1= -2.6344297288460963 b1= 1.5201977196796728 ow1= 3.488857454598678
iw2= -4.63539108485613 b2= 1.9305450767466283 ow2= -2.587632729938437
step 389
iw1= -2.3615766663684306 b1= 1.605308940296484 ow1= 3.499895499216798
iw2= -4.835865904616102 b2= 1.8686279605120957 ow2= -2.577628208089759
step 390
iw1= -2.6336453682707104 b1= 1.522905552345938 ow1= 3.487102977475743
iw2= -4.6358997214181406 b2= 1.9293206389849078 ow2= -2.591237779601296
step 391
iw1= -2.364168230821424 b1= 1.6070057653864422 ow1= 3.4980548786385666
iw2= -4.83425232940

step 440
iw1= -2.6465560388747376 b1= 1.5762778531259203 ow1= 3.457521431527856
iw2= -4.616651007911185 b2= 1.9121963475441026 ow2= -2.6694579834215504
step 441
iw1= -2.4037778312542915 b1= 1.6528245171440836 ow1= 3.468295201897545
iw2= -4.802168832427308 b2= 1.8544282071971006 ow2= -2.660026915548176
step 442
iw1= -2.648428618749447 b1= 1.5778321123082897 ow1= 3.456604124022327
iw2= -4.614850152663416 b2= 1.9119014488873423 ow2= -2.672219310510425
step 443
iw1= -2.4042923165959253 b1= 1.6548266144632027 ow1= 3.467463297724511
iw2= -4.801655130407946 b2= 1.8537148382835134 ow2= -2.6627049319099045
step 444
iw1= -2.6503788399448096 b1= 1.5793464660982397 ow1= 3.4556873878664067
iw2= -4.613012445805475 b2= 1.911620776513592 ow2= -2.6749570059139827
step 445
iw1= -2.4047414487528758 b1= 1.6568333765805132 ow1= 3.4666378278051067
iw2= -4.801216087262116 b2= 1.8529808338455305 ow2= -2.6653530294274614
step 446
iw1= -2.6524005511824953 b1= 1.5808220882440285 ow1= 3.4547695988274802
iw2= -4.6

step 495
iw1= -2.4109721562044455 b1= 1.7006307471738766 ow1= 3.4444881788383817
iw2= -4.805787305268817 b2= 1.8320842334215226 ow2= -2.7232983870756446
step 496
iw1= -2.7000780757303833 b1= 1.6104313781342794 ow1= 3.4300696346247657
iw2= -4.57794587703719 b2= 1.902913585365578 ow2= -2.737429126797088
step 497
iw1= -2.4116142360855863 b1= 1.7018391530010928 ow1= 3.4435497398025734
iw2= -4.806049340116056 b2= 1.8314193451835612 ow2= -2.725331121467399
step 498
iw1= -2.700907691590944 b1= 1.6115331070241223 ow1= 3.429107514185972
iw2= -4.5778322206484034 b2= 1.902397593104354 ow2= -2.739468026582166
step 499
iw1= -2.412305041247995 b1= 1.7029980505945015 ow1= 3.442620481570495
iw2= -4.8062799983148015 b2= 1.8307865209789536 ow2= -2.7273430534042338
step 500
iw1= -2.7016277383081326 b1= 1.6126363012050065 ow1= 3.4281637356844468
iw2= -4.577813283914583 b2= 1.9018738819698764 ow2= -2.7414786816207974
step 501
iw1= -2.4130421338657633 b1= 1.70410911903581 ow1= 3.4417027386399663
iw2= -4.806

step 550
iw1= -2.705938234998149 b1= 1.6374406449678642 ow1= 3.4132015899996286
iw2= -4.586050591826592 b2= 1.8917430259570311 ow2= -2.7834435468171153
step 551
iw1= -2.434681551768936 b1= 1.723783925718993 ow1= 3.426562960672063
iw2= -4.805137895576036 b2= 1.8228550068587859 ow2= -2.7716368129055597
step 552
iw1= -2.7063394038175 b1= 1.6381945520266838 ow1= 3.4129188748254986
iw2= -4.586089905100057 b2= 1.8915740824718488 ow2= -2.784840385437548
step 553
iw1= -2.435258747562415 b1= 1.724496614983569 ow1= 3.426291169112322
iw2= -4.805161899763746 b2= 1.8226804847260416 ow2= -2.773026211852426
step 554
iw1= -2.706789574166943 b1= 1.6389244333706035 ow1= 3.412651100225099
iw2= -4.586090253576973 b2= 1.8914232486505191 ow2= -2.7862194092020345
step 555
iw1= -2.435796712037218 b1= 1.725213164276083 ow1= 3.426037915740778
iw2= -4.805216249806103 b2= 1.8225021104893095 ow2= -2.7743939152271935
step 556
iw1= -2.7072885239299023 b1= 1.6396302746659366 ow1= 3.4123969475128013
iw2= -4.5860533355