# Distilling the knowledge in a Neural Network

In [1]:
import numpy as np
import matplotlib.pyplot as plt

from data_utils import get_CIFAR10_data, get_MNIST_data, unpickle
from CNN import ThreeLayerConvNet
from model import myModel

from ResNet164.resnet164 import ResNet164
from ResNet164.utils import load_mnist
from VGG16.VGG16 import VGG16
from keras.models import Model

import h5py
import time
import pickle
import timeit, os, math

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

Using TensorFlow backend.


## MNIST data set

### ResNet164 as a big model

In [2]:
# Pick up the logit (input of the final SoftMax) which was predicted using ResNet 164 before
logits_train = unpickle('ResNet164/resnet164_logits_train.txt')

In [2]:
# Data used in ResNet164
(x_train, y_train), (x_val, y_val), (x_test, y_test) = load_mnist()

In [3]:
# Prepare data to train with small model
data = {'X_train': x_train.transpose(0,3,1,2).copy(), 'y_train': np.argmax(y_train,axis=1),
        'X_val': x_val.transpose(0,3,1,2).copy(), 'y_val': np.argmax(y_val,axis=1),
        'X_test': x_test.transpose(0,3,1,2).copy(), 'y_test': np.argmax(y_test,axis=1),
       }

In [4]:
# Search best temperature for a studen model
def searchDistill_temperature(data, num_epochs, batch_size, learning_rate, 
                              logits_teacher, class_out=None, class_in=None, save_name=None):
    T = np.arange(1.0,50.0,0.5)
    if class_out is not None or class_in is not None:
        accuracy = np.zeros((3,len(T)))
    else:
        accuracy = np.zeros((2,len(T)))
    for i,t in enumerate(T):
        tic = time.time()
        # Must call network first to reinitialize the parameters. 
        # Without this the model will use the trained parameters to train
        net = ThreeLayerConvNet(input_dim=(1, 28, 28),num_filters=28,filter_size=5,hidden_dim=50,
                        reg=0.001,weight_scale=1,dtype=np.float32)
        small_model = myModel(net, data,
                        num_epochs=num_epochs, batch_size=batch_size,
                        optimizer='adam',
                        optim_config={
                          'learning_rate': learning_rate,
                        },
                        temperature=t,logit_distill=logits_teacher,
                        verbose=False)
        small_model.train()
        
        accuracy[0,i] = t
        accuracy[1,i] = small_model.best_val_acc
        
        if class_out is not None:
            mask = data['y_val']==class_out
            accuracy[2,i] = small_model.check_accuracy(data['X_val'][mask], data['y_val'][mask])
        if class_in is not None:
            mask = True
            for cl in class_in:
                mask = np.logical_and(mask, data['y_val']!=cl)
            accuracy[2,i] = small_model.check_accuracy(data['X_val'][mask], data['y_val'][mask])
        toc = time.time()
        print('Temperature {}, Training acc {}, Validation acc {}, Execution time {}'.format(
            t,small_model.train_acc_history[-1], accuracy[1,i],toc-tic))
                
    # write to the file
    if save_name is not None: 
        name = save_name + '_temperature_search.txt'
    else:
        name = 'temperature_search.txt'

    with open(name, 'wb') as fp:
        pickle.dump(accuracy, fp)

In [27]:
searchDistill_temperature(data, num_epochs=1, batch_size=100,
                              learning_rate=1e-3, logits_teacher=logits_train)

Execution time for temperature 1.0 is 387.4820120334625
Execution time for temperature 1.5 is 404.292197227478
Execution time for temperature 2.0 is 412.21479296684265
Execution time for temperature 2.5 is 371.93810391426086
Execution time for temperature 3.0 is 366.4074969291687
Execution time for temperature 3.5 is 367.27303814888
Execution time for temperature 4.0 is 422.70680809020996
Execution time for temperature 4.5 is 395.60803604125977
Execution time for temperature 5.0 is 370.8509747982025
Execution time for temperature 5.5 is 372.39865708351135
Execution time for temperature 6.0 is 372.6793038845062
Execution time for temperature 6.5 is 372.3903729915619
Execution time for temperature 7.0 is 372.86075091362
Execution time for temperature 7.5 is 372.43779015541077
Execution time for temperature 8.0 is 373.5046968460083
Execution time for temperature 8.5 is 372.1003658771515
Execution time for temperature 9.0 is 372.278094291687
Execution time for temperature 9.5 is 374.258793

