In [1]:
import tensorflow as tf
from tensorflow.keras.models import load_model,Model
from tensorflow.keras.layers import Input
import glob
import os
import numpy as np
import skimage.io as io

In [2]:
model=load_model('.\\model.h5')
config=model.get_config()
print(config["layers"][0]["config"]["batch_input_shape"])

(None, 256, 256, 1)


In [3]:
def adjustData(img,mask,flag_multi_class,num_class):
    if(flag_multi_class):
        img = img / 255
        mask = mask[:,:,:,0] if(len(mask.shape) == 4) else mask[:,:,0]
        new_mask = np.zeros(mask.shape + (num_class,))
        for i in range(num_class):
            #for one pixel in the image, find the class in mask and convert it into one-hot vector
            #index = np.where(mask == i)
            #index_mask = (index[0],index[1],index[2],np.zeros(len(index[0]),dtype = np.int64) + i) if (len(mask.shape) == 4) else (index[0],index[1],np.zeros(len(index[0]),dtype = np.int64) + i)
            #new_mask[index_mask] = 1
            new_mask[mask == i,i] = 1
        new_mask = np.reshape(new_mask,(new_mask.shape[0],new_mask.shape[1]*new_mask.shape[2],new_mask.shape[3])) if flag_multi_class else np.reshape(new_mask,(new_mask.shape[0]*new_mask.shape[1],new_mask.shape[2]))
        mask = new_mask
    elif(np.max(img) > 1):
        img = img / 255
        mask = mask /255
        mask[mask > 0.5] = 1
        mask[mask <= 0.5] = 0
    return (img,mask)


def genTrainNpy(image_path,mask_path,flag_multi_class = False,num_class = 2,image_prefix = "image",mask_prefix = "mask",image_as_gray = True,mask_as_gray = True):
    image_name_arr = glob.glob(os.path.join(image_path,"%s*.png"%image_prefix))
    image_arr = []
    mask_arr = []
    for index,item in enumerate(image_name_arr):
        img = io.imread(item,as_gray = image_as_gray)
        img = np.reshape(img,img.shape + (1,)) if image_as_gray else img
        mask = io.imread(item.replace(image_path,mask_path).replace(image_prefix,mask_prefix),as_gray = mask_as_gray)
        mask = np.reshape(mask,mask.shape + (1,)) if mask_as_gray else mask
        img,mask = adjustData(img,mask,flag_multi_class,num_class)
        image_arr.append(img)
        mask_arr.append(mask)
    image_arr = np.array(image_arr)
    mask_arr = np.array(mask_arr)
    return image_arr,mask_arr


In [4]:
imgs_train,imgs_mask_train = genTrainNpy("data/membrane/train/aug/","data/membrane/train/aug/")
imgs_train=imgs_train.astype('float32')
imgs_mask_train=imgs_mask_train.astype('float32')

In [5]:
def representative_dataset():
  for data in tf.data.Dataset.from_tensor_slices((imgs_train)).batch(1).take(30):
    yield [data]

In [6]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
fixed_input=Input((256,256,1))
fixed_model=Model(fixed_input,model(fixed_input))
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8  # or tf.uint8
converter.inference_output_type = tf.int8  # or tf.uint8
tflite_quant_model = converter.convert()
open("model.tflite","wb").write(tflite_quant_model)



INFO:tensorflow:Assets written to: C:\Users\adikr\AppData\Local\Temp\tmprgim6jda\assets


INFO:tensorflow:Assets written to: C:\Users\adikr\AppData\Local\Temp\tmprgim6jda\assets


1914000

In [7]:
interpreter=tf.lite.Interpreter(model_content=tflite_quant_model)
input_details=interpreter.get_input_details()
output_details=interpreter.get_output_details()

In [8]:
input_details

[{'name': 'serving_default_input_2:0',
  'index': 0,
  'shape': array([  1, 256, 256,   1]),
  'shape_signature': array([ -1, 256, 256,   1]),
  'dtype': numpy.int8,
  'quantization': (0.003921568859368563, -128),
  'quantization_parameters': {'scales': array([0.00392157], dtype=float32),
   'zero_points': array([-128]),
   'quantized_dimension': 0},
  'sparsity_parameters': {}}]

In [9]:
output_details

[{'name': 'StatefulPartitionedCall:0',
  'index': 55,
  'shape': array([  1, 256, 256,   1]),
  'shape_signature': array([ -1, 256, 256,   1]),
  'dtype': numpy.int8,
  'quantization': (0.00390625, -128),
  'quantization_parameters': {'scales': array([0.00390625], dtype=float32),
   'zero_points': array([-128]),
   'quantized_dimension': 0},
  'sparsity_parameters': {}}]