In [1]:
import model
import data
import tensorflow as tf
import numpy as np
import random

To avoid verbose warning messages... 

In [2]:
old_v = tf.logging.get_verbosity()
tf.logging.set_verbosity(tf.logging.ERROR)

Load MNIST dataset via a Dataset object:

In [3]:
dataset = data.Dataset(batch_size=128)

Extracting ./train-images-idx3-ubyte.gz
Extracting ./train-labels-idx1-ubyte.gz
Extracting ./t10k-images-idx3-ubyte.gz
Extracting ./t10k-labels-idx1-ubyte.gz
55000


Training parameters:

In [4]:
learning_rate=0.001
num_steps = 5000
batch_size = 128

Model parameters:

In [5]:
temperature = 5.0
dropout = 0.75

## Main Model Training

- ### Teacher Model

Output directory for model checkpoint:

In [6]:
checkpoint_dir="teachercpt"

Model definition:

In [7]:
teacher_model = model.BigModel(num_steps=num_steps, 
                               batch_size=batch_size,
                               learning_rate=learning_rate,
                               temperature=temperature,
                               dropoutprob=dropout,
                               checkpoint_dir=checkpoint_dir,
                               model_type="teacher");

Training:

At each step, the validation accuracy is computed and, if maximal, a model checkpoint is created. This is, in a way, analogous to **early stopping**.

In [8]:
teacher_model.start_session()
teacher_model.train(dataset)

Starting Training
Model Checkpointed to teachercpt\bigmodel.ckpt 
Step 1, Validation Loss= 54235.1562, Validation Accuracy= 0.108
Model Checkpointed to teachercpt\bigmodel.ckpt 
Step 100, Validation Loss= 2075.0391, Validation Accuracy= 0.869
Model Checkpointed to teachercpt\bigmodel.ckpt 
Step 200, Validation Loss= 1210.1136, Validation Accuracy= 0.918
Model Checkpointed to teachercpt\bigmodel.ckpt 
Step 300, Validation Loss= 810.2395, Validation Accuracy= 0.936
Model Checkpointed to teachercpt\bigmodel.ckpt 
Step 400, Validation Loss= 638.9126, Validation Accuracy= 0.945
Model Checkpointed to teachercpt\bigmodel.ckpt 
Step 500, Validation Loss= 551.1878, Validation Accuracy= 0.949
Model Checkpointed to teachercpt\bigmodel.ckpt 
Step 600, Validation Loss= 447.9606, Validation Accuracy= 0.952
Model Checkpointed to teachercpt\bigmodel.ckpt 
Step 700, Validation Loss= 359.3526, Validation Accuracy= 0.956
Model Checkpointed to teachercpt\bigmodel.ckpt 
Step 800, Validation Loss= 354.0197,

Test the **teacher model** (compute its accuracy againts the testing dataset) on the best model based on the validation set, this is, the *checkpointed model*. 

In [9]:
# Load the best model from created checkpoint
teacher_model.load_model_from_file(checkpoint_dir)
# Test the model against the testing set
teacher_model.run_inference(dataset)

Reading model parameters from teachercpt\bigmodel.ckpt
Testing Accuracy: 0.977455


In [10]:
# Close current tf session
teacher_model.close_session()

- ### Simple Student Model
Simple, as in, trained with the same data and parameters as the teacher model. 

Output directory for model checkpoint:

In [11]:
checkpoint_dir="sstudentcpt"

Model definition:

In [12]:
student_model = model.SmallModel(num_steps=num_steps, 
                                 batch_size=batch_size,
                                 learning_rate=learning_rate,
                                 temperature=temperature,
                                 dropoutprob=dropout,
                                 checkpoint_dir=checkpoint_dir,
                                 model_type="student");

Training:

At each step, the validation accuracy is computed and, if maximal, a model checkpoint is created. This is, in a way, analogous to **early stopping**.

In [13]:
student_model.start_session()
student_model.train(dataset)

