#### Libraries

In [1]:
import os
import numpy as np
from glob import glob
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import Conv2D, Activation, BatchNormalization, Lambda, concatenate, Add
from tensorflow.keras.layers import UpSampling2D, SeparableConv2D, Input, Concatenate, MaxPooling2D, Conv2DTranspose
from tensorflow.keras.layers import Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import Sequential
from tensorflow.keras.layers.experimental import preprocessing
from skimage.metrics import structural_similarity
from tensorflow import keras
#Check TensorFlow version:
print("TensorFlow Version: ", tf.__version__)
import tensorflow_model_optimization as tfmot
#Check if #### Process DatasetGPU is being used:
import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

TensorFlow Version:  2.6.0


ModuleNotFoundError: No module named 'keras'

#### Process Dataset

In [6]:
IMG_SIZE = 256
path_read_scene_and_flares = "/PATH/Merged_images/"
path_read_scenes = "/PATH/Merged_images_saturated/"
EPOCHS = 300
BATCH = 12
LR = 1e-3

def load_img(path_img, path_target, split):
    #Obtain all the file paths for the input images and output targets. 
    images = sorted(glob(os.path.join(path_img, "*")))
    target = sorted(glob(os.path.join(path_target, "*")))
    #Randomly select 10% of the entire Dataset as Validation data.  
    train_x, valid_x = train_test_split(images, test_size=int(split * len(images)), random_state=42)
    train_y, valid_y = train_test_split(target, test_size=int(split * len(images)), random_state=42)
    #Randomly select 10% of the entire Dataset as Testing data.
    train_x, test_x = train_test_split(train_x, test_size=int(split * len(images)), random_state=42)
    train_y, test_y = train_test_split(train_y, test_size=int(split * len(images)), random_state=42)
    return (train_x, train_y), (valid_x, valid_y), (test_x, test_y)

def read_image(path):
    #Get path from file image.
    path = path.decode()
    #Read image from path using OpenCV.
    img = cv2.imread(path)
    #Resize image to 255x255.
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    #Normalize image.
    img = img/255.0
    return img

def read_target(path):
    #Get path from file target.
    path = path.decode()
    #Read image from path as greyscale using OpenCV.
    target = cv2.imread(path)
    #Resize target to 255x255.
    target = cv2.resize(target, (IMG_SIZE, IMG_SIZE))
    #Define target at floating point.
    target = target/255.0
    return target

def call_convert(img, msk):
    def _convert(img, msk):
        #Call data processing functions.
        img = read_image(img)
        msk = read_target(msk)
        return img, msk
        
    img, msk = tf.numpy_function(_convert, [img, msk], [tf.float64, tf.float64])
    img.set_shape([IMG_SIZE, IMG_SIZE, 3])
    msk.set_shape([IMG_SIZE, IMG_SIZE, 3])

    return img, msk

def parse_dataset(img, msk, BATCH):
    
    data_set = tf.data.Dataset.from_tensor_slices((img, msk))
    data_set = data_set.map(call_convert)
    data_set = (data_set
    .shuffle(BATCH*100)
    .batch(BATCH)
    .map(lambda i, j: (i, j), num_parallel_calls=tf.data.AUTOTUNE)
    .prefetch(tf.data.AUTOTUNE))
    data_set = data_set.repeat()
    return data_set

(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = load_img(path_read_scene_and_flares, path_read_scenes, split=0.1)

print("Training data: ", len(train_x))
print("Validation data: ", len(valid_x))
print("Testing data: ", len(test_x))

Training data:  25426
Validation data:  3178
Testing data:  3178


**Quantization Aware Training**

In [8]:
EPOCHS = 5
BATCH = 12
LR = 1e-4

def SSIMLoss(y_true, y_pred):
    SSIMLoss = 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))
    return SSIMLoss

test_model = tf.keras.models.load_model("/PATH/FlareNet_xxx.h5", custom_objects={'SSIMLoss':SSIMLoss})

quantize_model = tfmot.quantization.keras.quantize_model
q_aware_model = quantize_model(test_model)
opt = tf.keras.optimizers.Nadam(LR)
metrics = [SSIMLoss]
q_aware_model.compile(loss=SSIMLoss, optimizer = opt, metrics=metrics)
q_aware_model.summary()

train_steps = len(train_x)//BATCH
if len(train_x) % BATCH != 0: train_steps += 1
valid_steps = len(valid_x)//BATCH
if len(valid_x) % BATCH != 0: valid_steps += 1

