In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import random
import math
import time
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("mnist/", one_hot=True)

  return f(*args, **kwds)
  from ._conv import register_converters as _register_converters


Extracting mnist/train-images-idx3-ubyte.gz
Extracting mnist/train-labels-idx1-ubyte.gz
Extracting mnist/t10k-images-idx3-ubyte.gz
Extracting mnist/t10k-labels-idx1-ubyte.gz


In [2]:
num_train=mnist.train.num_examples
num_val=mnist.validation.images.shape
num_test=mnist.test.images.shape

In [3]:
# parameters
batch_size=64
img_size=28
sensor_unit=256
lstm_size=256
N_glimpse=10
MC_test=128
loc_std=0.2
tot_size=batch_size*MC_test

In [4]:
class Glimpse_Network():
    def __init__(self):
        self.glimspe_size=[5,7,9]
        self.concat_size=9
        self.img_net=tf.layers.Dense(units=sensor_unit,name='glimpse_net/img_net')
        self.loc_net=tf.layers.Dense(units=sensor_unit,name='glimpse_net/loc_net')
        
    def glimpse_sensor(self,image,loc):
        glimpses_list=[tf.image.extract_glimpse(input=image,size=[gs,gs],offsets=loc) for gs in self.glimspe_size]
        glimpses_norm=[tf.image.resize_bilinear(g,[self.concat_size,self.concat_size]) for g in glimpses_list]
        glimpses=tf.concat(values=glimpses_norm,axis=3)  # batch_size*concat_size*concat_size*3
        return glimpses
    
    def forward(self,image,loc):
        glimpses=self.glimpse_sensor(image,loc) # tot_size*concat_size*concat_size*3
        g_image=self.img_net(inputs=tf.layers.flatten(glimpses))
        g_loc=self.loc_net(inputs=loc)
        g_out=tf.nn.relu(g_image+g_loc)
        return g_out

In [5]:
tf.reset_default_graph()
X=tf.placeholder(dtype=tf.float32,shape=[None,28,28,1])
y=tf.placeholder(dtype=tf.int64,shape=[None,10])
start_location=tf.random_uniform(shape=[tot_size,2],minval=-1.0,maxval=1.0)
gNet=Glimpse_Network()

lstm_cell = tf.contrib.rnn.LSTMCell(lstm_size)
state = lstm_cell.zero_state(tot_size, tf.float32)

emission_net=tf.layers.Dense(units=2,name='emission_net')
baseline_net=tf.layers.Dense(units=1,name='baseline_net')
predict_net=tf.layers.Dense(units=10,name='predict_net')

def loglikelihood(sample,mean):
    gaussian=tf.distributions.Normal(loc=mean,scale=tf.constant([loc_std,loc_std]))
    llh=-gaussian.log_prob(sample)
    return tf.reduce_sum(llh,axis=1)
    
loc_his=[]
loglikelihood_his=[]
baseline_his=[]
normalized_loc=start_location
for ng in range(N_glimpse):
    loc_his.append(normalized_loc)
    
    # extract glimpse
    glimpses_out=gNet.forward(X,normalized_loc)
    
    # RNN
    lstm_output,state=lstm_cell(glimpses_out,state)
    
    # emit mean of location
    loc_mean=emission_net(inputs=lstm_output)
    
    # sample next location by gaussian distribution centered at loc_mean
    loc_sample=tf.random_normal(shape=(tot_size,2),mean=loc_mean,stddev=loc_std)
    
    # calculate the -loglikelihood of the sampled position
    llh=loglikelihood(loc_sample,loc_mean)
    loglikelihood_his.append(llh)
    
    # normalize the location for next input
    normalized_loc=tf.tanh(loc_sample)
    
    # output time independent baseline
    baseline=baseline_net(inputs=lstm_output)
    baseline_his.append(tf.squeeze(baseline))

# pack data for calculation
baseline_his=tf.stack(baseline_his)
loglikelihood_his=tf.stack(loglikelihood_his)

# make prediction
score=predict_net(inputs=lstm_output)
prediction=tf.argmax(score,1)

# calculate reward, do variance reduction and calculate reinforced loglikelihood
reward=tf.cast(tf.equal(prediction,tf.argmax(y,1)),dtype=tf.float32)
reduce_var_reward=reward-tf.stop_gradient(baseline_his)
reinforce_llh=tf.reduce_mean(loglikelihood_his*reduce_var_reward)

# regression baseline towards reward
baseline_mse=tf.reduce_mean(tf.square(reward-baseline_his))

# softmax to output
softmax_loss=tf.reduce_mean(tf.losses.softmax_cross_entropy(onehot_labels=y,logits=score))

