In [1]:
%matplotlib inline

In [2]:
from matplotlib import pyplot as plt
import numpy as np
from ipywidgets import interact
from IPython.display import clear_output

import lstm

np.set_printoptions(precision=5, suppress=True)

In [136]:
X_DIM = 2
H_DIM = 3
Y_DIM = 1

TIME_STEPS = 5

net = lstm.LSTMNetwork(x_dim=X_DIM, 
                       h_dim=H_DIM, 
                       y_dim=Y_DIM, 
                       time_steps=TIME_STEPS, 
                       act=lstm.sigmoid, 
                       act_p=lstm.sigmoid_p)

LSTM Network version 0.3


In [137]:
def data(n=10):
    num1 = np.random.randint(0, 2, [n])
    num2 = np.random.randint(0, 2, [n])
    targ = np.zeros([n])
    c = 0
    for i in range(n):
        d1 = num1[i]
        d2 = num2[i]
        s = d1 + d2 + c
        p = 0
        if s == 0:
            p = 0
            c = 0
        elif s == 1:
            p = 1
            c = 0
        elif s == 2:
            p = 0
            c = 1
        elif s == 3:
            p = 1
            c = 1
            
        targ[i] = p
        
    return {'num1': num1.astype(np.float), 'num2': num2.astype(np.float), 'targ': targ}

data(n=TIME_STEPS)

{'num1': array([ 1.,  1.,  0.,  1.,  1.]),
 'num2': array([ 1.,  0.,  0.,  0.,  0.]),
 'targ': array([ 0.,  0.,  1.,  1.,  1.])}

In [138]:
def test(net, batch=10, n=10, debug=False):
    wrong = 0
    for b in range(batch):
        d = data(n)
        h = np.zeros([1, H_DIM])
        c = np.zeros([1, H_DIM])
        outs = []
        for i in range(n):
            net.ff([np.array([[d['num1'][i], d['num2'][i]]])], h, c)
            out = net.outputs[0].h[0, 0]
            h = net.units[0].h
            c = net.units[0].c
            outs.append(out)
            t = d['targ'][i]
            if abs(out - t) > .5:
                wrong += 1
                break
        if debug:
            print('targ: {}'.format(d['targ']))
            print('out: {}'.format(outs))
    print('correct: {}%'.format( (batch-wrong)/float(batch) * 100.0 ))
            
        
test(net)

correct: 0.0%


In [154]:
TS = range(TIME_STEPS)

for i in range(100000):
    d = data(n=TIME_STEPS)
    xs = [np.array([[d['num1'][t], d['num2'][t]]]) for t in TS]

    h0 = np.zeros([1, H_DIM])
    c0 = np.zeros([1, H_DIM])
    
    net.ff(xs, h0, c0)
    
    out = [net.outputs[t].h for t in TS]
    dys = [out[t] - d['targ'][t] for t in TS]
        
    net.bp(dys, learning_rate=0.05)
    
    if i % 10000 == 0:
        clear_output(wait=True)
        print(i)
        test(net, batch=100, n=10)
        for key in net.W:
            if np.max(np.abs(net.W[key])) > 50.0:
                print('weights out of control')
                break

90000
correct: 100.0%


In [143]:
for key in net.grad:
    print(key)
    print(net.grad[key])

i
[[-0.00005  0.00004  0.00005]
 [ 0.00002  0.00001 -0.00003]
 [-0.00005  0.00003  0.00003]
 [ 0.00004  0.00001 -0.00007]
 [ 0.00004 -0.00007 -0.00007]
 [-0.00024  0.00017 -0.00024]]
y
[[-0.00065]
 [-0.00042]
 [ 0.00115]
 [ 0.00168]]
c
[[ 0.00041 -0.00022  0.00019]
 [ 0.00019  0.00026 -0.00027]
 [-0.00021 -0.00005  0.00008]
 [-0.00052  0.00057 -0.00057]
 [-0.00004  0.00029 -0.00005]
 [-0.00073  0.00058 -0.00062]]
o
[[-0.00008  0.00008  0.00021]
 [ 0.00001  0.00003  0.0001 ]
 [-0.00005  0.00007 -0.00013]
 [ 0.00006  0.00003 -0.00027]
 [ 0.00009 -0.00012 -0.00006]
 [-0.00027  0.00034 -0.00073]]
f
[[-0.00033 -0.00012 -0.00015]
 [ 0.00012  0.0001   0.0001 ]
 [-0.00003 -0.00007 -0.00017]
 [ 0.00053  0.00019  0.00013]
 [ 0.00032  0.00018  0.00021]
 [ 0.00053  0.00019  0.00013]]


In [144]:
for key in net.W:
    print(key)
    print(np.round(net.W[key], 2))

i
[[-0.31 -0.59  0.11]
 [ 0.38 -1.09  0.77]
 [-0.05  0.61 -0.82]
 [ 1.09  0.55  3.05]
 [ 1.51  1.33 -1.99]
 [ 0.75  1.75  2.61]]
y
[[ -5.72]
 [  4.73]
 [-10.89]
 [  2.88]]
c
[[-0.8   0.51 -1.42]
 [-0.9   1.7  -1.62]
 [ 0.12 -0.93  0.75]
 [-2.55 -2.81 -2.78]
 [-1.49 -2.97  3.21]
 [ 1.41  3.29  2.7 ]]
o
[[-0.51 -0.68  1.46]
 [ 0.2  -1.37  5.98]
 [ 0.22  0.87 -2.45]
 [ 1.12  0.3   5.18]
 [ 1.57  1.87 -5.34]
 [ 0.93  1.41  2.05]]
f
[[-0.25 -0.32  0.31]
 [-0.    0.5   0.3 ]
 [ 0.36  0.23  0.95]
 [-0.05 -2.67  0.16]
 [-0.32  1.88 -0.46]
 [-0.91 -1.54 -1.34]]


In [153]:
test(net, batch=100, n=100, debug=False)

correct: 100.0%
