In [1]:
import sys
import time
import tensorflow as tf
import numpy as np
sys.path.insert(0, '../')
%load_ext autoreload
%autoreload 2

In [2]:
from seqmodel.bunch import Bunch
from seqmodel.experiment.policy_agent import ActorCriticAgent
from seqmodel import model
from seqmodel import data

In [3]:
vocab = data.Vocabulary.from_vocab_file('../data/tiny_copy/vocab.txt')
valid_iter = data.Seq2SeqIterator(vocab, vocab)
valid_iter.initialize('../data/tiny_copy/valid.txt')
train_iter = data.Seq2SeqIterator(vocab, vocab)
train_iter.initialize('../data/tiny_copy/train.txt')

In [4]:
tf.reset_default_graph()
agent_opt = ActorCriticAgent.default_opt()
emb_opt = agent_opt.policy_model.model_opt.embedding
dec_opt = agent_opt.policy_model.model_opt.decoder
enc_opt = agent_opt.policy_model.model_opt.encoder
optim_opt = agent_opt.optim

agent_opt.discount_factor = 1.0

emb_opt.decoder_dim = 32
emb_opt.encoder_dim = 32

dec_opt.rnn_opt.rnn_cell.cell_opt.num_units = 32
enc_opt.rnn_opt.rnn_cell.cell_opt.num_units = 32

optim_opt.learning_rate = 0.3
optim_opt.name = 'GradientDescentOptimizer'
optim_opt.max_epochs = 5

emb_opt = agent_opt.value_model.model_opt.embedding
dec_opt = agent_opt.value_model.model_opt.decoder
enc_opt = agent_opt.value_model.model_opt.encoder

emb_opt.decoder_dim = 32
emb_opt.encoder_dim = 32

dec_opt.rnn_opt.rnn_cell.cell_opt.num_units = 32
enc_opt.rnn_opt.rnn_cell.cell_opt.num_units = 32

sess_config = tf.ConfigProto(device_count={'GPU': 0})
sess = tf.Session(config = sess_config)

agent = ActorCriticAgent(agent_opt, sess)
agent.initialize_model(with_training=True)
agent.initialize_optim()
for v in tf.trainable_variables():
    print('{}, {}'.format(v.name, v.get_shape()))
sess.run(tf.global_variables_initializer())
agent.train(train_iter, 20, valid_iter, 20, verbose=True)

