In [None]:
# encoding=utf8  
import numpy as np
import tensorflow as tf

def unpickle(file):
  import _pickle as cPickle
  fo = open(file, 'rb')
  dict = cPickle.load(fo,encoding='latin1')
  fo.close()
  if 'data' in dict:
    dict['data'] = dict['data'].reshape((-1, 3, 32, 32)).swapaxes(1, 3).swapaxes(1, 2).reshape(-1, 32*32*3) / 256.

  return dict

In [2]:
def load_data_one(f):
  batch = unpickle(f)
  data = batch['data']
  labels = batch['labels']
  print ("Loading %s: %d" % (f, len(data)))
  return data, labels


In [3]:

def load_data(files, data_dir, label_count):
  data, labels = load_data_one(data_dir + '/' + files[0])
  for f in files[1:]:
    data_n, labels_n = load_data_one(data_dir + '/' + f)
    data = np.append(data, data_n, axis=0)
    labels = np.append(labels, labels_n, axis=0)
  labels = np.array([ [ float(i == label) for i in range(label_count) ] for label in labels ])
  return data, labels

In [4]:
def run_in_batch_avg(session, tensors, batch_placeholders, feed_dict={}, batch_size=200):                              
  res = [ 0 ] * len(tensors)                                                                                           
  batch_tensors = [ (placeholder, feed_dict[ placeholder ]) for placeholder in batch_placeholders ]                    
  total_size = len(batch_tensors[0][1])                                                                                
  batch_count = (total_size + batch_size - 1) / batch_size                                                             
  for batch_idx in range(batch_count):                                                                                
    current_batch_size = None                                                                                          
    for (placeholder, tensor) in batch_tensors:                                                                        
      batch_tensor = tensor[ batch_idx*batch_size : (batch_idx+1)*batch_size ]                                         
      current_batch_size = len(batch_tensor)                                                                           
      feed_dict[placeholder] = tensor[ batch_idx*batch_size : (batch_idx+1)*batch_size ]                               
    tmp = session.run(tensors, feed_dict=feed_dict)                                                                    
    res = [ r + t * current_batch_size for (r, t) in zip(res, tmp) ]                                                   
  return [ r / float(total_size) for r in res ]


In [5]:
def weight_variable(shape):
  initial = tf.truncated_normal(shape, stddev=0.01)
  return tf.Variable(initial)

def bias_variable(shape):
  initial = tf.constant(0.01, shape=shape)
  return tf.Variable(initial)

def conv2d(input, in_features, out_features, kernel_size, with_bias=False):
  W = weight_variable([ kernel_size, kernel_size, in_features, out_features ])
  conv = tf.nn.conv2d(input, W, [ 1, 1, 1, 1 ], padding='SAME')
  if with_bias:
    return conv + bias_variable([ out_features ])
  return conv

def batch_activ_conv(current, in_features, out_features, kernel_size, is_training, keep_prob):
  current = tf.contrib.layers.batch_norm(current, scale=True, is_training=is_training, updates_collections=None)
  current = tf.nn.relu(current)
  current = conv2d(current, in_features, out_features, kernel_size)
  current = tf.nn.dropout(current, keep_prob)
  return current

def block(input, layers, in_features, growth, is_training, keep_prob):
  current = input
  features = in_features
  for idx in range(layers):
    tmp = batch_activ_conv(current, features, growth, 3, is_training, keep_prob)
    current = tf.concat((current, tmp),3)
    features += growth
  return current, features

def avg_pool(input, s):
  return tf.nn.avg_pool(input, [ 1, s, s, 1 ], [1, s, s, 1 ], 'VALID')



In [6]:
data_dir = './data'
image_size = 32
image_dim = image_size * image_size * 3
# meta = unpickle(data_dir + '/batches.meta')
# label_names = meta['label_names']
# label_count = len(label_names)
label_count = 10
# train_files = [ 'data_batch_%d' % d for d in range(1, 6) ]
# train_data, train_labels = load_data(train_files, data_dir, label_count)
# pi = np.random.permutation(len(train_data))
# train_data, train_labels = train_data[pi], train_labels[pi]
# test_data, test_labels = load_data([ 'test_batch' ], data_dir, label_count)
# print ("Train:", np.shape(train_data), np.shape(train_labels))
# print ("Test:", np.shape(test_data), np.shape(test_labels))
# data = { 'train_data': train_data,
#   'train_labels': train_labels,
#   'test_data': test_data,
#   'test_labels': test_labels }
depth = 40

