In [1]:
import os

import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as trt

In [2]:
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) > 0:
    for device in physical_devices:
        tf.config.experimental.set_memory_growth(device, True)
        print('{} memory growth: {}'.format(device, tf.config.experimental.get_memory_growth(device)))
else:
    print("Not enough GPU hardware devices available")

PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU') memory growth: True


In [3]:
MODEL_BASE_PATH = '../models/'

In [4]:
class MobileNetV2Model(tf.keras.Model):
    def __init__(self, name: str):
        shape = (224, 224, 3)
        base_model = tf.keras.applications.MobileNetV2(input_shape=shape, include_top=True, weights='imagenet')
        inputs = tf.keras.Input(shape)
        outputs = base_model(inputs)
        
        super().__init__(inputs=inputs, outputs=outputs, name=name)

    @tf.function(
        input_signature=[tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.float32, name="imgs")]
    )
    def serving_fn(self, imgs) -> tf.Tensor:
        return self(imgs)

    def save(self):
        tf_saved_model_path = os.path.join(MODEL_BASE_PATH, self.name, '0')
        
        signatures = {"serving_default": self.serving_fn}
        tf.saved_model.save(self, tf_saved_model_path, signatures=signatures)
        
        params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
            precision_mode=trt.TrtPrecisionMode.FP16,
            max_batch_size=10
        )
        converter = trt.TrtGraphConverterV2(input_saved_model_dir=tf_saved_model_path, conversion_params=params)
        converter.convert()
        converter.save(os.path.join(MODEL_BASE_PATH, '{}_trt'.format(model.name), '0'))

In [5]:
model = MobileNetV2Model('test1')

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224.h5


In [6]:
model.save()

INFO:tensorflow:Assets written to: ../models/test1/0/assets
INFO:tensorflow:Linked TensorRT version: (7, 1, 3)
INFO:tensorflow:Loaded TensorRT version: (7, 1, 3)
INFO:tensorflow:Could not find TRTEngineOp_0_0 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.
INFO:tensorflow:Assets written to: ../models/test1_trt/0/assets