Starting Training
Model Checkpointed to sstudentcpt\smallmodel 
Step 1, Validation Loss= 15.6250, Validation Accuracy= 0.117
Model Checkpointed to sstudentcpt\smallmodel 
Step 100, Validation Loss= 8.6834, Validation Accuracy= 0.161
Model Checkpointed to sstudentcpt\smallmodel 
Step 200, Validation Loss= 5.7821, Validation Accuracy= 0.277
Model Checkpointed to sstudentcpt\smallmodel 
Step 300, Validation Loss= 4.1909, Validation Accuracy= 0.395
Model Checkpointed to sstudentcpt\smallmodel 
Step 400, Validation Loss= 3.1995, Validation Accuracy= 0.488
Model Checkpointed to sstudentcpt\smallmodel 
Step 500, Validation Loss= 2.5701, Validation Accuracy= 0.557
Model Checkpointed to sstudentcpt\smallmodel 
Step 600, Validation Loss= 2.1426, Validation Accuracy= 0.607
Model Checkpointed to sstudentcpt\smallmodel 
Step 700, Validation Loss= 1.8318, Validation Accuracy= 0.648
Model Checkpointed to sstudentcpt\smallmodel 
Step 800, Validation Loss= 1.6146, Validation Accuracy= 0.685
Model Check

Test the **student model** (compute its accuracy againts the testing dataset) on the best model based on the validation set, this is, the *checkpointed model*. 

In [16]:
# Load the best model from created checkpoint
student_model.load_model_from_file(checkpoint_dir)
# Test the model against the testing set
student_model.run_inference(dataset)

Reading model parameters from sstudentcpt\smallmodel
Testing Accuracy: 0.886801


In [17]:
# Close current tf session
student_model.close_session()

- ### Distilled Student Model
Training data consists of the **logits** from the Teacher Model predictions of the standard training set.

Pretrained **teacher model** loading:

In [6]:
# Model definition
teacher_model = model.BigModel(num_steps=num_steps, 
                               batch_size=batch_size,
                               learning_rate=learning_rate,
                               temperature=temperature,
                               dropoutprob=dropout,
                               checkpoint_dir="teachercpt",
                               model_type="teacher");
# Start tf session
teacher_model.start_session()

In [7]:
# Load best model from teacher checkpoint
checkpoint_dir = "teachercpt"
teacher_model.load_model_from_file(checkpoint_dir)

Reading model parameters from teachercpt\bigmodel.ckpt


Verify **teacher** model state before training **student**:

In [8]:
# Test the model against the testing set
teacher_model.run_inference(dataset)

Testing Accuracy: 0.977437


Output directory for distilled student model checkpoint:

In [9]:
checkpoint_dir="dstudentcpt"

Student model definition:

In [10]:
student_model = model.SmallModel(num_steps=num_steps, 
                                 batch_size=batch_size,
                                 learning_rate=learning_rate,
                                 temperature=temperature,
                                 dropoutprob=dropout,
                                 checkpoint_dir=checkpoint_dir,
                                 model_type="student");

Training:

At each step, the validation accuracy is computed and, if maximal, a model checkpoint is created. This is, in a way, analogous to **early stopping**.

In [11]:
student_model.start_session()
student_model.train(dataset, teacher_model)

Starting Training
Model Checkpointed to dstudentcpt\smallmodel 
Step 1, Validation Loss= 12.7508, Validation Accuracy= 0.075
Model Checkpointed to dstudentcpt\smallmodel 
Step 100, Validation Loss= 7.2757, Validation Accuracy= 0.194
Model Checkpointed to dstudentcpt\smallmodel 
Step 200, Validation Loss= 4.9041, Validation Accuracy= 0.360
Model Checkpointed to dstudentcpt\smallmodel 
Step 300, Validation Loss= 3.5745, Validation Accuracy= 0.486
Model Checkpointed to dstudentcpt\smallmodel 
Step 400, Validation Loss= 2.7939, Validation Accuracy= 0.576
Model Checkpointed to dstudentcpt\smallmodel 
Step 500, Validation Loss= 2.3026, Validation Accuracy= 0.637
Model Checkpointed to dstudentcpt\smallmodel 
Step 600, Validation Loss= 1.9884, Validation Accuracy= 0.682
Model Checkpointed to dstudentcpt\smallmodel 
Step 700, Validation Loss= 1.7399, Validation Accuracy= 0.719
Model Checkpointed to dstudentcpt\smallmodel 
Step 800, Validation Loss= 1.5756, Validation Accuracy= 0.745
Model Check

