In [12]:
# AlexNet
# source https://towardsdatascience.com/implementing-alexnet-cnn-architecture-using-tensorflow-2-0-and-keras-2113e090ad98
# baseline cnn model for AlexNet
from sklearn.model_selection import KFold

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import SGD
from tensorflow.python.framework import type_spec as type_spec_module
import os
import numpy as np
import time

from ode.tf_pruner import TfPruner
from ode.tf_quantizer import TfQuantizer


# load train and test dataset
def load_dataset():
    # load dataset
    (trainX, trainY), (testX, testY) = mnist.load_data()
    # reshape dataset to have a single channel
    trainX = trainX.reshape((trainX.shape[0], 28, 28, 1))
    testX = testX.reshape((testX.shape[0], 28, 28, 1))
    # one hot encode target values
    trainY = to_categorical(trainY)
    testY = to_categorical(testY)

    trainX = trainX[:1000]
    trainY = trainY[:1000]
    testX = testX[:1000]
    testY = testY[:1000]

    print(f'trainX.shape: {trainX.shape}')
    return trainX, trainY, testX, testY


# scale pixels
def prep_pixels(train, test):
    # convert from integers to floats
    train_norm = train.astype('float32')
    test_norm = test.astype('float32')
    # normalize to range 0-1
    train_norm = train_norm / 255.0
    test_norm = test_norm / 255.0
    # return normalized images
    return train_norm, test_norm


def compile_model(model):
    opt = SGD(lr=0.01, momentum=0.9)
    model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])


def define_model():
    """Model with activation layers"""
    model = keras.Sequential() #.to(device=device)
    model.add(keras.layers.Conv2D(32, (3, 3), kernel_initializer='he_uniform', input_shape=(28, 28, 1)))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.Activation('relu'))
    model.add(keras.layers.MaxPooling2D((2, 2)))
    model.add(keras.layers.Conv2D(64, (3, 3), kernel_initializer='he_uniform'))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.Activation('relu'))
    model.add(keras.layers.Conv2D(32, (3, 3), kernel_initializer='he_uniform'))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.Activation('relu'))
    model.add(keras.layers.MaxPooling2D((2, 2)))
    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(100, activation='relu', kernel_initializer='he_uniform'))
    model.add(keras.layers.Dense(10, activation='softmax'))

    # compile model
    compile_model(model)

    return model


current_milli_time = lambda: int(round(time.time() * 1000))

# prepare cross validation
kfold = KFold(5, shuffle=True, random_state=1)

train_ds_X, train_ds_Y, test_ds_X, test_ds_Y = load_dataset()
train_ds_X, test_ds_X = prep_pixels(train_ds_X, test_ds_X)

# define model
model = define_model()

# enumerate splits
for train_ix, test_ix in kfold.split(train_ds_X):
    
    # select rows for train and test
    trainX, trainY, testX, testY = train_ds_X[train_ix], train_ds_Y[train_ix], test_ds_X[test_ix], test_ds_Y[test_ix]
    # fit model
    history = model.fit(trainX, trainY, epochs=10, batch_size=32, validation_data=(testX, testY), verbose=0)
    # evaluate model

    _, acc = model.evaluate(testX, testY, verbose=0)

    print('> %.3f' % (acc * 100.0))

    latest_trainX = trainX
    latest_trainY = trainY
    latest_testX = testX
    latest_testY = testY

t2 = 0.0
for i in range(0, len(latest_testX)):
    
    img = latest_testX[i]
    img = (np.expand_dims(img,0))

    t1 = current_milli_time()
    prediction = model.predict(img)
    t2 += current_milli_time() - t1

t2 /= float(len(latest_testX))

print('> Original Model Accuracy: %.3f' % (acc * 100.0))
print('> Original Model Inference Time: {}'.format(t2))

model.summary()
model.save('mnist_base.h5')

# Now, let's prune the model...
pruner = TfPruner(model)
pruned_model = pruner.st_prune(pct=0.8)
pruned_model.summary()

compile_model(pruned_model)

for train_ix, test_ix in kfold.split(train_ds_X):
    
    # select rows for train and test
    trainX, trainY, testX, testY = train_ds_X[train_ix], train_ds_Y[train_ix], test_ds_X[test_ix], test_ds_Y[test_ix]
    # fit model
    history = pruned_model.fit(trainX, trainY, epochs=10, batch_size=32, validation_data=(testX, testY), verbose=0)
    # evaluate model

    _, acc = pruned_model.evaluate(testX, testY, verbose=0)

    print('> %.3f' % (acc * 100.0))

t2 = 0.0
for i in range(0, len(latest_testX)):
    
    img = latest_testX[i]
    img = (np.expand_dims(img,0))

    t1 = current_milli_time()
    prediction = model.predict(img)
    t2 += current_milli_time() - t1

t2 /= float(len(latest_testX))

print('> Pruned Model Accuracy: %.3f' % (acc * 100.0))
print('> Pruned Model Inference Time: {}'.format(t2))

pruned_model.save('mnist_pruned.h5')

q = TfQuantizer(pruned_model)
quantized_model, err = q.post_training_quantization(pruned_model, tf.int8, testX)
if not err:
    q.save_quantized_model(quantized_model, 'post_quant.tflite')
else:
    print('Could not quantize the model...')
    exit()

trainX.shape: (1000, 28, 28, 1)
> 86.500
> 94.500
> 95.500
> 94.500
> 96.000


INFO:ode.tf_struct_pruning_engine:(in)conv2d_21
	Shape: (None, 28, 28, 1) 
	dtype: <dtype: 'float32'> 
	name: conv2d_21_input:0
