In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import random
import math
import time
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("mnist/", one_hot=True)

  return f(*args, **kwds)
  from ._conv import register_converters as _register_converters


Extracting mnist/train-images-idx3-ubyte.gz
Extracting mnist/train-labels-idx1-ubyte.gz
Extracting mnist/t10k-images-idx3-ubyte.gz
Extracting mnist/t10k-labels-idx1-ubyte.gz


In [2]:
num_train=mnist.train.num_examples
num_val=mnist.validation.images.shape
num_test=mnist.test.images.shape

In [3]:
# parameters
batch_size=64
img_size=28
sensor_unit=256
lstm_size=256
N_glimpse=10
MC_test=128
loc_std=0.2
tot_size=batch_size*MC_test

In [4]:
class Glimpse_Network():
    def __init__(self):
        self.glimspe_size=[5,10,15]
        self.concat_size=5
        self.img_net=tf.layers.Dense(units=sensor_unit,name='glimpse_net/img_net')
        self.loc_net=tf.layers.Dense(units=sensor_unit,name='glimpse_net/loc_net')
        
    def glimpse_sensor(self,image,loc):
        glimpses_list=[tf.image.extract_glimpse(input=image,size=[gs,gs],offsets=loc) for gs in self.glimspe_size]
        glimpses_norm=[tf.image.resize_bilinear(g,[self.concat_size,self.concat_size]) for g in glimpses_list]
        glimpses=tf.concat(values=glimpses_norm,axis=3)  # batch_size*concat_size*concat_size*3
        return glimpses
    
    def forward(self,image,loc):
        glimpses=self.glimpse_sensor(image,loc) # tot_size*concat_size*concat_size*3
        glimpses=tf.stop_gradient(glimpses)  # gradient has no need to flow through glimpses
        g_image=tf.nn.relu(self.img_net(inputs=tf.layers.flatten(glimpses)))
        g_loc=tf.nn.relu(self.loc_net(inputs=loc))
        g_out=tf.nn.relu(g_image+g_loc)
        return g_out

In [5]:
tf.reset_default_graph()
X=tf.placeholder(dtype=tf.float32,shape=[None,28,28,1])
y=tf.placeholder(dtype=tf.int64,shape=[None,10])
start_location=tf.random_uniform(shape=[tot_size,2],minval=-1.0,maxval=1.0)
gNet=Glimpse_Network()

lstm_cell = tf.contrib.rnn.LSTMCell(lstm_size)
state = lstm_cell.zero_state(tot_size, tf.float32)

emission_net=tf.layers.Dense(units=2,name='emission_net')
baseline_net=tf.layers.Dense(units=1,name='baseline_net')
predict_net=tf.layers.Dense(units=10,name='predict_net')

def loglikelihood(sample,mean):
    gaussian=tf.distributions.Normal(loc=mean,scale=tf.constant([loc_std,loc_std]))
    llh=-gaussian.log_prob(sample)
    return tf.reduce_sum(llh,axis=1)
    
loc_his=[]
loglikelihood_his=[]
baseline_his=[]
normalized_loc=start_location
for ng in range(N_glimpse):
    loc_his.append(normalized_loc)
    
    # extract glimpse
    glimpses_out=gNet.forward(X,normalized_loc)
    
    # RNN
    lstm_output,state=lstm_cell(glimpses_out,state)
    
    # emit mean of location
    loc_mean=emission_net(inputs=lstm_output)
    
    # sample next location by gaussian distribution centered at loc_mean
    loc_sample=tf.random_normal(shape=(tot_size,2),mean=loc_mean,stddev=loc_std)
    loc_sample=tf.stop_gradient(loc_sample)
    
    # calculate the -loglikelihood of the sampled position
    llh=loglikelihood(loc_sample,loc_mean)
    loglikelihood_his.append(llh)
    
    # normalize the location for next input
    normalized_loc=tf.tanh(loc_sample)
    normalized_loc=tf.stop_gradient(normalized_loc)
    
    # output time independent baseline
    baseline=baseline_net(inputs=lstm_output)
    baseline_his.append(tf.squeeze(baseline))

