In [1]:
from os import path as osp
import numpy as np
import tensorflow as tf
import sonnet as snt
from attrdict import AttrDict

from evaluation import make_fig, make_logger

from data import load_data, tensors_from_data
from mnist_model import SeqAIRonMNIST

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
% matplotlib inline

In [2]:
learning_rate = 1e-4
n_steps = 3

results_dir = '../results/seq_grad_inspect'
run_name = 'no_baseline'

logdir = osp.join(results_dir, run_name)
checkpoint_name = osp.join(logdir, 'model.ckpt')
axes = {'imgs': 1, 'labels': 0, 'nums': 1, 'coords': 1}

In [3]:
n_timesteps = 1
batch_size = 64

num_steps_prior = AttrDict(
    anneal='exp',
    init=1.,
    final=1e-5,
    steps_div=1e4,
    steps=1e5,
    hold_init=1e3,
    analytic=False
)


appearance_prior = AttrDict(loc=0., scale=1.)
where_scale_prior = AttrDict(loc=0., scale=1.)
where_shift_prior = AttrDict(loc=0., scale=1.)

num_steps_prior = None
appearance_prior = None
where_scale_prior = None
where_shift_prior = None

use_reinforce = True
step_bias = .75
transform_var_bias = .5
output_multiplier = .5

l2_weight = 0. #1e-5

supervised_nums = False

In [4]:
# valid_data = load_data('small_seq_mnist_validation.pickle')
# train_data = load_data('small_seq_mnist_train.pickle')

In [5]:
valid_data = load_data('mnist_validation.pickle')
train_data = load_data('mnist_train.pickle')

# del axes['coords']
def fix(d):
    d.imgs = d.imgs[np.newaxis]
    d.nums = d.nums.T
    for k in d:
        print k, d[k].shape
    return d

valid_data, train_data = [fix(d) for d in (valid_data, train_data)]

imgs (1, 10000, 50, 50)
labels (10000, 2)
nums (1, 10000, 3)
imgs (1, 60000, 50, 50)
labels (60000, 2)
nums (1, 60000, 3)


In [6]:
tf.reset_default_graph()
train_tensors = tensors_from_data(train_data, batch_size, axes, shuffle=True)
valid_tensors = tensors_from_data(valid_data, batch_size, axes, shuffle=False)
x, valid_x = train_tensors['imgs'], valid_tensors['imgs']
y, test_y = train_tensors['nums'], valid_tensors['nums']
    
n_hiddens = 32 * 8
n_layers = 2
n_hiddens = [n_hiddens] * n_layers
    
n_timesteps = tf.Variable(n_timesteps, trainable=False, dtype=tf.int32, name='n_timesteps')
seq_x = x[:n_timesteps]
seq_y = y[:n_timesteps]
air = SeqAIRonMNIST(seq_x,
                max_steps=n_steps,
                inpt_encoder_hidden=n_hiddens,
                glimpse_encoder_hidden=n_hiddens,
                glimpse_decoder_hidden=n_hiddens,
                transform_estimator_hidden=n_hiddens,
                steps_pred_hidden=[128, 64],
                baseline_hidden=[256, 128],
                transform_var_bias=transform_var_bias,
                step_bias=step_bias,
                output_multiplier=output_multiplier
)

In [7]:
train_step, global_step = air.train_step(learning_rate, l2_weight, appearance_prior, where_scale_prior,
                            where_shift_prior, num_steps_prior, use_reinforce=use_reinforce,
                            decay_rate=None, nums=seq_y, supervised_nums=supervised_nums)

reinf (?, 64) (?, 64, 1)
model 34
lstm/w_gates:0 [512, 1024]
SeqAIRonMNIST/SeqAIRCell/StochasticTransformParam/MLP/linear_2/b:0 [8]
SeqAIRonMNIST/SeqAIRCell/StepsPredictor/MLP/linear_2/w:0 [64, 1]
SeqAIRonMNIST/SeqAIRCell/StochasticTransformParam/MLP/linear_2/w:0 [256, 8]
SeqAIRonMNIST/SeqAIRCell/StepsPredictor/MLP/linear_2/b:0 [1]
lstm/b_gates:0 [1024]
SeqAIRonMNIST/SeqAIRCell/Encoder/MLP/linear_1/w:0 [256, 256]
SeqAIRonMNIST/SeqAIRCell/Decoder/MLP/linear/w:0 [50, 256]
SeqAIRonMNIST/SeqAIRCell/StochasticTransformParam/MLP/linear/b:0 [256]
SeqAIRonMNIST/lstm_initial_state_0/w:0 [1, 256]
SeqAIRonMNIST/SeqAIRCell/StochasticTransformParam/MLP/linear/w:0 [256, 256]
SeqAIRonMNIST/SeqAIRCell/Decoder/MLP/linear/b:0 [256]
SeqAIRonMNIST/SeqAIRCell/Encoder/MLP/linear/w:0 [2500, 256]
SeqAIRonMNIST/SeqAIRCell/Decoder/MLP/linear_2/w:0 [256, 400]
SeqAIRonMNIST/SeqAIRCell/Decoder/MLP/linear_1/b:0 [256]
SeqAIRonMNIST/SeqAIRCell/StepsPredictor/MLP/linear/b:0 [128]
SeqAIRonMNIST/SeqAIRCell/Decoder/MLP/l

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Gradient for variable SeqAIRonMNIST/what_init is None
Gradient for variable SeqAIRonMNIST/where_init is None


In [8]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
    
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())

In [9]:
step_prob = air.num_steps_distrib.prob()
mean_prob = tf.reduce_mean(step_prob, (0, 1))
step_entropy = step_prob * tf.log(step_prob)
step_entropy = -tf.reduce_sum(step_entropy)

In [10]:
def compute_grad(y, x):
    g = tf.reduce_mean(tf.gradients(y, x)[0], (0, 1))
    norm = 1.#tf.sqrt(tf.reduce_sum(tf.square(g), -1, keep_dims=True))
    return g / norm

total_grad = compute_grad(air.opt_loss, step_prob)

grads = [total_grad]
names = ['total']

if air.use_prior:
    kl_grad = compute_grad(air.prior_loss.value, step_prob)
    grads.append(kl_grad)
    names.append('kl')
    
if use_reinforce:
    reinforce_grad = compute_grad(air.reinforce_loss, step_prob)
    grads.append(reinforce_grad)
    names.append('reinforce')
    
grads.append(mean_prob)
names.append('prob')


for g, name in zip(grads, names):
    step_grads = tf.unstack(g, axis=-1)
    for i, step_grad in enumerate(step_grads):
        tf.summary.scalar('prob_grad/{}_{}'.format(name, i), step_grad)

In [11]:
print grads
print names

[<tf.Tensor 'div:0' shape=(4,) dtype=float32>, <tf.Tensor 'div_1:0' shape=(4,) dtype=float32>, <tf.Tensor 'Mean_1:0' shape=(4,) dtype=float32>]
['total', 'reinforce', 'prob']


In [12]:
# a, d = sess.run([grads, mean_prob])
# for i in a:
#     print i
# print d

In [13]:
all_summaries = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(logdir, sess.graph)
saver = tf.train.Saver()