INFO:ode.tf_struct_pruning_engine:(out)conv2d_21
	Shape: (None, 26, 26, 32) 
	dtype: <dtype: 'float32'> 
	name: conv2d_21/BiasAdd:0
INFO:ode.tf_struct_pruning_engine:(in)batch_normalization_21
	Shape: (None, 26, 26, 32) 
	dtype: <dtype: 'float32'> 
	name: conv2d_21/BiasAdd:0
INFO:ode.tf_struct_pruning_engine:(out)batch_normalization_21
	Shape: (None, 26, 26, 32) 
	dtype: <dtype: 'float32'> 
	name: batch_normalization_21/cond/Identity:0
INFO:ode.tf_struct_pruning_engine:(in)activation_21
	Shape: (None, 26, 26, 32) 
	dtype: <dtype: 'float32'> 
	name: batch_normalization_21/cond/Identity:0
INFO:ode.tf_struct_pruning_engine:(out)activation_21
	Shape: (None, 26, 26, 32) 
	dtype: <dtype: 'float32'> 
	name: activation_21/Relu:0
INFO:ode.tf_struct_pruning_engine:(in)max_pooling2d_14
	Shape: (None, 26, 26, 32) 
	dtype: <dtype: 'float32'> 
	name: activation_21/Relu:0
IN

Original Inference Time: 22.155
Model: "sequential_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_21 (Conv2D)           (None, 26, 26, 32)        320       
_________________________________________________________________
batch_normalization_21 (Batc (None, 26, 26, 32)        128       
_________________________________________________________________
activation_21 (Activation)   (None, 26, 26, 32)        0         
_________________________________________________________________
max_pooling2d_14 (MaxPooling (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 11, 11, 64)        18496     
_________________________________________________________________
batch_normalization_22 (Batc (None, 11, 11, 64)        256       
_________________________________________________________________
activation_22 (Activat

INFO:ode.tf_struct_pruning_engine:conv2d_22 has 0 paths to one of ['Add', 'Concat']
INFO:ode.tf_struct_pruning_engine:Computing norm for conv2d_22 with shape=(64, 3, 3, 32)
INFO:ode.tf_struct_pruning_engine:conv2d_22 has the following n=5 min items: [(19.261688, 59), (19.616585, 43), (19.686619, 2), (19.815145, 52), (19.843084, 53)]
INFO:ode.tf_struct_pruning_engine:Paths: ['batch_normalization_22', 'activation_22', 'conv2d_23']
INFO:ode.tf_struct_pruning_engine:Cs: 51
INFO:ode.tf_struct_pruning_engine:conv2d_22 - Expected shape: (3, 3, 7, 13), weight shape: (3, 3, 7, 13)
INFO:ode.tf_struct_pruning_engine:Randy Marsh: (1, 13, 13, 7)
INFO:ode.tf_struct_pruning_engine:batch_normalization_22 - BatchNorm Weights Shape: 4
INFO:ode.tf_struct_pruning_engine:batch_normalization_22 - Randy Marsh: (1, 11, 11, 13)
INFO:ode.tf_struct_pruning_engine:conv2d_23 - Expected shape: (3, 3, 13, 32), weight shape: (3, 3, 13, 32)
INFO:ode.tf_struct_pruning_engine:conv2d_23 - Randy Marsh: (1, 11, 11, 13)
INF

Model: "sequential_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_21 (Conv2D)           (None, 26, 26, 7)         70        
_________________________________________________________________
batch_normalization_21 (Batc (None, 26, 26, 7)         28        
_________________________________________________________________
activation_21 (Activation)   multiple                  0         
_________________________________________________________________
max_pooling2d_14 (MaxPooling multiple                  0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 11, 11, 13)        832       
_________________________________________________________________
batch_normalization_22 (Batc (None, 11, 11, 13)        52        
_________________________________________________________________
activation_22 (Activation)   multiple                 

INFO:ode.tf_quantizer:Running post-training quantization...


> 95.500
> Pruned model acc = 95.500
INFO:tensorflow:Assets written to: C:\Users\lbath\AppData\Local\Temp\tmp8xykugct\assets


INFO:tensorflow:Assets written to: C:\Users\lbath\AppData\Local\Temp\tmp8xykugct\assets


INFO:tensorflow:Assets written to: C:\Users\lbath\AppData\Local\Temp\tmpfj36uby8\assets


INFO:tensorflow:Assets written to: C:\Users\lbath\AppData\Local\Temp\tmpfj36uby8\assets
INFO:ode.tf_quantizer:Temp orig model file saved to: C:\Users\lbath\AppData\Local\Temp\tmpde_4ryta.tflite
INFO:ode.tf_quantizer:Temp model file saved to: C:\Users\lbath\AppData\Local\Temp\tmpcu07flga.tflite
ERROR:ode.tf_quantizer:Could not save model due to [Errno 13] Permission denied: '/post_quant.tflite'.


Model: "sequential_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_21 (Conv2D)           (None, 26, 26, 7)         70        
_________________________________________________________________
batch_normalization_21 (Batc (None, 26, 26, 7)         28        
_________________________________________________________________
activation_21 (Activation)   multiple                  0         
_________________________________________________________________
max_pooling2d_14 (MaxPooling multiple                  0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 11, 11, 13)        832       
_________________________________________________________________
batch_normalization_22 (Batc (None, 11, 11, 13)        52        
_________________________________________________________________
activation_22 (Activation)   multiple                 