# pack data for calculation
baseline_his=tf.stack(baseline_his)
loglikelihood_his=tf.stack(loglikelihood_his)

# make prediction
score=predict_net(inputs=lstm_output)
prediction=tf.argmax(score,1)

# calculate reward, do variance reduction and calculate reinforced loglikelihood
reward=tf.cast(tf.equal(prediction,tf.argmax(y,1)),dtype=tf.float32)
reward_sum=tf.reduce_sum(reward)
reduce_var_reward=reward-tf.stop_gradient(baseline_his)
reinforce_llh=tf.reduce_mean(loglikelihood_his*reduce_var_reward)

# regression baseline towards reward
baseline_mse=tf.reduce_mean(tf.square(reward-baseline_his))

# softmax to output
softmax_loss=tf.reduce_mean(tf.losses.softmax_cross_entropy(onehot_labels=y,logits=score))

# summarize loss
loss=reinforce_llh+baseline_mse+softmax_loss


optimizier=tf.train.RMSPropOptimizer(learning_rate=1e-4)
train_step = optimizier.minimize(loss)

In [None]:
max_epoch=100
print_every=50
num_iteration=num_train//batch_size
loss_his=[]
reward_his=[]

with tf.Session() as sess:
#     print(tf.global_variables())
    tf.global_variables_initializer().run()
    for epoch in range(max_epoch):
        print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),'start epoch %d/%d' % (epoch+1,max_epoch))
        tot_loss=0
        tot_reward=0
        for it in range(num_iteration):
            images,labels=mnist.train.next_batch(batch_size)
            # prepare data for monte carlo test
            images=np.tile(images,(MC_test,1))
            labels=np.tile(labels,(MC_test,1))
            feed_dict={X:images.reshape(tot_size,28,28,1),y:labels}
            loss_1,loss_2,loss_3,loss_out,reward_out,_=sess.run([reinforce_llh,baseline_mse,softmax_loss,loss,
                                                                 reward_sum,train_step],feed_dict=feed_dict)
            tot_loss+=loss_out
            tot_reward+=reward_out
            if it==0 or (it+1)%print_every==0 or it==num_iteration-1:
                print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),
                      'iter %3d: l_1 = %8f, l_2 = %8f, l_3 = %8f, tot_l = %8f, tot_R =' % 
                      (it+1,loss_1,loss_2,loss_3,loss_out),reward_out)
        loss_his.append(tot_loss/num_iteration)
        reward_his.append(tot_reward/num_iteration)
        print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),
              'end epoch, average loss =',(tot_loss/num_iteration),'average reward =',(tot_reward/num_iteration))

2018-04-16 17:02:13 start epoch 1/100
2018-04-16 17:02:14 iter   1: l_1 = -0.025237, l_2 = 0.127027, l_3 = 2.294904, tot_l = 2.396694, tot_R = 1149.0
2018-04-16 17:02:30 iter  50: l_1 = -0.012037, l_2 = 0.100284, l_3 = 2.322230, tot_l = 2.410478, tot_R = 905.0
2018-04-16 17:02:47 iter 100: l_1 = -0.020275, l_2 = 0.134339, l_3 = 2.277950, tot_l = 2.392015, tot_R = 1240.0
2018-04-16 17:03:04 iter 150: l_1 = -0.065901, l_2 = 0.268693, l_3 = 2.203079, tot_l = 2.405871, tot_R = 3206.0
2018-04-16 17:03:21 iter 200: l_1 = -0.034244, l_2 = 0.235213, l_3 = 2.010151, tot_l = 2.211121, tot_R = 3180.0
2018-04-16 17:03:38 iter 250: l_1 = -0.039293, l_2 = 0.251383, l_3 = 1.714825, tot_l = 1.926915, tot_R = 3740.0
2018-04-16 17:03:55 iter 300: l_1 = -0.037555, l_2 = 0.238984, l_3 = 1.379016, tot_l = 1.580445, tot_R = 4524.0
2018-04-16 17:04:12 iter 350: l_1 = -0.025285, l_2 = 0.248260, l_3 = 1.418285, tot_l = 1.641260, tot_R = 4115.0
2018-04-16 17:04:29 iter 400: l_1 = -0.025978, l_2 = 0.221614, l_3 