train_dataset = parse_dataset(train_x, train_y, BATCH)
valid_dataset = parse_dataset(valid_x, valid_y, BATCH)

#Quantization Aware Training:
history = q_aware_model.fit(train_dataset, validation_data=valid_dataset, epochs=EPOCHS, steps_per_epoch=train_steps, validation_steps=valid_steps)



Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
quantize_layer (QuantizeLayer)  (None, 256, 256, 3)  3           input_5[0][0]                    
__________________________________________________________________________________________________
quant_inputs (QuantizeWrapperV2 (None, 256, 256, 16) 483         quantize_layer[0][0]             
__________________________________________________________________________________________________
quant_max_pooling2d_16 (Quantiz (None, 128, 128, 16) 1           quant_inputs[0][0]               
____________________________________________________________________________________________

#### Convert Model to TFLite format to export.

In [9]:
#Convert to TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_tflite_model = converter.convert()

# Save the model.
path = "/PATH/FlareNet_quant_xxx.tflite"
with open(path, 'wb') as f:
    f.write(quantized_tflite_model)



**Evaluate Performance of Quantized Model**

In [None]:
import time
from scipy import stats
mae_predict = 0
mse_predict = 0
mae_input = 0
mse_input = 0
ssmi_input_vs_original = 0
ssmi_input_vs_predicted = 0

num_test = len(test_x)

interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

def read_image(path):
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = cv2.resize(x, (IMG_SIZE, IMG_SIZE))
    x = x/255.0
    return x

def read_target(path):
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = cv2.resize(x, (IMG_SIZE, IMG_SIZE))
    x = x/255.0
    return x

for i, (x, y) in enumerate(zip(test_x[:num_test], test_y[:num_test])):

    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]
    
    x = read_image(x)
    y = read_target(y)
    
    test_image = np.expand_dims(x, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)
    #Run inference.
    interpreter.invoke()
    y_pred = interpreter.get_tensor(output_index)
    y_pred = np.squeeze(y_pred)
    
    h, w, _ = x.shape
    white_line = np.ones((h, 20, 3))
    all_images = [x, white_line, y, white_line, y_pred]

    (score, diff) = structural_similarity(y, x, full=True, multichannel=True)
    diff = (diff * 255).astype("uint8")
    ssmi_input_vs_original = ssmi_input_vs_original + score
    #print("SSIM Original vs Input: {}".format(score))

    (score, diff) = structural_similarity(y, y_pred, full=True, multichannel=True)
    diff = (diff * 255).astype("uint8")
    ssmi_input_vs_predicted = ssmi_input_vs_predicted + score
    #print("SSIM Original vs Cleaned: {}".format(score))
    
    mae_predict = mae_predict + np.mean(np.abs(y_pred - y))
    mse_predict = mse_predict + np.mean((y_pred - y) ** 2)
    
    mae_input = mae_input + np.mean(np.abs(x - y))
    mse_input = mse_input + np.mean((x - y) ** 2)
    
    #image = np.concatenate(all_images, axis=1)
    #fig = plt.figure(figsize=(12, 12))
    #a = fig.add_subplot(1, 1, 1)
    #imgplot = plt.imshow(image)

print("Num Test:", num_test)

print("Average SSIM Input vs Original:", ssmi_input_vs_original/num_test)
print("Average SSIM Input vs Predicted:", ssmi_input_vs_predicted/num_test)

print("Average MAE Input vs Original:", mae_input/num_test)
print("Average MAE Input vs Predicted:", mae_predict/num_test)

print("Average MSE Input vs Original", mse_input/num_test)
print("Average MSE Input vs Predicted", mse_predict/num_test)

  (score, diff) = structural_similarity(y, x, full=True, multichannel=True)
  (score, diff) = structural_similarity(y, y_pred, full=True, multichannel=True)


**Difference in Memory Consumption between Original and Quantized Model**

In [12]:
import tempfile
# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(test_model)
float_tflite_model = float_converter.convert()

# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
    f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
    f.write(float_tflite_model)

print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))

INFO:tensorflow:Assets written to: C:\Users\David\AppData\Local\Temp\tmp6hjicrri\assets


INFO:tensorflow:Assets written to: C:\Users\David\AppData\Local\Temp\tmp6hjicrri\assets


Float model in Mb: 0.3625679016113281
Quantized model in Mb: 0.11336517333984375
