In [1]:
import tensorflow as tf
import numpy as np
import random
import matplotlib.pyplot as plt
from keras.datasets import mnist
import pickle
from utils import load_mnist
import os

Using TensorFlow backend.


In [2]:
old_v = tf.logging.get_verbosity()
tf.logging.set_verbosity(tf.logging.ERROR)

In [3]:
def unpickle(file):    
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [4]:
logits_train = unpickle('logits_dir/resnet164_logits_train.txt')

In [5]:
# Data used in ResNet164
(x_train, y_train), (x_val, y_val), (x_test, y_test) = load_mnist()

In [6]:
# Prepare data to train with small model
data = {'X_train': x_train.reshape(len(x_train), 784).copy(), 'y_train': y_train,
        'X_val': x_val.reshape(len(x_val), 784).copy(), 'y_val': y_val,
        'X_test': x_test.reshape(len(x_test), 784).copy(), 'y_test': y_test,
       }

In [7]:
class StudentModel:
    def __init__(self, 
                 model_type,
                 num_steps=500, 
                 batch_size=128, 
                 display_step=100, 
                 n_hidden_1=256,
                 n_hidden_2=256,
                 num_input=784, 
                 num_classes=10,
                 dropoutprob=0.75,
                 checkpoint_dir="checkpoint",
                 checkpoint_file="smallmodel",
                 temperature=1.0,
                 log_dir="logs", 
                 learning_rate=0.001):
        
        self.learning_rate = learning_rate
        self.num_steps = num_steps
        self.batch_size = batch_size
        self.display_step = display_step
        self.n_hidden_1 = n_hidden_1  # 1st layer number of neurons
        self.n_hidden_2 = n_hidden_2  # 2nd layer number of neurons
        self.num_input = num_input  # MNIST data input (img shape: 28*28)
        self.num_classes = num_classes
        self.temperature = temperature
        self.checkpoint_dir = checkpoint_dir
        self.checkpoint_file = checkpoint_file
        self.checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_file)
        self.max_checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_file + "max")
        self.log_dir = os.path.join(log_dir, self.checkpoint_file)
        self.model_type = model_type

        self.weights = {
            'h1': tf.Variable(tf.random_normal([self.num_input, self.n_hidden_1]),
                              name="%s_%s" % (self.model_type, "h1")),
            'h2': tf.Variable(tf.random_normal([self.n_hidden_1, self.n_hidden_2]),
                              name="%s_%s" % (self.model_type, "h2")),
            'out': tf.Variable(tf.random_normal([self.n_hidden_2, self.num_classes]),
                               name="%s_%s" % (self.model_type, "out")),
            'linear': tf.Variable(tf.random_normal([self.num_input, self.num_classes]),
                                  name="%s_%s" % (self.model_type, "linear"))
        }
        self.biases = {
            'b1': tf.Variable(tf.random_normal([self.n_hidden_1]), name="%s_%s" % (self.model_type, "b1")),
            'b2': tf.Variable(tf.random_normal([self.n_hidden_2]), name="%s_%s" % (self.model_type, "b2")),
            'out': tf.Variable(tf.random_normal([self.num_classes]), name="%s_%s" % (self.model_type, "out")),
            'linear': tf.Variable(tf.random_normal([self.num_classes]), name="%s_%s" % (self.model_type, "linear"))
        }

        self.build_model()

        self.saver = tf.train.Saver()

    # Create model
    def build_model(self):
        self.X = tf.placeholder(tf.float32, [None, self.num_input], name="%s_%s" % (self.model_type, "xinput"))
        self.Y = tf.placeholder(tf.float32, [None, self.num_classes], name="%s_%s" % (self.model_type, "yinput"))

        self.flag = tf.placeholder(tf.bool, None, name="%s_%s" % (self.model_type, "flag"))
        self.soft_Y = tf.placeholder(tf.float32, [None, self.num_classes], name="%s_%s" % (self.model_type, "softy"))
        self.softmax_temperature = tf.placeholder(tf.float32, name="%s_%s" % (self.model_type, "softmaxtemperature"))

        with tf.name_scope("%sfclayer" % (self.model_type)), tf.variable_scope("%sfclayer" % (self.model_type)):
            # Hidden fully connected layer with 256 neurons
            layer_1 = tf.add(tf.matmul(self.X, self.weights['h1']), self.biases['b1'])
            # # Hidden fully connected layer with 256 neurons
            layer_2 = tf.add(tf.matmul(layer_1, self.weights['h2']), self.biases['b2'])
            # # Output fully connected layer with a neuron for each class
            logits = (tf.matmul(layer_2, self.weights['out']) + self.biases['out'])
            # logits = tf.add(tf.matmul(self.X, self.weights['linear']), self.biases['linear'])

        with tf.name_scope("%sprediction" % (self.model_type)), tf.variable_scope("%sprediction" % (self.model_type)):
            self.prediction = tf.nn.softmax(logits)

            self.correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(self.Y, 1))
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32))

        with tf.name_scope("%soptimization" % (self.model_type)), tf.variable_scope(
                        "%soptimization" % (self.model_type)):
            # Define loss and optimizer
            self.loss_op_standard = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
                logits=logits, labels=self.Y))

            self.total_loss = self.loss_op_standard

            self.loss_op_soft = tf.cond(self.flag,
                                        true_fn=lambda: tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
                                            logits=logits / self.softmax_temperature, labels=self.soft_Y)),
                                        false_fn=lambda: 0.0)

            self.total_loss += tf.square(self.softmax_temperature) * self.loss_op_soft

            # optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
            # optimizer = tf.train.GradientDescentOptimizer(0.05)
            # self.global_step = tf.Variable(0, trainable=False)
            # self.increment_global_step_op = tf.assign(self.global_step, self.global_step+1)
            # self.ad_learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step,
            #                                1000, 0.96, staircase=True)
            optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
            # self.train_op = optimizer.minimize(self.total_loss, global_step=self.global_step)
            self.train_op = optimizer.minimize(self.total_loss)

        with tf.name_scope("%ssummarization" % (self.model_type)), tf.variable_scope(
                        "%ssummarization" % (self.model_type)):
            tf.summary.scalar("loss_op_standard", self.loss_op_standard)
            tf.summary.scalar("total_loss", self.total_loss)
            # Create a summary to monitor accuracy tensor
            tf.summary.scalar("accuracy", self.accuracy)

            for var in tf.trainable_variables():
                tf.summary.histogram(var.name, var)

            # Merge all summaries into a single op

            # If using TF 1.6 or above, simply use the following merge_all function
            # which supports scoping
            self.merged_summary_op = tf.summary.merge_all(scope=self.model_type)

    def start_session(self):
        self.sess = tf.Session()

    def close_session(self):
        self.sess.close()

    def train(self, data, logits_distill=None):
        teacher_flag = False
        if logits_distill is not None:
            teacher_flag = True
        
        X_train = data['X_train']
        y_train = data['y_train']
        X_val = data['X_val']
        y_val = data['y_val']
        
        # Initialize the variables (i.e. assign their default value)
        self.sess.run(tf.global_variables_initializer())
        train_summary_writer = tf.summary.FileWriter(self.log_dir, graph=self.sess.graph)

        max_accuracy = 0

        print("Starting Training")

        def dev_step():
            validation_x = X_val
            validation_y = y_val
        
            loss, acc = self.sess.run([self.loss_op_standard, self.accuracy], feed_dict={self.X: validation_x,
                                                                                         self.Y: validation_y,
                                                                                         # self.soft_Y: validation_y,
                                                                                         self.flag: False,
                                                                                         self.softmax_temperature: 1.0})

            if acc > max_accuracy:
                save_path = self.saver.save(self.sess, self.checkpoint_path)
                print("Model Checkpointed to %s " % (save_path))

            print("Step " + str(step) + ", Validation Loss= " + "{:.4f}".format(
                loss) + ", Validation Accuracy= " + "{:.3f}".format(acc))

        for step in range(1, self.num_steps + 1):
            batch_x, batch_y, batch_logits = self.get_batch(X_train, y_train, logits_distill)
            soft_targets = batch_y
            if teacher_flag:
                # soft_targets = self.sess.run(tf.nn.softmax(batch_logits / self.temperature))
                soft_targets = batch_logits
                
            # self.sess.run(self.train_op,
            _, summary = self.sess.run([self.train_op, self.merged_summary_op],
                                       feed_dict={self.X: batch_x,
                                                  self.Y: batch_y,
                                                  self.soft_Y: soft_targets,
                                                  self.flag: teacher_flag,
                                                  self.softmax_temperature: self.temperature}
                                       )
            train_summary_writer.add_summary(summary, step)
            
            if (step % self.display_step) == 0 or step == 1:
                dev_step()
        else:
            # Final Evaluation and checkpointing before training ends
            dev_step()

        train_summary_writer.close()

        print("Optimization Finished!")

    def get_batch(self, X_train, y_train, logits_distill=None):
        num_train = X_train.shape[0]
        batch_mask = np.random.choice(num_train, self.batch_size)
        X_batch = X_train[batch_mask]
        y_batch = y_train[batch_mask]
        if logits_distill is not None:
            logit_batch = logits_distill[batch_mask]
        else:
            logit_batch = None
        return X_batch, y_batch, logit_batch
    
    def predict(self, data_X, temperature=1.0):
        return self.sess.run(self.prediction,
                             feed_dict={self.X: data_X, self.flag: False, self.softmax_temperature: temperature})

    def run_inference(self, data):   
        X_test = data['X_test']
        y_test = data['y_test']
        batch_size = self.batch_size
        batch_num = int(len(X_test) / batch_size)
        test_accuracy = 0

        for i in range(batch_num):
            batch_x = X_test[:batch_size * (i + 1)]
            batch_y = y_test[:batch_size * (i + 1)]
            test_accuracy += self.sess.run(self.accuracy, feed_dict={self.X: batch_x,
                                                                     self.Y: batch_y,
                                                                     self.flag: False,
                                                                     self.softmax_temperature: 1.0
                                                                    })
        # test_images, test_labels = dataset.get_test_data()
        # print("Testing Accuracy:", self.sess.run(self.accuracy, feed_dict={self.X: test_images,
                                                                           # self.Y: test_labels,
                                                                           # # self.soft_Y: test_labels,
                                                                           # self.flag: False,
                                                                           # self.softmax_temperature: 1.0
                                                                           # }))
        test_accuracy /= batch_num
        print("Testing Accuracy: %g"%test_accuracy)
        
    def run_inference_ex(self, dataset_ex):
        test_images, test_labels = dataset_ex.get_test_data_ex()
        print("Testing Accuracy:", self.sess.run(self.accuracy, feed_dict={self.X: test_images,
                                                                           self.Y: test_labels,
                                                                           # self.soft_Y: test_labels,
                                                                           self.flag: False,
                                                                           self.softmax_temperature: 1.0
                                                                           }))

    def load_model_from_file(self, load_path):
        ckpt = tf.train.get_checkpoint_state(load_path)
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            print("Created model with fresh parameters.")
            self.sess.run(tf.global_variables_initializer())