2018-04-16 17:20:33 iter 650: l_1 = -0.015717, l_2 = 0.098979, l_3 = 0.325096, tot_l = 0.408358, tot_R = 7271.0
2018-04-16 17:20:50 iter 700: l_1 = -0.017993, l_2 = 0.084479, l_3 = 0.293296, tot_l = 0.359782, tot_R = 7469.0
2018-04-16 17:21:07 iter 750: l_1 = -0.006817, l_2 = 0.109889, l_3 = 0.405203, tot_l = 0.508275, tot_R = 7168.0
2018-04-16 17:21:25 iter 800: l_1 = -0.040609, l_2 = 0.085930, l_3 = 0.252888, tot_l = 0.298208, tot_R = 7522.0
2018-04-16 17:21:41 iter 850: l_1 = -0.025301, l_2 = 0.063874, l_3 = 0.196865, tot_l = 0.235438, tot_R = 7687.0
2018-04-16 17:21:44 iter 859: l_1 = -0.019447, l_2 = 0.097095, l_3 = 0.314870, tot_l = 0.392517, tot_R = 7355.0
2018-04-16 17:21:44 end epoch, average loss = 0.5205756924933688 average reward = 7092.831199068684
2018-04-16 17:21:44 start epoch 5/100
2018-04-16 17:21:45 iter   1: l_1 = 0.003591, l_2 = 0.120966, l_3 = 0.399207, tot_l = 0.523764, tot_R = 7057.0
2018-04-16 17:22:01 iter  50: l_1 = -0.024635, l_2 = 0.099965, l_3 = 0.324808, 

2018-04-16 17:38:08 iter 300: l_1 = 0.007072, l_2 = 0.090214, l_3 = 0.314377, tot_l = 0.411663, tot_R = 7348.0
2018-04-16 17:38:25 iter 350: l_1 = 0.012587, l_2 = 0.083972, l_3 = 0.269525, tot_l = 0.366083, tot_R = 7452.0
2018-04-16 17:38:42 iter 400: l_1 = -0.018656, l_2 = 0.054621, l_3 = 0.165755, tot_l = 0.201720, tot_R = 7756.0
2018-04-16 17:39:00 iter 450: l_1 = -0.011127, l_2 = 0.065617, l_3 = 0.215294, tot_l = 0.269784, tot_R = 7629.0
2018-04-16 17:39:17 iter 500: l_1 = 0.006747, l_2 = 0.115167, l_3 = 0.441177, tot_l = 0.563091, tot_R = 7044.0
2018-04-16 17:39:34 iter 550: l_1 = 0.014842, l_2 = 0.090267, l_3 = 0.293469, tot_l = 0.398578, tot_R = 7370.0
2018-04-16 17:39:51 iter 600: l_1 = -0.014915, l_2 = 0.047844, l_3 = 0.152166, tot_l = 0.185094, tot_R = 7785.0
2018-04-16 17:40:08 iter 650: l_1 = -0.015854, l_2 = 0.039473, l_3 = 0.120372, tot_l = 0.143991, tot_R = 7875.0
2018-04-16 17:40:25 iter 700: l_1 = -0.007868, l_2 = 0.066035, l_3 = 0.266424, tot_l = 0.324591, tot_R = 760

