In [1]:
from os import path as osp
import tensorflow as tf

from evaluation import make_fig, make_logger
from experiment_tools import load, init_checkpoint, parse_flags, print_flags, set_flags

In [2]:
import sys
sys.path.append('../')

In [3]:
# Define flags
flags = tf.flags

flags.DEFINE_string('data_config', 'configs/static_mnist_data.py', '')
flags.DEFINE_string('model_config', 'configs/imp_weighted_nvil.py', '')
flags.DEFINE_string('results_dir', '../checkpoints', '')
flags.DEFINE_string('run_name', 'test_run', '')

flags.DEFINE_integer('batch_size', 64, '')

flags.DEFINE_integer('summary_every', 1000, '')
flags.DEFINE_integer('log_every', 5000, '')
flags.DEFINE_integer('save_every', 5000, '')
flags.DEFINE_integer('max_train_iter', int(3 * 1e5), '')
flags.DEFINE_boolean('restart', False, '')

flags.DEFINE_float('eval_size_fraction', .01, '')

# Parse flags
parse_flags()
F = flags.FLAGS

In [4]:
set_flags(
    log_every=500,
    eval_size_fraction=0.01
)

In [5]:
# Prepare enviornment
logdir = osp.join(F.results_dir, F.run_name)
logdir, flags, restart_checkpoint = init_checkpoint(logdir, F.data_config, F.model_config, F.restart)
checkpoint_name = osp.join(logdir, 'model.ckpt')

In [6]:
# Build the graph
tf.reset_default_graph()
data_dict = load(F.data_config, F.batch_size)
air, train_step, global_step = load(F.model_config, img=data_dict.train_img, num=data_dict.train_num)

print_flags()

Loading 'static_mnist_data' from configs/static_mnist_data.pyc
Loading 'imp_weighted_nvil' from configs/imp_weighted_nvil.pyc
iw_samples 5
KL by sampling!
KL by sampling!
KL by sampling!
KL by sampling!
constructed baseline
iw resampling!
resampled [<tf.Tensor 'Gather:0' shape=(64,) dtype=float32>]
iw resampling!
resampled [<tf.Tensor 'loss_1/Gather:0' shape=(64,) dtype=float32>]
iw resampling!
resampled [<tf.Tensor 'loss_1/Gather_1:0' shape=(64,) dtype=float32>]
iw resampling!
resampled [<tf.Tensor 'loss_1/Gather_2:0' shape=(64,) dtype=float32>]
iw resampling!
resampled [<tf.Tensor 'loss_1/Gather_3:0' shape=(64,) dtype=float32>]
iw resampling!
resampled [<tf.Tensor 'loss_1/Gather_4:0' shape=(64,) dtype=float32>]
iw resampling!
resampled [<tf.Tensor 'loss_1/Gather_5:0' shape=(64,) dtype=float32>]
AIR tensors:
baseline (64, 1)
baseline_loss ()
canvas (320, 50, 50)
elbo_importance_weights (64, 5)
glimpse (320, 3, 20, 20)
gt_num_steps (64,)
kl_div ()
kl_div_per_sample (320,)
kl_num_steps 

In [7]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

In [8]:
sess.run(tf.global_variables_initializer())

In [9]:
saver = tf.train.Saver()
if restart_checkpoint is not None:
    print "Restoring from '{}'",format(restart_checkpoint)
    saver.restore(sess, restart_checkpoint)

In [10]:
summary_writer = tf.summary.FileWriter(logdir, sess.graph)
all_summaries = tf.summary.merge_all()

In [11]:
# Logging
ax = data_dict['axes']['imgs']
factor = F.eval_size_fraction
train_batches, valid_batches = [int(data_dict[k]['imgs'].shape[ax] * factor) for k in ('train_data', 'valid_data')]
log = make_logger(air, sess, summary_writer, data_dict.train_tensors,
                  train_batches, data_dict.valid_tensors, valid_batches)

make_logger: unable to log all expressions:
	Skipping nums_xe
	Skipping l2_loss


In [12]:
train_itr = sess.run(global_step)
print 'Starting training at iter = {}'.format(train_itr)

if train_itr == 0:
    log(0)

while train_itr < F.max_train_iter:

    train_itr, _ = sess.run([global_step, train_step])

    if train_itr % F.summary_every == 0:
        summaries = sess.run(all_summaries)
        summary_writer.add_summary(summaries, train_itr)

    if train_itr % F.log_every == 0:
        log(train_itr)

    if train_itr % F.save_every == 0:
        saver.save(sess, checkpoint_name, global_step=train_itr)
        make_fig(air, sess, logdir, train_itr)

Starting training at iter = 0
Step 0, Data train baseline_loss = 119003.3220, kl_what = 0.1543, nelbo = -344.0491, num_step_acc = 0.3264, kl_num_steps = 14.5199, reinforce_loss = 2313.2861, rec_loss = -366.3901, num_step = 0.6111, kl_div = 20.7757, kl_where = 6.1384, proxy_loss = 1969.2370, eval time = 0.9988s
Step 0, Data test baseline_loss = 109291.7031, kl_what = 0.2041, nelbo = -319.3570, num_step_acc = 0.4062, kl_num_steps = 14.4977, reinforce_loss = 2098.5679, rec_loss = -342.6302, num_step = 0.6875, kl_div = 21.7059, kl_where = 7.0041, proxy_loss = 1779.2109, eval time = 0.2062s

Step 500, Data train baseline_loss = 12288.2697, kl_what = 0.1719, nelbo = -631.5133, num_step_acc = 0.3160, kl_num_steps = 15.0484, reinforce_loss = -152.8551, rec_loss = -651.6534, num_step = 0.4253, kl_div = 18.7251, kl_where = 3.5292, proxy_loss = -784.3684, eval time = 0.9819s
Step 500, Data test baseline_loss = 16116.3252, kl_what = 1.0604, nelbo = -595.7603, num_step_acc = 0.2344, kl_num_steps = 

KeyboardInterrupt: 

In [13]:
make_fig(air, sess, n_samples=64) 

In [14]:
s = air.iw_distrib.sample()

In [15]:
sess.run(s)

array([1, 3, 2, 4, 3, 4, 4, 1, 2, 1, 2, 3, 2, 0, 2, 1, 3, 0, 3, 3, 0, 1, 0,
       1, 2, 3, 0, 1, 0, 2, 1, 0, 4, 4, 3, 4, 0, 1, 0, 1, 4, 0, 1, 0, 1, 1,
       4, 2, 4, 1, 4, 1, 0, 3, 1, 2, 0, 0, 0, 3, 1, 2, 0, 3], dtype=int32)

In [16]:
d = air.iw_distrib
d.logits
tf.distributions

AttributeError: 'module' object has no attribute 'distributions'

In [None]:
dir(air.iw_distrib.__class__)