In [1]:
from os import path as osp
import numpy as np
import tensorflow as tf
import sonnet as snt

from tensorflow.contrib.distributions import Bernoulli

import matplotlib.pyplot as plt
% matplotlib inline

from neurocity import minimize_clipped
from neurocity.tools.params import num_trainable_params

from tf_tools.eval import make_expr_logger

from data import load_data, tensors_from_data
from model import AIRCell
from ops import Loss

In [2]:
learning_rate = 1e-4
batch_size = 64
img_size = 50, 50
crop_size = 20, 20
n_latent = 50
n_hidden = 256
n_steps = 3

results_dir = '../results'
run_name = 'discrete'

logdir = osp.join(results_dir, run_name)
checkpoint_name = osp.join(logdir, 'model.ckpt')
axes = {'imgs': 0, 'labels': 0, 'nums': 1}

In [3]:
prior_weight = 0

num_steps_prior = 0.
latent_code_prior = None

use_reinforce = True

In [4]:
test_data = load_data('mnist_test.pickle')
train_data = load_data('mnist_train.pickle')

In [5]:
# imgs = train_data['imgs']
# nums = train_data['nums']

# fig, fig_axes = plt.subplots(8, 8, figsize=(32, 32))
# idx = np.random.choice(imgs.shape[0], 64)
# for i, ax in zip(idx, fig_axes.flatten()):
#     ax.imshow(imgs[i], cmap='gray')
#     num_str = ' '.join([str(n) for n in nums[:, i].squeeze()])
#     ax.set_title(num_str)

In [6]:
tf.reset_default_graph()
train_tensors = tensors_from_data(train_data, batch_size, axes, shuffle=True)
test_tensors = tensors_from_data(test_data, batch_size, axes, shuffle=False)
x, test_x = train_tensors['imgs'], test_tensors['imgs']
y, test_y = train_tensors['nums'], test_tensors['nums']

transition = snt.LSTM(n_hidden)
air = AIRCell(img_size, crop_size, n_latent, transition, max_crop_size=1.0,
              presence_bias=1.,
              explore_eps=.0,
              debug=True)

initial_state = air.initial_state(x)

dummy_sequence = tf.zeros((n_steps, batch_size, 1), name='dummy_sequence')
outputs, state = tf.nn.dynamic_rnn(air, dummy_sequence, initial_state=initial_state, time_major=True)
canvas, cropped, what, where, presence_logit, presence = outputs
presence_prob = tf.nn.sigmoid(presence_logit)

with tf.variable_scope('notebook'):
    cropped = tf.reshape(presence * tf.nn.sigmoid(cropped), (n_steps, batch_size,) + tuple(crop_size))
    canvas = tf.reshape(canvas, (n_steps, batch_size,) + tuple(img_size))
    prob_canvas = tf.nn.sigmoid(canvas)
    final_canvas = canvas[-1]
    
    
with tf.variable_scope('baseline'):
    baseline = snt.TrainableVariable([], initializers={'w': tf.zeros_initializer()}, name='constant_baseline')
    
baseline_tensor = baseline()

In [7]:
print num_trainable_params()

990483.0


In [8]:
###    Loss #################################################################################
loss = Loss()
prior_loss = Loss()

###    Reconstruction Loss ##################################################################
rec_loss_per_sample = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=final_canvas)
rec_loss_per_sample = tf.reduce_sum(rec_loss_per_sample, axis=(1, 2))
rec_loss = tf.reduce_mean(rec_loss_per_sample)
tf.summary.scalar('rec_loss', rec_loss)

loss.add(rec_loss, rec_loss_per_sample)
# # ###    Prior Loss ###########################################################################

if prior_weight > 0.:
    
    if num_steps_prior is not None:    
        num_steps_prior_loss_per_sample = tf.squeeze((tf.reduce_sum(presence, 0) - num_steps_prior) ** 2)
        num_steps_prior_loss = tf.reduce_mean(num_steps_prior_loss_per_sample)
        tf.summary.scalar('num_steps_prior_loss', num_steps_prior_loss)

        prior_loss.add(num_steps_prior_loss, num_steps_prior_loss_per_sample)


    if latent_code_prior is not None:
        latent_code_prior_loss_per_sample = tf.reduce_mean(tf.reduce_sum((what - latent_code_prior) ** 2, -1), 0)
        latent_code_prior_loss = tf.reduce_mean(latent_code_prior_loss_per_sample)
        tf.summary.scalar('latent_code_prior_loss', latent_code_prior_loss)

        prior_loss.add(latent_code_prior_loss, latent_code_prior_loss_per_sample)

    tf.summary.scalar('prior_loss', prior_loss._value)
    loss.add(prior_loss, weight=prior_weight)