2018-04-16 17:56:10 iter   1: l_1 = -0.004730, l_2 = 0.041646, l_3 = 0.145524, tot_l = 0.182440, tot_R = 7833.0
2018-04-16 17:56:27 iter  50: l_1 = -0.004748, l_2 = 0.060363, l_3 = 0.186805, tot_l = 0.242421, tot_R = 7637.0
2018-04-16 17:56:45 iter 100: l_1 = 0.006161, l_2 = 0.083738, l_3 = 0.250588, tot_l = 0.340488, tot_R = 7421.0
2018-04-16 17:57:02 iter 150: l_1 = -0.005991, l_2 = 0.043951, l_3 = 0.178493, tot_l = 0.216453, tot_R = 7824.0
2018-04-16 17:57:19 iter 200: l_1 = -0.013470, l_2 = 0.044495, l_3 = 0.136271, tot_l = 0.167295, tot_R = 7812.0
2018-04-16 17:57:37 iter 250: l_1 = -0.011316, l_2 = 0.037161, l_3 = 0.112336, tot_l = 0.138181, tot_R = 7886.0
2018-04-16 17:57:54 iter 300: l_1 = 0.009656, l_2 = 0.073057, l_3 = 0.231801, tot_l = 0.314513, tot_R = 7532.0
2018-04-16 17:58:12 iter 350: l_1 = 0.005616, l_2 = 0.055145, l_3 = 0.176535, tot_l = 0.237296, tot_R = 7719.0
2018-04-16 17:58:29 iter 400: l_1 = -0.011993, l_2 = 0.035744, l_3 = 0.115068, tot_l = 0.138819, tot_R = 78

2018-04-16 18:14:51 iter 650: l_1 = 0.001329, l_2 = 0.059281, l_3 = 0.207599, tot_l = 0.268209, tot_R = 7660.0
2018-04-16 18:15:08 iter 700: l_1 = -0.007817, l_2 = 0.023487, l_3 = 0.080275, tot_l = 0.095945, tot_R = 8008.0
2018-04-16 18:15:26 iter 750: l_1 = -0.003077, l_2 = 0.044348, l_3 = 0.199190, tot_l = 0.240461, tot_R = 7812.0
2018-04-16 18:15:43 iter 800: l_1 = -0.008615, l_2 = 0.027393, l_3 = 0.086295, tot_l = 0.105074, tot_R = 7976.0
2018-04-16 18:16:01 iter 850: l_1 = -0.003743, l_2 = 0.033058, l_3 = 0.101574, tot_l = 0.130890, tot_R = 7905.0
2018-04-16 18:16:04 iter 859: l_1 = -0.007754, l_2 = 0.031247, l_3 = 0.102394, tot_l = 0.125888, tot_R = 7923.0
2018-04-16 18:16:04 end epoch, average loss = 0.1933712500111743 average reward = 7795.722933643772
2018-04-16 18:16:04 start epoch 16/100
2018-04-16 18:16:04 iter   1: l_1 = 0.008120, l_2 = 0.068106, l_3 = 0.231934, tot_l = 0.308160, tot_R = 7547.0
2018-04-16 18:16:23 iter  50: l_1 = -0.002858, l_2 = 0.033761, l_3 = 0.114716, 

2018-04-16 18:32:56 iter 300: l_1 = -0.001743, l_2 = 0.045250, l_3 = 0.142181, tot_l = 0.185688, tot_R = 7796.0
2018-04-16 18:33:13 iter 350: l_1 = -0.000331, l_2 = 0.046993, l_3 = 0.141891, tot_l = 0.188553, tot_R = 7782.0
2018-04-16 18:33:30 iter 400: l_1 = -0.001696, l_2 = 0.028224, l_3 = 0.093252, tot_l = 0.119781, tot_R = 7959.0
2018-04-16 18:33:48 iter 450: l_1 = 0.002331, l_2 = 0.045103, l_3 = 0.133756, tot_l = 0.181190, tot_R = 7797.0
2018-04-16 18:34:05 iter 500: l_1 = -0.010726, l_2 = 0.015760, l_3 = 0.054833, tot_l = 0.059866, tot_R = 8078.0
2018-04-16 18:34:23 iter 550: l_1 = 0.002593, l_2 = 0.052474, l_3 = 0.147322, tot_l = 0.202389, tot_R = 7723.0
2018-04-16 18:34:40 iter 600: l_1 = -0.003375, l_2 = 0.029431, l_3 = 0.087960, tot_l = 0.114016, tot_R = 7948.0
2018-04-16 18:34:57 iter 650: l_1 = -0.004324, l_2 = 0.039236, l_3 = 0.119240, tot_l = 0.154152, tot_R = 7835.0
2018-04-16 18:35:15 iter 700: l_1 = -0.008300, l_2 = 0.027541, l_3 = 0.089217, tot_l = 0.108457, tot_R = 7

