In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import random
import math
import time
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("mnist/", one_hot=True)

In [None]:
num_train=mnist.train.num_examples
num_val=mnist.validation.images.shape
num_test=mnist.test.images.shape

# Baseline: CNN

In [None]:
X=tf.placeholder(dtype=tf.float32,shape=[None,28,28,1],name='X')
y=tf.placeholder(dtype=tf.int64,shape=[None,10],name='y')
is_training=tf.placeholder(dtype=tf.bool,name='is_training')
cnn_out1=tf.layers.conv2d(X,128,kernel_size=3,strides=(1, 1),padding='same')
bn_out1=tf.layers.batch_normalization(cnn_out1,axis=3,training=is_training)
relu_out1=tf.nn.relu(bn_out1)
pool_out1=tf.layers.max_pooling2d(relu_out1,[2,2],[2,2])
cnn_out2=tf.layers.conv2d(pool_out1,128,kernel_size=3,strides=(1, 1),padding='same')
bn_out2=tf.layers.batch_normalization(cnn_out2,axis=3,training=is_training)
relu_out1=tf.nn.relu(bn_out2)
pool_out2=tf.layers.max_pooling2d(relu_out1,[2,2],[2,2])
flt=tf.layers.flatten(pool_out2)
out1=tf.layers.dense(flt,1024)
bn_out3=tf.layers.batch_normalization(out1,axis=1,training=is_training)
score=tf.layers.dense(out1,10)
predictions = tf.argmax(score, 1)
correct_predictions = tf.equal(predictions, tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"))

loss=tf.losses.softmax_cross_entropy(onehot_labels=y,logits=score)
optimizier=tf.train.AdamOptimizer(learning_rate=1e-5)
train_step = optimizier.minimize(loss)

In [None]:
max_epoch=1
batch_size=8
print_every=625

def train():
    num_iteration=num_train//batch_size
    for it in range(num_iteration):
        images,labels=mnist.train.next_batch(batch_size)
        loss_num,_ = sess.run([loss,train_step],feed_dict={X:images.reshape(-1,28,28,1),y:labels,is_training:True})
        if it==0 or (it+1)%print_every==0 or it==num_iteration-1:
            print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),
                  'iteration %d/%d:' % (it+1,num_iteration),'current training loss = %f' % (loss_num))
            
def eval(dataset,num_iteration):
    total_loss=0
    total_accuracy=0
    for it in range(num_iteration):
        images,labels=dataset.next_batch(batch_size)
        loss_num,accuracy_num = sess.run([loss,accuracy],feed_dict={X:images.reshape(-1,28,28,1),y:labels,is_training:False})
        total_loss+=loss_num
        total_accuracy+=accuracy_num
    total_loss/=num_iteration
    total_accuracy/=num_iteration
    return total_loss,total_accuracy
    
acc_train_his=[]
acc_val_his=[]
acc_test_his=[]
    
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for epoch in range(max_epoch):
        print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),'start epoch %d/%d:' % (epoch+1,max_epoch))
        train()
        loss_train,acc_train=eval(mnist.train,2000)
        loss_val,acc_val=eval(mnist.validation,500)
        loss_test,acc_test=eval(mnist.test,1000)
        acc_train_his.append(acc_train)
        acc_val_his.append(acc_val)
        acc_test_his.append(acc_test)
        print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),'end epoch %d/%d:' % (epoch+1,max_epoch),
             'acc_train=%.3f%% acc_val=%.3f%% acc_test=%.3f%%' % (acc_train*100.0,acc_val*100.0,acc_test*100.0))

In [None]:
plt.figure(1)
ptr,=plt.plot(range(max_epoch),acc_train_his,'r-')
pva,=plt.plot(range(max_epoch),acc_val_his,'b-')
pte,=plt.plot(range(max_epoch),acc_test_his,'y-')
plt.xlabel('training epoch')
plt.ylabel('accuracy')
plt.title('accuracy on three sets')
plt.legend((ptr,pva,pte),('train','validation','test'))
plt.show()

# RAM

In [None]:
# parameters
batch_size=64
img_size=28
sensor_unit=256
lstm_size=256
N_glimpse=10
MC_test=128
loc_std=0.2
tot_size=batch_size*MC_test

In [None]:
class Glimpse_Network():
    def __init__(self):
        self.glimspe_size=[5,9,13]
        self.concat_size=13
        self.img_net=tf.layers.Dense(units=sensor_unit,name='glimpse_net/img_net')
        self.loc_net=tf.layers.Dense(units=sensor_unit,name='glimpse_net/loc_net')
        
    def glimpse_sensor(self,image,loc):
        glimpses_list=[tf.image.extract_glimpse(input=image,size=[gs,gs],offsets=loc) for gs in self.glimspe_size]
        glimpses_norm=[tf.image.resize_bilinear(g,[self.concat_size,self.concat_size]) for g in glimpses_list]
        glimpses=tf.concat(values=glimpses_norm,axis=3)
        return glimpses
    
    def forward(self,image,loc):
        glimpses=self.glimpse_sensor(image,loc) # tot_size*13*13*1
        g_image=self.img_net(inputs=tf.layers.flatten(glimpses))
        g_loc=self.loc_net(inputs=loc)
        g_out=tf.nn.relu(g_image+g_loc)
        return g_out