In [29]:
T = unpickle('ResNet164/temperature_search.txt')

In [30]:
print(T)

[[ 1.      1.5     2.      2.5     3.      3.5     4.      4.5     5.
   5.5     6.      6.5     7.      7.5     8.      8.5     9.      9.5
  10.     10.5    11.     11.5    12.     12.5    13.     13.5    14.
  14.5    15.     15.5    16.     16.5    17.     17.5    18.     18.5
  19.     19.5    20.     20.5    21.     21.5    22.     22.5    23.
  23.5    24.     24.5    25.     25.5    26.     26.5    27.     27.5
  28.     28.5    29.     29.5    30.     30.5    31.     31.5    32.
  32.5    33.     33.5    34.     34.5    35.     35.5    36.     36.5
  37.     37.5    38.     38.5    39.     39.5    40.     40.5    41.
  41.5    42.     42.5    43.     43.5    44.     44.5    45.     45.5
  46.     46.5    47.     47.5    48.     48.5    49.     49.5   ]
 [ 0.977   0.9768  0.9747  0.9735  0.9774  0.9782  0.98    0.9797  0.9795
   0.9798  0.9794  0.9781  0.9805  0.9811  0.9788  0.9795  0.9804  0.9808
   0.98    0.9786  0.9819  0.9798  0.9792  0.9817  0.9803  0.9796  0.9804
   0.9

In [33]:
# accuracy increase a little until T = 4.0 and then not change so much
i = np.argmax(T[1,:])
print(T[0,i])

11.0


### Train distill model with dataset omitting a specific class

In [4]:
# Get out digit 3 from data
data_not_3 = data.copy()
mask_not_3 = data_not_3['y_train']!=3
    
data_not_3['X_train'] = data_not_3['X_train'][mask_not_3]
data_not_3['y_train'] = data_not_3['y_train'][mask_not_3]

In [16]:
# Load ResNet which achieve 99.7% test accuracy
big_model = ResNet164()

big_model.compile()
# Load pre-trained model
big_model.load_weights('ResNet164/ResNet164.h5')

Instructions for updating:
Colocations handled automatically by placer.


In [6]:
# Get logits from teacher model for new dataset
logits_not_3 = big_model.predict(x_train[mask_not_3], verbose = 1)
print(logits_not_3[0])

[-6.2030087  -3.9134192  -6.317644    6.6796656  -2.957327   12.809154
 -3.4207807   0.18087192  2.3757007   0.04373608]


In [8]:
print(logits_not_3.shape)
with open('ResNet164/resnet164_logits_not_3.txt', 'wb') as fp:
    pickle.dump(logits_not_3, fp)

(44899, 10)


In [None]:
# For being faster pick up file of logit from resnet164
logits_not_3 = unpickle('ResNet164/resnet164_logits_not_3.txt')

In [9]:
searchDistill_temperature(data_not_3, num_epochs=1, batch_size=100,
                          learning_rate=1e-3, logits_teacher=logits_not_3, 
                          class_out=3, save_name='ResNet164/not_3')

Temperature 1.0, Validation accuracy 0.8715, Execution time 363.0232141017914
Temperature 1.5, Validation accuracy 0.8787, Execution time 335.13925313949585
Temperature 2.0, Validation accuracy 0.8755, Execution time 333.1119830608368
Temperature 2.5, Validation accuracy 0.8809, Execution time 329.9299228191376
Temperature 3.0, Validation accuracy 0.883, Execution time 331.1158969402313
Temperature 3.5, Validation accuracy 0.8834, Execution time 329.57569193840027
Temperature 4.0, Validation accuracy 0.8795, Execution time 329.90114998817444
Temperature 4.5, Validation accuracy 0.8811, Execution time 329.37685918807983
Temperature 5.0, Validation accuracy 0.8847, Execution time 330.04841232299805
Temperature 5.5, Validation accuracy 0.889, Execution time 332.824871301651
Temperature 6.0, Validation accuracy 0.8837, Execution time 332.21027302742004
Temperature 6.5, Validation accuracy 0.8883, Execution time 329.65273809432983
Temperature 7.0, Validation accuracy 0.8924, Execution time 

In [11]:
T = unpickle('ResNet164/not_3_temperature_search.txt')
print(T.T)

[[1.00000000e+00 8.71500000e-01 0.00000000e+00]
 [1.50000000e+00 8.78700000e-01 0.00000000e+00]
 [2.00000000e+00 8.75500000e-01 0.00000000e+00]
 [2.50000000e+00 8.80900000e-01 5.82524272e-03]
 [3.00000000e+00 8.83000000e-01 2.62135922e-02]
 [3.50000000e+00 8.83400000e-01 2.81553398e-02]
 [4.00000000e+00 8.79500000e-01 1.26213592e-02]
 [4.50000000e+00 8.81100000e-01 0.00000000e+00]
 [5.00000000e+00 8.84700000e-01 4.75728155e-02]
 [5.50000000e+00 8.89000000e-01 6.89320388e-02]
 [6.00000000e+00 8.83700000e-01 1.84466019e-02]
 [6.50000000e+00 8.88300000e-01 7.66990291e-02]
 [7.00000000e+00 8.92400000e-01 9.70873786e-02]
 [7.50000000e+00 8.81700000e-01 0.00000000e+00]
 [8.00000000e+00 9.03300000e-01 2.25242718e-01]
 [8.50000000e+00 8.88200000e-01 7.66990291e-02]
 [9.00000000e+00 8.88700000e-01 5.92233010e-02]
 [9.50000000e+00 8.81700000e-01 8.73786408e-03]
 [1.00000000e+01 8.86900000e-01 6.21359223e-02]
 [1.05000000e+01 8.86400000e-01 5.43689320e-02]
 [1.10000000e+01 8.92200000e-01 1.174757

In [13]:
# Get data with only digit 7 and 8
data_7_8 = data.copy()
mask_7_8 = np.logical_or(data['y_train']==7, data['y_train']==8)
    
data_7_8['X_train'] = data_7_8['X_train'][mask_7_8]
data_7_8['y_train'] = data_7_8['y_train'][mask_7_8]

In [14]:
print(sum(mask_7_8))

10017


In [17]:
# Get logits from teacher model for new dataset
logits_7_8 = big_model.predict(x_train[mask_7_8], verbose = 1)
print(logits_7_8[0])
print(logits_7_8.shape)

[-1.1341337  0.8800306  2.634239  -0.8169588 -1.7799497 -5.7832465
 -9.071484  16.43823   -4.005032   3.3581958]
(10017, 10)


In [18]:
with open('ResNet164/resnet164_logits_7_8.txt', 'wb') as fp:
    pickle.dump(logits_7_8, fp)

In [20]:
searchDistill_temperature(data_7_8, num_epochs=1, batch_size=50,
                          learning_rate=1e-3, logits_teacher=logits_7_8, 
                          class_in=[7,8], save_name='ResNet164/7_8')

Temperature 1.0, Training acc 0.9959069581711091, Validation acc 0.2083, Execution time 86.46539497375488
Temperature 1.5, Training acc 0.9916142557651991, Validation acc 0.2076, Execution time 86.74065589904785
Temperature 2.0, Training acc 0.9949086552860138, Validation acc 0.2089, Execution time 85.94187068939209
Temperature 2.5, Training acc 0.9962064490366377, Validation acc 0.2126, Execution time 90.26660513877869
Temperature 3.0, Training acc 0.9962064490366377, Validation acc 0.2106, Execution time 86.58881402015686
Temperature 3.5, Training acc 0.99570729759409, Validation acc 0.2091, Execution time 88.85907030105591
Temperature 4.0, Training acc 0.9953079764400519, Validation acc 0.2085, Execution time 86.62404584884644
Temperature 4.5, Training acc 0.9944095038434662, Validation acc 0.2086, Execution time 87.35391306877136
Temperature 5.0, Training acc 0.9963062793251473, Validation acc 0.2087, Execution time 88.70967197418213
Temperature 5.5, Training acc 0.9951083158630328

Temperature 40.0, Training acc 0.9936108615353898, Validation acc 0.2098, Execution time 95.30957198143005
Temperature 40.5, Training acc 0.9956074673055805, Validation acc 0.2086, Execution time 87.16808700561523
Temperature 41.0, Training acc 0.9956074673055805, Validation acc 0.2109, Execution time 90.91860389709473
Temperature 41.5, Training acc 0.9956074673055805, Validation acc 0.2096, Execution time 90.25294804573059
Temperature 42.0, Training acc 0.9945093341319756, Validation acc 0.2092, Execution time 88.49704217910767
Temperature 42.5, Training acc 0.9950084855745233, Validation acc 0.2103, Execution time 104.67101621627808
Temperature 43.0, Training acc 0.9947089947089947, Validation acc 0.2129, Execution time 92.67923307418823
Temperature 43.5, Training acc 0.9952081461515424, Validation acc 0.211, Execution time 89.91246581077576
Temperature 44.0, Training acc 0.9962064490366377, Validation acc 0.2107, Execution time 94.59250903129578
Temperature 44.5, Training acc 0.9960

Search 'optimal' bias

In [8]:
T = unpickle('ResNet164/temperature_search_7_8.txt')
print(T.T)

[[1.00000000e+00 2.08300000e-01 0.00000000e+00]
 [1.50000000e+00 2.07600000e-01 0.00000000e+00]
 [2.00000000e+00 2.08900000e-01 0.00000000e+00]
 [2.50000000e+00 2.12600000e-01 5.18921655e-03]
 [3.00000000e+00 2.10600000e-01 2.15162638e-03]
 [3.50000000e+00 2.09100000e-01 0.00000000e+00]
 [4.00000000e+00 2.08500000e-01 0.00000000e+00]
 [4.50000000e+00 2.08600000e-01 0.00000000e+00]
 [5.00000000e+00 2.08700000e-01 0.00000000e+00]
 [5.50000000e+00 2.09100000e-01 1.26566257e-04]
 [6.00000000e+00 2.08900000e-01 0.00000000e+00]
 [6.50000000e+00 2.09000000e-01 2.53132515e-04]
 [7.00000000e+00 2.08600000e-01 0.00000000e+00]
 [7.50000000e+00 2.08800000e-01 1.26566257e-04]
 [8.00000000e+00 2.08600000e-01 0.00000000e+00]
 [8.50000000e+00 2.08900000e-01 0.00000000e+00]
 [9.00000000e+00 2.08900000e-01 0.00000000e+00]
 [9.50000000e+00 2.08800000e-01 0.00000000e+00]
 [1.00000000e+01 2.08600000e-01 0.00000000e+00]
 [1.05000000e+01 2.11900000e-01 3.41728895e-03]
 [1.10000000e+01 2.10400000e-01 1.898493

In [14]:
T[:,T[2,:].argsort()[-5:][::-1]].T

array([[3.95000000e+01, 2.14600000e-01, 6.83457790e-03],
       [2.95000000e+01, 2.14200000e-01, 6.70801164e-03],
       [3.75000000e+01, 2.13000000e-01, 5.94861410e-03],
       [4.75000000e+01, 2.13300000e-01, 5.31578281e-03],
       [3.60000000e+01, 2.12800000e-01, 5.18921655e-03]])

In [11]:
def search_bias(data, num_epochs, batch_size, learning_rate, 
                logits_teacher, temperature, 
                class_out=3, class_in=None, save_name='not_3'):
    net = ThreeLayerConvNet(input_dim=(1, 28, 28),num_filters=28,filter_size=5,hidden_dim=50,
                    reg=0.001,weight_scale=1,dtype=np.float32)
    small_model = myModel(net, data,
                    num_epochs=num_epochs, batch_size=batch_size,
                    optimizer='adam',
                    optim_config={
                      'learning_rate': learning_rate,
                    },
                    temperature=temperature,logit_distill=logits_teacher,
                    verbose=False)
    small_model.train()
    
    bias_range = np.arange(-10.0,10.1,0.1)
    accuracy = np.zeros((3,len(bias_range)))
    
    if class_out is not None:
        mask = data['y_val']==class_out 
        mask_bias = class_out
    if class_in is not None:
        mask = True
        for cl in class_in:
            mask = np.logical_and(mask, data['y_val']!=cl)
        mask_bias = class_in
    
    bias = np.zeros_like(small_model.model.params['b3'])
    bias[mask_bias] = 0.1
    small_model.model.params['b3'][mask_bias] -= 10.0
    
    accuracy[0,0] = bias_range[0]
    accuracy[1,0] = small_model.check_accuracy(data['X_val'],data['y_val'])
    accuracy[2,0] = small_model.check_accuracy(data['X_val'][mask],data['y_val'][mask],sum(mask))
    
    for i in range(1,len(bias_range)): 
        small_model.model.params['b3'][mask_bias] += 0.1
        accuracy[0,i] = bias_range[i]
        accuracy[1,i] = small_model.check_accuracy(data['X_val'],data['y_val'])
        accuracy[2,i] = small_model.check_accuracy(data['X_val'][mask], data['y_val'][mask],sum(mask))
                
    # return to original bias
    small_model.model.params['b3'] -= 10.0
    
    # write to the file
    if save_name is not None: 
        name = save_name + '_bias_search.txt'
    else:
        name = 'bias_search.txt'

    with open(name, 'wb') as fp:
        pickle.dump(accuracy, fp)

In [5]:
# Get out digit 3 from data
data_not_3 = data.copy()
mask_not_3 = data_not_3['y_train']!=3
    
data_not_3['X_train'] = data_not_3['X_train'][mask_not_3]
data_not_3['y_train'] = data_not_3['y_train'][mask_not_3]

In [6]:
# For being faster pick up file of logit from resnet164
logits_not_3 = unpickle('ResNet164/resnet164_logits_not_3.txt')

In [12]:
search_bias(data_not_3, num_epochs=1, batch_size=100, learning_rate=1e-3, 
                logits_teacher=logits_not_3, temperature=17.0, 
                class_out=3, save_name='ResNet164/not_3')

In [13]:
bias = unpickle('ResNet164/not_3_bias_search.txt')
print(bias.T)

[[-1.00000000e+01  8.81300000e-01  0.00000000e+00]
 [-9.90000000e+00  8.81300000e-01  0.00000000e+00]
 [-9.80000000e+00  8.81300000e-01  0.00000000e+00]
 [-9.70000000e+00  8.81300000e-01  0.00000000e+00]
 [-9.60000000e+00  8.81300000e-01  0.00000000e+00]
 [-9.50000000e+00  8.81300000e-01  0.00000000e+00]
 [-9.40000000e+00  8.81300000e-01  0.00000000e+00]
 [-9.30000000e+00  8.81300000e-01  0.00000000e+00]
 [-9.20000000e+00  8.81300000e-01  0.00000000e+00]
 [-9.10000000e+00  8.81300000e-01  0.00000000e+00]
 [-9.00000000e+00  8.81300000e-01  0.00000000e+00]
 [-8.90000000e+00  8.81300000e-01  0.00000000e+00]
 [-8.80000000e+00  8.81300000e-01  0.00000000e+00]
 [-8.70000000e+00  8.81300000e-01  0.00000000e+00]
 [-8.60000000e+00  8.81300000e-01  0.00000000e+00]
 [-8.50000000e+00  8.81300000e-01  0.00000000e+00]
 [-8.40000000e+00  8.81300000e-01  0.00000000e+00]
 [-8.30000000e+00  8.81300000e-01  0.00000000e+00]
 [-8.20000000e+00  8.81300000e-01  0.00000000e+00]
 [-8.10000000e+00  8.81300000e-

In [14]:
# Get data with only digit 7 and 8
data_7_8 = data.copy()
mask_7_8 = np.logical_or(data['y_train']==7, data['y_train']==8)
    
data_7_8['X_train'] = data_7_8['X_train'][mask_7_8]
data_7_8['y_train'] = data_7_8['y_train'][mask_7_8]

In [19]:
# For being faster pick up file of logit from resnet164
logits_7_8 = unpickle('resnet164_logits_7_8.txt')

In [23]:
search_bias(data_7_8, num_epochs=1, batch_size=50, learning_rate=1e-3, 
                logits_teacher=logits_7_8, temperature=30.0, 
                class_out=None, class_in=[7,8], save_name='ResNet164/7_8')

In [15]:
bias = unpickle('ResNet164/bias_search_7_8.txt')
print(bias.T)

[[-1.00000000e+01  4.09100000e-01  2.73129984e-01]
 [-9.90000000e+00  4.07800000e-01  2.70851791e-01]
 [-9.80000000e+00  4.04700000e-01  2.66295406e-01]
 [-9.70000000e+00  4.03400000e-01  2.63637514e-01]
 [-9.60000000e+00  4.01100000e-01  2.59967093e-01]
 [-9.50000000e+00  3.98800000e-01  2.56549804e-01]
 [-9.40000000e+00  3.96700000e-01  2.53132515e-01]
 [-9.30000000e+00  3.95000000e-01  2.50474623e-01]
 [-9.20000000e+00  3.92400000e-01  2.46551069e-01]
 [-9.10000000e+00  3.89900000e-01  2.42880648e-01]
 [-9.00000000e+00  3.87500000e-01  2.39336793e-01]
 [-8.90000000e+00  3.85800000e-01  2.36425769e-01]
 [-8.80000000e+00  3.84000000e-01  2.33261612e-01]
 [-8.70000000e+00  3.82600000e-01  2.30730287e-01]
 [-8.60000000e+00  3.81000000e-01  2.28452095e-01]
 [-8.50000000e+00  3.78400000e-01  2.24528541e-01]
 [-8.40000000e+00  3.75700000e-01  2.20478420e-01]
 [-8.30000000e+00  3.72900000e-01  2.16554866e-01]
 [-8.20000000e+00  3.71000000e-01  2.13264144e-01]
 [-8.10000000e+00  3.67800000e-

In [16]:
bias[:,bias[2,:].argsort()[-5:][::-1]].T

array([[-10.        ,   0.4091    ,   0.27312998],
       [ -9.9       ,   0.4078    ,   0.27085179],
       [ -9.8       ,   0.4047    ,   0.26629541],
       [ -9.7       ,   0.4034    ,   0.26363751],
       [ -9.6       ,   0.4011    ,   0.25996709]])

# Test with VGG16 as teacher model

In [4]:
# Load teacher model
big_model = VGG16()
big_model.compile()
# Load pre-trained model
big_model.load_weights('VGG16/VGG16.h5')
# Remove softmax from VGG16
big_model_woSM = Model(inputs=big_model.model.input, outputs=big_model.model.layers[-2].output)

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.


In [5]:
# Predict to get logits
logits_train = big_model_woSM.predict(x_train, verbose = 1)



In [6]:
def SoftMax(s):
    # minus max to avoid large s case
    p = np.exp(s-np.expand_dims(np.max(s,axis=1),axis=1))/\
    np.expand_dims(np.exp(s-np.expand_dims(np.max(s,axis=1),axis=1)).sum(axis=1),axis=1)  # matrix of size NxK
    return p

In [7]:
print(logits_train[0])
y_pred_big = np.argmax(SoftMax(logits_train),axis=1)
y_true = np.argmax(y_train,axis=1)
print('Train accuracy of big model: {}'.format(np.mean(y_true==y_pred_big)))

[-3.1318357  -0.58660555 -6.1529765   7.2426195  -6.0088015  11.868018
 -0.74990106 -5.515557    0.960757    2.0794492 ]
Train accuracy of big model: 0.99914


In [8]:
print(logits_train.shape)
with open('VGG16/vgg16_logits_train.txt', 'wb') as fp:
    pickle.dump(logits_train, fp)

(50000, 10)


In [9]:
# Predict to get logits for test
logits_test = big_model_woSM.predict(x_test, verbose = 1)



In [10]:
print(logits_test[0])
y_pred_big = np.argmax(SoftMax(logits_test),axis=1)
y_true = np.argmax(y_test,axis=1)
print('Train accuracy of big model: {}'.format(np.mean(y_true==y_pred_big)))

[-10.137868    7.378173    4.5927176   2.4612358  -3.1164281  -9.72687
  -9.742915   21.030071   -8.707618    4.8068333]
Train accuracy of big model: 0.9968


In [5]:
# For being faster pick up file of logit from vgg16
logits_train = unpickle('VGG16/vgg16_logits_train.txt')

In [11]:
# Train small model with distilling
net = ThreeLayerConvNet(input_dim=(1, 28, 28),num_filters=28,filter_size=5,hidden_dim=50,
                        reg=0.001,weight_scale=1,dtype=np.float32)
small_model = myModel(net, data, 
                      num_epochs=1, batch_size=100,
                      optimizer='adam',
                      optim_config={
                          'learning_rate': 1e-3,},
                      temperature=5.0,logit_distill=logits_train,
                      verbose=True, print_every=100)
tic = time.time()
small_model.train()
toc = time.time()
print('Execution time: ',toc-tic)

(Iteration 1 / 500) loss: 31.080262
(Epoch 0 / 1) train acc: 0.191260; val_acc: 0.195300
(Iteration 101 / 500) loss: 11.306222
(Iteration 201 / 500) loss: 9.500901
(Iteration 301 / 500) loss: 9.689771
(Iteration 401 / 500) loss: 10.326211
(Epoch 1 / 1) train acc: 0.977080; val_acc: 0.975300
Execution time:  533.5676369667053


In [12]:
print('Test accuracy: {}'.format(small_model.check_accuracy(data['X_test'],data['y_test'])))

Test accuracy: 0.9765


In [6]:
searchDistill_temperature(data, num_epochs=1, batch_size=100,
                          learning_rate=1e-3, logits_teacher=logits_train, 
                          save_name='VGG16/')

Temperature 1.0, Training acc 0.96478, Validation acc 0.9625, Execution time 658.2517602443695
Temperature 1.5, Training acc 0.97676, Validation acc 0.9761, Execution time 502.8496437072754
Temperature 2.0, Training acc 0.9772, Validation acc 0.9777, Execution time 448.3176610469818
Temperature 2.5, Training acc 0.97134, Validation acc 0.9717, Execution time 459.4270541667938
Temperature 3.0, Training acc 0.98002, Validation acc 0.9794, Execution time 416.7173328399658
Temperature 3.5, Training acc 0.97548, Validation acc 0.9761, Execution time 391.3283898830414
Temperature 4.0, Training acc 0.97922, Validation acc 0.9777, Execution time 549.3627419471741
Temperature 4.5, Training acc 0.98122, Validation acc 0.9815, Execution time 539.9320690631866
Temperature 5.0, Training acc 0.97904, Validation acc 0.9775, Execution time 513.6396298408508
Temperature 5.5, Training acc 0.9791, Validation acc 0.9797, Execution time 464.8682780265808
Temperature 6.0, Training acc 0.9797, Validation acc

Temperature 44.0, Training acc 0.97028, Validation acc 0.9734, Execution time 480.2203960418701
Temperature 44.5, Training acc 0.9725, Validation acc 0.977, Execution time 489.51318287849426
Temperature 45.0, Training acc 0.97332, Validation acc 0.9754, Execution time 428.68076610565186
Temperature 45.5, Training acc 0.97074, Validation acc 0.9739, Execution time 489.9127428531647
Temperature 46.0, Training acc 0.97102, Validation acc 0.9746, Execution time 583.3919117450714
Temperature 46.5, Training acc 0.97208, Validation acc 0.9754, Execution time 535.7426490783691
Temperature 47.0, Training acc 0.97248, Validation acc 0.9747, Execution time 540.3633768558502
Temperature 47.5, Training acc 0.97332, Validation acc 0.9739, Execution time 586.9538519382477
Temperature 48.0, Training acc 0.97058, Validation acc 0.9722, Execution time 499.61859607696533
Temperature 48.5, Training acc 0.97408, Validation acc 0.9749, Execution time 464.3874228000641
Temperature 49.0, Training acc 0.97562,