2018-04-16 18:51:11 iter   1: l_1 = 0.011959, l_2 = 0.062867, l_3 = 0.223949, tot_l = 0.298774, tot_R = 7630.0
2018-04-16 18:51:29 iter  50: l_1 = -0.004771, l_2 = 0.021664, l_3 = 0.066941, tot_l = 0.083834, tot_R = 8018.0
2018-04-16 18:51:47 iter 100: l_1 = -0.006185, l_2 = 0.022175, l_3 = 0.067824, tot_l = 0.083813, tot_R = 8008.0
2018-04-16 18:52:04 iter 150: l_1 = -0.000073, l_2 = 0.044949, l_3 = 0.124977, tot_l = 0.169853, tot_R = 7809.0
2018-04-16 18:52:22 iter 200: l_1 = -0.000913, l_2 = 0.027050, l_3 = 0.080046, tot_l = 0.106183, tot_R = 7960.0
2018-04-16 18:52:39 iter 250: l_1 = 0.004262, l_2 = 0.036505, l_3 = 0.100757, tot_l = 0.141523, tot_R = 7873.0
2018-04-16 18:52:57 iter 300: l_1 = 0.000680, l_2 = 0.034373, l_3 = 0.121686, tot_l = 0.156740, tot_R = 7907.0
2018-04-16 18:53:15 iter 350: l_1 = -0.005553, l_2 = 0.040133, l_3 = 0.127874, tot_l = 0.162453, tot_R = 7844.0
2018-04-16 18:53:33 iter 400: l_1 = -0.008443, l_2 = 0.027209, l_3 = 0.093390, tot_l = 0.112155, tot_R = 79

2018-04-16 19:10:12 iter 650: l_1 = 0.000384, l_2 = 0.037518, l_3 = 0.110957, tot_l = 0.148859, tot_R = 7860.0
2018-04-16 19:10:29 iter 700: l_1 = 0.008056, l_2 = 0.045392, l_3 = 0.154467, tot_l = 0.207915, tot_R = 7788.0
2018-04-16 19:10:47 iter 750: l_1 = -0.003403, l_2 = 0.032580, l_3 = 0.101326, tot_l = 0.130504, tot_R = 7911.0
2018-04-16 19:11:04 iter 800: l_1 = -0.002708, l_2 = 0.028547, l_3 = 0.083337, tot_l = 0.109176, tot_R = 7951.0
2018-04-16 19:11:22 iter 850: l_1 = 0.000749, l_2 = 0.046173, l_3 = 0.143365, tot_l = 0.190287, tot_R = 7784.0
2018-04-16 19:11:25 iter 859: l_1 = 0.001012, l_2 = 0.038432, l_3 = 0.113709, tot_l = 0.153153, tot_R = 7863.0
2018-04-16 19:11:25 end epoch, average loss = 0.12230744112144215 average reward = 7943.476135040745
2018-04-16 19:11:25 start epoch 27/100
2018-04-16 19:11:25 iter   1: l_1 = -0.001518, l_2 = 0.025533, l_3 = 0.074591, tot_l = 0.098606, tot_R = 7977.0
2018-04-16 19:11:44 iter  50: l_1 = -0.001125, l_2 = 0.033615, l_3 = 0.116828, t

In [None]:
plt.figure(2)
plo,=plt.plot(range(max_epoch),loss_his)
prw,=plt.plot(range(max_epoch),reward_his)
plt.legend((plo,prw),('loss','reward'))
plt.xlabel('training epoch')
plt.ylabel('loss')
plt.show()

In [None]:
# mean=tf.zeros((100,2),dtype=tf.float32)
# std=tf.constant([1,1],dtype=tf.float32)
# gaussian=tf.distributions.Normal(loc=mean,scale=std)
# rand=tf.random_normal(shape=(100,2),mean=0,stddev=1)
# sampled=mean+rand
# prob=-gaussian.log_prob(sampled)
# prob=tf.reduce_mean(tf.reduce_sum(prob,1))
# with tf.Session() as sess:
#     out=sess.run([prob])
#     print(out)