In [60]:
checkpoint_dir="new_studentcpt_t1.75"
log_dir=checkpoint_dir + "\logs"
temperature = 1.75

In [61]:
student_model = StudentModel(num_steps=3000, 
                                 batch_size=128,
                                 learning_rate=0.001,
                                 temperature=temperature,
                                 dropoutprob=0,
                                 checkpoint_dir=checkpoint_dir,
                                 log_dir=log_dir,
                                 model_type="student");

In [62]:
# To speedup the calculations.
session = tf.Session()
soft_targets = session.run(tf.nn.softmax(logits_train / temperature))

In [63]:
student_model.start_session()
student_model.train(data, soft_targets)

Starting Training
Model Checkpointed to new_studentcpt_t1.75\smallmodel 
Step 1, Validation Loss= 8475.0674, Validation Accuracy= 0.078
Model Checkpointed to new_studentcpt_t1.75\smallmodel 
Step 100, Validation Loss= 812.3453, Validation Accuracy= 0.768
Model Checkpointed to new_studentcpt_t1.75\smallmodel 
Step 200, Validation Loss= 547.8128, Validation Accuracy= 0.828
Model Checkpointed to new_studentcpt_t1.75\smallmodel 
Step 300, Validation Loss= 439.7054, Validation Accuracy= 0.851
Model Checkpointed to new_studentcpt_t1.75\smallmodel 
Step 400, Validation Loss= 385.6259, Validation Accuracy= 0.858
Model Checkpointed to new_studentcpt_t1.75\smallmodel 
Step 500, Validation Loss= 330.9241, Validation Accuracy= 0.868
Model Checkpointed to new_studentcpt_t1.75\smallmodel 
Step 600, Validation Loss= 297.5023, Validation Accuracy= 0.874
Model Checkpointed to new_studentcpt_t1.75\smallmodel 
Step 700, Validation Loss= 275.9644, Validation Accuracy= 0.875
Model Checkpointed to new_stude

In [64]:
# Load the best model from created checkpoint
student_model.load_model_from_file(checkpoint_dir)
# Test the model against the testing set
student_model.run_inference(data)

Reading model parameters from new_studentcpt_t1.75\smallmodel
Testing Accuracy: 0.857691


In [59]:
# Close current tf sessions
student_model.close_session()
tf.reset_default_graph()

In [13]:
logits_train[0]

array([-6.2030087 , -3.9134192 , -6.317644  ,  6.6796656 , -2.957327  ,
       12.809154  , -3.4207807 ,  0.18087192,  2.3757007 ,  0.04373608],
      dtype=float32)

In [11]:
y_train.shape

(50000, 10)

In [9]:
x = [0., -1., 2., 3.]

In [12]:
session = tf.Session()
soft_targets = session.run(tf.nn.softmax(logits_train))

In [16]:
big_model_pred = np.argmax(session.run(tf.nn.softmax(logits_train)), 1)

In [17]:
test_pred = np.argmax(y_train,axis=1)

In [19]:
np.sum(big_model_pred == test_pred) / len(y_train)

0.99966