Skip to content

OneFlow XRT如何使用低精度计算

Houjiang Chen edited this page Jul 11, 2022 · 5 revisions

大多数硬件都提供比单精度(32bit)更低的计算精度,通常情况下低精度具有较高的峰值计算性能,降低功耗的同时还能减少内存占用。OneFlow-XRT支持开启FP16自动混合精度和INT8两种低精度计算方式。

  • 自动混合精度

    自动混合精度指的是根据算子对精度的敏感度,自动将部分算子从单精度浮点(FP32)转换到半精度浮点(FP16)进行计算。在OneFlow-XRT中,只有TensorRT原生支持自动混合精度,XLA和OpenVINO都不支持,因此XLA和OpenVINO的自动混合精度依赖OneFlow插入cast算子来实现。

    对于任意engine,设置use_fp16=True可以开启自动混合精度,

    import oneflow as flow
    import oneflow_xrt as ofrt
    
    m = flow.nn.Linear(3, 4).to("cuda")
    # 设置use_fp16=True开启自动混合精度
    m = ofrt.XRTModule(m, engine=["tensorrt"], use_fp16=True)
  • INT8量化计算

    目前只有TensorRT engine支持INT8量化计算,TensorRT根据量化校准表将算子的输入和输出转换成INT8,并选择INT8的kernel来计算。OneFlow-XRT支持离线加载和在线生成量化校准表两种方式来启动INT8的量化计算。离线加载的方式需要用户提前生成一个TensorRT格式的量化校准表,该量化校准表通常可以被重复使用,而在线生成的方式则是在推理过程中,使用一部分数据进行量化校准表的生成(该过程使用正常精度计算),一旦校准表生成后,就会在下一个迭代中自动切换到INT8精度计算。

    • 生成离线量化校准表

      import oneflow as flow
      import oneflow_xrt as ofrt
      
      m = flow.nn.Linear(3, 4).to("cuda")
      m = ofrt.XRTModule(m, engine=["tensorrt"])
      # 进入量化校准表生成模式,
      # 量化校准表生成后将保存在./int8_calibration目录中
      with ofrt.ptq_calibration_mode(cache_path="./int8_calibration"):
          for calib_data in calibration_dataset():
              m(calib_data)
    • 加载离线量化表并进行INT8计算

      import oneflow as flow
      import oneflow_xrt as ofrt
      
      m = flow.nn.Linear(3, 4).to("cuda")
      # 设置use_int8=True开启INT8量化计算,同时指定量化校准表的路径
      m = ofrt.XRTModule(m, engine=["tensorrt"], use_int8=True, int8_calibration="./int8_calibration")
      # 下面为INT8计算模式
      for test_data in test_dataset():
          m(test_data)
    • 在线生成量化校准表并进行INT8计算

      import oneflow as flow
      import oneflow_xrt as ofrt
      
      m = flow.nn.Linear(3, 4).to("cuda")
      m = ofrt.XRTModule(m, engine=["tensorrt"])
      # 进入量化校准表生成模式,
      # 量化校准表生成后将cache在内存中
      with ofrt.ptq_calibration_mode():
          for calib_data in calibration_dataset():
              m(calib_data)
      # 下面为INT8计算模式
      for test_data in test_dataset():
          m(test_data)