In [None]:
import tensorrt as trt
import sys

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

def build_engine(onnx_file_path, max_batch_size=1, max_workspace_size=1<<30):
    """
    通过 TensorRT Builder 将 ONNX 模型转换为 TensorRT 引擎

    Parameters:
        onnx_file_path (str): ONNX 模型文件路径
        max_batch_size (int): 最大批次数
        max_workspace_size (int): 工作空间大小（单位字节），例如 1<<30 表示 1GB

    Returns:
        engine: 构建成功的 TensorRT engine 对象，如果构建失败则返回 None
    """
    # 使用 EXPLICIT_BATCH 模式创建网络
    network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    with trt.Builder(TRT_LOGGER) as builder, \
         builder.create_network(network_flags) as network, \
         trt.OnnxParser(network, TRT_LOGGER) as parser:

        builder.max_workspace_size = max_workspace_size  # 设置工作空间大小
        builder.max_batch_size = max_batch_size

        # 读取 ONNX 模型文件并解析
        with open(onnx_file_path, 'rb') as model_file:
            if not parser.parse(model_file.read()):
                print("ERROR: Failed to parse the ONNX file.")
                for error in range(parser.num_errors):
                    print(parser.get_error(error))
                return None

        # 构建 engine
        engine = builder.build_cuda_engine(network)
        return engine

def save_engine(engine, engine_file_path):
    """
    将 TensorRT engine 序列化并保存到文件中
    """
    with open(engine_file_path, "wb") as f:
        f.write(engine.serialize())

if __name__ == "__main__":
    onnx_model_path = "/home/bydguikong/yy_ws/PlanScope/onnx/model.onnx"      # 替换为你的 ONNX 模型路径
    engine_file_path = "/home/bydguikong/yy_ws/PlanScope/onnx/model.trt"       # 保存的 engine 文件名

    engine = build_engine(onnx_model_path, max_batch_size=1)
    if engine is None:
        print("Failed to build the engine!")
        sys.exit(1)
    else:
        save_engine(engine, engine_file_path)
        print("Engine built and saved successfully at:", engine_file_path)
