In [None]:
from os import path as osp
import numpy as np
import tensorflow as tf
import sonnet as snt

import matplotlib.pyplot as plt
% matplotlib inline

from neurocity.tools.params import num_trainable_params

from tf_tools.eval import make_expr_logger

from data import load_data, tensors_from_data
from model import AIRCell

In [None]:
learning_rate = 1e-4
batch_size = 64
img_size = 50, 50
crop_size = 20, 20
n_latent = 50
n_hidden = 256
n_steps = 3

results_dir = '../results'
run_name = 'sigmoid_scale_uniform_canvas_bias_lower_pres_weight2'

logdir = osp.join(results_dir, run_name)
checkpoint_name = osp.join(logdir, 'model.ckpt')
axes = {'imgs': 0, 'labels': 0, 'nums': 1}

In [None]:
test_data = load_data('mnist_test.pickle')
train_data = load_data('mnist_train.pickle')

In [None]:
imgs = train_data['imgs']
nums = train_data['nums']

fig, fig_axes = plt.subplots(8, 8, figsize=(32, 32))
idx = np.random.choice(imgs.shape[0], 64)
for i, ax in zip(idx, fig_axes.flatten()):
    ax.imshow(imgs[i], cmap='gray')
    num_str = ' '.join([str(n) for n in nums[:, i].squeeze()])
    ax.set_title(num_str)

In [None]:
tf.reset_default_graph()
train_tensors = tensors_from_data(train_data, batch_size, axes, shuffle=True)
test_tensors = tensors_from_data(test_data, batch_size, axes, shuffle=False)
x, test_x = train_tensors['imgs'], test_tensors['imgs']
y, test_y = train_tensors['nums'], test_tensors['nums']

transition = snt.LSTM(n_hidden)
air = AIRCell(img_size, crop_size, n_latent, transition, max_crop_size=1.0)
initial_state = air.initial_state(x)

dummy_sequence = tf.zeros((n_steps, batch_size, 1), name='dummy_sequence')
outputs, state = tf.nn.dynamic_rnn(air, dummy_sequence, initial_state=initial_state, time_major=True)
canvas, cropped, what, where, presence_logit = outputs
presence = tf.nn.sigmoid(presence_logit)

cropped = tf.reshape(presence * tf.nn.sigmoid(cropped), (n_steps, batch_size,) + tuple(crop_size))
canvas = tf.reshape(canvas, (n_steps, batch_size,) + tuple(img_size))
prob_canvas = tf.nn.sigmoid(canvas)
final_canvas = canvas[-1]

In [None]:
print num_trainable_params()

In [None]:
# loss = ((x - final_canvas)**2
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=final_canvas)

rec_loss = tf.reduce_mean(tf.reduce_sum(loss, axis=(1, 2)))
tf.summary.scalar('rec_loss', rec_loss)

alpha= 1.
num_steps_penalty = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=presence_logit)
num_steps_penalty = tf.reduce_mean(num_steps_penalty)
tf.summary.scalar('steps_loss', num_steps_penalty)

loss = rec_loss + alpha * num_steps_penalty
tf.summary.scalar('loss', loss)

# opt = tf.train.AdamOptimizer(learning_rate)
opt = tf.train.RMSPropOptimizer(learning_rate, momentum=.9, centered=True)
train_step = opt.minimize(loss)

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
all_summaries = tf.summary.merge_all()

In [None]:
summary_writer = tf.summary.FileWriter(logdir)
saver = tf.train.Saver()

In [None]:
imgs = train_data['imgs']
presence_gt = train_data['nums']
train_itr = 0

In [None]:
def make_fig(checkpoint_dir, global_step):
    xx, pred_canvas, pred_crop, pres = sess.run([x, prob_canvas, cropped, presence])

    max_imgs = 10
    bs = min(max_imgs, batch_size)
    scale = 1.
    figsize = scale * np.asarray((bs, 2 * n_steps + 1))
    fig, axes = plt.subplots(2 * n_steps + 1, bs, figsize=figsize)

    for i, ax in enumerate(axes[0]):
        ax.imshow(xx[i], cmap='gray', vmin=0, vmax=1)

    for i, ax_row in enumerate(axes[1:1+n_steps]):
        for j, ax in enumerate(ax_row):
            ax.imshow(pred_canvas[i, j], cmap='gray', vmin=0, vmax=1)

    for i, ax_row in enumerate(axes[1+n_steps:]):
        for j, ax in enumerate(ax_row):
            ax.imshow(pred_crop[i, j], cmap='gray', vmin=0, vmax=1)
            ax.set_title('{:.02f}'.format(pres[i, j, 0]), fontsize=4*scale)

    for ax in axes.flatten():
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

    
    fig_name = osp.join(checkpoint_dir, 'progress_fig_{}.png'.format(global_step))
    fig.savefig(fig_name, dpi=300)
    plt.close('all')
    
exprs = dict(loss=loss, rec_loss=rec_loss, steps_loss=num_steps_penalty)
train_log = make_expr_logger(sess, summary_writer, train_data['imgs'].shape[0] / batch_size, exprs, name='train')
test_log = make_expr_logger(sess, summary_writer, test_data['imgs'].shape[0] / batch_size, exprs, name='test', data_dict={x: test_x, y: test_y})

def log(train_itr):
    train_log(train_itr)
    test_log(train_itr)

In [None]:
log(train_itr)
for train_itr in xrange(train_itr+1, int(1e7)):
    
    sess.run(train_step)       
    if train_itr % 1000 == 0:
        summaries = sess.run(all_summaries)
        summary_writer.add_summary(summaries, train_itr)
        
    if train_itr % 1000 == 0:
        log(train_itr)
        
    if train_itr % 1000 == 0:
#         saver.save(sess, checkpoint_name, global_step=train_itr)
        make_fig(logdir, train_itr)    