In [None]:

weight_decay = 1e-4
layers = int((depth - 4) / 3)
graph = tf.Graph()

xs = tf.placeholder("float", shape=[None, image_dim])
ys = tf.placeholder("float", shape=[None, label_count])
lr = tf.placeholder("float", shape=[])
keep_prob = tf.placeholder(tf.float32)
is_training = tf.placeholder("bool", shape=[])


current = tf.reshape(xs, [ -1, 32, 32, 3 ])
current = conv2d(current, 3, 16, 3)

current, features = block(current, layers, 16, 12, is_training, keep_prob)
current = batch_activ_conv(current, features, features, 1, is_training, keep_prob)
current = avg_pool(current, 2)
current, features = block(current, layers, features, 12, is_training, keep_prob)
current = batch_activ_conv(current, features, features, 1, is_training, keep_prob)
current = avg_pool(current, 2)
current, features = block(current, layers, features, 12, is_training, keep_prob)

current = tf.contrib.layers.batch_norm(current, scale=True, is_training=is_training, updates_collections=None)
current = tf.nn.relu(current)
current = avg_pool(current, 8)
final_dim = features
current = tf.reshape(current, [ -1, final_dim ])
Wfc = weight_variable([ final_dim, label_count ])
bfc = bias_variable([ label_count ])
ys_ = tf.nn.softmax( tf.matmul(current, Wfc) + bfc )