Test the **distilled student model** (compute its accuracy againts the testing dataset) on the best model based on the validation set, this is, the *checkpointed model*. 

In [12]:
# Load the best model from created checkpoint
student_model.load_model_from_file(checkpoint_dir)
# Test the model against the testing set
student_model.run_inference(dataset)

Reading model parameters from dstudentcpt\smallmodel
Testing Accuracy: 0.907361


In [13]:
# Close current tf sessions
teacher_model.close_session()
student_model.close_session()

## Experiments

- ### 1. Learn from Probabilities
Take a class out from the training set for the Distilled Model, and later test the accuracy for that class. 

In [9]:
# MNIST dataset with class 3 excluded from training
dataset_ex = data.DatasetExclude(batch_size=128, exclude_class=3)

Extracting ./train-images-idx3-ubyte.gz
Extracting ./train-labels-idx1-ubyte.gz
Extracting ./t10k-images-idx3-ubyte.gz
Extracting ./t10k-labels-idx1-ubyte.gz
55000


Pretrained **teacher model** loading:

In [7]:
# Model definition
teacher_model = model.BigModel(num_steps=num_steps, 
                               batch_size=batch_size,
                               learning_rate=learning_rate,
                               temperature=temperature,
                               dropoutprob=dropout,
                               checkpoint_dir="teachercpt",
                               model_type="teacher");
# Start tf session
teacher_model.start_session()

In [8]:
# Load best model from teacher checkpoint
checkpoint_dir = "teachercpt"
teacher_model.load_model_from_file(checkpoint_dir)

Reading model parameters from teachercpt\bigmodel.ckpt


Verify **teacher** model state before training **student**:

In [9]:
# Test the model against the testing set
teacher_model.run_inference(dataset)

Testing Accuracy: 0.977418


Output directory for distilled student model checkpoint:

In [6]:
checkpoint_dir="ex1_dstudentcpt"

Student model definition:

In [7]:
student_model = model.SmallModel(num_steps=num_steps, 
                                 batch_size=batch_size,
                                 learning_rate=learning_rate,
                                 temperature=temperature,
                                 dropoutprob=dropout,
                                 checkpoint_dir=checkpoint_dir,
                                 model_type="student");

Training:

At each step, the validation accuracy is computed and, if maximal, a model checkpoint is created. This is, in a way, analogous to **early stopping**.

In [13]:
student_model.start_session()
student_model.train(dataset_ex, teacher_model)

Starting Training
Model Checkpointed to ex1_dstudentcpt\smallmodel 
Step 1, Validation Loss= 13.8729, Validation Accuracy= 0.123
Model Checkpointed to ex1_dstudentcpt\smallmodel 
Step 100, Validation Loss= 8.1246, Validation Accuracy= 0.207
Model Checkpointed to ex1_dstudentcpt\smallmodel 
Step 200, Validation Loss= 6.2020, Validation Accuracy= 0.348
Model Checkpointed to ex1_dstudentcpt\smallmodel 
Step 300, Validation Loss= 5.3502, Validation Accuracy= 0.468
Model Checkpointed to ex1_dstudentcpt\smallmodel 
Step 400, Validation Loss= 4.9623, Validation Accuracy= 0.551
Model Checkpointed to ex1_dstudentcpt\smallmodel 
Step 500, Validation Loss= 4.7677, Validation Accuracy= 0.612
Model Checkpointed to ex1_dstudentcpt\smallmodel 
Step 600, Validation Loss= 4.7209, Validation Accuracy= 0.656
Model Checkpointed to ex1_dstudentcpt\smallmodel 
Step 700, Validation Loss= 4.6939, Validation Accuracy= 0.685
Model Checkpointed to ex1_dstudentcpt\smallmodel 
Step 800, Validation Loss= 4.7184, Va

Test the **distilled student model** (compute its accuracy againts the testing dataset) on the best model based on the validation set, this is, the *checkpointed model*. 

