In [1]:
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')
from common.optimizer import SGD
from common.util import eval_perplexity
from common.trainer import RnnlmTrainer
from grulm import Grulm

from dataset import ptb


In [2]:
batch_size = 20
wordvec_size = 100
hidden_size = 100
time_size = 35
lr = 20.0
max_epoch = 4
max_grad = 0.25
dropout = 0.1

corpus, word_to_id, id_to_word = ptb.load_data('train')
corpus_val, _, _ = ptb.load_data('val')
corpus_test, _, _ = ptb.load_data('test')

In [4]:
vocab_size = len(word_to_id)
xs = corpus[:-1]
ts = corpus[1:]


model = Grulm(vocab_size, wordvec_size, hidden_size)
optimizer = SGD(lr)
trainer = RnnlmTrainer(model, optimizer)

In [5]:
best_ppl = float('inf')
for epoch in range(max_epoch):
    trainer.fit(xs, ts, max_epoch=1, batch_size=batch_size,
                time_size=time_size, max_grad=max_grad)

    model.reset_state()
    ppl = eval_perplexity(model, corpus_val)
    print('valid perplexity: ', ppl)

    if best_ppl > ppl:
        best_ppl = ppl
        model.save_params()
    else:
        lr /= 4.0
        optimizer.lr = lr

    model.reset_state()
    print('-' * 50)

| epoch 1 |  iter 1 / 1327 | time 0[s] | perplexity 9999.32
| epoch 1 |  iter 21 / 1327 | time 6[s] | perplexity 5314.74
| epoch 1 |  iter 41 / 1327 | time 13[s] | perplexity 2583.91
| epoch 1 |  iter 61 / 1327 | time 19[s] | perplexity 1833.87
| epoch 1 |  iter 81 / 1327 | time 24[s] | perplexity 1180.26
| epoch 1 |  iter 101 / 1327 | time 27[s] | perplexity 905.83
| epoch 1 |  iter 121 / 1327 | time 31[s] | perplexity 837.67
| epoch 1 |  iter 141 / 1327 | time 35[s] | perplexity 743.89
| epoch 1 |  iter 161 / 1327 | time 39[s] | perplexity 723.87
| epoch 1 |  iter 181 / 1327 | time 43[s] | perplexity 710.69
| epoch 1 |  iter 201 / 1327 | time 47[s] | perplexity 588.04
| epoch 1 |  iter 221 / 1327 | time 50[s] | perplexity 559.60
| epoch 1 |  iter 241 / 1327 | time 53[s] | perplexity 466.42
| epoch 1 |  iter 261 / 1327 | time 57[s] | perplexity 487.54
| epoch 1 |  iter 281 / 1327 | time 60[s] | perplexity 472.43
| epoch 1 |  iter 301 / 1327 | time 63[s] | perplexity 413.26
| epoch 1 |

In [6]:
model.reset_state()
ppl_test = eval_perplexity(model, corpus_test)
print('test perplexity: ', ppl_test)

evaluating perplexity ...
234 / 235
test perplexity:  157.50564277837282
