In [1]:
%matplotlib inline



In [2]:
from matplotlib import pyplot as plt
import numpy as np
from ipywidgets import interact
from IPython.display import clear_output

import lstm

np.set_printoptions(precision=5, suppress=True)

In [11]:
X_DIM = 2
H_DIM = 3
Y_DIM = 1

TIME_STEPS = 3

net = lstm.LSTMNetwork(x_dim=X_DIM, 
                       h_dim=H_DIM, 
                       y_dim=Y_DIM, 
                       time_steps=TIME_STEPS, 
                       act=lstm.sigmoid, 
                       act_p=lstm.sigmoid_p)

LSTM Network version 0.3


In [12]:
def data(n=10):
    num1 = np.random.randint(0, 2, [n])
    num2 = np.random.randint(0, 2, [n])
    targ = np.zeros([n])
    c = 0
    for i in range(n):
        d1 = num1[i]
        d2 = num2[i]
        s = d1 + d2 + c
        p = 0
        if s == 0:
            p = 0
            c = 0
        elif s == 1:
            p = 1
            c = 0
        elif s == 2:
            p = 0
            c = 1
        elif s == 3:
            p = 1
            c = 1
            
        targ[i] = p
        
    return {'num1': num1.astype(np.float), 'num2': num2.astype(np.float), 'targ': targ}

data(n=5)

{'num1': array([ 0.,  1.,  0.,  0.,  1.]),
 'num2': array([ 0.,  0.,  1.,  1.,  0.]),
 'targ': array([ 0.,  1.,  1.,  1.,  1.])}

In [13]:
def test(net, batch=10, n=10, debug=False):
    wrong = 0
    for b in range(batch):
        d = data(n)
        h = np.zeros([1, H_DIM])
        c = np.zeros([1, H_DIM])
        outs = []
        for i in range(n):
            net.ff([np.array([[d['num1'][i], d['num2'][i]]])], h, c)
            out = net.outputs[0].h[0, 0]
            outs.append(out)
            t = d['targ'][i]
            if abs(out - t) > .5:
                wrong += 1
                break
        if debug:
            print('targ: {}'.format(d['targ']))
            print('out: {}'.format(outs))
    print('correct: {}%'.format( (batch-wrong)/float(batch) * 100.0 ))
            
        
test(net)

correct: 0.0%


In [15]:
TS = range(TIME_STEPS)

for i in range(100000):
    d = data(n=TIME_STEPS)
    xs = [np.array([[d['num1'][t], d['num2'][t]]]) for t in TS]

    h0 = np.zeros([1, H_DIM])
    c0 = np.zeros([1, H_DIM])
    
    net.ff(xs, h0, c0)
    
    out = [net.outputs[t].h for t in TS]
    dys = [out[t] - d['targ'][t] for t in TS]
        
    net.bp(dys, learning_rate=0.01)
    
    if i % 10000 == 0:
        clear_output(wait=True)
        print(i)
        test(net, batch=100, n=5)
        for key in net.W:
            if np.max(np.abs(net.W[key])) > 50.0:
                print('weights out of control')
                break

80000
correct: 32.0%


KeyboardInterrupt: 

In [8]:
for key in net.grad:
    print(key)
    print(net.grad[key])

i
[[-0.00001 -0.00143  0.00105]
 [-0.00005 -0.01349  0.0029 ]
 [ 0.00003  0.0143  -0.0002 ]
 [ 0.00019  0.03307 -0.01427]
 [ 0.00012  0.0058  -0.01378]
 [ 0.00019  0.03307 -0.01427]]
y
[[-0.00648]
 [-0.07301]
 [ 0.08577]
 [ 0.16082]]
c
[[ 0.00005  0.00284  0.00201]
 [ 0.0002   0.01211  0.02267]
 [-0.00011 -0.0068  -0.02556]
 [-0.00076 -0.04472 -0.0518 ]
 [-0.00053 -0.03153 -0.00313]
 [-0.00076 -0.04472 -0.0518 ]]
o
[[-0.00061 -0.00546  0.00424]
 [-0.00324 -0.01879  0.0204 ]
 [ 0.00241  0.0065  -0.0136 ]
 [ 0.01049  0.07953 -0.07002]
 [ 0.00586  0.06666 -0.04379]
 [ 0.01049  0.07953 -0.07002]]
f
[[-0.00002 -0.00091  0.00048]
 [-0.0001  -0.00649  0.00601]
 [ 0.00006  0.00603 -0.00697]
 [ 0.00037  0.01804 -0.01325]
 [ 0.00025  0.00651  0.00002]
 [ 0.00037  0.01804 -0.01325]]


In [9]:
for key in net.W:
    print(key)
    print(np.round(net.W[key], 2))

i
[[-0.66 -0.36 -0.3 ]
 [-0.41 -0.64 -0.74]
 [-0.69 -1.15  1.88]
 [-0.13  0.45  4.77]
 [ 0.84  2.35 -0.61]
 [ 2.07  0.02 -2.12]]
y
[[-1.8 ]
 [-3.46]
 [-3.2 ]
 [ 0.15]]
c
[[-1.63 -1.19  0.53]
 [-1.06 -0.3  -0.16]
 [-1.9  -2.12  0.81]
 [-0.74 -0.7   0.33]
 [-1.17 -1.    2.6 ]
 [ 0.13  0.19 -0.33]]
o
[[-1.36 -0.35 -1.38]
 [ 0.74 -0.63  0.74]
 [-2.01  1.32 -1.79]
 [-2.85  1.47 -1.27]
 [-1.46 -0.39 -1.19]
 [ 1.26 -0.66  3.54]]
f
[[-0.07  0.77 -0.31]
 [ 0.42 -0.95  0.65]
 [ 0.16 -0.36 -0.1 ]
 [ 1.17 -0.35  2.05]
 [-0.54  0.67  2.82]
 [ 0.51 -1.85 -1.38]]


In [10]:
test(net, batch=1, n=5, debug=True)

targ: [ 1.  0.  1.  0.  0.]
out: [0.97847817931059511, -0.012380963869431161, 0.97847817931059511, -0.024924521375988551, 0.97847817931059511]
correct: 0.0%
