In [1]:
%matplotlib inline

In [14]:
from matplotlib import pyplot as plt
import numpy as np
from ipywidgets import interact
from IPython.display import clear_output

import lstm

np.set_printoptions(precision=5, suppress=True)

In [15]:
X_DIM = 2
H_DIM = 5
C_DIM = 5
Y_DIM = 1

TIME_STEPS = 5

net = lstm.LSTMNetwork(x_dim=X_DIM, 
                       h_dim=H_DIM, 
                       y_dim=Y_DIM, 
                       c_dim=C_DIM, 
                       time_steps=TIME_STEPS, 
                       act=None, 
                       act_p=None)

LSTM Network version 0.2


In [16]:
def data(n=10):
    num1 = np.random.randint(0, 2, [n])
    num2 = np.random.randint(0, 2, [n])
    targ = np.zeros([n])
    c = 0
    for i in range(n):
        d1 = num1[i]
        d2 = num2[i]
        s = d1 + d2 + c
        p = 0
        if s == 0:
            p = 0
            c = 0
        elif s == 1:
            p = 1
            c = 0
        elif s == 2:
            p = 0
            c = 1
        elif s == 3:
            p = 1
            c = 1
            
        targ[i] = p
        
    return {'num1': num1.astype(np.float), 'num2': num2.astype(np.float), 'targ': targ}

data(n=5)

{'num1': array([ 1.,  1.,  1.,  0.,  0.]),
 'num2': array([ 1.,  0.,  0.,  0.,  1.]),
 'targ': array([ 0.,  0.,  0.,  1.,  1.])}

In [17]:
def test(net, batch=10, n=10, debug=False):
    wrong = 0
    for b in range(batch):
        d = data(n)
        h = np.zeros([1, H_DIM])
        c = np.zeros([1, C_DIM])
        outs = []
        for i in range(n):
            net.ff([np.array([[d['num1'][i], d['num2'][i]]])], h, c)
            out = net.outputs[0].h[0, 0]
            outs.append(out)
            t = d['targ'][i]
            if abs(out - t) > .5:
                wrong += 1
                break
        if debug:
            print('targ: {}'.format(d['targ']))
            print('out: {}'.format(outs))
    print('correct: {}%'.format( (batch-wrong)/float(batch) * 100.0 ))
            
        
test(net)

correct: 0.0%


In [None]:
TS = range(TIME_STEPS)

for i in range(1000000):
    d = data(n=TIME_STEPS)
    xs = [np.array([[d['num1'][t], d['num2'][t]]]) for t in TS]

    h0 = np.zeros([1, H_DIM])
    c0 = np.zeros([1, C_DIM])
    
    net.ff(xs, h0, c0)
    
    out = [net.outputs[t].h for t in TS]
    dys = [out[t] - d['targ'][t] for t in TS]
    
    net.bp(dys, learning_rate=0.0001)
    
    if i % 100000 == 0:
        clear_output(wait=True)
        print(i)
        test(net, batch=100, n=5)
        for key in net.W:
            if np.max(np.abs(net.W[key])) > 50.0:
                print('weights out of control')
                break

0
correct: 2.0%


In [12]:
for key in net.grad:
    print(key)
    print(net.grad[key])

i
[[ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]]
y
[[ 0.    ]
 [ 0.    ]
 [ 0.    ]
 [ 0.    ]
 [ 0.    ]
 [-2.0982]]
c
[[ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]]
o
[[ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]]
f
[[ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]]


In [13]:
for key in net.W:
    print(key)
    print(np.round(net.W[key], 2))

i
[[     235.62      234.2       235.66      235.56      235.57]
 [    -235.43     -234.28     -236.04     -237.32     -236.34]
 [    -235.48     -233.73     -235.68     -235.71     -235.63]
 [     235.06      233.05      235.72      236.21      235.26]
 [     235.08      233.23      235.75      236.05      235.64]
 [-1654669.29 -1638972.86 -1656006.56 -1659164.54 -1657647.55]
 [   13335.73    12679.66    13391.9     13521.99    13460.11]
 [-1640950.33 -1625927.76 -1642230.14 -1645258.95 -1643799.31]]
y
[[ 0.96]
 [-0.52]
 [ 0.3 ]
 [ 0.83]
 [-2.06]
 [ 0.48]]
c
[[    -235.67      234.38      235.69     -236.12     -236.17]
 [     236.21     -234.2      -235.95      235.93      235.6 ]
 [     235.78     -234.27     -235.83      236.06      235.99]
 [    -235.83      233.65      235.66     -235.72     -235.41]
 [    -235.6       234.4       235.55     -235.71     -236.44]
 [ 1654670.52 -1638972.96 -1656006.68  1659173.    1657646.13]
 [  -13334.6     12676.81    13391.93   -13524.62   -134

In [14]:
test(net, batch=1, n=5, debug=True)

targ: [ 1.  1.  1.  1.  0.]
out: [0.54568293717596394, 0.54568293717596394, 0.54568293717596394, 0.54568293717596394, 0.54568293717596394]
correct: 0.0%