In [14]:
n_train_samples = train_data['imgs'].shape[axes['imgs']]
n_valid_samples = valid_data['imgs'].shape[axes['imgs']]
n_train_samples, n_valid_samples = [i//100 for i in (n_train_samples, n_valid_samples)] 
log = make_logger(air, sess, summary_writer, train_tensors, n_train_samples, valid_tensors, n_valid_samples)

In [15]:
n_train_samples, n_valid_samples

(600, 100)

In [16]:
# log(0)

In [17]:
# tensors = {
#     'loss': air.loss.value,
#     'imp_weight': .5*tf.reduce_mean((air.reinforce_imp_weight)**2),
#     'baseline_loss': air.baseline_loss,
#     'baseline': tf.reduce_mean(air.baseline)
# }

In [18]:
# output = sess.run(tensors)
# for o, v in output.iteritems():
#     print '{}: {}'.format(o, v)

In [None]:
* template and coord storing in data.create_mnist
* load_data returns an AttrDict
* 

In [20]:
train_itr = sess.run(global_step)
print 'Starting training at iter = {}'.format(train_itr)

if train_itr == 0:
#     if air.use_reinforce:
#         sess.run(air._baseline_train_step)
    log(0)
    
while train_itr < 1e6:
        
    train_itr, _ = sess.run([global_step, train_step])
    
#     if train_itr % 100 == 0:
#     if (train_itr % 1000) < 100:
    if train_itr % 50 == 0:
        summaries = sess.run(all_summaries)
        summary_writer.add_summary(summaries, train_itr)
        
    if train_itr % 100 == 0:
        log(train_itr)
        
    if train_itr % 5000 == 0:
        saver.save(sess, checkpoint_name, global_step=train_itr)
#         make_fig(air, sess, logdir, train_itr)    

Starting training at iter = 0
Step 0, Data train loss = -118.6387, imp_weight = -118.6387, opt_loss = 40.1163, reinforce_loss = 158.7550, rec_loss = -118.6387, num_step = 1.3524, num_step_acc = 0.2222, eval time = 0.2532s
Step 0, Data test loss = -29.1283, imp_weight = -29.1283, opt_loss = 9.4869, reinforce_loss = 38.6153, rec_loss = -29.1283, num_step = 1.4062, num_step_acc = 0.1719, eval time = 0.09908s

Step 100, Data train loss = -199.0291, imp_weight = -199.0291, opt_loss = 66.0636, reinforce_loss = 265.0926, rec_loss = -199.0291, num_step = 1.2274, num_step_acc = 0.2153, eval time = 0.1685s
Step 100, Data test loss = -150.9201, imp_weight = -150.9201, opt_loss = 49.1561, reinforce_loss = 200.0762, rec_loss = -150.9201, num_step = 1.3281, num_step_acc = 0.1875, eval time = 0.01688s

Step 200, Data train loss = -263.1973, imp_weight = -263.1973, opt_loss = 68.8205, reinforce_loss = 332.0178, rec_loss = -263.1973, num_step = 1.2378, num_step_acc = 0.2049, eval time = 0.1832s
Step 20

Step 2100, Data train loss = -507.3460, imp_weight = -507.3460, opt_loss = 162.0468, reinforce_loss = 669.3929, rec_loss = -507.3460, num_step = 1.7378, num_step_acc = 0.1823, eval time = 0.1924s
Step 2100, Data test loss = -439.9178, imp_weight = -439.9178, opt_loss = 163.5695, reinforce_loss = 603.4873, rec_loss = -439.9178, num_step = 1.5625, num_step_acc = 0.2656, eval time = 0.01952s

Step 2200, Data train loss = -506.4790, imp_weight = -506.4790, opt_loss = 145.8249, reinforce_loss = 652.3039, rec_loss = -506.4790, num_step = 1.7448, num_step_acc = 0.1962, eval time = 0.1837s
Step 2200, Data test loss = -528.3838, imp_weight = -528.3838, opt_loss = 158.8640, reinforce_loss = 687.2477, rec_loss = -528.3838, num_step = 1.8594, num_step_acc = 0.2812, eval time = 0.01929s

Step 2300, Data train loss = -489.9593, imp_weight = -489.9593, opt_loss = 158.1847, reinforce_loss = 648.1440, rec_loss = -489.9593, num_step = 1.6562, num_step_acc = 0.2135, eval time = 0.1935s
Step 2300, Data te

Step 4200, Data train loss = -529.2955, imp_weight = -529.2955, opt_loss = 171.5116, reinforce_loss = 700.8071, rec_loss = -529.2955, num_step = 1.6823, num_step_acc = 0.2135, eval time = 0.1794s
Step 4200, Data test loss = -444.4440, imp_weight = -444.4440, opt_loss = 124.4225, reinforce_loss = 568.8665, rec_loss = -444.4440, num_step = 1.5312, num_step_acc = 0.2188, eval time = 0.01897s

Step 4300, Data train loss = -497.1604, imp_weight = -497.1604, opt_loss = 155.8989, reinforce_loss = 653.0593, rec_loss = -497.1604, num_step = 1.6163, num_step_acc = 0.2188, eval time = 0.1787s
Step 4300, Data test loss = -452.6310, imp_weight = -452.6310, opt_loss = 150.1467, reinforce_loss = 602.7777, rec_loss = -452.6310, num_step = 1.5000, num_step_acc = 0.1875, eval time = 0.01958s

Step 4400, Data train loss = -510.2068, imp_weight = -510.2068, opt_loss = 163.8861, reinforce_loss = 674.0929, rec_loss = -510.2068, num_step = 1.7066, num_step_acc = 0.2101, eval time = 0.1831s
Step 4400, Data te

Step 6300, Data train loss = -499.5271, imp_weight = -499.5271, opt_loss = 159.7515, reinforce_loss = 659.2786, rec_loss = -499.5271, num_step = 1.6458, num_step_acc = 0.1892, eval time = 0.1806s
Step 6300, Data test loss = -470.4097, imp_weight = -470.4097, opt_loss = 145.5066, reinforce_loss = 615.9163, rec_loss = -470.4097, num_step = 1.6250, num_step_acc = 0.2188, eval time = 0.01995s

Step 6400, Data train loss = -528.9655, imp_weight = -528.9655, opt_loss = 170.6001, reinforce_loss = 699.5657, rec_loss = -528.9655, num_step = 1.6545, num_step_acc = 0.2135, eval time = 0.182s
Step 6400, Data test loss = -511.0793, imp_weight = -511.0793, opt_loss = 218.9889, reinforce_loss = 730.0682, rec_loss = -511.0793, num_step = 1.5469, num_step_acc = 0.2969, eval time = 0.01881s

Step 6500, Data train loss = -502.4463, imp_weight = -502.4463, opt_loss = 156.0632, reinforce_loss = 658.5094, rec_loss = -502.4463, num_step = 1.6267, num_step_acc = 0.1997, eval time = 0.1846s
Step 6500, Data tes

Step 8400, Data train loss = -508.5145, imp_weight = -508.5145, opt_loss = 167.4895, reinforce_loss = 676.0039, rec_loss = -508.5145, num_step = 1.6215, num_step_acc = 0.1892, eval time = 0.174s
Step 8400, Data test loss = -456.2158, imp_weight = -456.2158, opt_loss = 150.4776, reinforce_loss = 606.6934, rec_loss = -456.2158, num_step = 1.6250, num_step_acc = 0.1250, eval time = 0.01928s

Step 8500, Data train loss = -514.5843, imp_weight = -514.5843, opt_loss = 160.0290, reinforce_loss = 674.6133, rec_loss = -514.5843, num_step = 1.6910, num_step_acc = 0.1806, eval time = 0.1772s
Step 8500, Data test loss = -494.5698, imp_weight = -494.5698, opt_loss = 191.1527, reinforce_loss = 685.7225, rec_loss = -494.5698, num_step = 1.6094, num_step_acc = 0.2656, eval time = 0.01968s

Step 8600, Data train loss = -497.1361, imp_weight = -497.1361, opt_loss = 156.7696, reinforce_loss = 653.9057, rec_loss = -497.1361, num_step = 1.5799, num_step_acc = 0.2031, eval time = 0.2323s
Step 8600, Data tes

Step 10500, Data train loss = -536.5315, imp_weight = -536.5315, opt_loss = 164.1647, reinforce_loss = 700.6962, rec_loss = -536.5315, num_step = 1.4722, num_step_acc = 0.2639, eval time = 0.194s
Step 10500, Data test loss = -480.3174, imp_weight = -480.3174, opt_loss = 138.6963, reinforce_loss = 619.0137, rec_loss = -480.3174, num_step = 1.4531, num_step_acc = 0.2500, eval time = 0.02213s

Step 10600, Data train loss = -509.3467, imp_weight = -509.3467, opt_loss = 144.2442, reinforce_loss = 653.5909, rec_loss = -509.3467, num_step = 1.3194, num_step_acc = 0.3056, eval time = 0.1857s
Step 10600, Data test loss = -502.2973, imp_weight = -502.2973, opt_loss = 173.1334, reinforce_loss = 675.4307, rec_loss = -502.2973, num_step = 1.6094, num_step_acc = 0.2656, eval time = 0.01921s

Step 10700, Data train loss = -509.9578, imp_weight = -509.9578, opt_loss = 147.2362, reinforce_loss = 657.1939, rec_loss = -509.9578, num_step = 1.3906, num_step_acc = 0.2378, eval time = 0.1698s
Step 10700, Da

Step 12600, Data train loss = -524.7757, imp_weight = -524.7757, opt_loss = 172.0944, reinforce_loss = 696.8701, rec_loss = -524.7757, num_step = 1.6580, num_step_acc = 0.2031, eval time = 0.1788s
Step 12600, Data test loss = -436.6789, imp_weight = -436.6789, opt_loss = 153.8156, reinforce_loss = 590.4945, rec_loss = -436.6789, num_step = 1.5000, num_step_acc = 0.1406, eval time = 0.01921s

Step 12700, Data train loss = -530.6256, imp_weight = -530.6256, opt_loss = 168.4777, reinforce_loss = 699.1033, rec_loss = -530.6256, num_step = 1.7309, num_step_acc = 0.1875, eval time = 0.191s
Step 12700, Data test loss = -398.4207, imp_weight = -398.4207, opt_loss = 144.9896, reinforce_loss = 543.4103, rec_loss = -398.4207, num_step = 1.4062, num_step_acc = 0.1719, eval time = 0.02082s

Step 12800, Data train loss = -534.3347, imp_weight = -534.3347, opt_loss = 176.0405, reinforce_loss = 710.3752, rec_loss = -534.3347, num_step = 1.7066, num_step_acc = 0.2309, eval time = 0.1883s
Step 12800, Da

Step 14700, Data train loss = -530.4851, imp_weight = -530.4851, opt_loss = 164.2553, reinforce_loss = 694.7404, rec_loss = -530.4851, num_step = 1.7344, num_step_acc = 0.2049, eval time = 0.1822s
Step 14700, Data test loss = -492.0665, imp_weight = -492.0665, opt_loss = 179.1624, reinforce_loss = 671.2289, rec_loss = -492.0665, num_step = 1.7344, num_step_acc = 0.2500, eval time = 0.02066s

Step 14800, Data train loss = -524.4951, imp_weight = -524.4951, opt_loss = 162.5853, reinforce_loss = 687.0804, rec_loss = -524.4951, num_step = 1.6701, num_step_acc = 0.2049, eval time = 0.1766s
Step 14800, Data test loss = -451.8069, imp_weight = -451.8069, opt_loss = 160.0807, reinforce_loss = 611.8877, rec_loss = -451.8069, num_step = 1.7031, num_step_acc = 0.2500, eval time = 0.01818s

Step 14900, Data train loss = -522.7555, imp_weight = -522.7555, opt_loss = 158.5335, reinforce_loss = 681.2890, rec_loss = -522.7555, num_step = 1.7222, num_step_acc = 0.1840, eval time = 0.1811s
Step 14900, D

Step 16800, Data train loss = -522.9739, imp_weight = -522.9739, opt_loss = 149.5188, reinforce_loss = 672.4927, rec_loss = -522.9739, num_step = 1.7135, num_step_acc = 0.1840, eval time = 0.1994s
Step 16800, Data test loss = -471.5223, imp_weight = -471.5223, opt_loss = 153.2183, reinforce_loss = 624.7407, rec_loss = -471.5223, num_step = 1.6094, num_step_acc = 0.1719, eval time = 0.02204s

Step 16900, Data train loss = -532.4659, imp_weight = -532.4659, opt_loss = 174.3365, reinforce_loss = 706.8024, rec_loss = -532.4659, num_step = 1.7101, num_step_acc = 0.2222, eval time = 0.1838s
Step 16900, Data test loss = -468.6757, imp_weight = -468.6757, opt_loss = 159.6163, reinforce_loss = 628.2920, rec_loss = -468.6757, num_step = 1.5938, num_step_acc = 0.1875, eval time = 0.01871s

Step 17000, Data train loss = -524.5538, imp_weight = -524.5538, opt_loss = 183.1265, reinforce_loss = 707.6803, rec_loss = -524.5538, num_step = 1.5660, num_step_acc = 0.2500, eval time = 0.1815s
Step 17000, D

Step 18900, Data train loss = -509.8383, imp_weight = -509.8383, opt_loss = 155.4519, reinforce_loss = 665.2901, rec_loss = -509.8383, num_step = 1.6771, num_step_acc = 0.1910, eval time = 0.1768s
Step 18900, Data test loss = -498.4914, imp_weight = -498.4914, opt_loss = 184.4106, reinforce_loss = 682.9020, rec_loss = -498.4914, num_step = 1.6875, num_step_acc = 0.2188, eval time = 0.01969s

Step 19000, Data train loss = -531.8979, imp_weight = -531.8979, opt_loss = 172.6062, reinforce_loss = 704.5041, rec_loss = -531.8979, num_step = 1.6528, num_step_acc = 0.2257, eval time = 0.1987s
Step 19000, Data test loss = -530.2011, imp_weight = -530.2011, opt_loss = 146.9338, reinforce_loss = 677.1349, rec_loss = -530.2011, num_step = 1.7188, num_step_acc = 0.2656, eval time = 0.0223s

Step 19100, Data train loss = -530.9480, imp_weight = -530.9480, opt_loss = 175.9733, reinforce_loss = 706.9213, rec_loss = -530.9480, num_step = 1.7083, num_step_acc = 0.2222, eval time = 0.1924s
Step 19100, Da

Step 21000, Data train loss = -525.9524, imp_weight = -525.9524, opt_loss = 174.8410, reinforce_loss = 700.7934, rec_loss = -525.9524, num_step = 1.6997, num_step_acc = 0.2049, eval time = 0.1797s
Step 21000, Data test loss = -499.5146, imp_weight = -499.5146, opt_loss = 157.9893, reinforce_loss = 657.5039, rec_loss = -499.5146, num_step = 1.8438, num_step_acc = 0.1875, eval time = 0.01892s

Step 21100, Data train loss = -522.0480, imp_weight = -522.0480, opt_loss = 165.0872, reinforce_loss = 687.1352, rec_loss = -522.0480, num_step = 1.6094, num_step_acc = 0.2153, eval time = 0.2022s
Step 21100, Data test loss = -500.0629, imp_weight = -500.0629, opt_loss = 196.6645, reinforce_loss = 696.7274, rec_loss = -500.0629, num_step = 1.5000, num_step_acc = 0.3125, eval time = 0.02198s

Step 21200, Data train loss = -518.3892, imp_weight = -518.3892, opt_loss = 177.9945, reinforce_loss = 696.3837, rec_loss = -518.3892, num_step = 1.6788, num_step_acc = 0.1910, eval time = 0.1748s
Step 21200, D

Step 23100, Data train loss = -524.7514, imp_weight = -524.7514, opt_loss = 160.2687, reinforce_loss = 685.0201, rec_loss = -524.7514, num_step = 1.7101, num_step_acc = 0.1875, eval time = 0.1828s
Step 23100, Data test loss = -467.2227, imp_weight = -467.2227, opt_loss = 136.7809, reinforce_loss = 604.0036, rec_loss = -467.2227, num_step = 1.7031, num_step_acc = 0.2031, eval time = 0.01973s

Step 23200, Data train loss = -517.5161, imp_weight = -517.5161, opt_loss = 165.1077, reinforce_loss = 682.6238, rec_loss = -517.5161, num_step = 1.6875, num_step_acc = 0.2101, eval time = 0.1843s
Step 23200, Data test loss = -497.8367, imp_weight = -497.8367, opt_loss = 135.9966, reinforce_loss = 633.8334, rec_loss = -497.8367, num_step = 1.6406, num_step_acc = 0.2031, eval time = 0.01922s

Step 23300, Data train loss = -529.1228, imp_weight = -529.1228, opt_loss = 173.4101, reinforce_loss = 702.5329, rec_loss = -529.1228, num_step = 1.6476, num_step_acc = 0.2153, eval time = 0.1826s
Step 23300, D

Step 25200, Data train loss = -500.8138, imp_weight = -500.8138, opt_loss = 179.4140, reinforce_loss = 680.2277, rec_loss = -500.8138, num_step = 1.5677, num_step_acc = 0.2031, eval time = 0.1867s
Step 25200, Data test loss = -495.4247, imp_weight = -495.4247, opt_loss = 144.5711, reinforce_loss = 639.9958, rec_loss = -495.4247, num_step = 1.7656, num_step_acc = 0.1094, eval time = 0.01873s

Step 25300, Data train loss = -521.1172, imp_weight = -521.1172, opt_loss = 169.7830, reinforce_loss = 690.9001, rec_loss = -521.1172, num_step = 1.6597, num_step_acc = 0.1806, eval time = 0.1773s
Step 25300, Data test loss = -456.5816, imp_weight = -456.5816, opt_loss = 131.8604, reinforce_loss = 588.4420, rec_loss = -456.5816, num_step = 1.6250, num_step_acc = 0.1406, eval time = 0.0186s

Step 25400, Data train loss = -500.3276, imp_weight = -500.3276, opt_loss = 159.9628, reinforce_loss = 660.2903, rec_loss = -500.3276, num_step = 1.6215, num_step_acc = 0.1806, eval time = 0.1873s
Step 25400, Da

Step 27300, Data train loss = -523.8331, imp_weight = -523.8331, opt_loss = 175.4317, reinforce_loss = 699.2649, rec_loss = -523.8331, num_step = 1.6267, num_step_acc = 0.2240, eval time = 0.1891s
Step 27300, Data test loss = -520.0125, imp_weight = -520.0125, opt_loss = 218.3144, reinforce_loss = 738.3268, rec_loss = -520.0125, num_step = 1.7812, num_step_acc = 0.2031, eval time = 0.0199s

Step 27400, Data train loss = -542.8487, imp_weight = -542.8487, opt_loss = 172.6097, reinforce_loss = 715.4585, rec_loss = -542.8487, num_step = 1.7031, num_step_acc = 0.1962, eval time = 0.186s
Step 27400, Data test loss = -511.8897, imp_weight = -511.8897, opt_loss = 181.3689, reinforce_loss = 693.2586, rec_loss = -511.8897, num_step = 1.7031, num_step_acc = 0.1875, eval time = 0.02108s

Step 27500, Data train loss = -508.8869, imp_weight = -508.8869, opt_loss = 170.4303, reinforce_loss = 679.3172, rec_loss = -508.8869, num_step = 1.6597, num_step_acc = 0.1823, eval time = 0.1978s
Step 27500, Dat

Step 29400, Data train loss = -527.1465, imp_weight = -527.1465, opt_loss = 163.0520, reinforce_loss = 690.1985, rec_loss = -527.1465, num_step = 1.6684, num_step_acc = 0.1944, eval time = 0.2201s
Step 29400, Data test loss = -507.2887, imp_weight = -507.2887, opt_loss = 166.8460, reinforce_loss = 674.1348, rec_loss = -507.2887, num_step = 1.6875, num_step_acc = 0.1094, eval time = 0.02034s

Step 29500, Data train loss = -536.3480, imp_weight = -536.3480, opt_loss = 190.3034, reinforce_loss = 726.6514, rec_loss = -536.3480, num_step = 1.6649, num_step_acc = 0.2083, eval time = 0.1916s
Step 29500, Data test loss = -539.1941, imp_weight = -539.1941, opt_loss = 151.0504, reinforce_loss = 690.2444, rec_loss = -539.1941, num_step = 1.8594, num_step_acc = 0.2500, eval time = 0.02645s

Step 29600, Data train loss = -525.3957, imp_weight = -525.3957, opt_loss = 165.9174, reinforce_loss = 691.3132, rec_loss = -525.3957, num_step = 1.7188, num_step_acc = 0.1632, eval time = 0.1772s
Step 29600, D

Step 31500, Data train loss = -502.7984, imp_weight = -502.7984, opt_loss = 159.8191, reinforce_loss = 662.6174, rec_loss = -502.7984, num_step = 1.6198, num_step_acc = 0.2066, eval time = 0.1831s
Step 31500, Data test loss = -485.9983, imp_weight = -485.9983, opt_loss = 173.1098, reinforce_loss = 659.1080, rec_loss = -485.9983, num_step = 1.5781, num_step_acc = 0.2344, eval time = 0.01942s

Step 31600, Data train loss = -530.1537, imp_weight = -530.1537, opt_loss = 162.2668, reinforce_loss = 692.4205, rec_loss = -530.1537, num_step = 1.6545, num_step_acc = 0.2274, eval time = 0.1744s
Step 31600, Data test loss = -425.6735, imp_weight = -425.6735, opt_loss = 175.2613, reinforce_loss = 600.9348, rec_loss = -425.6735, num_step = 1.2969, num_step_acc = 0.1719, eval time = 0.02004s

Step 31700, Data train loss = -518.5688, imp_weight = -518.5688, opt_loss = 164.2871, reinforce_loss = 682.8559, rec_loss = -518.5688, num_step = 1.6250, num_step_acc = 0.2153, eval time = 0.178s
Step 31700, Da

Step 33600, Data train loss = -522.2753, imp_weight = -522.2753, opt_loss = 167.9130, reinforce_loss = 690.1883, rec_loss = -522.2753, num_step = 1.6528, num_step_acc = 0.1979, eval time = 0.1822s
Step 33600, Data test loss = -483.5404, imp_weight = -483.5404, opt_loss = 125.9919, reinforce_loss = 609.5323, rec_loss = -483.5404, num_step = 1.7969, num_step_acc = 0.1250, eval time = 0.02092s

Step 33700, Data train loss = -520.6133, imp_weight = -520.6133, opt_loss = 167.0878, reinforce_loss = 687.7011, rec_loss = -520.6133, num_step = 1.7240, num_step_acc = 0.2066, eval time = 0.1822s
Step 33700, Data test loss = -429.9692, imp_weight = -429.9692, opt_loss = 106.1242, reinforce_loss = 536.0934, rec_loss = -429.9692, num_step = 1.6875, num_step_acc = 0.1094, eval time = 0.02007s

Step 33800, Data train loss = -527.4949, imp_weight = -527.4949, opt_loss = 173.7892, reinforce_loss = 701.2841, rec_loss = -527.4949, num_step = 1.6094, num_step_acc = 0.1962, eval time = 0.1938s
Step 33800, D

Step 35700, Data train loss = -523.9334, imp_weight = -523.9334, opt_loss = 169.9849, reinforce_loss = 693.9183, rec_loss = -523.9334, num_step = 1.6250, num_step_acc = 0.2517, eval time = 0.1858s
Step 35700, Data test loss = -450.9573, imp_weight = -450.9573, opt_loss = 163.6841, reinforce_loss = 614.6414, rec_loss = -450.9573, num_step = 1.4688, num_step_acc = 0.1875, eval time = 0.02043s

Step 35800, Data train loss = -523.4121, imp_weight = -523.4121, opt_loss = 165.9614, reinforce_loss = 689.3736, rec_loss = -523.4121, num_step = 1.6667, num_step_acc = 0.1753, eval time = 0.1823s
Step 35800, Data test loss = -464.5528, imp_weight = -464.5528, opt_loss = 105.1279, reinforce_loss = 569.6807, rec_loss = -464.5528, num_step = 1.6875, num_step_acc = 0.2031, eval time = 0.01917s

Step 35900, Data train loss = -516.0698, imp_weight = -516.0698, opt_loss = 155.3111, reinforce_loss = 671.3809, rec_loss = -516.0698, num_step = 1.6406, num_step_acc = 0.1944, eval time = 0.1867s
Step 35900, D

Step 37800, Data train loss = -535.7910, imp_weight = -535.7910, opt_loss = 168.0680, reinforce_loss = 703.8591, rec_loss = -535.7910, num_step = 1.7031, num_step_acc = 0.1875, eval time = 0.1793s
Step 37800, Data test loss = -446.0394, imp_weight = -446.0394, opt_loss = 121.8464, reinforce_loss = 567.8859, rec_loss = -446.0394, num_step = 1.6406, num_step_acc = 0.1094, eval time = 0.01842s

Step 37900, Data train loss = -527.5250, imp_weight = -527.5250, opt_loss = 176.8105, reinforce_loss = 704.3356, rec_loss = -527.5250, num_step = 1.6372, num_step_acc = 0.2101, eval time = 0.1865s
Step 37900, Data test loss = -459.9946, imp_weight = -459.9946, opt_loss = 153.5429, reinforce_loss = 613.5375, rec_loss = -459.9946, num_step = 1.7031, num_step_acc = 0.1875, eval time = 0.02084s

Step 38000, Data train loss = -515.4128, imp_weight = -515.4128, opt_loss = 169.1501, reinforce_loss = 684.5629, rec_loss = -515.4128, num_step = 1.6510, num_step_acc = 0.1823, eval time = 0.193s
Step 38000, Da

Step 39900, Data train loss = -509.5479, imp_weight = -509.5479, opt_loss = 171.8835, reinforce_loss = 681.4314, rec_loss = -509.5479, num_step = 1.6181, num_step_acc = 0.1875, eval time = 0.1933s
Step 39900, Data test loss = -427.2770, imp_weight = -427.2770, opt_loss = 149.2167, reinforce_loss = 576.4937, rec_loss = -427.2770, num_step = 1.3594, num_step_acc = 0.2188, eval time = 0.02171s

Step 40000, Data train loss = -517.9697, imp_weight = -517.9697, opt_loss = 166.9837, reinforce_loss = 684.9534, rec_loss = -517.9697, num_step = 1.6753, num_step_acc = 0.1997, eval time = 0.1832s
Step 40000, Data test loss = -459.8980, imp_weight = -459.8980, opt_loss = 152.9902, reinforce_loss = 612.8882, rec_loss = -459.8980, num_step = 1.4375, num_step_acc = 0.2500, eval time = 0.01941s

Step 40100, Data train loss = -519.5110, imp_weight = -519.5110, opt_loss = 166.5736, reinforce_loss = 686.0846, rec_loss = -519.5110, num_step = 1.6458, num_step_acc = 0.1979, eval time = 0.2006s
Step 40100, D

Step 42000, Data train loss = -533.1051, imp_weight = -533.1051, opt_loss = 180.3431, reinforce_loss = 713.4482, rec_loss = -533.1051, num_step = 1.7049, num_step_acc = 0.2014, eval time = 0.1835s
Step 42000, Data test loss = -493.9705, imp_weight = -493.9705, opt_loss = 212.5876, reinforce_loss = 706.5581, rec_loss = -493.9705, num_step = 1.4219, num_step_acc = 0.2500, eval time = 0.01868s

Step 42100, Data train loss = -548.7127, imp_weight = -548.7127, opt_loss = 191.8028, reinforce_loss = 740.5154, rec_loss = -548.7127, num_step = 1.6302, num_step_acc = 0.2240, eval time = 0.1885s
Step 42100, Data test loss = -496.4141, imp_weight = -496.4141, opt_loss = 148.2234, reinforce_loss = 644.6375, rec_loss = -496.4141, num_step = 1.6562, num_step_acc = 0.1875, eval time = 0.01913s

Step 42200, Data train loss = -507.1909, imp_weight = -507.1909, opt_loss = 162.3281, reinforce_loss = 669.5189, rec_loss = -507.1909, num_step = 1.5972, num_step_acc = 0.1962, eval time = 0.1866s
Step 42200, D

Step 44100, Data train loss = -519.0144, imp_weight = -519.0144, opt_loss = 161.3598, reinforce_loss = 680.3743, rec_loss = -519.0144, num_step = 1.6997, num_step_acc = 0.1788, eval time = 0.1766s
Step 44100, Data test loss = -396.3869, imp_weight = -396.3869, opt_loss = 110.5204, reinforce_loss = 506.9072, rec_loss = -396.3869, num_step = 1.2656, num_step_acc = 0.1250, eval time = 0.01924s

Step 44200, Data train loss = -522.0273, imp_weight = -522.0273, opt_loss = 162.9852, reinforce_loss = 685.0125, rec_loss = -522.0273, num_step = 1.5816, num_step_acc = 0.1997, eval time = 0.1769s
Step 44200, Data test loss = -520.2247, imp_weight = -520.2247, opt_loss = 167.5754, reinforce_loss = 687.8000, rec_loss = -520.2247, num_step = 1.7969, num_step_acc = 0.2500, eval time = 0.01905s

Step 44300, Data train loss = -537.1570, imp_weight = -537.1570, opt_loss = 179.7123, reinforce_loss = 716.8693, rec_loss = -537.1570, num_step = 1.6389, num_step_acc = 0.2170, eval time = 0.1806s
Step 44300, D

Step 46200, Data train loss = -527.3890, imp_weight = -527.3890, opt_loss = 150.2316, reinforce_loss = 677.6205, rec_loss = -527.3890, num_step = 1.8264, num_step_acc = 0.1684, eval time = 0.1797s
Step 46200, Data test loss = -512.3695, imp_weight = -512.3695, opt_loss = 165.5273, reinforce_loss = 677.8969, rec_loss = -512.3695, num_step = 1.7344, num_step_acc = 0.2031, eval time = 0.02035s

Step 46300, Data train loss = -547.3203, imp_weight = -547.3203, opt_loss = 181.8695, reinforce_loss = 729.1898, rec_loss = -547.3203, num_step = 1.7205, num_step_acc = 0.1875, eval time = 0.1864s
Step 46300, Data test loss = -479.9910, imp_weight = -479.9910, opt_loss = 163.4704, reinforce_loss = 643.4614, rec_loss = -479.9910, num_step = 1.5938, num_step_acc = 0.2188, eval time = 0.01992s

Step 46400, Data train loss = -533.1684, imp_weight = -533.1684, opt_loss = 163.6843, reinforce_loss = 696.8527, rec_loss = -533.1684, num_step = 1.7066, num_step_acc = 0.1962, eval time = 0.1885s
Step 46400, D

Step 48300, Data train loss = -520.8590, imp_weight = -520.8590, opt_loss = 156.8571, reinforce_loss = 677.7160, rec_loss = -520.8590, num_step = 1.7031, num_step_acc = 0.1892, eval time = 0.1849s
Step 48300, Data test loss = -463.4720, imp_weight = -463.4720, opt_loss = 153.3071, reinforce_loss = 616.7791, rec_loss = -463.4720, num_step = 1.5625, num_step_acc = 0.1250, eval time = 0.02109s

Step 48400, Data train loss = -522.2873, imp_weight = -522.2873, opt_loss = 164.7909, reinforce_loss = 687.0783, rec_loss = -522.2873, num_step = 1.6701, num_step_acc = 0.1944, eval time = 0.1832s
Step 48400, Data test loss = -502.6870, imp_weight = -502.6870, opt_loss = 176.5934, reinforce_loss = 679.2805, rec_loss = -502.6870, num_step = 1.5625, num_step_acc = 0.2188, eval time = 0.01973s

Step 48500, Data train loss = -550.8311, imp_weight = -550.8311, opt_loss = 176.0951, reinforce_loss = 726.9261, rec_loss = -550.8311, num_step = 1.6580, num_step_acc = 0.2170, eval time = 0.1816s
Step 48500, D

Step 50400, Data train loss = -512.5403, imp_weight = -512.5403, opt_loss = 172.6905, reinforce_loss = 685.2308, rec_loss = -512.5403, num_step = 1.5677, num_step_acc = 0.1962, eval time = 0.1908s
Step 50400, Data test loss = -509.6325, imp_weight = -509.6325, opt_loss = 153.1296, reinforce_loss = 662.7621, rec_loss = -509.6325, num_step = 1.5312, num_step_acc = 0.3125, eval time = 0.01997s

Step 50500, Data train loss = -524.1016, imp_weight = -524.1016, opt_loss = 167.6671, reinforce_loss = 691.7687, rec_loss = -524.1016, num_step = 1.6528, num_step_acc = 0.1806, eval time = 0.1958s
Step 50500, Data test loss = -505.3898, imp_weight = -505.3898, opt_loss = 149.9531, reinforce_loss = 655.3429, rec_loss = -505.3898, num_step = 1.7812, num_step_acc = 0.1875, eval time = 0.01917s

Step 50600, Data train loss = -533.7029, imp_weight = -533.7029, opt_loss = 175.0206, reinforce_loss = 708.7236, rec_loss = -533.7029, num_step = 1.6875, num_step_acc = 0.2049, eval time = 0.1806s
Step 50600, D

Step 52500, Data train loss = -526.2784, imp_weight = -526.2784, opt_loss = 169.2506, reinforce_loss = 695.5290, rec_loss = -526.2784, num_step = 1.6753, num_step_acc = 0.1944, eval time = 0.1843s
Step 52500, Data test loss = -537.2550, imp_weight = -537.2550, opt_loss = 155.2615, reinforce_loss = 692.5165, rec_loss = -537.2550, num_step = 1.9531, num_step_acc = 0.1875, eval time = 0.01864s

Step 52600, Data train loss = -531.5662, imp_weight = -531.5662, opt_loss = 177.8089, reinforce_loss = 709.3751, rec_loss = -531.5662, num_step = 1.6094, num_step_acc = 0.2170, eval time = 0.1786s
Step 52600, Data test loss = -472.9211, imp_weight = -472.9211, opt_loss = 137.4695, reinforce_loss = 610.3906, rec_loss = -472.9211, num_step = 1.5000, num_step_acc = 0.1875, eval time = 0.01951s

Step 52700, Data train loss = -539.7419, imp_weight = -539.7419, opt_loss = 164.0017, reinforce_loss = 703.7436, rec_loss = -539.7419, num_step = 1.7135, num_step_acc = 0.2014, eval time = 0.1758s
Step 52700, D

Step 54600, Data train loss = -510.1770, imp_weight = -510.1770, opt_loss = 157.3479, reinforce_loss = 667.5249, rec_loss = -510.1770, num_step = 1.6458, num_step_acc = 0.1840, eval time = 0.1824s
Step 54600, Data test loss = -449.2332, imp_weight = -449.2332, opt_loss = 174.8979, reinforce_loss = 624.1310, rec_loss = -449.2332, num_step = 1.4219, num_step_acc = 0.1875, eval time = 0.01831s

Step 54700, Data train loss = -517.9703, imp_weight = -517.9703, opt_loss = 159.6107, reinforce_loss = 677.5810, rec_loss = -517.9703, num_step = 1.6788, num_step_acc = 0.1910, eval time = 0.179s
Step 54700, Data test loss = -432.5341, imp_weight = -432.5341, opt_loss = 165.8727, reinforce_loss = 598.4068, rec_loss = -432.5341, num_step = 1.4688, num_step_acc = 0.1562, eval time = 0.02149s

Step 54800, Data train loss = -547.8687, imp_weight = -547.8687, opt_loss = 164.5232, reinforce_loss = 712.3918, rec_loss = -547.8687, num_step = 1.7726, num_step_acc = 0.2135, eval time = 0.1817s
Step 54800, Da

Step 56700, Data train loss = -522.6772, imp_weight = -522.6772, opt_loss = 167.5658, reinforce_loss = 690.2431, rec_loss = -522.6772, num_step = 1.6615, num_step_acc = 0.2135, eval time = 0.1882s
Step 56700, Data test loss = -471.5461, imp_weight = -471.5461, opt_loss = 176.6602, reinforce_loss = 648.2063, rec_loss = -471.5461, num_step = 1.3750, num_step_acc = 0.1875, eval time = 0.01978s

Step 56800, Data train loss = -544.6801, imp_weight = -544.6801, opt_loss = 183.1226, reinforce_loss = 727.8028, rec_loss = -544.6801, num_step = 1.6979, num_step_acc = 0.2153, eval time = 0.2182s
Step 56800, Data test loss = -541.6903, imp_weight = -541.6903, opt_loss = 225.2733, reinforce_loss = 766.9636, rec_loss = -541.6903, num_step = 1.7031, num_step_acc = 0.3125, eval time = 0.02142s

Step 56900, Data train loss = -528.9903, imp_weight = -528.9903, opt_loss = 157.0322, reinforce_loss = 686.0226, rec_loss = -528.9903, num_step = 1.6910, num_step_acc = 0.2153, eval time = 0.2188s
Step 56900, D

Step 58800, Data train loss = -535.8003, imp_weight = -535.8003, opt_loss = 179.5202, reinforce_loss = 715.3205, rec_loss = -535.8003, num_step = 1.5955, num_step_acc = 0.2101, eval time = 0.1812s
Step 58800, Data test loss = -519.0376, imp_weight = -519.0376, opt_loss = 103.8773, reinforce_loss = 622.9149, rec_loss = -519.0376, num_step = 2.0625, num_step_acc = 0.1094, eval time = 0.02064s

Step 58900, Data train loss = -518.3272, imp_weight = -518.3272, opt_loss = 166.8809, reinforce_loss = 685.2081, rec_loss = -518.3272, num_step = 1.6128, num_step_acc = 0.1997, eval time = 0.1846s
Step 58900, Data test loss = -443.0969, imp_weight = -443.0969, opt_loss = 147.4014, reinforce_loss = 590.4983, rec_loss = -443.0969, num_step = 1.4375, num_step_acc = 0.1406, eval time = 0.01845s

Step 59000, Data train loss = -539.5394, imp_weight = -539.5394, opt_loss = 164.4881, reinforce_loss = 704.0275, rec_loss = -539.5394, num_step = 1.7170, num_step_acc = 0.2066, eval time = 0.174s
Step 59000, Da

Step 60900, Data train loss = -516.0235, imp_weight = -516.0235, opt_loss = 160.2089, reinforce_loss = 676.2324, rec_loss = -516.0235, num_step = 1.6215, num_step_acc = 0.1875, eval time = 0.1847s
Step 60900, Data test loss = -436.3336, imp_weight = -436.3336, opt_loss = 120.0161, reinforce_loss = 556.3497, rec_loss = -436.3336, num_step = 1.5156, num_step_acc = 0.1406, eval time = 0.01948s

Step 61000, Data train loss = -530.6974, imp_weight = -530.6974, opt_loss = 168.9806, reinforce_loss = 699.6780, rec_loss = -530.6974, num_step = 1.6528, num_step_acc = 0.2188, eval time = 0.1764s
Step 61000, Data test loss = -555.2385, imp_weight = -555.2385, opt_loss = 158.7855, reinforce_loss = 714.0239, rec_loss = -555.2385, num_step = 1.8125, num_step_acc = 0.2344, eval time = 0.02054s

Step 61100, Data train loss = -537.1973, imp_weight = -537.1973, opt_loss = 169.1288, reinforce_loss = 706.3260, rec_loss = -537.1973, num_step = 1.6059, num_step_acc = 0.2083, eval time = 0.1818s
Step 61100, D

Step 63000, Data train loss = -532.2750, imp_weight = -532.2750, opt_loss = 169.1756, reinforce_loss = 701.4507, rec_loss = -532.2750, num_step = 1.6545, num_step_acc = 0.1927, eval time = 0.1814s
Step 63000, Data test loss = -444.9018, imp_weight = -444.9018, opt_loss = 121.5356, reinforce_loss = 566.4374, rec_loss = -444.9018, num_step = 1.6406, num_step_acc = 0.0938, eval time = 0.01949s

Step 63100, Data train loss = -540.8747, imp_weight = -540.8747, opt_loss = 179.9709, reinforce_loss = 720.8456, rec_loss = -540.8747, num_step = 1.6181, num_step_acc = 0.2222, eval time = 0.189s
Step 63100, Data test loss = -447.3949, imp_weight = -447.3949, opt_loss = 150.1790, reinforce_loss = 597.5739, rec_loss = -447.3949, num_step = 1.5781, num_step_acc = 0.1562, eval time = 0.02116s

Step 63200, Data train loss = -531.7953, imp_weight = -531.7953, opt_loss = 164.7178, reinforce_loss = 696.5131, rec_loss = -531.7953, num_step = 1.7413, num_step_acc = 0.1736, eval time = 0.1857s
Step 63200, Da

Step 65100, Data train loss = -552.5668, imp_weight = -552.5668, opt_loss = 172.0290, reinforce_loss = 724.5958, rec_loss = -552.5668, num_step = 1.7465, num_step_acc = 0.1753, eval time = 0.1624s
Step 65100, Data test loss = -454.6965, imp_weight = -454.6965, opt_loss = 165.0411, reinforce_loss = 619.7377, rec_loss = -454.6965, num_step = 1.4688, num_step_acc = 0.2188, eval time = 0.01812s

Step 65200, Data train loss = -528.2681, imp_weight = -528.2681, opt_loss = 167.8688, reinforce_loss = 696.1368, rec_loss = -528.2681, num_step = 1.5868, num_step_acc = 0.2188, eval time = 0.18s
Step 65200, Data test loss = -473.2891, imp_weight = -473.2891, opt_loss = 137.1306, reinforce_loss = 610.4197, rec_loss = -473.2891, num_step = 1.5625, num_step_acc = 0.1250, eval time = 0.02002s

Step 65300, Data train loss = -542.6655, imp_weight = -542.6655, opt_loss = 191.6844, reinforce_loss = 734.3499, rec_loss = -542.6655, num_step = 1.6476, num_step_acc = 0.2413, eval time = 0.19s
Step 65300, Data 

Step 67200, Data train loss = -531.6320, imp_weight = -531.6320, opt_loss = 170.5569, reinforce_loss = 702.1889, rec_loss = -531.6320, num_step = 1.7587, num_step_acc = 0.1562, eval time = 0.1822s
Step 67200, Data test loss = -533.4117, imp_weight = -533.4117, opt_loss = 186.3832, reinforce_loss = 719.7949, rec_loss = -533.4117, num_step = 1.6875, num_step_acc = 0.2031, eval time = 0.0189s

Step 67300, Data train loss = -511.6466, imp_weight = -511.6466, opt_loss = 154.5808, reinforce_loss = 666.2274, rec_loss = -511.6466, num_step = 1.5920, num_step_acc = 0.1962, eval time = 0.1909s
Step 67300, Data test loss = -522.8671, imp_weight = -522.8671, opt_loss = 174.0374, reinforce_loss = 696.9044, rec_loss = -522.8671, num_step = 1.7656, num_step_acc = 0.1562, eval time = 0.02214s

Step 67400, Data train loss = -522.4062, imp_weight = -522.4062, opt_loss = 162.0821, reinforce_loss = 684.4883, rec_loss = -522.4062, num_step = 1.6979, num_step_acc = 0.1927, eval time = 0.1852s
Step 67400, Da

Step 69300, Data train loss = -509.6013, imp_weight = -509.6013, opt_loss = 156.6348, reinforce_loss = 666.2360, rec_loss = -509.6013, num_step = 1.6493, num_step_acc = 0.1562, eval time = 0.185s
Step 69300, Data test loss = -446.9808, imp_weight = -446.9808, opt_loss = 149.1141, reinforce_loss = 596.0948, rec_loss = -446.9808, num_step = 1.3750, num_step_acc = 0.2031, eval time = 0.01992s

Step 69400, Data train loss = -549.1737, imp_weight = -549.1737, opt_loss = 162.4861, reinforce_loss = 711.6597, rec_loss = -549.1737, num_step = 1.6667, num_step_acc = 0.2083, eval time = 0.1818s
Step 69400, Data test loss = -542.3424, imp_weight = -542.3424, opt_loss = 181.0300, reinforce_loss = 723.3724, rec_loss = -542.3424, num_step = 1.7031, num_step_acc = 0.2344, eval time = 0.01913s

Step 69500, Data train loss = -553.7348, imp_weight = -553.7348, opt_loss = 174.9475, reinforce_loss = 728.6823, rec_loss = -553.7348, num_step = 1.7049, num_step_acc = 0.2049, eval time = 0.193s
Step 69500, Dat

Step 71400, Data train loss = -533.1840, imp_weight = -533.1840, opt_loss = 168.5917, reinforce_loss = 701.7758, rec_loss = -533.1840, num_step = 1.6736, num_step_acc = 0.1823, eval time = 0.1909s
Step 71400, Data test loss = -474.9770, imp_weight = -474.9770, opt_loss = 134.4011, reinforce_loss = 609.3781, rec_loss = -474.9770, num_step = 1.6719, num_step_acc = 0.1094, eval time = 0.02036s

Step 71500, Data train loss = -533.8949, imp_weight = -533.8949, opt_loss = 163.9265, reinforce_loss = 697.8215, rec_loss = -533.8949, num_step = 1.6528, num_step_acc = 0.2101, eval time = 0.1878s
Step 71500, Data test loss = -535.8447, imp_weight = -535.8447, opt_loss = 217.0925, reinforce_loss = 752.9371, rec_loss = -535.8447, num_step = 1.7344, num_step_acc = 0.3125, eval time = 0.01921s

Step 71600, Data train loss = -533.4551, imp_weight = -533.4551, opt_loss = 165.6562, reinforce_loss = 699.1113, rec_loss = -533.4551, num_step = 1.6788, num_step_acc = 0.2326, eval time = 0.1756s
Step 71600, D

Step 73500, Data train loss = -523.5993, imp_weight = -523.5993, opt_loss = 165.0930, reinforce_loss = 688.6923, rec_loss = -523.5993, num_step = 1.6372, num_step_acc = 0.2188, eval time = 0.1788s
Step 73500, Data test loss = -543.5380, imp_weight = -543.5380, opt_loss = 209.0560, reinforce_loss = 752.5940, rec_loss = -543.5380, num_step = 1.7344, num_step_acc = 0.2188, eval time = 0.01903s

Step 73600, Data train loss = -534.2201, imp_weight = -534.2201, opt_loss = 161.9961, reinforce_loss = 696.2161, rec_loss = -534.2201, num_step = 1.6476, num_step_acc = 0.1892, eval time = 0.1843s
Step 73600, Data test loss = -493.6873, imp_weight = -493.6873, opt_loss = 144.9930, reinforce_loss = 638.6802, rec_loss = -493.6873, num_step = 1.6406, num_step_acc = 0.1875, eval time = 0.02007s

Step 73700, Data train loss = -520.7867, imp_weight = -520.7867, opt_loss = 171.2896, reinforce_loss = 692.0763, rec_loss = -520.7867, num_step = 1.6476, num_step_acc = 0.1944, eval time = 0.1839s
Step 73700, D

Step 75600, Data train loss = -527.3914, imp_weight = -527.3914, opt_loss = 165.5201, reinforce_loss = 692.9116, rec_loss = -527.3914, num_step = 1.6649, num_step_acc = 0.1806, eval time = 0.1835s
Step 75600, Data test loss = -475.4633, imp_weight = -475.4633, opt_loss = 140.0880, reinforce_loss = 615.5512, rec_loss = -475.4633, num_step = 1.5781, num_step_acc = 0.2344, eval time = 0.01902s

Step 75700, Data train loss = -529.4142, imp_weight = -529.4142, opt_loss = 161.8224, reinforce_loss = 691.2365, rec_loss = -529.4142, num_step = 1.6128, num_step_acc = 0.2240, eval time = 0.1819s
Step 75700, Data test loss = -491.8780, imp_weight = -491.8780, opt_loss = 138.7210, reinforce_loss = 630.5990, rec_loss = -491.8780, num_step = 1.7344, num_step_acc = 0.1250, eval time = 0.0195s

Step 75800, Data train loss = -521.9401, imp_weight = -521.9401, opt_loss = 156.1509, reinforce_loss = 678.0911, rec_loss = -521.9401, num_step = 1.7535, num_step_acc = 0.1615, eval time = 0.1882s
Step 75800, Da

Step 77700, Data train loss = -524.4464, imp_weight = -524.4464, opt_loss = 171.0422, reinforce_loss = 695.4885, rec_loss = -524.4464, num_step = 1.5938, num_step_acc = 0.2257, eval time = 0.1968s
Step 77700, Data test loss = -468.7150, imp_weight = -468.7150, opt_loss = 122.5704, reinforce_loss = 591.2854, rec_loss = -468.7150, num_step = 1.7344, num_step_acc = 0.1094, eval time = 0.02342s

Step 77800, Data train loss = -535.7700, imp_weight = -535.7700, opt_loss = 168.3706, reinforce_loss = 704.1406, rec_loss = -535.7700, num_step = 1.7083, num_step_acc = 0.2257, eval time = 0.1806s
Step 77800, Data test loss = -471.8835, imp_weight = -471.8835, opt_loss = 119.6674, reinforce_loss = 591.5508, rec_loss = -471.8835, num_step = 1.5781, num_step_acc = 0.1719, eval time = 0.01873s

Step 77900, Data train loss = -531.2885, imp_weight = -531.2885, opt_loss = 168.0867, reinforce_loss = 699.3751, rec_loss = -531.2885, num_step = 1.6458, num_step_acc = 0.1927, eval time = 0.1776s
Step 77900, D

Step 79800, Data train loss = -522.0952, imp_weight = -522.0952, opt_loss = 163.6660, reinforce_loss = 685.7613, rec_loss = -522.0952, num_step = 1.6007, num_step_acc = 0.2222, eval time = 0.2287s
Step 79800, Data test loss = -482.1916, imp_weight = -482.1916, opt_loss = 132.1759, reinforce_loss = 614.3676, rec_loss = -482.1916, num_step = 1.5781, num_step_acc = 0.1250, eval time = 0.0247s

Step 79900, Data train loss = -520.0135, imp_weight = -520.0135, opt_loss = 166.2109, reinforce_loss = 686.2244, rec_loss = -520.0135, num_step = 1.6441, num_step_acc = 0.1979, eval time = 0.2628s
Step 79900, Data test loss = -509.9578, imp_weight = -509.9578, opt_loss = 146.7956, reinforce_loss = 656.7534, rec_loss = -509.9578, num_step = 1.6875, num_step_acc = 0.2344, eval time = 0.02842s



KeyboardInterrupt: 

In [None]:
gt, imgs = sess.run([x, air.canvas])

In [None]:
nt = imgs.shape[0]
fig, axes = plt.subplots(2, nt, figsize=(nt*2, 4), sharex=True, sharey=True)
axes = axes.reshape((2, nt))
n = np.random.randint(imgs.shape[1])
for t, ax in enumerate(axes.T):
    ax[0].imshow(imgs[t, n], cmap='gray', vmin=0., vmax=1.)
    ax[1].imshow(gt[t, n], cmap='gray', vmin=0., vmax=1.)
    for a in ax:
        a.grid(False)