In [1]:
%matplotlib inline

In [2]:
from matplotlib import pyplot as plt
import numpy as np
from ipywidgets import interact
from IPython.display import clear_output

import lstm

np.set_printoptions(precision=5, suppress=True)

In [3]:
X_DIM = 2
H_DIM = 3
Y_DIM = 1

TIME_STEPS = 5

net = lstm.LSTMNetwork(x_dim=X_DIM, 
                       h_dim=H_DIM, 
                       y_dim=Y_DIM, 
                       time_steps=TIME_STEPS, 
                       act=None, 
                       act_p=None)

LSTM Network version 0.3


In [4]:
def data(n=10):
    num1 = np.random.randint(0, 2, [n])
    num2 = np.random.randint(0, 2, [n])
    targ = np.zeros([n])
    c = 0
    for i in range(n):
        d1 = num1[i]
        d2 = num2[i]
        s = d1 + d2 + c
        p = 0
        if s == 0:
            p = 0
            c = 0
        elif s == 1:
            p = 1
            c = 0
        elif s == 2:
            p = 0
            c = 1
        elif s == 3:
            p = 1
            c = 1
            
        targ[i] = p
        
    return {'num1': num1.astype(np.float), 'num2': num2.astype(np.float), 'targ': targ}

data(n=5)

{'num1': array([ 1.,  1.,  0.,  0.,  1.]),
 'num2': array([ 1.,  0.,  0.,  0.,  0.]),
 'targ': array([ 0.,  0.,  1.,  0.,  1.])}

In [5]:
def test(net, batch=10, n=10, debug=False):
    wrong = 0
    for b in range(batch):
        d = data(n)
        h = np.zeros([1, H_DIM])
        c = np.zeros([1, H_DIM])
        outs = []
        for i in range(n):
            net.ff([np.array([[d['num1'][i], d['num2'][i]]])], h, c)
            out = net.outputs[0].h[0, 0]
            outs.append(out)
            t = d['targ'][i]
            if abs(out - t) > .5:
                wrong += 1
                break
        if debug:
            print('targ: {}'.format(d['targ']))
            print('out: {}'.format(outs))
    print('correct: {}%'.format( (batch-wrong)/float(batch) * 100.0 ))
            
        
test(net)

correct: 0.0%


In [6]:
TS = range(TIME_STEPS)

for i in range(1000000):
    d = data(n=TIME_STEPS)
    xs = [np.array([[d['num1'][t], d['num2'][t]]]) for t in TS]

    h0 = np.zeros([1, H_DIM])
    c0 = np.zeros([1, H_DIM])
    
    net.ff(xs, h0, c0)
    
    out = [net.outputs[t].h for t in TS]
    dys = [out[t] - d['targ'][t] for t in TS]
        
    net.bp(dys, learning_rate=0.0001)
    
    if i % 10000 == 0:
        clear_output(wait=True)
        print(i)
        test(net, batch=100, n=5)
        for key in net.W:
            if np.max(np.abs(net.W[key])) > 50.0:
                print('weights out of control')
                break

20000
correct: 6.0%


KeyboardInterrupt: 

In [7]:
for key in net.grad:
    print(key)
    print(net.grad[key])

i
[[-0.00793 -0.00051 -0.0143 ]
 [-0.00084  0.0023  -0.00482]
 [-0.00903 -0.00283  0.00648]
 [-0.03829  0.00538 -0.03374]
 [-0.01391 -0.01334 -0.01574]
 [-0.10209 -0.0219  -0.00932]]
y
[[ 0.17377]
 [-0.03495]
 [-0.04651]
 [ 0.63749]]
c
[[-0.00233 -0.00533  0.0161 ]
 [-0.00043  0.00022  0.00281]
 [-0.00403  0.00175  0.00375]
 [-0.01305  0.00423  0.02744]
 [-0.00142 -0.01386  0.01824]
 [-0.06521 -0.0025   0.06935]]
o
[[-0.00267  0.00647 -0.01501]
 [-0.00088  0.00292 -0.0001 ]
 [-0.01488 -0.00058 -0.01359]
 [-0.07886  0.04985 -0.1069 ]
 [-0.01183  0.01298 -0.05892]
 [-0.0646  -0.0142   0.00313]]
f
[[-0.00815 -0.00324  0.00702]
 [-0.00025 -0.00228  0.00522]
 [-0.00678  0.00452 -0.01492]
 [-0.02663  0.00064 -0.0057 ]
 [-0.01729  0.00533 -0.00825]
 [-0.01821 -0.00624  0.01604]]


In [8]:
for key in net.W:
    print(key)
    print(np.round(net.W[key], 2))

i
[[-0.86  0.55 -0.44]
 [-0.35 -0.04  0.69]
 [-0.03  0.79  0.75]
 [-0.32  0.45  1.07]
 [-0.03  0.5   0.04]
 [ 0.02 -0.16 -0.44]]
y
[[-0.25]
 [ 0.11]
 [ 0.45]
 [ 0.66]]
c
[[ 0.78  0.8  -0.71]
 [ 0.43  0.07 -0.62]
 [-0.99  0.77 -0.72]
 [ 0.93  0.78  0.02]
 [ 0.78  0.92 -1.03]
 [ 0.9  -0.37  0.27]]
o
[[-0.84 -0.51 -0.18]
 [ 0.98  0.96  0.02]
 [ 0.75  0.35 -0.8 ]
 [ 0.12 -0.29  0.86]
 [ 0.1  -0.14  0.31]
 [ 0.85 -0.37 -0.48]]
f
[[ 0.61  0.85  0.69]
 [-0.3   0.95 -0.1 ]
 [ 0.64 -0.03  0.12]
 [-0.81  0.56  0.8 ]
 [ 0.73 -0.77  0.71]
 [-0.18 -0.69  0.27]]


In [9]:
test(net, batch=1, n=5, debug=True)

targ: [ 0.  1.  0.  1.  0.]
out: [0.61103804266529749]
correct: 0.0%
