In [1]:
import sys
import time
import tensorflow as tf
import numpy as np
sys.path.insert(0, '../')
%load_ext autoreload
%autoreload 2

In [10]:
from seqmodel.bunch import Bunch
from seqmodel.experiment.basic_agent import BasicAgent
from seqmodel import model
from seqmodel import data

In [11]:
vocab = data.Vocabulary.from_vocab_file('../data/tiny_copy/vocab.txt')
valid_iter = data.Seq2SeqIterator(vocab, vocab)
valid_iter.initialize('../data/tiny_copy/valid.txt')

In [13]:
tf.reset_default_graph()
agent_opt = BasicAgent.default_opt()
emb_opt = agent_opt.model.model_opt.embedding
dec_opt = agent_opt.model.model_opt.decoder
enc_opt = agent_opt.model.model_opt.encoder
optim_opt = agent_opt.optim

emb_opt.decoder_dim = 64
emb_opt.encoder_dim = 64

dec_opt.rnn_opt.rnn_cell.cell_opt.num_units = 64
enc_opt.rnn_opt.rnn_cell.cell_opt.num_units = 64

optim_opt.learning_rate = 0.3
optim_opt.name = 'GradientDescentOptimizer'

sess = tf.Session()
agent = BasicAgent(agent_opt, sess)
agent.initialize_model(with_training=True)
agent.initialize_optim()
for v in tf.trainable_variables():
    print('{}, {}'.format(v.name, v.get_shape()))
sess.run(tf.global_variables_initializer())
agent.train(valid_iter, 10, valid_iter, 10, verbose=True)
info = agent.evaluate(valid_iter, 10)
print("PPL: {}, time: {}".format(
    info.eval_cost/info.num_tokens, info.end_time - info.start_time))

[36m[INFO ][0mep: 0, lr: 0.300000


basic_agent/model/encoder_embedding:0, (15, 64)
basic_agent/model/decoder_embedding:0, (15, 64)
basic_agent/model/encoder_rnn/rnn/basic_lstm_cell/weights:0, (128, 256)
basic_agent/model/encoder_rnn/rnn/basic_lstm_cell/biases:0, (256,)
basic_agent/model/decoder_rnn/rnn/basic_lstm_cell/weights:0, (128, 256)
basic_agent/model/decoder_rnn/rnn/basic_lstm_cell/biases:0, (256,)
basic_agent/model/decoder_rnn/logit_w:0, (15, 64)
basic_agent/model/decoder_rnn/logit_b:0, (15,)


[36m[INFO ][0mtrain: @99 tr_loss: 15.19742, eval_loss: 2.28316, wps: 2723.4
[36m[INFO ][0mvalid: @99 tr_loss: 0.00000, eval_loss: 1.94776, wps: 7716.3
[36m[INFO ][0mep: 1, lr: 0.300000
[36m[INFO ][0mtrain: @99 tr_loss: 11.86525, eval_loss: 1.78379, wps: 2897.9
[36m[INFO ][0mvalid: @99 tr_loss: 0.00000, eval_loss: 1.52662, wps: 7527.5
[36m[INFO ][0mep: 2, lr: 0.300000
[36m[INFO ][0mtrain: @99 tr_loss: 8.48660, eval_loss: 1.26395, wps: 2818.2
[36m[INFO ][0mvalid: @99 tr_loss: 0.00000, eval_loss: 0.89669, wps: 8077.1
[36m[INFO ][0mep: 3, lr: 0.300000
[36m[INFO ][0mtrain: @99 tr_loss: 5.59850, eval_loss: 0.82901, wps: 2779.4
[36m[INFO ][0mvalid: @99 tr_loss: 0.00000, eval_loss: 0.70110, wps: 7763.8
[36m[INFO ][0mep: 4, lr: 0.300000
[36m[INFO ][0mtrain: @99 tr_loss: 3.82389, eval_loss: 0.56853, wps: 2962.7
[36m[INFO ][0mvalid: @99 tr_loss: 0.00000, eval_loss: 0.42430, wps: 7859.5
[36m[INFO ][0mep: 5, lr: 0.300000
[36m[INFO ][0mtrain: @99 tr_loss: 2.79611, eva

PPL: 0.142930377801, time: 0.827893018723


In [14]:
test_data = ([['d e f', '']])
test_iter = data.Seq2SeqIterator(vocab, vocab)
test_iter.initialize(test_data)
test_iter.init_batch(1)
env = data.environment.Env(test_iter, re_init=False)

In [15]:
res, trans = agent.sample(env, greedy=True)

In [16]:
res

[SampleOutputTuple(batch=Seq2SeqBatchTuple(features=Seq2SeqFeatureTuple(encoder_input=array([[ 1],
       [ 8],
       [ 9],
       [10],
       [ 3]], dtype=int32), encoder_seq_len=array([5], dtype=int32), decoder_input=array([[2]], dtype=int32), decoder_seq_len=array([1], dtype=int32)), labels=Seq2SeqLabelTuple(decoder_label=array([[0]], dtype=int32), decoder_label_weight=array([[ 1.]], dtype=float32), decoder_seq_label=array([ 1.], dtype=float32)), num_tokens=1.0), samples=[array([[ 8],
       [ 9],
       [10],
       [ 0]])], scores=[array([[ 0.99131376],
       [ 0.9912107 ],
       [ 0.99973089],
       [ 0.99922585]], dtype=float32)])]

In [17]:
res[0].batch.features

Seq2SeqFeatureTuple(encoder_input=array([[ 1],
       [ 8],
       [ 9],
       [10],
       [ 3]], dtype=int32), encoder_seq_len=array([5], dtype=int32), decoder_input=array([[2]], dtype=int32), decoder_seq_len=array([1], dtype=int32))

In [18]:
res[0].samples

[array([[ 8],
        [ 9],
        [10],
        [ 0]])]

In [19]:
res[0].scores

[array([[ 0.99131376],
        [ 0.9912107 ],
        [ 0.99973089],
        [ 0.99922585]], dtype=float32)]