In [1]:
from os import path as osp
import tensorflow as tf

from evaluation import make_fig, make_logger
from experiment_tools import load, init_checkpoint, parse_flags, print_flags, set_flags

In [2]:
import sys
sys.path.append('../')

In [3]:
# Define flags
flags = tf.flags

flags.DEFINE_string('data_config', 'configs/static_mnist_data.py', '')
flags.DEFINE_string('model_config', 'configs/imp_weighted_nvil.py', '')
flags.DEFINE_string('results_dir', '../checkpoints', '')
flags.DEFINE_string('run_name', 'test_run', '')

flags.DEFINE_integer('batch_size', 64, '')

flags.DEFINE_integer('summary_every', 1000, '')
flags.DEFINE_integer('log_every', 5000, '')
flags.DEFINE_integer('save_every', 5000, '')
flags.DEFINE_integer('max_train_iter', int(3 * 1e5), '')
flags.DEFINE_boolean('restart', False, '')

flags.DEFINE_float('eval_size_fraction', .01, '')

# # Parse flags
# parse_flags()
# F = flags.FLAGS

In [4]:
set_flags(
    log_every=500,
    eval_size_fraction=0.01,
    model_config='configs/vimco.py',
    learning_rate=1e-5,
    vimco_per_sample_control=True,
)

# Parse flags
parse_flags()
F = flags.FLAGS

In [5]:
# Prepare enviornment
logdir = osp.join(F.results_dir, F.run_name)
logdir, flags, restart_checkpoint = init_checkpoint(logdir, F.data_config, F.model_config, F.restart)
checkpoint_name = osp.join(logdir, 'model.ckpt')

loading flags from configs/vimco.py
loading flags from configs/static_mnist_data.py


In [6]:
# Build the graph
tf.reset_default_graph()
data_dict = load(F.data_config, F.batch_size)
air, train_step, global_step = load(F.model_config, img=data_dict.train_img, num=data_dict.train_num)

print_flags()

Loading 'static_mnist_data' from configs/static_mnist_data.pyc
Loading 'vimco' from configs/vimco.pyc
iw_samples 5
Flags:
batch_size: 64
data_config: configs/static_mnist_data.py
eval_size_fraction: 0.01
git_commit: 9ee473d9d43bdeb436ae9738d62a78b922bde5c6
importance_resample: False
learning_rate: 1e-05
log_every: 500
max_train_iter: 300000
model_config: configs/vimco.py
n_iw_samples: 5
n_steps_per_image: 3
restart: False
results_dir: ../checkpoints
run_name: test_run
save_every: 5000
step_bias: 1.0
summary_every: 1000
train_path: mnist_train.pickle
transform_var_bias: -3.0
use_r_imp_weight: True
valid_path: mnist_validation.pickle
vimco_per_sample_control: True


In [7]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

In [8]:
sess.run(tf.global_variables_initializer())

In [9]:
saver = tf.train.Saver()
if restart_checkpoint is not None:
    print "Restoring from '{}'",format(restart_checkpoint)
    saver.restore(sess, restart_checkpoint)

In [10]:
summary_writer = tf.summary.FileWriter(logdir, sess.graph)
all_summaries = tf.summary.merge_all()

In [11]:
# Logging
ax = data_dict['axes']['imgs']
factor = F.eval_size_fraction
train_batches, valid_batches = [int(data_dict[k]['imgs'].shape[ax] * factor) for k in ('train_data', 'valid_data')]
log = make_logger(air, sess, summary_writer, data_dict.train_tensors,
                  train_batches, data_dict.valid_tensors, valid_batches)

make_logger: unable to log all expressions:
	Skipping baseline_loss
	Skipping nums_xe
	Skipping l2_loss


In [12]:
import numpy as np
a, b = sess.run([air.nelbo_per_sample, air.baseline])
print abs(a).mean(), abs(b).mean(), abs(a-b).mean(), (a-b).mean()
# print np.concatenate((a, b), -1)


388.31 425.072 62.6797 53.9437


In [13]:
# o = sess.run(air.control)
# # o = sess.run(air.biggest)
# # o = sess.run(air.second_biggest)
# # o = sess.run(air.all_but_one_average)
# # o = sess.run(air.summed_exped_per_sample_elbo)
# # o = sess.run(air.summed_exped_per_sample_elbo - air.exped_per_sample_elbo)
# # o = sess.run(air.all_but_one_average - air.control)
# #
# print np.isnan(o).any(), o.min(), o.mean(), o.max()
# print o

In [None]:
print a.shape, b.shape

(64, 1) (64, 5)


In [None]:
train_itr = sess.run(global_step)
print 'Starting training at iter = {}'.format(train_itr)

if train_itr == 0:
    log(0)

while train_itr < F.max_train_iter:

    train_itr, _ = sess.run([global_step, train_step])

    if train_itr % F.summary_every == 0:
        summaries = sess.run(all_summaries)
        summary_writer.add_summary(summaries, train_itr)

    if train_itr % F.log_every == 0:
        log(train_itr)

    if train_itr % F.save_every == 0:
        saver.save(sess, checkpoint_name, global_step=train_itr)
        make_fig(air, sess, logdir, train_itr)

Starting training at iter = 0
Step 0, Data train kl_what = 1.4929, kl_shift = 8.4830, nelbo = -342.9016, num_step_acc = 0.2083, kl_num_steps = 14.6062, reinforce_loss = -371.0725, rec_loss = 355.9470, num_step = 1.6875, kl_scale = 8.4921, kl_div = 33.0742, kl_where = 16.9751, proxy_loss = -713.9741, eval time = 0.8359s
Step 0, Data test kl_what = 0.6001, kl_shift = 7.1789, nelbo = -392.7419, num_step_acc = 0.1719, kl_num_steps = 14.5929, reinforce_loss = -389.8730, rec_loss = 174.4362, num_step = 1.4219, kl_scale = 7.0753, kl_div = 29.4472, kl_where = 14.2542, proxy_loss = -782.6150, eval time = 0.1787s

Step 500, Data train kl_what = 1.0298, kl_shift = 4.1411, nelbo = -623.8975, num_step_acc = 0.3403, kl_num_steps = 15.5134, reinforce_loss = -77.7015, rec_loss = -144.3114, num_step = 0.7431, kl_scale = 4.2879, kl_div = 24.9721, kl_where = 8.4290, proxy_loss = -701.5990, eval time = 0.9346s
Step 500, Data test kl_what = 1.3271, kl_shift = 5.8700, nelbo = -596.7224, num_step_acc = 0.328

In [None]:
# make_fig(air, sess, n_samples=64) 