In [1]:
import sys
import time
import tensorflow as tf
import numpy as np
sys.path.insert(0, '../')
%load_ext autoreload
%autoreload 2

In [2]:
from seqmodel.bunch import Bunch
from seqmodel.experiment.policy_agent import ActorCriticAgent
from seqmodel import model
from seqmodel import data

In [3]:
vocab = data.Vocabulary.from_vocab_file('../data/tiny_copy/vocab.txt')
valid_iter = data.Seq2SeqIterator(vocab, vocab)
valid_iter.initialize('../data/tiny_copy/valid.txt')
train_iter = data.Seq2SeqIterator(vocab, vocab)
train_iter.initialize('../data/tiny_copy/train.txt')

In [4]:
tf.reset_default_graph()
agent_opt = ActorCriticAgent.default_opt()
emb_opt = agent_opt.policy_model.model_opt.embedding
dec_opt = agent_opt.policy_model.model_opt.decoder
enc_opt = agent_opt.policy_model.model_opt.encoder
optim_opt = agent_opt.optim

agent_opt.discount_factor = 0.5

emb_opt.decoder_dim = 32
emb_opt.encoder_dim = 32

dec_opt.rnn_opt.rnn_cell.cell_opt.num_units = 32
enc_opt.rnn_opt.rnn_cell.cell_opt.num_units = 32

optim_opt.learning_rate = 0.3
optim_opt.name = 'GradientDescentOptimizer'
optim_opt.max_epochs = 5

emb_opt = agent_opt.value_model.model_opt.embedding
dec_opt = agent_opt.value_model.model_opt.decoder
enc_opt = agent_opt.value_model.model_opt.encoder

emb_opt.decoder_dim = 32
emb_opt.encoder_dim = 32

dec_opt.rnn_opt.rnn_cell.cell_opt.num_units = 32
enc_opt.rnn_opt.rnn_cell.cell_opt.num_units = 32

sess_config = tf.ConfigProto(device_count={'GPU': 0})
sess = tf.Session(config = sess_config)

agent = ActorCriticAgent(agent_opt, sess)
agent.initialize_model(with_training=True)
agent.initialize_optim()
for v in tf.trainable_variables():
    print('{}, {}'.format(v.name, v.get_shape()))
sess.run(tf.global_variables_initializer())
agent.train(train_iter, 20, valid_iter, 20, verbose=True)

