In [None]:
# Install xxd if it is not available
! apt-get -qq install xxd
#Clone the repository
!git clone --recursive https://github.com/mlcommons/tiny.git
%cd tiny/benchmark/training/image_classification

In [2]:
%%capture
# Install the required dependencies to run the training
!pip install -r requirements.txt

In [None]:
# Download training dataset
!wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
!tar -xvf cifar-10-python.tar.gz

In [None]:
#Train the model
!python train.py

In [None]:
#Test the model 
!python test.py

In [None]:
# Convert the model to TFlite with quantization
import tensorflow as tf
import numpy as np
import train
import keras_model

model_name = keras_model.get_quant_model_name()
tfmodel_path = 'trained_models/' + model_name + ".h5"
tfmodel = tf.keras.models.load_model(tfmodel_path)
cifar_10_dir = 'cifar-10-batches-py'


def representative_dataset_generator():
    train_data, train_filenames, train_labels, test_data, test_filenames, test_labels, label_names = \
        train.load_cifar_10_data(cifar_10_dir)
    _idx = np.load('calibration_samples_idxs.npy')
    for i in _idx:
        sample_img = np.expand_dims(np.array(test_data[i], dtype=np.float32), axis=0)
        yield [sample_img]

converter = tf.lite.TFLiteConverter.from_keras_model(tfmodel)
tflite_model = converter.convert()
open('trained_models/' + model_name + '.tflite', 'wb').write(tflite_model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.representative_dataset = representative_dataset_generator
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_quant_model = converter.convert()
open('trained_models/' + model_name + '_quant.tflite', 'wb').write(tflite_quant_model)


In [None]:
#Test quantized model quality
!python tflite_test.py

In [None]:
#Generate .cc and .h file to be used for hardware implementation
import keras_model
model_name = keras_model.get_quant_model_name()

!xxd -i /content/tiny/benchmark/training/image_classification/trained_models/{model_name}_quant.tflite > /content/network_model_data_tmp.cc
!wget -O /content/convert_model_cc.py https://www.dropbox.com/s/1q8n4jm9fk4gzzf/convert_model_cc.py?dl=0
%cd /content
!python /content/convert_model_cc.py --network="resnet" --application="image_classify"
!rm -rf /content/network_model_data_tmp.cc