cross_entropy = -tf.reduce_mean(ys * tf.log(ys_ + 1e-12))
l2 = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()])
train_step = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize(cross_entropy + l2 * weight_decay)
correct_prediction = tf.equal(tf.argmax(ys_, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    


In [None]:
para_dict={}
for k in tf.global_variables():
    if k not in tf.contrib.framework.get_variables_by_suffix('Momentum'):
            para_dict[k.name[:-2]] = k

In [None]:
sess=tf.InteractiveSession()
saver = tf.train.Saver(para_dict)
#saver.restore(sess,'./inqmodel/stage2/64pinq80/64pinq80ok_93149_7.ckpt')
#saver.restore(sess,'./modellog/weightonlypara93.ckpt')
saver.restore(sess,'./inqmodel/stage2/inq16_97/inq1697_92729_5.ckpt')
#saver.restore(sess,'./prunemodel/stage2/inc100adj/prune100ar_92969_10ok.ckpt')

In [None]:
import config


In [11]:
def apply_inq(weights, inq_dict):
   
    for target in config.all_para:
        wl = target
        bit = config.inq_para[wl]
        # Get target layer's weights
        weight_obj = weights[wl]
        weight_arr = weight_obj.eval()
        
        
        weight_rest = np.reshape(weight_arr,[-1])
        dic_tem = np.reshape(inq_dict[wl],[-1])
        idx_rest = np.flip(np.argsort(abs(np.reshape(weight_rest,[-1]))),0)
        
        num_prune = int(len(weight_rest)*config.inqpercen_para[wl])
       # print(wl,sum(np.reshape(inq_dict[wl],-1)))
     
       # print('prune',num_prune)
        weight_toINQ = weight_rest[idx_rest[:num_prune]] 
        n1 = (np.floor(np.log2(max(abs(np.reshape(weight_arr,[-1])))*4/3)))
        n2 = n1 +1 - bit/4
        print(n1,n2,n1-n2)
        upper_bound = 2**(np.floor(np.log2(max(abs(np.reshape(weight_arr,[-1])))*4/3)))
        lower_bound = 2**(n1 +1 - bit/4)
        mistake = 2**(upper_bound +1 - bit/4)
        print(wl, lower_bound , mistake, lower_bound/mistake)
        weight_toINQ[abs(weight_toINQ) < lower_bound] = 0
        weight_toINQ[weight_toINQ != 0] = 2**(np.floor(np.log2(abs(weight_toINQ[weight_toINQ != 0]*4/3))))*np.sign(weight_toINQ[weight_toINQ != 0])

        
        weight_rest[idx_rest[:num_prune]] = weight_toINQ
        
        weight_arr =  np.reshape(weight_rest,np.shape(weight_arr))
        dic_tem [idx_rest[:num_prune]] =  np.zeros_like(dic_tem [idx_rest[:num_prune]])
        inq_dict[wl] = np.reshape(dic_tem,np.shape(inq_dict[wl]))
        # Apply pruning
       # print('left',sum(np.reshape(inq_dict[wl],-1)))

        # Store pruned weights as tensorflow objects
        sess.run(weight_obj.assign(weight_arr))

    return inq_dict

In [12]:
prune_dict = {}
for target in config.all_para:
    wl =target
    weight_obj = para_dict[wl]
    prune_dict[wl] = np.ones_like(weight_obj.eval())

In [13]:
prune_dict = apply_inq(para_dict, prune_dict)

-1.0 -4.0 3.0
Variable 0.0625 0.176776695297 0.353553390593
-2.0 -5.0 3.0
Variable_1 0.03125 0.148650889375 0.210224103813
-3.0 -6.0 3.0
Variable_2 0.015625 0.136313466583 0.114625505401
-3.0 -6.0 3.0
Variable_3 0.015625 0.136313466583 0.114625505401
-3.0 -6.0 3.0
Variable_4 0.015625 0.136313466583 0.114625505401
-3.0 -6.0 3.0
Variable_5 0.015625 0.136313466583 0.114625505401
-3.0 -6.0 3.0
Variable_6 0.015625 0.136313466583 0.114625505401
-3.0 -6.0 3.0
Variable_7 0.015625 0.136313466583 0.114625505401
-3.0 -6.0 3.0
Variable_8 0.015625 0.136313466583 0.114625505401
-3.0 -6.0 3.0
Variable_9 0.015625 0.136313466583 0.114625505401
-3.0 -6.0 3.0
Variable_10 0.015625 0.136313466583 0.114625505401
-4.0 -7.0 3.0
Variable_11 0.0078125 0.130534222803 0.0598502050437
-3.0 -6.0 3.0
Variable_12 0.015625 0.136313466583 0.114625505401
-3.0 -6.0 3.0
Variable_13 0.015625 0.136313466583 0.114625505401
-4.0 -7.0 3.0
Variable_14 0.0078125 0.130534222803 0.0598502050437
-3.0 -6.0 3.0
Variable_15 0.015625 0

In [14]:
saver.save(sess,'./inqmodel/stage1/inq1697.ckpt')

'./inqmodel/stage1/inq1697.ckpt'

In [15]:
import pickle
# create dict
# save dict
f1 = open("C:/Users/lhlne/Desktop/project/densenet/inqmodel/stage1/inq1697.txt","wb")
pickle.dump(prune_dict, f1)
f1.close()
# load dict
f2 = open("C:/Users/lhlne/Desktop/project/densenet/inqmodel/stage1/inq1697.txt","rb")
load_list = pickle.load(f2)
f2.close()
# print 
print(load_list)

{'Variable': array([[[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
           0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
           0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
           0.,  0.,  0.]],

        [[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
           0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
           0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,
           0.,  0.,  0.]],

        [[ 1.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  0.,
           1.,  1.,  1.],
         [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
           0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
           0.,  0.,  0.]]],


       [[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 

In [16]:
config.inqpercen_para

{'Variable': 0.975,
 'Variable_1': 0.975,
 'Variable_10': 0.975,
 'Variable_11': 0.975,
 'Variable_12': 0.975,
 'Variable_13': 0.975,
 'Variable_14': 0.975,
 'Variable_15': 0.975,
 'Variable_16': 0.975,
 'Variable_17': 0.975,
 'Variable_18': 0.975,
 'Variable_19': 0.975,
 'Variable_2': 0.975,
 'Variable_20': 0.975,
 'Variable_21': 0.975,
 'Variable_22': 0.975,
 'Variable_23': 0.975,
 'Variable_24': 0.975,
 'Variable_25': 0.975,
 'Variable_26': 0.975,
 'Variable_27': 0.975,
 'Variable_28': 0.975,
 'Variable_29': 0.975,
 'Variable_3': 0.975,
 'Variable_30': 0.975,
 'Variable_31': 0.975,
 'Variable_32': 0.975,
 'Variable_33': 0.975,
 'Variable_34': 0.975,
 'Variable_35': 0.975,
 'Variable_36': 0.975,
 'Variable_37': 0.975,
 'Variable_38': 0.975,
 'Variable_39': 0.975,
 'Variable_4': 0.975,
 'Variable_5': 0.975,
 'Variable_6': 0.975,
 'Variable_7': 0.975,
 'Variable_8': 0.975,
 'Variable_9': 0.975}

In [17]:
config.inq_para

{'Variable': 16,
 'Variable_1': 16,
 'Variable_10': 16,
 'Variable_11': 16,
 'Variable_12': 16,
 'Variable_13': 16,
 'Variable_14': 16,
 'Variable_15': 16,
 'Variable_16': 16,
 'Variable_17': 16,
 'Variable_18': 16,
 'Variable_19': 16,
 'Variable_2': 16,
 'Variable_20': 16,
 'Variable_21': 16,
 'Variable_22': 16,
 'Variable_23': 16,
 'Variable_24': 16,
 'Variable_25': 16,
 'Variable_26': 16,
 'Variable_27': 16,
 'Variable_28': 16,
 'Variable_29': 16,
 'Variable_3': 16,
 'Variable_30': 16,
 'Variable_31': 16,
 'Variable_32': 16,
 'Variable_33': 16,
 'Variable_34': 16,
 'Variable_35': 16,
 'Variable_36': 16,
 'Variable_37': 16,
 'Variable_38': 16,
 'Variable_39': 16,
 'Variable_4': 16,
 'Variable_5': 16,
 'Variable_6': 16,
 'Variable_7': 16,
 'Variable_8': 16,
 'Variable_9': 16}

In [17]:
f2 = open("C:/Users/lhlne/Desktop/project/densenet/inqmodel/stage1/inqfinal64inc20adj.txt","rb")
load_list = pickle.load(f2)
f2.close()
# print 
print(load_list)

{'Variable': array([[[[ 0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
           1.,  1.,  1.],
         [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,
           1.,  1.,  1.],
         [ 0.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,
           1.,  1.,  1.]],

        [[ 0.,  0.,  1.,  1.,  0.,  1.,  0.,  0.,  1.,  1.,  1.,  1.,  0.,
           1.,  1.,  1.],
         [ 1.,  0.,  1.,  1.,  1.,  1.,  0.,  0.,  1.,  1.,  1.,  1.,  1.,
           1.,  1.,  0.],
         [ 0.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  1.,  1.,  1.,  1.,  0.,
           1.,  1.,  1.]],

        [[ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,
           1.,  1.,  1.],
         [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
           1.,  1.,  1.],
         [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  1.,  1.,  0.,
           1.,  1.,  1.]]],


       [[[ 0.,  1.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., 

In [11]:
config.inq_para

{'Variable': 32,
 'Variable_1': 32,
 'Variable_10': 32,
 'Variable_11': 32,
 'Variable_12': 32,
 'Variable_13': 32,
 'Variable_14': 32,
 'Variable_15': 32,
 'Variable_16': 32,
 'Variable_17': 32,
 'Variable_18': 32,
 'Variable_19': 32,
 'Variable_2': 32,
 'Variable_20': 32,
 'Variable_21': 32,
 'Variable_22': 32,
 'Variable_23': 32,
 'Variable_24': 32,
 'Variable_25': 32,
 'Variable_26': 32,
 'Variable_27': 32,
 'Variable_28': 32,
 'Variable_29': 32,
 'Variable_3': 32,
 'Variable_30': 32,
 'Variable_31': 32,
 'Variable_32': 32,
 'Variable_33': 32,
 'Variable_34': 32,
 'Variable_35': 32,
 'Variable_36': 32,
 'Variable_37': 32,
 'Variable_38': 32,
 'Variable_39': 32,
 'Variable_4': 32,
 'Variable_5': 32,
 'Variable_6': 32,
 'Variable_7': 32,
 'Variable_8': 32,
 'Variable_9': 32}

In [19]:
para_dict['Variable_3'].eval()

array([[[[  4.43102850e-04,   2.34267558e-04,   3.21055355e-04, ...,
           -2.68972857e-04,   6.01321110e-04,  -5.89598611e-04],
         [  1.70976692e-03,   1.12546142e-03,  -8.56381084e-04, ...,
           -1.93680215e-04,   1.48893334e-04,   7.44329416e-04],
         [ -1.18282827e-04,   2.84192432e-03,  -1.23337284e-03, ...,
            7.04517151e-05,  -9.96216899e-04,   1.56250000e-02],
         ..., 
         [ -3.12500000e-02,  -1.13697106e-03,  -6.25000000e-02, ...,
            2.90875643e-04,   2.84710107e-03,  -6.25000000e-02],
         [ -3.12500000e-02,   1.04962173e-03,  -3.12500000e-02, ...,
           -1.23348983e-03,   1.24882394e-03,  -1.90443161e-03],
         [  2.55485508e-03,  -8.21402791e-05,  -7.52473308e-04, ...,
            2.09355610e-04,  -1.07425428e-03,   4.93022126e-05]],

        [[ -1.02333703e-04,  -9.37894278e-04,   9.87505773e-04, ...,
           -1.27386840e-04,  -4.72871616e-04,   4.66628320e-04],
         [  1.49076537e-03,   5.39196248e-04,

In [17]:
from matplotlib import pyplot as plt
for i in para_dict.items():
    if 'Variable' in i[0]:
        a = para_dict[i[0]].eval()
        b = np.reshape(a,[-1])
        c = sum(b ==0)
        d = np.shape(b)
   #     e = np.percentile(b, 75)
        print(i[0])
 #       print(e)
   #     print(int(np.log10(abs(e))))
        print('percentage',c/d)
#         print(sum(b ==0))
#         plt.hist(abs(b)) 
#         plt.title(i[0]) 
#         plt.show()

Variable
percentage [ 0.]
Variable_1
percentage [ 0.]
Variable_2
percentage [ 0.]
Variable_3
percentage [ 0.]
Variable_4
percentage [ 0.]
Variable_5
percentage [ 0.]
Variable_6
percentage [ 0.]
Variable_7
percentage [ 0.]
Variable_8
percentage [ 0.]
Variable_9
percentage [ 0.]
Variable_10
percentage [ 0.]
Variable_11
percentage [ 0.]
Variable_12
percentage [ 0.]
Variable_13
percentage [ 0.]
Variable_14
percentage [ 0.]
Variable_15
percentage [ 0.]
Variable_16
percentage [ 0.]
Variable_17
percentage [ 0.]
Variable_18
percentage [ 0.]
Variable_19
percentage [ 0.]
Variable_20
percentage [ 0.]
Variable_21
percentage [ 0.]
Variable_22
percentage [ 0.]
Variable_23
percentage [ 0.]
Variable_24
percentage [ 0.]
Variable_25
percentage [ 0.]
Variable_26
percentage [ 0.]
Variable_27
percentage [ 0.]
Variable_28
percentage [ 0.]
Variable_29
percentage [ 0.]
Variable_30
percentage [ 0.]
Variable_31
percentage [ 0.]
Variable_32
percentage [ 0.]
Variable_33
percentage [ 0.]
Variable_34
percentage [ 0

In [30]:
a = para_dict['Variable_27'].eval()
b = np.reshape(a,[-1])
print(sum(b <=0.00001))
print(np.shape(b))


25166
(32832,)


In [27]:
parm_dict={}
for k in tf.trainable_variables():
    parm_dict[k.name] = k
with tf.Session(graph=graph) as session:
    saver= tf.train.Saver(parm_dict) 
    saver.restore(session,"weights_densenet")

Variable:0
BatchNorm/beta:0
BatchNorm/gamma:0
Variable_1:0
BatchNorm_1/beta:0
BatchNorm_1/gamma:0
Variable_2:0
BatchNorm_2/beta:0
BatchNorm_2/gamma:0
Variable_3:0
BatchNorm_3/beta:0
BatchNorm_3/gamma:0
Variable_4:0
BatchNorm_4/beta:0
BatchNorm_4/gamma:0
Variable_5:0
BatchNorm_5/beta:0
BatchNorm_5/gamma:0
Variable_6:0
BatchNorm_6/beta:0
BatchNorm_6/gamma:0
Variable_7:0
BatchNorm_7/beta:0
BatchNorm_7/gamma:0
Variable_8:0
BatchNorm_8/beta:0
BatchNorm_8/gamma:0
Variable_9:0
BatchNorm_9/beta:0
BatchNorm_9/gamma:0
Variable_10:0
BatchNorm_10/beta:0
BatchNorm_10/gamma:0
Variable_11:0
BatchNorm_11/beta:0
BatchNorm_11/gamma:0
Variable_12:0
BatchNorm_12/beta:0
BatchNorm_12/gamma:0
Variable_13:0
BatchNorm_13/beta:0
BatchNorm_13/gamma:0
Variable_14:0
BatchNorm_14/beta:0
BatchNorm_14/gamma:0
Variable_15:0
BatchNorm_15/beta:0
BatchNorm_15/gamma:0
Variable_16:0
BatchNorm_16/beta:0
BatchNorm_16/gamma:0
Variable_17:0
BatchNorm_17/beta:0
BatchNorm_17/gamma:0
Variable_18:0
BatchNorm_18/beta:0
BatchNorm_18

In [7]:
depth = 40
weight_decay = 1e-4
layers = int((depth - 4) / 3)

xs = tf.placeholder("float", shape=[None, image_dim])
ys = tf.placeholder("float", shape=[None, label_count])
lr = tf.placeholder("float", shape=[])
keep_prob = tf.placeholder(tf.float32)
is_training = tf.placeholder("bool", shape=[])


current = tf.reshape(xs, [ -1, 32, 32, 3 ])
current = conv2d(current, 3, 16, 3)

current, features = block(current, layers, 16, 12, is_training, keep_prob)
current = batch_activ_conv(current, features, features, 1, is_training, keep_prob)
current = avg_pool(current, 2)
current, features = block(current, layers, features, 12, is_training, keep_prob)
current = batch_activ_conv(current, features, features, 1, is_training, keep_prob)
current = avg_pool(current, 2)
current, features = block(current, layers, features, 12, is_training, keep_prob)

current = tf.contrib.layers.batch_norm(current, scale=True, is_training=is_training, updates_collections=None)
current = tf.nn.relu(current)
current = avg_pool(current, 8)
final_dim = features
current = tf.reshape(current, [ -1, final_dim ])
Wfc = weight_variable([ final_dim, label_count ])
bfc = bias_variable([ label_count ])
ys_ = tf.nn.softmax( tf.matmul(current, Wfc) + bfc )

cross_entropy = -tf.reduce_mean(ys * tf.log(ys_ + 1e-12))
l2 = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()])
train_step = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize(cross_entropy + l2 * weight_decay)
correct_prediction = tf.equal(tf.argmax(ys_, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

sess = tf.Session()
batch_size = 64
learning_rate = 0.1
tf.global_variables_initializer()
    # saver = tf.train.Saver()
    # train_data, train_labels = data['train_data'], data['train_labels']
    # batch_count = int(len(train_data) / batch_size)
    # batches_data = np.split(train_data[:batch_count * batch_size], batch_count)
    # batches_labels = np.split(train_labels[:batch_count * batch_size], batch_count)
    # print ("Batch per epoch: ", batch_count)
    # for epoch in range(1, 1+300):
    #   if epoch == 150: learning_rate = 0.01
    #   if epoch == 225: learning_rate = 0.001
    #   for batch_idx in range(batch_count):
    #     xs_, ys_ = batches_data[batch_idx], batches_labels[batch_idx]
    #     batch_res = session.run([ train_step, cross_entropy, accuracy ],
    #       feed_dict = { xs: xs_, ys: ys_, lr: learning_rate, is_training: True, keep_prob: 0.8 })
    #     if batch_idx % 100 == 0: print (epoch, batch_idx, batch_res[1:])

    #   save_path = saver.save(session, 'densenet_%d.ckpt' % epoch)
    #   test_results = run_in_batch_avg(session, [ cross_entropy, accuracy ], [ xs, ys ],
    #       feed_dict = { xs: data['test_data'], ys: data['test_labels'], is_training: False, keep_prob: 1. })
    #   print (epoch, batch_res[1:], test_results)


<tf.Operation 'init' type=NoOp>