In [None]:
# suppress unnecessary warnings
import warnings
warnings.filterwarnings('ignore')

import os

import tensorflow as tf

import scripts.mnasnet.combined_tools as ct
from scripts import helpers as hp
from scripts.thresholds import eval_thresholds
from scripts.data.data_generator import DataGenerator
from scripts.trainer.trainer import Trainer

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # limit usage down to one 0-th GPU (use -1 to use CPU) 

## Prepare model data and environment

1. Specify the path to the __*.pickle__ file `weights_path` containing weights of the MNasNet model you want to quantize.
2. Specify the output folder `output_dir_name`, where the results produced by this notebook will be stored
 
Size of the input images will be retrieved automatically (we pick the last number from the model name). If the size can not be retrieved, you can specify the parameter `input_size` explicitly or the input size will be 224px by default.

In [None]:
weights_path = "../storage/mnasnet_original/mnasnet_1.0_128_weights.pickle"
output_dir_name = "../storage/mnasnet_processed/"

INPUT_SHAPE, ckpt, best_ckpt, thresholds_path, fakequant_pb = ct.prepare_mnasnet_environment(weights_path, 
                                                                                             output_dir_name, 
                                                                                             input_size=None)

## Prepare dataset

`train_labels` and `val_labels` are paths to the files containing lists in the following format:
```
rel/path/to_img_1.JPEG label_1
rel/path/to_img_2.JPEG label_2
...
```
Path to each image specified in these lists must be relative to the folders `train_images` and `val_images` correspondingly

In [None]:
train_images = "/path/to/train/images/folder/"
train_labels = "/path/to/train/data/list.txt"
val_images = "/path/to/validation/images/folder/"
val_labels = "/path/to/validation/data/list.txt"

# ------------------------------------------------------------------------------------------------
hp.check_paths_exist(train_images, train_labels, val_images, val_labels)

### Create Validation and Training data generators

In [None]:
train_gen = DataGenerator(train_images, train_labels, hp.googlenet_preprocess, imsize=INPUT_SHAPE[1])
valid_gen = DataGenerator(val_images, val_labels, hp.googlenet_preprocess, imsize=INPUT_SHAPE[1])

### Calibration data generator

The calibrator for the MNasNet uses lesser number of batched images and does not use the information about labels.
We wrap the train data generator into the `cal_data_generator` function to preprocess the `train_gen` data before feeding it for calibration.

In [None]:
number_of_calibration_batches = 1
imgs_per_calibration_batch = 100

def cal_data_generator():
    for imgs, _ in train_gen.generate_batches(imgs_per_calibration_batch, number_of_calibration_batches):
        yield imgs

## Calculating initial quantization thresholds from the original model

In [None]:
%%time

WEIGHTS = hp.load_pickle(weights_path)

float_model = ct.create_float_model(ct.create_input_node(INPUT_SHAPE), WEIGHTS)

with tf.Session(graph=float_model.graph) as sess:
    THRESHOLDS = eval_thresholds(sess, float_model.input_node, float_model.reference_nodes, cal_data_generator)

## Train

In [None]:
train_input_node = ct.create_input_node(INPUT_SHAPE)

float_model, quant_model = ct.create_adjustable_model(train_input_node, WEIGHTS, THRESHOLDS)

train_configuration_data = hp.load_json("settings_config/train.json")

# We override these parameters for better flexibility only, in order to automatically change 
# the output folders for different mnasnet models.
# You can comment the following two lines to use ckpt paths specified by the configuration file
train_configuration_data['save_dir'] = ckpt
train_configuration_data['best_ckpt_dir'] = best_ckpt

hp.clear_dir(ckpt)

my_trainer = Trainer(train_input_node.graph, 
                     train_gen, 
                     valid_gen,
                     train_input_node, 
                     float_model.output_node, 
                     quant_model.output_node, 
                     **train_configuration_data)  

with tf.Session(graph=train_input_node.graph) as sess:
    with sess.graph.as_default():
        
        sess.run(quant_model.initializer)
        
        # check accuracy of the original model
        print("Check accuracy of the original model ...")
        original_top1, _ = my_trainer.validate(sess, False)
        print("Original top 1:", original_top1)
        
        # check accuracy of the non-trained quantized model
        print("Check accuracy  of the non-trained quantized model ...")
        initial_top1, _ = my_trainer.validate(sess)
        print("Initial top 1:", initial_top1)
        
        # train thresholds
        print("Training thresholds of the quantized model ...")
        my_trainer.train(sess)
        
        # save trained thresholds
        adjusted_thresholds = sess.run(quant_model.adjusted_thresholds)
        hp.save_pickle(adjusted_thresholds, thresholds_path)

## Save model with fakequant nodes based on adjusted quantization thresholds

In [None]:
fakequant_model = ct.create_fakequant_model(ct.create_input_node(INPUT_SHAPE), WEIGHTS, adjusted_thresholds)
hp.save_pb(fakequant_model.graph, fakequant_pb)