# summarize loss
loss=reinforce_llh+baseline_mse+softmax_loss


optimizier=tf.train.AdamOptimizer(learning_rate=1e-6)
train_step = optimizier.minimize(loss)

In [6]:
max_epoch=100
print_every=50
num_iteration=num_train//batch_size
loss_his=[]

with tf.Session() as sess:
#     print(tf.global_variables())
    tf.global_variables_initializer().run()
    for epoch in range(max_epoch):
        print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),'start epoch %d/%d' % (epoch+1,max_epoch))
        tot_loss=0
        for it in range(num_iteration):
            images,labels=mnist.train.next_batch(batch_size)
            # prepare data for monte carlo test
            images=np.tile(images,(MC_test,1))
            labels=np.tile(labels,(MC_test,1))
            feed_dict={X:images.reshape(tot_size,28,28,1),y:labels}
            loss_1,loss_2,loss_3,loss_out,_=sess.run([reinforce_llh,baseline_mse,softmax_loss,loss,train_step],
                                                     feed_dict=feed_dict)
            tot_loss+=loss_out
            if it==0 or (it+1)%print_every==0 or it==num_iteration-1:
                print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),
                      'iter',it+1,': loss_1 =',loss_1,'loss_2 =',loss_2,'loss_3 =',loss_3,'total_loss =',loss_out)
        loss_his.append(tot_loss/num_iteration)
        print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),
              'end epoch, average loss =',(tot_loss/num_iteration))

2018-04-13 18:39:00 start epoch 1/100
2018-04-13 18:39:01 iter 1 : loss_1 = -0.055326004 loss_2 = 0.15681723 loss_3 = 2.3258123 total_loss = 2.4273036
2018-04-13 18:39:12 iter 50 : loss_1 = -0.03875608 loss_2 = 0.1341498 loss_3 = 2.3155053 total_loss = 2.410899
2018-04-13 18:39:22 iter 100 : loss_1 = -0.028772796 loss_2 = 0.13506845 loss_3 = 2.3040724 total_loss = 2.410368
2018-04-13 18:39:32 iter 150 : loss_1 = -0.021591634 loss_2 = 0.1231664 loss_3 = 2.3270884 total_loss = 2.428663
2018-04-13 18:39:43 iter 200 : loss_1 = -0.02415372 loss_2 = 0.12844166 loss_3 = 2.314249 total_loss = 2.418537
2018-04-13 18:39:53 iter 250 : loss_1 = -0.03769056 loss_2 = 0.15967295 loss_3 = 2.2869048 total_loss = 2.4088871
2018-04-13 18:40:04 iter 300 : loss_1 = -0.039121605 loss_2 = 0.16623122 loss_3 = 2.2905312 total_loss = 2.4176407
2018-04-13 18:40:14 iter 350 : loss_1 = -0.020655824 loss_2 = 0.13412912 loss_3 = 2.3145723 total_loss = 2.4280457
2018-04-13 18:40:24 iter 400 : loss_1 = -0.0030090595 l

2018-04-13 18:50:04 iter 600 : loss_1 = -0.03037629 loss_2 = 0.24028563 loss_3 = 2.1479075 total_loss = 2.357817
2018-04-13 18:50:15 iter 650 : loss_1 = -0.015234714 loss_2 = 0.23176336 loss_3 = 2.1507587 total_loss = 2.3672874
2018-04-13 18:50:25 iter 700 : loss_1 = -0.017314034 loss_2 = 0.2280232 loss_3 = 2.1563177 total_loss = 2.3670268
2018-04-13 18:50:36 iter 750 : loss_1 = -0.013461946 loss_2 = 0.23351666 loss_3 = 2.1352973 total_loss = 2.355352
2018-04-13 18:50:46 iter 800 : loss_1 = -0.027532807 loss_2 = 0.23548241 loss_3 = 2.118918 total_loss = 2.3268676
2018-04-13 18:50:57 iter 850 : loss_1 = -0.008855047 loss_2 = 0.22237507 loss_3 = 2.1432023 total_loss = 2.3567224
2018-04-13 18:50:58 iter 859 : loss_1 = -0.04012896 loss_2 = 0.24921879 loss_3 = 2.1141462 total_loss = 2.323236
2018-04-13 18:50:58 end epoch, average loss = 2.3532611596570443
2018-04-13 18:50:58 start epoch 5/100
2018-04-13 18:50:59 iter 1 : loss_1 = -0.020074796 loss_2 = 0.2318038 loss_3 = 2.1240907 total_loss

