In [None]:
%matplotlib inline

In [None]:
from matplotlib import pyplot as plt
import numpy as np
from ipywidgets import interact
from IPython.display import clear_output

import lstm

np.set_printoptions(precision=5, suppress=True)

In [None]:
X_DIM = 2
H_DIM = 3
C_DIM = H_DIM
Y_DIM = 1

TIME_STEPS = 5

net = lstm.LSTMNetwork(x_dim=X_DIM, 
                       h_dim=H_DIM, 
                       y_dim=Y_DIM, 
                       c_dim=C_DIM, 
                       time_steps=TIME_STEPS, 
                       act=None, 
                       act_p=None)

In [None]:
def data(n=10):
    num1 = np.random.randint(0, 2, [n])
    num2 = np.random.randint(0, 2, [n])
    targ = np.zeros([n])
    c = 0
    for i in range(n):
        d1 = num1[i]
        d2 = num2[i]
        s = d1 + d2 + c
        p = 0
        if s == 0:
            p = 0
            c = 0
        elif s == 1:
            p = 1
            c = 0
        elif s == 2:
            p = 0
            c = 1
        elif s == 3:
            p = 1
            c = 1
            
        targ[i] = p
        
    return {'num1': num1.astype(np.float), 'num2': num2.astype(np.float), 'targ': targ}

data(n=5)

In [None]:
def test(net, batch=10, n=10, debug=False):
    wrong = 0
    for b in range(batch):
        d = data(n)
        h = np.zeros([1, H_DIM])
        c = np.zeros([1, C_DIM])
        outs = []
        for i in range(n):
            net.ff([np.array([[d['num1'][i], d['num2'][i]]])], h, c)
            out = net.outputs[0].h[0, 0]
            outs.append(out)
            t = d['targ'][i]
            if abs(out - t) > .5:
                wrong += 1
                break
        if debug:
            print('targ: {}'.format(d['targ']))
            print('out: {}'.format(outs))
    print('correct: {}%'.format( (batch-wrong)/float(batch) * 100.0 ))
            
        
test(net)

In [None]:
TS = range(TIME_STEPS)

for i in range(1000000):
    d = data(n=TIME_STEPS)
    xs = [np.array([[d['num1'][t], d['num2'][t]]]) for t in TS]

    h0 = np.zeros([1, H_DIM])
    c0 = np.zeros([1, C_DIM])
    
    net.ff(xs, h0, c0)
    
    out = [net.outputs[t].h for t in TS]
    dys = [out[t] - d['targ'][t] for t in TS]
        
    net.bp(dys, learning_rate=0.0001)
    
    if i % 10000 == 0:
        clear_output(wait=True)
        print(i)
        test(net, batch=100, n=5)
        for key in net.W:
            if np.max(np.abs(net.W[key])) > 50.0:
                print('weights out of control')
                break

In [None]:
for key in net.grad:
    print(key)
    print(net.grad[key])

In [None]:
for key in net.W:
    print(key)
    print(np.round(net.W[key], 2))

In [None]:
test(net, batch=1, n=5, debug=True)