In [1]:
import sys
import time
import tensorflow as tf
import numpy as np
sys.path.insert(0, '../')
%load_ext autoreload
%autoreload 2

In [2]:
from seqmodel.bunch import Bunch
from seqmodel.experiment.policy_agent import PolicyAgent
from seqmodel import model
from seqmodel import data

In [3]:
vocab = data.Vocabulary.from_vocab_file('../data/tiny_copy/vocab.txt')
valid_iter = data.Seq2SeqIterator(vocab, vocab)
valid_iter.initialize('../data/tiny_copy/valid.txt')
train_iter = data.Seq2SeqIterator(vocab, vocab)
train_iter.initialize('../data/tiny_copy/train.txt')

In [4]:
tf.reset_default_graph()
agent_opt = PolicyAgent.default_opt()
emb_opt = agent_opt.policy_model.model_opt.embedding
dec_opt = agent_opt.policy_model.model_opt.decoder
enc_opt = agent_opt.policy_model.model_opt.encoder
optim_opt = agent_opt.optim

emb_opt.decoder_dim = 32
emb_opt.encoder_dim = 32

dec_opt.rnn_opt.rnn_cell.cell_opt.num_units = 32
enc_opt.rnn_opt.rnn_cell.cell_opt.num_units = 32

optim_opt.learning_rate = 0.3
optim_opt.name = 'GradientDescentOptimizer'
optim_opt.max_epochs = 5

sess_config = tf.ConfigProto(device_count={'GPU': 0})
sess = tf.Session(config = sess_config)

agent = PolicyAgent(agent_opt, sess)
agent.initialize_model(with_training=True)
agent.initialize_optim()
for v in tf.trainable_variables():
    print('{}, {}'.format(v.name, v.get_shape()))
sess.run(tf.global_variables_initializer())
agent.train(train_iter, 20, valid_iter, 20, verbose=True)
info = agent.evaluate(valid_iter, 20)
print("PPL: {}, time: {}".format(
    info.eval_cost/info.num_tokens, info.end_time - info.start_time))

[36m[INFO ][0mep: 0, lr: 0.300000


policy_agent/policy/model/encoder_embedding:0, (15, 32)
policy_agent/policy/model/decoder_embedding:0, (15, 32)
policy_agent/policy/model/encoder_rnn/rnn/basic_lstm_cell/weights:0, (64, 128)
policy_agent/policy/model/encoder_rnn/rnn/basic_lstm_cell/biases:0, (128,)
policy_agent/policy/model/decoder_rnn/rnn/basic_lstm_cell/weights:0, (64, 128)
policy_agent/policy/model/decoder_rnn/rnn/basic_lstm_cell/biases:0, (128,)
policy_agent/policy/model/decoder_rnn/logit_w:0, (15, 32)
policy_agent/policy/model/decoder_rnn/logit_b:0, (15,)


[36m[INFO ][0mtrain: @499 tr_loss: 8.62677, eval_loss: 1.31262, wps: 15594.6
[36m[INFO ][0mvalid: @49 eval_loss: 0.46201, wps: 34065.7
[36m[INFO ][0mep: 1, lr: 0.300000
[36m[INFO ][0mtrain: @499 tr_loss: 2.04577, eval_loss: 0.30868, wps: 15741.2
[36m[INFO ][0mvalid: @49 eval_loss: 0.22965, wps: 38115.0
[36m[INFO ][0mep: 2, lr: 0.300000
[36m[INFO ][0mtrain: @499 tr_loss: 1.11693, eval_loss: 0.16864, wps: 15802.6
[36m[INFO ][0mvalid: @49 eval_loss: 0.12822, wps: 36624.0
[36m[INFO ][0mep: 3, lr: 0.300000
[36m[INFO ][0mtrain: @499 tr_loss: 0.73904, eval_loss: 0.11146, wps: 15517.7
[36m[INFO ][0mvalid: @49 eval_loss: 0.08695, wps: 37627.9
[36m[INFO ][0mep: 4, lr: 0.300000
[36m[INFO ][0mtrain: @499 tr_loss: 0.56944, eval_loss: 0.08575, wps: 15671.8
[36m[INFO ][0mvalid: @49 eval_loss: 0.07734, wps: 37723.6


PPL: 0.0773383819605, time: 0.175106048584


In [5]:
valid_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.EACH_MATCH)
valid_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_env)
print(info.eval_loss)

valid_hard_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.ALL_MATCH)
valid_hard_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_hard_env)
print(info.eval_loss)

0.963726984127
0.867


In [6]:
agent.reset_training_state()
train_env = data.env.CopyEnv(train_iter, re_init=False, reward_mode=data.env.ToyRewardMode.ALL_MATCH)
info = agent.policy_gradient(train_env, 20, valid_hard_env, 20)

[36m[INFO ][0mep: 0, lr: 0.300000
[36m[INFO ][0mtrain: @500 tr_loss: 0.10987, base_loss: 0.00000, avg_return: 0.81570, wps: 895.5
[36m[INFO ][0mvalid: @50 avg_return: 0.87500, wps: 1340.7
[36m[INFO ][0mep: 1, lr: 0.300000
[36m[INFO ][0mtrain: @500 tr_loss: 0.09725, base_loss: 0.00000, avg_return: 0.82690, wps: 903.4
[36m[INFO ][0mvalid: @50 avg_return: 0.87200, wps: 1346.8
[36m[INFO ][0mep: 2, lr: 0.300000
[36m[INFO ][0mtrain: @500 tr_loss: 0.08960, base_loss: 0.00000, avg_return: 0.83360, wps: 893.8
[36m[INFO ][0mvalid: @50 avg_return: 0.86100, wps: 1398.6
[36m[INFO ][0mep: 3, lr: 0.300000
[36m[INFO ][0mtrain: @500 tr_loss: 0.08656, base_loss: 0.00000, avg_return: 0.83840, wps: 897.5
[36m[INFO ][0mvalid: @50 avg_return: 0.90600, wps: 1379.2
[36m[INFO ][0mep: 4, lr: 0.300000
[36m[INFO ][0mtrain: @500 tr_loss: 0.08011, base_loss: 0.00000, avg_return: 0.84730, wps: 896.9
[36m[INFO ][0mvalid: @50 avg_return: 0.90300, wps: 1425.8


In [7]:
valid_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.EACH_MATCH)
valid_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_env)
print(info.eval_loss)

valid_hard_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.ALL_MATCH)
valid_hard_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_hard_env)
print(info.eval_loss)

0.9764
0.903