2018-04-13 19:00:48 iter 250 : loss_1 = -0.01551086 loss_2 = 0.23249622 loss_3 = 1.9861124 total_loss = 2.2030978
2018-04-13 19:00:59 iter 300 : loss_1 = -0.0018654246 loss_2 = 0.21926527 loss_3 = 2.02346 total_loss = 2.2408597
2018-04-13 19:01:09 iter 350 : loss_1 = 0.0051425556 loss_2 = 0.21801524 loss_3 = 2.0309238 total_loss = 2.2540817
2018-04-13 19:01:20 iter 400 : loss_1 = -0.02603564 loss_2 = 0.24226587 loss_3 = 2.0163126 total_loss = 2.2325428
2018-04-13 19:01:30 iter 450 : loss_1 = 0.018486302 loss_2 = 0.21162936 loss_3 = 2.0783114 total_loss = 2.308427
2018-04-13 19:01:41 iter 500 : loss_1 = -0.012742212 loss_2 = 0.22957441 loss_3 = 2.012807 total_loss = 2.229639
2018-04-13 19:01:51 iter 550 : loss_1 = -0.0037573054 loss_2 = 0.2246943 loss_3 = 2.0031836 total_loss = 2.2241206
2018-04-13 19:02:02 iter 600 : loss_1 = -0.0072469325 loss_2 = 0.21943025 loss_3 = 2.0141907 total_loss = 2.226374
2018-04-13 19:02:12 iter 650 : loss_1 = -0.03132675 loss_2 = 0.2370646 loss_3 = 1.95290

2018-04-13 19:11:59 iter 859 : loss_1 = 0.0037964755 loss_2 = 0.20830555 loss_3 = 1.9899135 total_loss = 2.2020154
2018-04-13 19:11:59 end epoch, average loss = 2.204590381371406
2018-04-13 19:11:59 start epoch 12/100
2018-04-13 19:11:59 iter 1 : loss_1 = -0.019462049 loss_2 = 0.22582126 loss_3 = 1.9883912 total_loss = 2.1947503
2018-04-13 19:12:10 iter 50 : loss_1 = -0.014454904 loss_2 = 0.2291849 loss_3 = 1.9781506 total_loss = 2.1928806
2018-04-13 19:12:20 iter 100 : loss_1 = -0.013904506 loss_2 = 0.22119336 loss_3 = 1.973253 total_loss = 2.1805418
2018-04-13 19:12:31 iter 150 : loss_1 = -0.012217127 loss_2 = 0.2192277 loss_3 = 2.0176237 total_loss = 2.2246342
2018-04-13 19:12:41 iter 200 : loss_1 = -0.025270632 loss_2 = 0.23630567 loss_3 = 1.9488223 total_loss = 2.1598573
2018-04-13 19:12:52 iter 250 : loss_1 = -0.027992744 loss_2 = 0.24050418 loss_3 = 1.9443772 total_loss = 2.1568887
2018-04-13 19:13:03 iter 300 : loss_1 = -0.0012562476 loss_2 = 0.21612544 loss_3 = 1.999311 total_

2018-04-13 19:23:18 iter 550 : loss_1 = -0.033322595 loss_2 = 0.24497333 loss_3 = 1.8435125 total_loss = 2.0551634
2018-04-13 19:23:30 iter 600 : loss_1 = 0.0072070984 loss_2 = 0.2113719 loss_3 = 1.9898062 total_loss = 2.2083852
2018-04-13 19:23:41 iter 650 : loss_1 = -0.008256571 loss_2 = 0.22490224 loss_3 = 1.9353552 total_loss = 2.152001
2018-04-13 19:23:53 iter 700 : loss_1 = 0.007913722 loss_2 = 0.20637059 loss_3 = 1.9356592 total_loss = 2.1499436
2018-04-13 19:24:04 iter 750 : loss_1 = 0.00020046005 loss_2 = 0.21247001 loss_3 = 1.9564221 total_loss = 2.1690927
2018-04-13 19:24:15 iter 800 : loss_1 = -0.01684562 loss_2 = 0.23090371 loss_3 = 1.89007 total_loss = 2.1041281
2018-04-13 19:24:27 iter 850 : loss_1 = -0.021455605 loss_2 = 0.23513587 loss_3 = 1.89012 total_loss = 2.1038003
2018-04-13 19:24:29 iter 859 : loss_1 = -0.004628774 loss_2 = 0.21159232 loss_3 = 1.9468842 total_loss = 2.1538477
2018-04-13 19:24:29 end epoch, average loss = 2.160622043298759
2018-04-13 19:24:29 sta

