In [1]:
from os import path as osp
import numpy as np
import tensorflow as tf
import sonnet as snt
from attrdict import AttrDict

from evaluation import make_fig, make_seq_fig, make_logger

from data import load_data, tensors_from_data
from mnist_model import SeqAIRonMNIST

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
% matplotlib inline

ImportError: No module named seaborn

In [None]:
learning_rate = 1e-5
n_steps = 3

results_dir = '../results/seq'
run_name = 'test'

logdir = osp.join(results_dir, run_name)
checkpoint_name = osp.join(logdir, 'model.ckpt')
axes = {'imgs': 1, 'labels': 0, 'nums': 1, 'coords': 1}

In [None]:
n_timesteps = 1
batch_size = 64

num_steps_prior = AttrDict(
    anneal='exp',
    init=1.,
    final=1e-5,
    steps_div=1e4,
    steps=1e5,
    hold_init=1e3,
    analytic=False
)

appearance_prior = AttrDict(loc=0., scale=1.)
where_scale_prior = AttrDict(loc=.5, scale=1.)
where_shift_prior = AttrDict(loc=0., scale=1.)

# num_steps_prior = None
# appearance_prior = None
# where_scale_prior = None
# where_shift_prior = None

discrete_steps = True
use_reinforce = True
step_bias = .75
transform_var_bias = .5
output_multiplier = .5

l2_weight = 0. #1e-5
nums_xe_weight = 0.

In [None]:
# valid_data = load_data('small_seq_mnist_validation.pickle')
# train_data = load_data('small_seq_mnist_train.pickle')

In [None]:
valid_data = load_data('mnist_validation.pickle')
train_data = load_data('mnist_train.pickle')

def fix(d):
    d.imgs = d.imgs[np.newaxis]
    d.nums = d.nums.T
    return d

valid_data, train_data = [fix(d) for d in (valid_data, train_data)]

In [None]:
tf.reset_default_graph()
train_tensors = tensors_from_data(train_data, batch_size, axes, shuffle=True)
valid_tensors = tensors_from_data(valid_data, batch_size, axes, shuffle=False)
x, valid_x = train_tensors['imgs'], valid_tensors['imgs']
y, test_y = train_tensors['nums'], valid_tensors['nums']
    
n_hiddens = 32 * 8
n_layers = 2
n_hiddens = [n_hiddens] * n_layers
    
n_timesteps = tf.Variable(n_timesteps, trainable=False, dtype=tf.int32, name='n_timesteps')
seq_x = x[:n_timesteps]
seq_y = y[:n_timesteps]
air = SeqAIRonMNIST(seq_x,
                max_steps=n_steps,
                inpt_encoder_hidden=n_hiddens,
                glimpse_encoder_hidden=n_hiddens,
                glimpse_decoder_hidden=n_hiddens,
                transform_estimator_hidden=n_hiddens,
                steps_pred_hidden=[128, 64],
                baseline_hidden=[256, 128],
                transform_var_bias=transform_var_bias,
                step_bias=step_bias,
                output_multiplier=output_multiplier,
                discrete_steps=discrete_steps
)

In [None]:
train_step, global_step = air.train_step(learning_rate, l2_weight, appearance_prior, where_scale_prior,
                            where_shift_prior, num_steps_prior, use_reinforce=use_reinforce,
                            decay_rate=None, nums=seq_y, nums_xe_weight=nums_xe_weight)

In [None]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
    
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())

In [None]:
all_summaries = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(logdir, sess.graph)
saver = tf.train.Saver()

In [None]:
n_train_samples = train_data['imgs'].shape[axes['imgs']]
n_valid_samples = valid_data['imgs'].shape[axes['imgs']]
n_train_samples, n_valid_samples = [i//100 for i in (n_train_samples, n_valid_samples)] 
log = make_logger(air, sess, summary_writer, train_tensors, n_train_samples, valid_tensors, n_valid_samples)

In [None]:
train_itr = sess.run(global_step)
print 'Starting training at iter = {}'.format(train_itr)

if train_itr == 0:
    if air.use_reinforce:
        sess.run(air._baseline_train_step)
    log(0)
    
while train_itr < 1e6:
        
    train_itr, _ = sess.run([global_step, train_step])
    
#     if train_itr % 100 == 0:
#     if (train_itr % 1000) < 100:
    if train_itr % 1000 == 0:
        summaries = sess.run(all_summaries)
        summary_writer.add_summary(summaries, train_itr)
        
    if train_itr % 500 == 0:
        log(train_itr)
        
    if train_itr % 5000 == 0:
        saver.save(sess, checkpoint_name, global_step=train_itr)
#         make_fig(air, sess, logdir, train_itr)    

In [None]:
make_seq_fig(air, sess)