In [11]:
# Load the best model from created checkpoint
student_model.load_model_from_file(checkpoint_dir)
# Test the model against the testing set
student_model.run_inference(dataset)

Reading model parameters from ex1_dstudentcpt\smallmodel
Testing Accuracy: 0.824392


In [12]:
# Test the model againts a testing set containing only the initially excluded class
student_model.run_inference_ex(dataset_ex)

Testing Accuracy: 0.0


In [13]:
# Close current tf sessions
teacher_model.close_session()
student_model.close_session()

### Tests

In [23]:
only3Data = dataset_ex.get_test_data_ex()

In [27]:
test_images_ex, test_labels_ex = dataset_ex.get_test_data_ex()
test_images, test_labels = dataset.get_test_data()

In [45]:
res = student_model.predict(test_images_ex)
res

array([[1.1815443e-16, 1.0408784e-21, 4.1169969e-06, ..., 2.3125008e-19,
        1.1995219e-06, 4.0976341e-16],
       [5.9604062e-14, 1.7043893e-17, 1.2083776e-15, ..., 9.4897306e-01,
        2.8533020e-06, 2.3076200e-06],
       [1.6960637e-20, 8.4145722e-26, 8.5389697e-18, ..., 3.1040057e-24,
        1.4308544e-09, 7.0362976e-19],
       ...,
       [6.5682438e-14, 1.8740755e-17, 9.9998462e-01, ..., 1.1531632e-16,
        1.5435768e-05, 6.6995472e-12],
       [6.1676939e-20, 3.4999019e-14, 7.3642968e-07, ..., 6.9019401e-10,
        9.9999928e-01, 3.5820596e-10],
       [1.5866134e-07, 1.4374173e-16, 4.7634237e-02, ..., 3.4302823e-07,
        2.2476334e-07, 6.4896581e-17]], dtype=float32)

In [49]:
labels = [np.argmax(x) for x in res]
labels

[5,
 7,
 5,
 5,
 5,
 2,
 5,
 5,
 5,
 5,
 5,
 5,
 2,
 2,
 5,
 8,
 2,
 5,
 9,
 2,
 5,
 8,
 6,
 5,
 5,
 5,
 5,
 8,
 5,
 2,
 5,
 7,
 5,
 7,
 5,
 5,
 2,
 5,
 5,
 5,
 5,
 8,
 2,
 7,
 5,
 5,
 2,
 5,
 8,
 5,
 5,
 2,
 6,
 5,
 5,
 5,
 2,
 5,
 9,
 8,
 5,
 5,
 1,
 9,
 5,
 5,
 5,
 8,
 2,
 5,
 2,
 8,
 5,
 8,
 5,
 7,
 5,
 5,
 5,
 7,
 6,
 5,
 5,
 5,
 7,
 5,
 2,
 8,
 2,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 8,
 5,
 5,
 5,
 5,
 5,
 5,
 0,
 8,
 8,
 8,
 8,
 5,
 5,
 9,
 8,
 8,
 5,
 5,
 2,
 5,
 5,
 5,
 8,
 8,
 7,
 5,
 5,
 0,
 5,
 9,
 8,
 9,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 7,
 5,
 9,
 5,
 5,
 5,
 9,
 5,
 5,
 7,
 2,
 5,
 5,
 5,
 5,
 5,
 2,
 5,
 5,
 5,
 5,
 5,
 9,
 5,
 8,
 5,
 7,
 5,
 5,
 5,
 8,
 7,
 5,
 5,
 5,
 2,
 7,
 8,
 7,
 7,
 8,
 5,
 5,
 5,
 6,
 5,
 5,
 2,
 2,
 2,
 5,
 5,
 5,
 8,
 5,
 5,
 5,
 5,
 8,
 2,
 5,
 2,
 5,
 5,
 5,
 5,
 5,
 5,
 2,
 5,
 5,
 7,
 5,
 8,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 7,
 5,
 8,
 5,
 7,
 5,
 5,
 5,
 5,
 5,
 5,
 8,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 8,
 5,
 5,
 5,
 9,
 0,
 5,
 5,
 2,
 7,
 2,
 5,
