Check Numeric Overflow

In [None]:
import onnx
import numpy as np

def is_overflow_in_fp16(tensor):
    original_data = np.asarray(tensor, dtype=np.float32)
    fp16_data = original_data.astype(np.float16)
    back_converted_data = fp16_data.astype(np.float32)

    diff = np.abs(back_converted_data - original_data)
    return np.any(diff > 0.001)

def is_truncated_in_fp16(tensor):
    original_data = np.asarray(tensor, dtype=np.float32)

    return np.any(np.abs(original_data) <= 0.0000001)  # Check if the FP16 weight is zero


model = onnx.load("AsymFormer.onnx")  # Load ONNX model
overflow_list=[]
for node in model.graph.node:
    if node.input:  # Check network layer which has 'input'
        for input_name in node.input:
            weight = next((init for init in model.graph.initializer if init.name == input_name), None)
            if weight is not None:  # Make sure the layer has 'weight'
                weights = onnx.numpy_helper.to_array(weight)
                if is_overflow_in_fp16(weights):
                    print(f"Node {node.name} ({node.op_type}): Weight overflow in fp16")
                    overflow_list.append(node.name)
                if is_truncated_in_fp16(weights):
                    print(f"Node {node.name} ({node.op_type}): Weight truncated in fp16")
                    overflow_list.append(node.name)

print('个数：',len(overflow_list))

Generate Mixed Precision TensorRT Model

In [None]:
import os
import tensorrt as trt

def build_engine(onnx_file_path, engine_file_path, overflow_list, flop=16):
    trt_logger = trt.Logger(trt.Logger.WARNING)  # trt.Logger.ERROR
    builder = trt.Builder(trt_logger)
    builder_config = builder.create_builder_config()
    
    network = builder.create_network(
        1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    )
    
    parser = trt.OnnxParser(network, trt_logger)
    # parse ONNX
    with open(onnx_file_path, 'rb') as model:
        if not parser.parse(model.read()):
            print('ERROR: Failed to parse the ONNX file.')
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            return None
    print("Completed parsing ONNX file")

    # default = 1 for fixed batch size
    builder.max_batch_size = 1
    # set mixed flop computation for the best performance
    
    builder_config.set_flag(trt.BuilderFlag.FP16)

    if os.path.isfile(engine_file_path):
        try:
            os.remove(engine_file_path)
        except Exception:
            print("Cannot remove existing file: ",
                engine_file_path)

    print("Creating Tensorrt Engine")

    for layer in network:
        for layer_name in overflow_list:
            if layer_name in layer.name:
                layer.precision = trt.float32
                print(f'Network Layer: {layer.name}, {layer.type}, {layer.precision}, is_set: {layer.precision_is_set}')

    config = builder.create_builder_config()
    config.set_tactic_sources(1 << int(trt.TacticSource.CUBLAS))
    config.max_workspace_size = 2 << 30
    config.set_flag(trt.BuilderFlag.FP16)
    config.set_flag(trt.BuilderFlag.STRICT_TYPES)

    print('config.flags: ', config.flags)
    
    engine = builder.build_engine(network, config)
    with open(engine_file_path, "wb") as f:
        f.write(engine.serialize())
    print("Serialized Engine Saved at: ", engine_file_path)
    return engine


In [None]:
ONNX_SIM_MODEL_PATH = 'AsymFormer.onnx'
TENSORRT_ENGINE_PATH_PY = 'AsymFormer.engine'

build_engine(ONNX_SIM_MODEL_PATH, TENSORRT_ENGINE_PATH_PY, overflow_list)