In [1]:
from tools.numerical_gradient import *
from models.layers import *
from models.networks.vanilla_rnn import *
from models.networks.lstm_rnn import *
from models.solver.solver import *
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import h5py
import nltk
import re
import pickle

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

In [2]:
#############
# Constants #
#############

word_sequence_dest = "word_sequence.hdf5"
word_mapping_dest = "word_map.pkl"
idx_mapping_dest = "idx_map.pkl"
word_dataset_dest = "word_dataset.hdf5"
cache_model_dest = "cache_model.pkl"
seq_len = 10

delims = ' |\t|\r\n'

In [3]:
# Read in the data of X_all
with h5py.File(word_dataset_dest, 'r') as f:
    bible = f["bible"][:]
with open(idx_mapping_dest, 'r') as f:
    idx_mapping = pickle.load(f)
max_idx = np.max(bible)

In [4]:
### Hyperparameters ###
N,T = bible.shape
num_samples = 256
num_words = max_idx+1
time_dim = 10
hidden_dim = 100
word_vec_dim = 100
print num_words
curPtr = 0

rnn = LSTM_RNN(num_samples, num_words, time_dim, hidden_dim, word_vec_dim, l2_lambda=0.0001) # activate regularization
solver = Solver({"learning_rate" : 1e-2, "type" : "adam",
                 "beta1" : 0.9, "beta2" : 0.99}) # adagrad/adam can sustain higher learning rates

11648


In [None]:
# Run the loss function on the neural network with the parameters: y, h0
history = []
print_every = 50

for i in xrange(5000):
    if curPtr >= N-num_samples-1:
        curPtr = 0
    if curPtr == 0:
        h0 = np.zeros((num_samples, hidden_dim))
    loss, l, h0 = rnn.loss(bible[curPtr:curPtr+num_samples, :], bible[curPtr+1:curPtr+num_samples+1, :], h0)
    if i % print_every == 0:
        print "loss at epoch ", i, ": ", loss
    solver.step(i, loss, l)
    curPtr += num_samples
    
grad_descent_plot = plt.plot(*solver.get_loss_history())
plt.setp(grad_descent_plot, 'color', 'r', 'linewidth', 2.0)
plt.show()