# ###   REINFORCE ############################################################################

opt_loss = loss.value
if use_reinforce:
    # clipped_presence_prob = tf.clip_by_value(presence_prob, 1e-7, 1. - 1e-7)
    # log_prob = Bernoulli(probs=clipped_presence_prob).log_prob(presence)
    # log_prob = tf.squeeze(tf.reduce_mean(log_prob, 0))

    # instead of maximising probability we'll minimise cross-entropy, where labels are the taken actions
    log_prob = tf.nn.sigmoid_cross_entropy_with_logits(labels=presence, logits=presence_logit)

    importance_weight = loss._per_sample
    importance_weight -= baseline()

    reinforce_loss_per_sample = tf.stop_gradient(importance_weight) * log_prob
    reinforce_loss = tf.reduce_mean(reinforce_loss_per_sample)
    tf.summary.scalar('reinforce_loss', reinforce_loss)

    opt_loss -= reinforce_loss

    
    
### Optimizer #################################################################################
lr_tensor = tf.Variable(learning_rate, name='learning_rate', trainable=False)
opt = tf.train.RMSPropOptimizer(lr_tensor, momentum=.9, centered=True)
# true_train_step = opt.minimize(opt_loss)
true_train_step = minimize_clipped(opt, opt_loss, clip_value=.3, normalize_by_num_params=True)

# # ### Baseline Optimisation ##################################################################################
baseline_target = loss.value
baseline_loss = (baseline_target - baseline()) ** 2
tf.summary.scalar('baseline_loss', baseline_loss)

baseline_opt = tf.train.RMSPropOptimizer(10 * lr_tensor, momentum=.9, centered=True)
baseline_train_step = baseline_opt.minimize(baseline_loss)

###    Train Step ##############################################################################
train_step = [true_train_step, baseline_train_step]

###    Metrics #################################################################################
gt_num = tf.reduce_sum(y, 0)
pred_num = tf.reduce_sum(presence, 0)
num_step_accuracy = tf.reduce_mean(tf.to_float(tf.equal(gt_num, pred_num)))
num_step = tf.reduce_mean(tf.to_float(pred_num))

In [9]:
vs = tf.trainable_variables()
if 'baseline' in locals():
    vs = list(set(vs) - set([baseline]))
gs = tf.gradients(opt_loss, vs)

for v, g in zip(vs, gs):
    if g is None:
        print 'Skipping', v.name
    else:
        assert v.get_shape() == g.get_shape(), v.name

named_grads = {v.name: g for v, g in zip(vs, gs) if g is not None}

Skipping baseline/constant_baseline/w:0


In [10]:
def grad_variance(n=10, sort_by_var=True):
    gs = {k: [] for k in named_grads}
    for i in xrange(n):
        values = sess.run(named_grads)
        for k, v in values.iteritems():
            gs[k].append(v)

    for k, v in gs.iteritems():
        v = np.stack(v, 0).reshape((n, -1))
        gs[k] = np.var(v, 0).mean()
        
    sort_idx = 1 if sort_by_var else 0
    gs = sorted(gs.items(), key=lambda x: x[sort_idx], reverse=True)
    return gs

def print_grad_variance():
    grad_vars = grad_variance(10)
    print
    for g in grad_vars:
        if g[1] > 1e-2:
            print g
    print

In [11]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
all_summaries = tf.summary.merge_all()

In [12]:
summary_writer = tf.summary.FileWriter(logdir)
saver = tf.train.Saver()

In [13]:
imgs = train_data['imgs']
presence_gt = train_data['nums']
train_itr = -1

