In [5]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from ResNet import ResNet
from utils import load_cifar100

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

Using TensorFlow backend.


In [2]:
class Foo():
    def __init__(self, phase, dataset, epoch, batch_size, res_n, lr, checkpoint_dir, log_dir):
        self.phase = phase
        self.dataset = dataset
        self.epoch = epoch
        self.batch_size =batch_size
        self.res_n=res_n
        self.lr=lr
        self.checkpoint_dir=checkpoint_dir
        self.log_dir=log_dir

In [6]:
# Load data
x_train, y_train, x_test, y_test = load_cifar100()

In [15]:
data = {'X_train': x_train[:49000].transpose(0,3,1,2).copy(), 'y_train': np.argmax(y_train[:49000],axis=1),
        'X_val': x_train[49000:].transpose(0,3,1,2).copy(), 'y_val': np.argmax(y_train[49000:],axis=1),
        'X_test': x_test.transpose(0,3,1,2).copy(), 'y_test': np.argmax(y_test,axis=1),
       }

In [9]:
x_train.shape

(50000, 32, 32, 3)

In [8]:
args = Foo(phase='train',
           dataset='cifar100',
           epoch=40,
           batch_size=100,
           res_n=18,
           lr=0.1,
           checkpoint_dir='checkpoint',
           log_dir='logs')

images = x_train[:49000]
logits_predict = np.zeros((len(images),100))

tf.logging.set_verbosity(tf.logging.ERROR)

with tf.Graph().as_default():
        
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        cnn = ResNet(sess,args)
        # build graph
        cnn.build_model()
        
        tf.global_variables_initializer().run()

        cnn.saver = tf.train.Saver()
        could_load, _ = cnn.load(cnn.checkpoint_dir)

        if could_load:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")        
        
        # get batch data
        iteration = len(images)//cnn.batch_size
        for idx in range(iteration):
            batch_x = images[idx*cnn.batch_size:(idx+1)*cnn.batch_size]

            predict_feed_dict = {
                cnn.train_inptus : batch_x
            }

            logits = cnn.sess.run(cnn.train_logits, feed_dict=predict_feed_dict)
            logits_predict[idx*cnn.batch_size:(idx+1)*cnn.batch_size] = logits
    

 [*] Reading checkpoints...
 [*] Success to read ResNet.model-20000
 [*] Load SUCCESS


In [10]:
print(logits_predict.shape)
print(logits_predict[0])

(49000, 100)
[-3.32770371 -2.74245572 -1.11332297  2.34665799  2.48451042 -6.06812906
  2.02936721  0.18213981 -1.15276384 -1.29720879 -2.68880868 -1.79230702
  1.27115953  0.4580842  -0.74583721  2.88306355 -2.76108265 -0.99674559
 -0.32749847  4.04755068 -3.51942849  5.56274509 -2.0092423   3.71169114
 -0.88281202 -3.58735561  0.26666641  1.8727057  -3.60197568 -1.24948847
  2.50073552  5.4399457   3.10534453  2.285532   -0.22726886 -1.32995486
 -1.05093575 -0.78550661 -0.28507647 -2.67669797 -0.8436814  -1.86945808
 -0.86866707 -1.8867681  -0.48333511 -1.94488466  1.46266103  1.13487577
 -0.43671042  0.58670652 -1.11991513  2.5235517   0.47974521 -5.74578428
 -2.30505466  1.17884636  1.08945751 -4.9439292  -2.35735059  4.22439718
  0.59201896 -3.98419762  1.35350811  5.92362976 -0.08460547  1.36316156
  1.77266073  1.52367055  0.48252025  0.45280215 -2.11718726  1.7239995
  6.47616339  5.24690723  3.72679043  4.6872468  -1.03298676  2.65097833
 -0.76894796  1.71192181  2.72173691  1

In [11]:
def SoftMax(s):
    # minus max to avoid large s case
    p = np.exp(s-np.expand_dims(np.max(s,axis=1),axis=1))/\
    np.expand_dims(np.exp(s-np.expand_dims(np.max(s,axis=1),axis=1)).sum(axis=1),axis=1)  # matrix of size NxK
    return p

In [12]:
y_pred_big = np.argmax(SoftMax(logits_predict),axis=1)
y_true = np.argmax(y_train[:49000],axis=1)
print('Train accuracy of big model: {}'.format(np.mean(y_true==y_pred_big)))

Train accuracy of big model: 0.8726326530612245


In [14]:
import pickle
with open('resnet18_logits_train.txt', 'wb') as fp:
    pickle.dump(logits_predict, fp)

In [15]:
# Check test accuracy
with tf.Graph().as_default():
        
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        cnn = ResNet(sess,args)
        # build graph
        cnn.build_model()

        tf.global_variables_initializer().run()

        cnn.saver = tf.train.Saver()
        could_load, _ = cnn.load(cnn.checkpoint_dir)

        if could_load:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")        


        predict_feed_dict = {
            cnn.test_inptus : x_test
        }

        logits_test = cnn.sess.run(cnn.test_logits, feed_dict=predict_feed_dict)

 [*] Reading checkpoints...
 [*] Success to read ResNet.model-20000
 [*] Load SUCCESS


In [16]:
print(logits_test.shape)
print(logits_test[0])
y_pred_big = np.argmax(SoftMax(logits_test),axis=1)
y_true = np.argmax(y_test,axis=1)
print('Train accuracy of big model: {}'.format(np.mean(y_true==y_pred_big)))

(10000, 100)
[-5.1639113e+00 -7.5697646e+00  1.9156983e+00  5.8520656e+00
 -1.0737603e+00  3.1251509e+00 -3.1966140e+00 -1.7989796e-01
 -5.8173418e-01 -2.2535324e+00 -4.7869942e-01  4.9313660e+00
  7.0755067e+00  1.2944865e+00 -4.6343422e+00  2.5551825e+00
  9.5872897e-01  3.7017753e+00  1.6514316e+00 -5.4881730e+00
 -3.0841486e+00  8.6602241e-01  6.5830040e-01  1.3590215e+00
  3.4914798e-01  1.5674220e+00 -1.7908673e+00 -3.3286507e+00
 -3.9012530e+00 -2.2346768e+00  6.0295162e+00 -6.6462827e-01
  1.6215535e+00 -1.6266257e-01 -1.5407908e+00  3.7863574e+00
 -1.1186889e+00  1.3114735e-03 -1.3134971e+00 -9.0551227e-03
  5.1770344e+00 -6.5990796e+00 -1.6741199e+00 -4.0537977e+00
  3.5908180e-01  4.2425418e-01  4.7662096e+00 -7.7134719e+00
 -1.5018659e+00  4.3662581e+00 -5.1767999e-01 -5.8424187e+00
 -4.2030573e+00 -7.6513343e+00 -3.3669430e-01  6.7252803e+00
  2.6845725e+00 -1.8835118e+00 -1.9494309e+00 -1.3035301e+00
 -1.2396645e-01  8.8761449e-01 -5.4402494e-01  2.4234636e+00
  4.9623609