In [None]:
from os import path as osp
import numpy as np
import tensorflow as tf
import sonnet as snt

import matplotlib.pyplot as plt
% matplotlib inline

from neurocity.tools.params import num_trainable_params

from data import create_mnist
from model import AIRCell

In [None]:
learning_rate = 1e-4
batch_size = 64
img_size = 50, 50
crop_size = 20, 20
n_latent = 50
n_hidden = 256
n_steps = 3

logdir = 'supervised'
checkpoint_name = osp.join(logdir, 'model.ckpt')

In [None]:
data = create_mnist(n_samples=60000)

In [None]:
imgs = data['imgs']
plt.imshow(imgs[2], cmap='gray')

In [None]:
tf.reset_default_graph()
x = tf.placeholder(tf.float32, (batch_size,) + img_size, name='inpt')
y = tf.placeholder(tf.float32, (n_steps, batch_size, 1), name='gt_presence')

transition = snt.LSTM(n_hidden)
air = AIRCell(img_size, crop_size, n_latent, transition)
initial_state = air.initial_state(x)

dummy_sequence = tf.zeros((n_steps, batch_size, 1), name='dummy_sequence')
outputs, state = tf.nn.dynamic_rnn(air, dummy_sequence, initial_state=initial_state, time_major=True)
canvas, cropped, what, where, presence_logit = outputs
presence = tf.nn.sigmoid(presence_logit)

cropped = tf.reshape(presence * tf.nn.sigmoid(cropped), (n_steps, batch_size,) + tuple(crop_size))
canvas = tf.reshape(canvas, (n_steps, batch_size,) + tuple(img_size))
prob_canvas = tf.nn.sigmoid(canvas)
final_canvas = canvas[-1]

In [None]:
print num_trainable_params()

In [None]:
# loss = ((x - final_canvas)**2
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=final_canvas)

rec_loss = tf.reduce_mean(tf.reduce_sum(loss, axis=(1, 2)))
tf.summary.scalar('rec_loss', rec_loss)

alpha= 1.
# num_steps = tf.reduce_mean(tf.reduce_sum(presence, 0))
# num_steps_penalty = alpha * num_steps
num_steps_penalty = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=presence_logit)
num_steps_penalty = tf.reduce_mean(num_steps_penalty)
tf.summary.scalar('steps_loss', num_steps_penalty)

loss = rec_loss + num_steps_penalty
tf.summary.scalar('loss', loss)

# opt = tf.train.AdamOptimizer(learning_rate)
opt = tf.train.RMSPropOptimizer(learning_rate, momentum=.9, centered=True)
train_step = opt.minimize(loss)

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
all_summaries = tf.summary.merge_all()

In [None]:
summary_writer = tf.summary.FileWriter(logdir)
saver = tf.train.Saver()

In [None]:
imgs = data['imgs']
presence_gt = data['nums']
train_itr = 1

In [None]:
def make_fig(checkpoint_dir, global_step, fd):
    pred_canvas, pred_crop, pres = sess.run([prob_canvas, cropped, presence], fd)

    max_imgs = 10
    bs = min(max_imgs, batch_size)
    scale = 3.
    figsize = scale * np.asarray((bs, 2 * n_steps + 1))
    fig, axes = plt.subplots(2 * n_steps + 1, bs, figsize=figsize)

    for i, ax in enumerate(axes[0]):
        ax.imshow(xx[i], cmap='gray', vmin=0, vmax=1)

    for i, ax_row in enumerate(axes[1:1+n_steps]):
        for j, ax in enumerate(ax_row):
            ax.imshow(pred_canvas[i, j], cmap='gray', vmin=0, vmax=1)

    for i, ax_row in enumerate(axes[1+n_steps:]):
        for j, ax in enumerate(ax_row):
            ax.imshow(pred_crop[i, j], cmap='gray', vmin=0, vmax=1)

    for ax in axes.flatten():
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

    
    fig_name = osp.join(checkpoint_dir, 'progress_fig_{}.png'.format(global_step))
    fig.savefig(fig_name, dpi=300)
    plt.close('all')

In [None]:
for train_itr in xrange(train_itr+1, 100000):
    
    idx = np.random.choice(imgs.shape[0], batch_size)
    xx = imgs[idx]
    yy = presence_gt[:, idx]
    
    fd = {x: xx, y: yy}
    sess.run(train_step, fd)
    if train_itr % 100 == 0:
        l = sess.run([loss, rec_loss, num_steps_penalty], fd)
        print train_itr, l
        
    if train_itr % 1000 == 0:
        summaries = sess.run(all_summaries)
        summary_writer.add_summary(summaries, train_itr)
        
    if train_itr % 1000 == 0:
        saver.save(sess, checkpoint_name, global_step=train_itr)
        make_fig(logdir, train_itr, fd)