In [None]:
def make_fig(checkpoint_dir, global_step):
    xx, pred_canvas, pred_crop, pres = sess.run([x, prob_canvas, cropped, presence])

    max_imgs = 10
    bs = min(max_imgs, batch_size)
    scale = 1.
    figsize = scale * np.asarray((bs, 2 * n_steps + 1))
    fig, axes = plt.subplots(2 * n_steps + 1, bs, figsize=figsize)

    for i, ax in enumerate(axes[0]):
        ax.imshow(xx[i], cmap='gray', vmin=0, vmax=1)

    for i, ax_row in enumerate(axes[1:1+n_steps]):
        for j, ax in enumerate(ax_row):
            ax.imshow(pred_canvas[i, j], cmap='gray', vmin=0, vmax=1)

    for i, ax_row in enumerate(axes[1+n_steps:]):
        for j, ax in enumerate(ax_row):
            ax.imshow(pred_crop[i, j], cmap='gray', vmin=0, vmax=1)
            ax.set_title('{:.02f}'.format(pres[i, j, 0]), fontsize=4*scale)

    for ax in axes.flatten():
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

    
    fig_name = osp.join(checkpoint_dir, 'progress_fig_{}.png'.format(global_step))
    fig.savefig(fig_name, dpi=300)
    plt.close('all')
    
exprs = {
#     'loss': loss,
    'rec_loss': rec_loss,
    'prior_loss': prior_loss.value,
    'reinforce_loss': reinforce_loss,
    'baseline_loss': baseline_loss,
    'num_step_acc': num_step_accuracy,
    'num_step': num_step
}
train_log = make_expr_logger(sess, summary_writer, train_data['imgs'].shape[0] / batch_size, exprs, name='train')
test_log = make_expr_logger(sess, summary_writer, test_data['imgs'].shape[0] / batch_size, exprs, name='test', data_dict={x: test_x, y: test_y})

def log(train_itr):
    train_log(train_itr)
    test_log(train_itr)

In [None]:
for train_itr in xrange(train_itr+1, int(1e7)):
        
    sess.run(train_step)
    if train_itr % 1000 == 0:
        summaries = sess.run(all_summaries)
        summary_writer.add_summary(summaries, train_itr)
        
    if train_itr % 1000 == 0:
        log(train_itr)
        
    if train_itr % 1000 == 0:
#         saver.save(sess, checkpoint_name, global_step=train_itr)
        make_fig(logdir, train_itr)    
    
    if train_itr % 1000 == 0:
        print 'baseline value = {}'.format(sess.run(baseline_tensor))
        print_grad_variance()

Step 0, Data train prior_loss = 0.0000, baseline_loss = 265755.3693, reinforce_loss = 296.3327, rec_loss = 512.4475, num_step = 2.2046, num_step_acc = 0.1993, eval time = 17.28s
Step 0, Data test prior_loss = 0.0000, baseline_loss = 235370.5521, reinforce_loss = 283.1625, rec_loss = 485.1533, num_step = 2.1885, num_step_acc = 0.1917, eval time = 0.3637s
baseline value = 0.00333332177252

(u'AIRCell/presence/linear_1/b:0', 476.58575)
(u'AIRCell/canvas_value:0', 13.716306)
(u'AIRCell/presence/linear/b:0', 8.7405243)
(u'AIRCell/presence/linear_1/w:0', 0.41233018)
(u'AIRCell/Encoder/linear/b:0', 0.33140901)
(u'lstm_initial_state_1/w:0', 0.13984928)
(u'AIRCell/rnn_inpt/linear/b:0', 0.11753002)
(u'AIRCell/Encoder/linear_1/b:0', 0.10108509)
(u'lstm/b_gates:0', 0.10018014)
(u'AIRCell/Encoder_1/linear_1/b:0', 0.085266963)
(u'where_init:0', 0.049342848)
(u'AIRCell/Encoder_1/linear/b:0', 0.043961104)
(u'lstm_initial_state_0/w:0', 0.022068784)
(u'AIRCell/Decoder/linear/b:0', 0.020463252)
(u'what_i

In [None]:
# named_outputs = {
#     'canvas': canvas,
#     'cropped': cropped,
#     'what': what,
#     'where': where, 
#     'presence_logit': presence_logit,
#     'presence': presence,
#     'rec_loss': rec_loss,
#     'imp_weight': importance_weight,
#     'log_prob': log_prob,
#     'baseline_loss': baseline_loss,
#     'clipped_pres': clipped_presence_prob,
#     'pres_prob': presence_prob
# }

In [None]:
from tensorflow.contrib.distributions.python.ops import kullback_leibler

In [None]:
for d in kullback_leibler._DIVERGENCES:
    print d[0].__name__, d[1].__name__

In [None]:
for e in exprs:
    print e, exprs[e]

In [None]:
ee = sess.run(exprs)

In [None]:
baseline_tensor = baseline()