# Optimize baseline with quantization and pruning

In [1]:
import numpy as np
import pandas as pd
import keras
import tensorflow as tf
import cv2

import warnings
warnings.filterwarnings('ignore')

In [2]:
def tflite_predict(interp, sample, input_index, output_index):
    
    ''' lite model predict function'''
    
    interp.set_tensor(input_index, sample)
    interp.invoke()
    return interp.get_tensor(output_index)

## Code of pruning and opt operations

In [None]:
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

batch_size = 16
epochs = 32
end_step = 28

pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.5,
        final_sparsity=0.8,
        begin_step=0,
        end_step=end_step
    )
}

model = keras.models.load_model('''there was some model''')
model_prn = prune_low_magnitude(model, **pruning_params)

model_prn.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy']
)

logdir = tempfile.mkdtemp()

callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir=logdir)
]

model_prn.fit(
    imgs_train, y_imgs_train,
    batch_size=batch_size, epochs=epochs,
    callbacks=callbacks,
    validation_data=[imgs_test, y_imgs_test]
)

    Baseline test accuracy: 0.9300000071525574
    Pruned test accuracy: 0.9075000286102295


then convert model into tflite format

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(export_model)
lite_model_prnd = converter.convert() # return a bite-file

# save pruned lite version of model
with open('tf_models/cnn_cifar_cars__prnd_lite/cnn_lite.tflite', 'wb') as f:
    f.write(lite_model_prnd)

In [None]:
converter = tf.lite.TFLiteConverter.from_saved_model('tf_models/cnn_cifar_cars__prnd') # base on keras model

def representative_dataset():
    for data in tf.data.Dataset.from_tensor_slices((imgs_train)).batch(1).take(500):
        yield [tf.dtypes.cast(data, tf.float32)]

    # quantization
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
converter.representative_dataset = representative_dataset

lq_model = converter.convert() # convert

# save
with open('tf_models/cnn_cifar_cars__prnd_lite/cnn_lite_quant.tflite', 'wb') as f:
    f.write(lq_model)

### load this model

In [5]:
interp_lite = tf.lite.Interpreter('tf_models/cnn_cifar_cars__prnd_lite/cnn_lite_quant.tflite')
interp_lite.allocate_tensors()

input_details_lq = interp_lite.get_input_details()
output_details_lq = interp_lite.get_output_details()
input_index_lq = input_details_lq[0]['index']
output_index_lq = output_details_lq[0]['index']

___

# New processing time

In [3]:
VIDEO_PATH = '../../another_datasets/car_det_vids/traffic01.mp4'

In [6]:
%%time
video = cv2.VideoCapture(VIDEO_PATH)
subs = cv2.createBackgroundSubtractorMOG2(history=100, varThreshold=240, detectShadows=10)

stop = 0
bxs = []
while True:
    
    ok, frame = video.read()
    if not ok:
        break
    mask = subs.apply(frame)
    
    '''find countours'''
    
    cs = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[1]
    dets = []
    cars_coors = []

    for contour in cs:
        area = cv2.contourArea(contour)

        if area > 700:
            x, y, w, h = cv2.boundingRect(contour)
            
            img = frame[y:y+h,x:x+w][...,::-1]
            detd_imgs.append(img)
            try:
                img = cv2.resize(img, (32,32), interpolation=cv2.INTER_AREA)

                # lite-quant model
                pred = tflite_predict(interp_lite, img.reshape(-1,32,32,3).astype('float32'),
                                      input_index=input_index_lq, output_index=output_index_lq)
                if pred > .1:
                    stop += 1
                    cars_coors.append([x, y, w, h])
                    
            except:
                pass

            if w < 300 and h < 300:
                dets.append([x, y, w, h])
    
    '''print mask area > 700 rectangle'''
    for box in dets:
        x, y, w, h = box
        bxs.append(box)
        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 255), 1)
        
    '''print model detected rectangle'''
    for car in cars_coors:
        x, y, w, h = car
        bxs.append(car)
        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 0, 255), 2)
    

    
    cv2.imshow('video', frame)

    if cv2.waitKey(1) & 0xFF == ord(' '):
        break
    
video.release()
cv2.destroyAllWindows()

Wall time: 22.5 s


### compare with old model

In [7]:
old_model = keras.models.load_model('tf_models/cnn_cifar_cars')

old_model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=['accuracy']
)

In [8]:
%%time
video = cv2.VideoCapture(VIDEO_PATH)
subs = cv2.createBackgroundSubtractorMOG2(history=100, varThreshold=240, detectShadows=10)

stop = 0
bxs = []
while True:
    
    ok, frame = video.read()
    if not ok:
        break
    mask = subs.apply(frame)
    
    '''find countours'''
    
    cs = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[1]
    dets = []
    cars_coors = []

    for contour in cs:
        area = cv2.contourArea(contour)

        if area > 700:
            x, y, w, h = cv2.boundingRect(contour)
            
            img = frame[y:y+h,x:x+w][...,::-1]
            detd_imgs.append(img)
            try:
                img = cv2.resize(img, (32,32), interpolation=cv2.INTER_AREA)

                # keras model
                pred = old_model.predict(img.reshape(-1,32,32,3).astype('float32'))
                
                if pred > .1:
                    stop += 1
                    cars_coors.append([x, y, w, h])
                    
            except:
                pass

            if w < 300 and h < 300:
                dets.append([x, y, w, h])
    
    '''print mask area > 700 rectangle'''
    for box in dets:
        x, y, w, h = box
        bxs.append(box)
        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 255), 1)
        
    '''print model detected rectangle'''
    for car in cars_coors:
        x, y, w, h = car
        bxs.append(car)
        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 0, 255), 2)
    

    
    cv2.imshow('video', frame)

    if cv2.waitKey(1) & 0xFF == ord(' '):
        break
    
video.release()
cv2.destroyAllWindows()

Wall time: 1min 16s