In [None]:
tf.reset_default_graph()
X=tf.placeholder(dtype=tf.float32,shape=[None,28,28,1])
y=tf.placeholder(dtype=tf.int64,shape=[None,10])
start_location=tf.random_uniform(shape=[tot_size,2],minval=-1.0,maxval=1.0)
gNet=Glimpse_Network()

lstm_cell = tf.contrib.rnn.LSTMCell(lstm_size)
state = lstm_cell.zero_state(tot_size, tf.float32)

emission_net=tf.layers.Dense(units=2,name='emission_net')
baseline_net=tf.layers.Dense(units=1,name='baseline_net')
predict_net=tf.layers.Dense(units=10,name='predict_net')

def loglikelihood(sample,mean):
    gaussian=tf.distributions.Normal(loc=mean,scale=tf.constant([loc_std,loc_std]))
    llh=-gaussian.log_prob(sample)
    return tf.reduce_sum(llh,axis=1)
    
loc_his=[]
loglikelihood_his=[]
baseline_his=[]
normalized_loc=start_location
for ng in range(N_glimpse):
    loc_his.append(normalized_loc)
    
    # extract glimpse
    glimpses_out=gNet.forward(X,normalized_loc)
    
    # RNN
    lstm_output,state=lstm_cell(glimpses_out,state)
    
    # emit mean of location
    loc_mean=emission_net(inputs=lstm_output)
    
    # sample next location by gaussian distribution centered at loc_mean
    loc_sample=tf.random_normal(shape=(tot_size,2),mean=loc_mean,stddev=loc_std)
    
    # calculate the -loglikelihood of the sampled position
    llh=loglikelihood(loc_sample,loc_mean)
    loglikelihood_his.append(llh)
    
    # normalize the location for next input
    normalized_loc=tf.tanh(loc_sample)
    
    # output time independent baseline
    baseline=baseline_net(inputs=lstm_output)
    baseline_his.append(tf.squeeze(baseline))

# pack data for calculation
baseline_his=tf.stack(baseline_his)
loglikelihood_his=tf.stack(loglikelihood_his)

# make prediction
score=predict_net(inputs=lstm_output)
prediction=tf.argmax(score,1)

# calculate reward, do variance reduction and calculate reinforced loglikelihood
reward=tf.cast(tf.equal(prediction,tf.argmax(y,1)),dtype=tf.float32)
reduce_var_reward=reward-tf.stop_gradient(baseline_his)
reinforce_llh=tf.reduce_mean(loglikelihood_his*reduce_var_reward)

# regression baseline towards reward
baseline_mse=tf.reduce_mean(tf.square(reward-baseline_his))

# softmax to output
softmax_loss=tf.reduce_mean(tf.losses.softmax_cross_entropy(onehot_labels=y,logits=score))

# summarize loss
loss=reinforce_llh+baseline_mse+softmax_loss


optimizier=tf.train.AdamOptimizer(learning_rate=1e-5)
train_step = optimizier.minimize(loss)

In [None]:
max_epoch=100
num_iteration=num_train//batch_size

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for epoch in range(max_epoch):
        print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),'start epoch %d/%d' % (epoch+1,max_epoch))
        for it in range(num_iteration):
            images,labels=mnist.train.next_batch(batch_size)
            # prepare data for monte carlo test
            images=np.tile(images,(MC_test,1))
            labels=np.tile(labels,(MC_test,1))
            feed_dict={X:images.reshape(tot_size,28,28,1),y:labels}
            loss_1,loss_2,loss_3,loss_out,_=sess.run([reinforce_llh,baseline_mse,softmax_loss,loss,train_step],
                                                     feed_dict=feed_dict)
            print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),
                  'loss_1 =',loss_1,'loss_2 =',loss_2,'loss_3 =',loss_3,'total_loss =',loss_out)
        print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),'end epoch')

In [None]:
# mean=tf.zeros((100,2),dtype=tf.float32)
# std=tf.constant([1,1],dtype=tf.float32)
# gaussian=tf.distributions.Normal(loc=mean,scale=std)
# rand=tf.random_normal(shape=(100,2),mean=0,stddev=1)
# sampled=mean+rand
# prob=-gaussian.log_prob(sampled)
# prob=tf.reduce_mean(tf.reduce_sum(prob,1))
# with tf.Session() as sess:
#     out=sess.run([prob])
#     print(out)