## 0. Imports and Paths


In [None]:
# Use tensorflow 2.5 for conversion!!!
!conda install tensorflow==2.5

In [1]:
import os
import tensorflow as tf
from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb

In [None]:
WORKSPACE_PATH = os.path.join('..\\', '02_Workspace')

LABELMAP_FILE = os.path.join(WORKSPACE_PATH, 'annotations', 'label_map.txt')

OUTPUT_MODEL_PATH = "licence_model.tflite"
OUTPUT_MODEL_PATH_JSON = "licence_model.json"

SAVED_MODEL_PATH = os.path.join(
    WORKSPACE_PATH, 'models', 'my_ssd_mobilenet', 'tfliteexport', 'saved_model')
TFLITE_MODEL_PATH = os.path.join(
    OUTPUT_MODEL_PATH, 'licence_model_with_metadata.tflite')

## 1. Convert Saved Model to TFLite Model


In [None]:
# Convert SaveModel to TFLite format
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH)
tflite_model = converter.convert()

# Save the SavedModel to a TFLite model
with open(TFLITE_MODEL_PATH, 'wb') as f:
    f.write(tflite_model)

## 2. Add Metadata to TFLite Model


In [None]:
# Create model meta data
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "my_ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8"
model_meta.description = (
    "This model detects car licence plates and is based on the ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8 model")
model_meta.version = "1.0"
model_meta.author = "Jin-Jin Lee"

In [None]:
# Create input tensor info
input_meta = _metadata_fb.TensorMetadataT()
input_meta.name = "image"
input_meta.description = (
    "Input image to be classified.\n"
    "One input: image, as a float32 tensor of shape[1, 320, 320, 3]"
    "The expected image is 320 x 320, with three channels (red, blue, and green) per pixel. Input image is *normalized*."
)

# Set content properties for feature
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = (
    _metadata_fb.ColorSpaceType.RGB)
input_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.ImageProperties)

input_normalization = _metadata_fb.ProcessUnitT()
input_normalization.optionsType = (
    _metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_normalization.options = _metadata_fb.NormalizationOptionsT()
input_normalization.options.mean = [127.5]
input_normalization.options.std = [127.5]

input_meta.processUnits = [input_normalization]

input_stats = _metadata_fb.StatsT()
input_stats.max = [255]
input_stats.min = [0]

input_meta.stats = input_stats

In [None]:
# Create first output tensor info (locations)
location_meta = _metadata_fb.TensorMetadataT()
location_meta.name = "locations"
location_meta.description = "The locations of the detected boxes."

# Set content properties for bounding box
location_meta.content = _metadata_fb.ContentT()
location_meta.content.contentProperties = _metadata_fb.BoundingBoxPropertiesT()
location_meta.content.contentProperties.index = [1, 0, 3, 2]
location_meta.content.contentProperties.type = _metadata_fb.BoundingBoxType.BOUNDARIES
location_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.BoundingBoxProperties


# Set range
location_meta.content.range = _metadata_fb.ValueRangeT()
location_meta.content.range.min = 2
location_meta.content.range.max = 2

# Index 0: Corresponds to the x-coordinate of the top-left corner (x_min).
# Index 1: Corresponds to the y-coordinate of the top-left corner (y_min).
# Index 2: Corresponds to the x-coordinate of the bottom-right corner (x_max).
# Index 3: Corresponds to the y-coordinate of the bottom-right corner (y_max).

In [None]:
# Create second output tensor info (classes)
classes_meta = _metadata_fb.TensorMetadataT()
classes_meta.name = "classes"
classes_meta.description = "The classes of the detected boxes."

classes_meta.content = _metadata_fb.ContentT()
classes_meta.content.contentProperties = _metadata_fb.FeaturePropertiesT()
classes_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.FeatureProperties

label_file = _metadata_fb.AssociatedFileT()
label_file.name = os.path.basename(LABELMAP_FILE)
label_file.description = "Labels for objects that the model can recognize."
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_VALUE_LABELS

classes_meta.associatedFiles = [label_file]

In [None]:
# Create third output tensor info (scores)
scores_meta = _metadata_fb.TensorMetadataT()
scores_meta.name = "scores"
scores_meta.description = "The scores of the detected boxes."

scores_meta.content = _metadata_fb.ContentT()
scores_meta.content.contentProperties = _metadata_fb.FeaturePropertiesT()
scores_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.FeatureProperties

In [None]:
# Create forth output tensor info (number of detections)
num_meta = _metadata_fb.TensorMetadataT()
num_meta.name = "number of detections"
num_meta.description = "The number of the detected boxes."

num_meta.content = _metadata_fb.ContentT()
num_meta.content.contentProperties = _metadata_fb.FeaturePropertiesT()
num_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.FeatureProperties

In [None]:
# Creates subgraph info.
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = [
    location_meta, classes_meta, scores_meta, num_meta]
model_meta.subgraphMetadata = [subgraph]

# Initializes a FlatBuffers Builder object with an initial size of 0
builder = flatbuffers.Builder(0)
builder.Finish(model_meta.Pack(builder),
               _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = builder.Output()

In [None]:
populator = _metadata.MetadataPopulator.with_model_file(OUTPUT_MODEL_PATH)
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files([LABELMAP_FILE])
populator.populate()

In [None]:
displayer = _metadata.MetadataDisplayer.with_model_file(OUTPUT_MODEL_PATH)

json_file = displayer.get_metadata_json()
# Write out the metadata as a json file
with open(OUTPUT_MODEL_PATH_JSON, "w") as f:
    f.write(json_file)