[36m[INFO ][0mep: 0, lr: 0.300000


policy_agent/policy/model/encoder_embedding:0, (15, 32)
policy_agent/policy/model/decoder_embedding:0, (15, 32)
policy_agent/policy/model/encoder_rnn/rnn/basic_lstm_cell/weights:0, (64, 128)
policy_agent/policy/model/encoder_rnn/rnn/basic_lstm_cell/biases:0, (128,)
policy_agent/policy/model/decoder_rnn/rnn/basic_lstm_cell/weights:0, (64, 128)
policy_agent/policy/model/decoder_rnn/rnn/basic_lstm_cell/biases:0, (128,)
policy_agent/policy/model/decoder_rnn/logit_w:0, (15, 32)
policy_agent/policy/model/decoder_rnn/logit_b:0, (15,)
policy_agent/value/model/encoder_embedding:0, (15, 32)
policy_agent/value/model/decoder_embedding:0, (15, 32)
policy_agent/value/model/encoder_rnn/rnn/basic_lstm_cell/weights:0, (64, 128)
policy_agent/value/model/encoder_rnn/rnn/basic_lstm_cell/biases:0, (128,)
policy_agent/value/model/decoder_rnn/rnn/basic_lstm_cell/weights:0, (64, 128)
policy_agent/value/model/decoder_rnn/rnn/basic_lstm_cell/biases:0, (128,)
policy_agent/value/model/decoder_rnn/regression_w:0, 

[36m[INFO ][0mtrain: @499 tr_loss: 8.74717, eval_loss: 1.32847, wps: 15394.8
[36m[INFO ][0mvalid: @49 eval_loss: 0.45677, wps: 32005.0
[36m[INFO ][0mep: 1, lr: 0.300000
[36m[INFO ][0mtrain: @499 tr_loss: 2.09442, eval_loss: 0.31601, wps: 15507.3
[36m[INFO ][0mvalid: @49 eval_loss: 0.24743, wps: 37340.9
[36m[INFO ][0mep: 2, lr: 0.300000
[36m[INFO ][0mtrain: @499 tr_loss: 1.12990, eval_loss: 0.17042, wps: 15633.9
[36m[INFO ][0mvalid: @49 eval_loss: 0.17978, wps: 37437.8
[36m[INFO ][0mep: 3, lr: 0.300000
[36m[INFO ][0mtrain: @499 tr_loss: 0.74202, eval_loss: 0.11155, wps: 15602.9
[36m[INFO ][0mvalid: @49 eval_loss: 0.09635, wps: 37646.7
[36m[INFO ][0mep: 4, lr: 0.300000
[36m[INFO ][0mtrain: @499 tr_loss: 0.55698, eval_loss: 0.08412, wps: 15507.8
[36m[INFO ][0mvalid: @49 eval_loss: 0.07772, wps: 37124.9


<seqmodel.experiment.run_info.TrainingState at 0x7fab1c53ed50>

In [5]:
info = agent.evaluate(valid_iter, 20)
print("PPL: {}]".format(
    info.eval_cost/info.num_tokens))

valid_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.EACH_MATCH)
valid_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_env)
print('Each match: {}'.format(info.eval_loss))

valid_hard_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.ALL_MATCH)
valid_hard_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_hard_env)
print('Exact match: {}'.format(info.eval_loss))

valid_bleu_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.SEN_BLEU)
valid_bleu_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_bleu_env)
print('BLEU: {}'.format(info.eval_loss))

PPL: 0.0777161237701]
Each match: 0.962945634921
Exact match: 0.876
BLEU: 0.94766835828


In [6]:
agent.reset_training_state()
train_env = data.env.CopyEnv(train_iter, re_init=False, reward_mode=data.env.ToyRewardMode.EACH_MATCH)
info = agent.policy_gradient(train_env, 20, valid_hard_env, 20, max_steps=20)

[36m[INFO ][0mep: 0, lr: 0.300000
[36m[INFO ][0mtrain: @500 tr_loss: -0.09527, base_loss: 0.00760, avg_return: 0.93500, wps: 597.9
[36m[INFO ][0mvalid: @50 avg_return: 0.90900, wps: 1377.0
[36m[INFO ][0mep: 1, lr: 0.300000
[36m[INFO ][0mtrain: @500 tr_loss: -0.07511, base_loss: 0.00246, avg_return: 0.94578, wps: 603.8
[36m[INFO ][0mvalid: @50 avg_return: 0.91400, wps: 1372.4
[36m[INFO ][0mep: 2, lr: 0.300000
[36m[INFO ][0mtrain: @500 tr_loss: -0.07019, base_loss: 0.00229, avg_return: 0.94943, wps: 603.3
[36m[INFO ][0mvalid: @50 avg_return: 0.91900, wps: 1298.1
[36m[INFO ][0mep: 3, lr: 0.300000
[36m[INFO ][0mtrain: @500 tr_loss: -0.06476, base_loss: 0.00214, avg_return: 0.95280, wps: 601.7
[36m[INFO ][0mvalid: @50 avg_return: 0.91500, wps: 1349.7
[36m[INFO ][0mep: 4, lr: 0.300000
[36m[INFO ][0mtrain: @500 tr_loss: -0.06103, base_loss: 0.00205, avg_return: 0.95523, wps: 603.2
[36m[INFO ][0mvalid: @50 avg_return: 0.91800, wps: 1386.1


In [7]:
info = agent.evaluate(valid_iter, 20)
print("PPL: {}".format(
    info.eval_cost/info.num_tokens))

valid_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.EACH_MATCH)
valid_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_env)
print('Each match: {}'.format(info.eval_loss))

valid_hard_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.ALL_MATCH)
valid_hard_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_hard_env)
print('Exact match: {}'.format(info.eval_loss))

valid_bleu_env = data.env.CopyEnv(valid_iter, re_init=False, reward_mode=data.env.ToyRewardMode.SEN_BLEU)
valid_bleu_env.restart(batch_size=20)
info = agent.evaluate_policy(valid_bleu_env)
print('BLEU: {}'.format(info.eval_loss))

PPL: 0.0471603506177
Each match: 0.976432539683
Exact match: 0.918
BLEU: 0.966214409899


In [8]:
test_data = ([['a a b c a d a f a', 'a a b c a d a f a'], ['a b c d', 'a b c d']])
test_iter = data.Seq2SeqIterator(vocab, vocab)
test_iter.initialize(test_data)
test_iter.init_batch(2)
env = data.env.CopyEnv(test_iter, re_init=False, reward_mode=data.env.ToyRewardMode.EACH_MATCH)
# env = data.env.Seq2SeqEnv(test_iter, re_init=False)
transitions, states, rewards = agent.rollout(env, greedy=True)
rewards = np.array(rewards)
returns, targets = agent._compute_return(states, rewards)
print('Return: ') 
print(returns.T)
print('Target: ')
print(targets.T)
pg_data = env.create_transition_return(states, returns)
val_data = env.create_transition_value(states, targets)

Return: 
[[ 0.01633463 -0.00246463 -0.00823536 -0.00223882  0.01537414  0.02000853
   0.02424275  0.00252199 -0.07068406  0.00957855]
 [-0.03090422  0.00377899  0.02025822  0.02127074 -0.03912793 -0.25561094
  -0.25561094 -0.25561094 -0.25561094 -0.25561094]]
Target: 
[[ 0.19941406  0.19882813  0.19765625  0.1953125   0.190625    0.18125
   0.1625      0.125       0.05        0.1       ]
 [ 0.3875      0.375       0.35        0.3         0.2         0.          0.
   0.          0.          0.        ]]


In [9]:
pg_data.features.encoder_input.T

array([[ 1,  5,  5,  6,  7,  5,  8,  5, 10,  5,  3],
       [ 1,  5,  6,  7,  8,  3,  3,  3,  3,  3,  3]], dtype=int32)

In [10]:
pg_data.labels.decoder_label.T

array([[ 5,  5,  6,  7,  5,  8,  5, 10,  7,  0],
       [ 5,  6,  7,  8,  0,  0,  0,  0,  0,  0]])

In [11]:
pg_data.labels.decoder_label_weight.T

array([[ 0.01633463, -0.00246463, -0.00823536, -0.00223882,  0.01537414,
         0.02000853,  0.02424275,  0.00252199, -0.07068406,  0.00957855],
       [-0.03090422,  0.00377899,  0.02025822,  0.02127074, -0.03912793,
        -0.        , -0.        , -0.        , -0.        , -0.        ]])