[36m[INFO ][0mep: 0, lr: 0.300000


policy_agent/policy/model/encoder_embedding:0, (15, 32)
policy_agent/policy/model/decoder_embedding:0, (15, 32)
policy_agent/policy/model/encoder_rnn/rnn/basic_lstm_cell/weights:0, (64, 128)
policy_agent/policy/model/encoder_rnn/rnn/basic_lstm_cell/biases:0, (128,)
policy_agent/policy/model/decoder_rnn/rnn/basic_lstm_cell/weights:0, (64, 128)
policy_agent/policy/model/decoder_rnn/rnn/basic_lstm_cell/biases:0, (128,)
policy_agent/policy/model/decoder_rnn/logit_w:0, (15, 32)
policy_agent/policy/model/decoder_rnn/logit_b:0, (15,)
policy_agent/value/model/encoder_embedding:0, (15, 32)
policy_agent/value/model/decoder_embedding:0, (15, 32)
policy_agent/value/model/encoder_rnn/rnn/basic_lstm_cell/weights:0, (64, 128)
policy_agent/value/model/encoder_rnn/rnn/basic_lstm_cell/biases:0, (128,)
policy_agent/value/model/decoder_rnn/rnn/basic_lstm_cell/weights:0, (64, 128)
policy_agent/value/model/decoder_rnn/rnn/basic_lstm_cell/biases:0, (128,)
policy_agent/value/model/decoder_rnn/regression_w:0, 

[36m[INFO ][0mtrain: @499 tr_loss: 8.75940, eval_loss: 1.33140 (3.78634), wps: 15345.9
[36m[INFO ][0mvalid: @49 eval_loss: 0.49504 (1.64056), wps: 31225.5
[36m[INFO ][0mep: 1, lr: 0.300000
[36m[INFO ][0mtrain: @499 tr_loss: 2.08008, eval_loss: 0.31413 (1.36907), wps: 15619.9
[36m[INFO ][0mvalid: @49 eval_loss: 0.24724 (1.28048), wps: 36564.7
[36m[INFO ][0mep: 2, lr: 0.300000
[36m[INFO ][0mtrain: @499 tr_loss: 1.14843, eval_loss: 0.17270 (1.18851), wps: 15497.4
[36m[INFO ][0mvalid: @49 eval_loss: 0.17177 (1.18741), wps: 35876.5
[36m[INFO ][0mep: 3, lr: 0.300000
[36m[INFO ][0mtrain: @499 tr_loss: 0.78560, eval_loss: 0.11829 (1.12557), wps: 15673.9
[36m[INFO ][0mvalid: @49 eval_loss: 0.10460 (1.11027), wps: 37163.1
[36m[INFO ][0mep: 4, lr: 0.300000
[36m[INFO ][0mtrain: @499 tr_loss: 0.58187, eval_loss: 0.08775 (1.09171), wps: 15368.8
[36m[INFO ][0mvalid: @49 eval_loss: 0.08459 (1.08827), wps: 28585.3


<seqmodel.experiment.run_info.TrainingState at 0x7fe77815a510>

In [5]:
info = agent.evaluate(valid_iter, 20)
print("PPL: {}]".format(
    info.eval_cost/info.num_tokens))

valid_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.EACH_MATCH)
valid_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_env)
print('Each match: {}'.format(info.eval_loss))

valid_hard_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.ALL_MATCH)
valid_hard_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_hard_env)
print('Exact match: {}'.format(info.eval_loss))

valid_bleu_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.SEN_BLEU)
valid_bleu_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_bleu_env)
print('BLEU: {}'.format(info.eval_loss))

PPL: 0.0845933976382]
Each match: -0.958009126984
Exact match: -0.845
BLEU: -0.917033496979


In [6]:
agent.reset_training_state()
train_env = data.env.CopyEnv(train_iter, re_init=False, reward_mode=data.env.ToyRewardMode.EACH_MATCH)
info = agent.policy_gradient(train_env, 20, valid_hard_env, 20, max_steps=20)

[36m[INFO ][0mep: 0, lr: 0.300000
[36m[INFO ][0mtrain: @500 tr_loss: -0.10057, base_loss: 0.02064, avg_return: 0.93499, wps: 3858.3
[36m[INFO ][0mvalid: @50 avg_return: 0.89100, wps: 8995.6
[36m[INFO ][0mep: 1, lr: 0.300000
[36m[INFO ][0mtrain: @500 tr_loss: -0.09817, base_loss: 0.01208, avg_return: 0.94365, wps: 3972.9
[36m[INFO ][0mvalid: @50 avg_return: 0.89000, wps: 8954.8
[36m[INFO ][0mep: 2, lr: 0.300000
[36m[INFO ][0mtrain: @500 tr_loss: -0.09995, base_loss: 0.01084, avg_return: 0.94547, wps: 3975.0
[36m[INFO ][0mvalid: @50 avg_return: 0.88400, wps: 8660.5
[36m[INFO ][0mep: 3, lr: 0.300000
[36m[INFO ][0mtrain: @500 tr_loss: -0.09916, base_loss: 0.01001, avg_return: 0.94777, wps: 3976.1
[36m[INFO ][0mvalid: @50 avg_return: 0.89300, wps: 9082.8
[36m[INFO ][0mep: 4, lr: 0.300000
[36m[INFO ][0mtrain: @500 tr_loss: -0.09278, base_loss: 0.00941, avg_return: 0.95108, wps: 3951.1
[36m[INFO ][0mvalid: @50 avg_return: 0.90400, wps: 9084.4


In [12]:
info = agent.evaluate(valid_iter, 20)
print("PPL: {}".format(
    info.eval_cost/info.num_tokens))

valid_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.EACH_MATCH)
valid_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_env)
print('Each match: {}'.format(-1 * info.eval_loss))

valid_hard_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.ALL_MATCH)
valid_hard_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_hard_env)
print('Exact match: {}'.format(-1 * info.eval_loss))

valid_bleu_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.SEN_BLEU)
valid_bleu_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_bleu_env)
print('BLEU: {}'.format(-1 * info.eval_loss))

PPL: 0.0621478553859
Each match: 0.975246428571
Exact match: 0.904
BLEU: 0.943561717181


In [13]:
test_data = ([['a a b c a d a f a', 'a a b c a d a f a'], ['a b c d', 'a b c d']])
test_iter = data.Seq2SeqIterator(vocab, vocab)
test_iter.initialize(test_data)
test_iter.init_batch(2)
env = data.env.CopyEnv(test_iter, re_init=False, reward_mode=data.env.ToyRewardMode.EACH_MATCH)
# env = data.env.Seq2SeqEnv(test_iter, re_init=False)
transitions, states, rewards = agent.rollout(env, greedy=True)
rewards = np.array(rewards)
returns, targets = agent._compute_return(states, rewards)
print('Return: ') 
print(returns.T)
print('Target: ')
print(targets.T)
pg_data = env.create_transition_return(states, returns)
val_data = env.create_transition_value(states, targets)

Return: 
[[-0.01575243 -0.00287337  0.01562468  0.00134041 -0.03334575 -0.10050561
  -0.10050561 -0.10050561 -0.10050561 -0.10050561]
 [-0.33200121 -0.37118188 -0.39225872 -0.3926519  -0.30676489 -0.21994679
  -0.21045455 -0.13292591 -0.0579662   0.00362953]]
Target: 
[[ 1.   0.8  0.6  0.4  0.2  0.   0.   0.   0.   0. ]
 [ 0.5  0.4  0.3  0.2  0.2  0.2  0.1  0.1  0.1  0.1]]


In [14]:
pg_data.features.encoder_input.T

array([[ 1,  5,  6,  7,  8,  3,  3,  3,  3,  3,  3],
       [ 1,  5,  5,  6,  7,  5,  8,  5, 10,  5,  3]], dtype=int32)

In [15]:
pg_data.labels.decoder_label.T

array([[ 5,  6,  7,  8,  0,  0,  0,  0,  0,  0],
       [ 5,  5,  6,  5,  7,  8, 10,  5,  9,  0]])

In [16]:
pg_data.labels.decoder_label_weight.T

array([[-0.01575243, -0.00287337,  0.01562468,  0.00134041, -0.03334575,
        -0.        , -0.        , -0.        , -0.        , -0.        ],
       [-0.33200121, -0.37118188, -0.39225872, -0.3926519 , -0.30676489,
        -0.21994679, -0.21045455, -0.13292591, -0.0579662 ,  0.00362953]])