2018-04-13 19:35:50 iter 200 : loss_1 = -0.015965674 loss_2 = 0.2145046 loss_3 = 1.9182682 total_loss = 2.1168072
2018-04-13 19:36:03 iter 250 : loss_1 = -0.009621205 loss_2 = 0.20566913 loss_3 = 2.0228932 total_loss = 2.2189412
2018-04-13 19:36:16 iter 300 : loss_1 = 0.011320075 loss_2 = 0.18608221 loss_3 = 2.0502934 total_loss = 2.2476957
2018-04-13 19:36:29 iter 350 : loss_1 = -0.01118616 loss_2 = 0.21149588 loss_3 = 2.0016527 total_loss = 2.2019625
2018-04-13 19:36:42 iter 400 : loss_1 = 0.0027131303 loss_2 = 0.20021844 loss_3 = 1.9764624 total_loss = 2.179394
2018-04-13 19:36:55 iter 450 : loss_1 = 0.0077846595 loss_2 = 0.19166443 loss_3 = 1.9893612 total_loss = 2.1888103
2018-04-13 19:37:09 iter 500 : loss_1 = -0.022807816 loss_2 = 0.2206494 loss_3 = 1.9727962 total_loss = 2.1706378
2018-04-13 19:37:22 iter 550 : loss_1 = -0.016110841 loss_2 = 0.20824723 loss_3 = 1.9744029 total_loss = 2.1665392
2018-04-13 19:37:35 iter 600 : loss_1 = -0.014000619 loss_2 = 0.20532775 loss_3 = 1.9

2018-04-13 19:50:37 iter 800 : loss_1 = -0.0038941994 loss_2 = 0.19214778 loss_3 = 2.0001504 total_loss = 2.188404
2018-04-13 19:50:52 iter 850 : loss_1 = -0.0023754865 loss_2 = 0.18894058 loss_3 = 2.0305748 total_loss = 2.21714
2018-04-13 19:50:55 iter 859 : loss_1 = -0.0063428194 loss_2 = 0.19309333 loss_3 = 2.0379355 total_loss = 2.224686
2018-04-13 19:50:55 end epoch, average loss = 2.200668978607835
2018-04-13 19:50:55 start epoch 23/100
2018-04-13 19:50:55 iter 1 : loss_1 = 0.005145523 loss_2 = 0.17741421 loss_3 = 2.040583 total_loss = 2.2231426
2018-04-13 19:51:09 iter 50 : loss_1 = -0.01676277 loss_2 = 0.20816994 loss_3 = 1.973248 total_loss = 2.1646552
2018-04-13 19:51:24 iter 100 : loss_1 = -0.022270305 loss_2 = 0.20534329 loss_3 = 1.9724791 total_loss = 2.1555521
2018-04-13 19:51:38 iter 150 : loss_1 = -0.012059224 loss_2 = 0.20498414 loss_3 = 1.9590943 total_loss = 2.1520193
2018-04-13 19:51:53 iter 200 : loss_1 = -0.034936678 loss_2 = 0.21835461 loss_3 = 1.9214648 total_lo

2018-04-13 20:05:58 iter 450 : loss_1 = -0.010209637 loss_2 = 0.204074 loss_3 = 2.0210824 total_loss = 2.2149467
2018-04-13 20:06:14 iter 500 : loss_1 = -0.004381531 loss_2 = 0.19995503 loss_3 = 2.0162354 total_loss = 2.211809
2018-04-13 20:06:29 iter 550 : loss_1 = 0.010647161 loss_2 = 0.17880674 loss_3 = 2.049679 total_loss = 2.239133
2018-04-13 20:06:45 iter 600 : loss_1 = -0.0065267035 loss_2 = 0.19162194 loss_3 = 2.0534637 total_loss = 2.238559
2018-04-13 20:07:00 iter 650 : loss_1 = -0.012475845 loss_2 = 0.20185924 loss_3 = 2.0412145 total_loss = 2.230598


KeyboardInterrupt: 

In [None]:
plt.figure(2)
ptr,=plt.plot(range(max_epoch),loss_his)
plt.xlabel('training epoch')
plt.ylabel('loss')
plt.show()

In [None]:
# mean=tf.zeros((100,2),dtype=tf.float32)
# std=tf.constant([1,1],dtype=tf.float32)
# gaussian=tf.distributions.Normal(loc=mean,scale=std)
# rand=tf.random_normal(shape=(100,2),mean=0,stddev=1)
# sampled=mean+rand
# prob=-gaussian.log_prob(sampled)
# prob=tf.reduce_mean(tf.reduce_sum(prob,1))
# with tf.Session() as sess:
#     out=sess.run([prob])
#     print(out)