# 0. Import Library

In [1]:
import os
if not os.path.exists("./tfdet"):
    !git clone -q http://github.com/burf/tfdetection.git
    !mv ./tfdetection/tfdet ./tfdet
    !rm -rf ./tfdetection

In [2]:
#ignore warning
import warnings, os
warnings.filterwarnings(action = "ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

import tfdet

# 1. Build Detector

In [3]:
import numpy as np

image_shape = [256, 256]
n_class = 21
batch_size = 1

sample_data = np.random.random([batch_size, *image_shape, 3]).astype(np.float32)

In [4]:
save_path = "./model.h5"

x = tf.keras.layers.Input(shape = [*image_shape, 3], batch_size = batch_size) #init batch_size for tensorrt
feature = tfdet.model.backbone.resnet50(x, weights = None)

out = tfdet.model.detector.upernet(feature, n_class = n_class)
out = tf.keras.layers.UpSampling2D((4, 4))(out)
model = tf.keras.Model(x, out)
try:
    model.load_weights(save_path) #load weight
except:
    pass

model.predict(sample_data, verbose = 0).shape

(1, 256, 256, 21)

# 2. TF2TF-Lite

2-1. Convert

In [5]:
save_path = "./model.tflite"
tfdet.export.tf2lite(model, save_path, dtype = tf.float32)



'./model.tflite'

2-2. Load

In [6]:
convert_model = tfdet.export.load_tflite(save_path)
pred = convert_model(sample_data)
del convert_model
pred.shape

(1, 256, 256, 3)

# 3. TF2ONNX

3-1. Convert

In [7]:
save_path = "./model.onnx"
tfdet.export.tf2onnx(model, save_path, opset = 13)

'./model.onnx'

3-2. Load

In [8]:
convert_model = tfdet.export.load_onnx(save_path, gpu = 0)
pred = convert_model(sample_data)
del convert_model
pred.shape

(1, 256, 256, 21)

# 4. TF2TensorRT

3-1. convert(tf > onnx > tensorrt)

In [9]:
save_path = "./model.onnx"
tfdet.export.tf2onnx(model, save_path, opset = 13)

'./model.onnx'

In [10]:
save_trt_path = "./model.trt"
tfdet.export.onnx2trt(save_path, save_trt_path, dtype = "FP32", memory_limit = 1)

'./model.trt'

3-2. load

In [11]:
convert_model = tfdet.export.load_trt(save_trt_path)
pred = convert_model(sample_data)
del convert_model
pred.shape

(1, 256, 256, 21)