In [None]:
from os import path as osp
import numpy as np
import tensorflow as tf
import sonnet as snt
from attrdict import AttrDict

from evaluation import make_fig, make_logger

from data import load_data, tensors_from_data
from mnist_model import (AIRonMNIST, NVILEstimatorWithBaseline, ImportanceWeightedNVILEstimatorWithBaseline,
                         KLMixin, KLNoStepsGradMixin, KLBySamplingMixin)

In [None]:
learning_rate = 1e-5
n_steps = 3

results_dir = '../results'
run_name = 'sampled_kl'

logdir = osp.join(results_dir, run_name)
checkpoint_name = osp.join(logdir, 'model.ckpt')
axes = {'imgs': 0, 'labels': 0, 'nums': 1}

In [None]:
batch_size = 64

use_reinforce = True
discrete_steps = True
step_bias = 1.
transform_var_bias = -3.
iw_samples = 5

In [None]:
valid_data = load_data('mnist_validation.pickle')
train_data = load_data('mnist_train.pickle')

In [None]:
tf.reset_default_graph()
train_tensors = tensors_from_data(train_data, batch_size, axes, shuffle=True)
valid_tensors = tensors_from_data(valid_data, batch_size, axes, shuffle=False)
x, valid_x = train_tensors['imgs'], valid_tensors['imgs']
y, test_y = train_tensors['nums'], valid_tensors['nums']
    
n_hidden = 32 * 8
n_layers = 2
n_hiddens = [n_hidden] * n_layers

class ConcreteAIR(AIRonMNIST, ImportanceWeightedNVILEstimatorWithBaseline, KLBySamplingMixin):
    importance_resample = False
    use_r_imp_weight = True
    
air = ConcreteAIR(x,
                max_steps=n_steps, 
                inpt_encoder_hidden=n_hiddens,
                glimpse_encoder_hidden=n_hiddens,
                glimpse_decoder_hidden=n_hiddens,
                transform_estimator_hidden=n_hiddens,
                steps_pred_hidden=[128, 64],
                baseline_hidden=[256, 128],
                transform_var_bias=transform_var_bias,
                step_bias=step_bias,
                discrete_steps=discrete_steps,
                iw_samples=iw_samples)

In [None]:
train_step, global_step = air.train_step(learning_rate, nums=y)

In [None]:
print 'AIR tensors:'
for k in dir(air):
    v = getattr(air, k)
    if isinstance(v, tf.Tensor):
        print k, v.shape

In [None]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
    
sess = tf.Session(config=config)
all_summaries = tf.summary.merge_all()

In [None]:
sess.run(tf.global_variables_initializer())

In [None]:
summary_writer = tf.summary.FileWriter(logdir, sess.graph)
saver = tf.train.Saver()

In [None]:
ax = axes['imgs']
factor = 100
train_batches, valid_batches = [i['imgs'].shape[ax]//factor for i in (train_data, valid_data)]
log = make_logger(air, sess, summary_writer, train_tensors, train_batches, valid_tensors, valid_batches)

In [None]:
train_itr = sess.run(global_step)
print 'Starting training at iter = {}'.format(train_itr)

if train_itr == 0:
    log(0)

while train_itr < 1e6:
        
    train_itr, _ = sess.run([global_step, train_step])
    
    if train_itr % 1000 == 0:
        summaries = sess.run(all_summaries)
        summary_writer.add_summary(summaries, train_itr)
        
    if train_itr % 500 == 0:
        log(train_itr)
        
    if train_itr % 5000 == 0:
        saver.save(sess, checkpoint_name, global_step=train_itr)
        make_fig(air, sess, logdir, train_itr)    

In [None]:
make_fig(air, sess, n_samples=64)