In [10]:
from tools.numerical_gradient import *
from models.layers import *
from models.networks.vanilla_rnn import *
from models.networks.lstm_rnn import *
from models.solver.solver import *
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import h5py
import nltk
import re
import pickle

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

In [11]:
#############
# Constants #
#############

word_sequence_dest = "word_sequence.hdf5"
word_mapping_dest = "word_map.pkl"
idx_mapping_dest = "idx_map.pkl"
word_dataset_dest = "word_dataset.hdf5"
cache_model_dest = "cache_model.pkl"
seq_len = 10

delims = ' |\t|\r\n'

In [12]:
# Read in the data of X_all
with h5py.File(word_dataset_dest, 'r') as f:
    bible = f["bible"][:]
with open(idx_mapping_dest, 'r') as f:
    idx_mapping = pickle.load(f)
max_idx = np.max(bible)

In [13]:
### Hyperparameters ###
N,T = bible.shape
num_samples = 256
num_words = max_idx+1
time_dim = 10
hidden_dim = 200
word_vec_dim = 200
print num_words
curPtr = 0

rnn = LSTM_RNN(num_samples, num_words, time_dim, hidden_dim, word_vec_dim, l2_lambda=0.0001) # activate regularization
solver = Solver({"learning_rate" : 5e-3, "type" : "adam",
                 "beta1" : 0.9, "beta2" : 0.99}) # adagrad/adam can sustain higher learning rates

4094


In [8]:
def populate_prediction(d, loss, seed=860, seq_len=100):
    seq = rnn.predict(seed, seq_len=seq_len) # We fed in "God" as the first token.
    words = [idx_mapping[i] for i in seq]
    d[loss] = " ".join(words)

In [None]:
# Run the loss function on the neural network with the parameters: y, h0
history = []
print_every = 50
best_weights = None
min_loss = None
prediction_history = {}

for i in xrange(1000):
    if curPtr >= N-num_samples-1:
        curPtr = 0
    if curPtr == 0:
        h0 = np.zeros((num_samples, hidden_dim))
    loss, l, h0 = rnn.loss(bible[curPtr:curPtr+num_samples, :], bible[curPtr+1:curPtr+num_samples+1, :], h0)
    if min_loss is None or loss < min_loss:
        min_loss = loss
        best_weights = l
        populate_prediction(prediction_history, int(loss))
    if i % print_every == 0:
        print "loss at epoch ", i, ": ", loss
    solver.step(i, loss, l)
    curPtr += num_samples
    
grad_descent_plot = plt.plot(*solver.get_loss_history())
plt.setp(grad_descent_plot, 'color', 'r', 'linewidth', 2.0)
plt.show()

loss at epoch  0 :  83.4880585441
loss at epoch  50 :  53.9052425993
loss at epoch  100 :  51.0988528886
loss at epoch  150 :  45.8932644176
loss at epoch  200 :  43.3796779137
loss at epoch  250 :  51.7494907532
loss at epoch  300 :  40.6747315017
loss at epoch  350 :  37.5778213771
loss at epoch  400 :  39.7500726394
loss at epoch  450 :  42.7446731187
loss at epoch  500 :  38.6105970869
loss at epoch  550 :  40.6156264635
loss at epoch  600 :  50.6368125474
loss at epoch  650 :  41.1233254759
loss at epoch  700 :  50.6872332572
loss at epoch  750 :  40.5473983634
loss at epoch  800 :  40.0201072283
loss at epoch  850 :  44.4298524655
loss at epoch 

In [41]:
with h5py.File("words", 'w') as f:
    f.create_dataset('words', data=l[0][0])
with h5py.File("W_xh", 'w') as f:
    f.create_dataset("W_xh", data=l[1][0])
with h5py.File("W_hh", 'w') as f:
    f.create_dataset("W_hh", data=l[2][0])
with h5py.File("W_hy", 'w') as f:
    f.create_dataset("W_hy", data=l[3][0])
with h5py.File("b_affine", 'w') as f:
    f.create_dataset("b_affine", data=l[4][0])
with h5py.File("b_rnn", 'w') as f:
    f.create_dataset("b_rnn", data=l[5][0])

In [15]:
# Only call this cell when you have cached in the states previously!
with h5py.File("words", 'r') as f:
    rnn.params["words"] = f["words"][:]
with h5py.File("W_xh", 'r') as f:
    rnn.params["W_xh"] = f["W_xh"][:]
with h5py.File("W_hh", 'r') as f:
    rnn.params["W_hh"] = f["W_hh"][:]
with h5py.File("W_hy", 'r') as f:
    rnn.params["W_hy"] = f["W_hy"][:]
with h5py.File("b_affine", 'r') as f:
    rnn.params["b_affine"] = f["b_affine"][:]
with h5py.File("b_rnn", 'r') as f:
    rnn.params["b_rnn"] = f["b_rnn"][:]

In [20]:
seq = rnn.predict(852, seq_len=150) # We fed in "God" as the first token.
print seq
words = [idx_mapping[i] for i in seq]
print " ".join(words)

[1924, 778, 1251, 289, 1511, 778, 2500, 1924, 2004, 3793, 2091, 169, 289, 1307, 670, 1713, 3527, 289, 1511, 2069, 778, 11, 1924, 778, 3936, 1924, 2288, 670, 3751, 3527, 1088, 289, 1511, 134, 2629, 860, 2614, 758, 670, 3751, 778, 2500, 1924, 2288, 807, 986, 379, 2091, 1852, 725, 2253, 3167, 1511, 359, 3400, 2931, 725, 2836, 670, 1901, 1813, 778, 3488, 517, 3440, 1432, 2850, 303, 289, 534, 289, 722, 3167, 1511, 2614, 3377, 505, 3569, 725, 778, 3716, 1924, 670, 3751, 4005, 2730, 2629, 3226, 3716, 289, 1511, 4005, 1528, 778, 2860, 289, 1475, 3984, 266, 2614, 758, 289, 1511, 533, 2091, 2836, 289, 2964, 2253, 778, 535, 670, 3751, 778, 3137, 1088, 1475, 3558, 929, 2450, 1511, 30, 778, 907, 1924, 1712, 1666, 2660, 238, 3131, 3751, 3199, 2730, 2851, 1924, 315, 289, 3097, 100, 1924, 537, 929, 1985, 303, 289, 1511, 2931, 1178, 289, 2069]
of the sword , and the king of Judah began to reign , years . But Elijah , and all the captain of the host of Israel . And Elijah did , and brought him for his h

In [9]:
def print_log(key, prediction_history):
    print "At loss : ", key, ", we have the following phrase : ", prediction_history[key]

print prediction_history.keys()

[55.585483596433647, 82.890960913599471, 54.927192808938784, 66.616409983764513, 66.985596593904816, 83.026719080666012, 75.92157852744927, 80.322396619154148, 60.641278026358215, 80.582223620789733, 76.727845969987897, 73.239770776682093, 56.15846778702138, 81.556825835363483, 79.077395318832998, 82.623797348416602, 83.284525824614363, 82.346491818891494, 65.432689302988635, 81.956737227702774, 54.988617951563924, 71.154392678780553, 53.269048975791179, 68.979245174516564, 83.166705092144497, 53.238245751484619, 57.52535908187361, 82.762777225135849, 64.117406801376291, 75.10812997404571, 81.672812487360858, 56.506015375014222, 83.383003060934072, 55.268718707100014]
