# Step 0 — Specify config  and build TensorRT engines

Before running CAICE inference, please **choose a component-level precision** and **build TensorRT engines** from your FP32 ONNX.
### 1) Specify component precision 

Use the format:

* `ga-fp8`, `ga-fp16`
* `gs-fp8`, `gs-fp16`
* `ha-fp8`, `ha-fp16`
* `hs-fp8`, `hs-fp16`

default configs:

* `ga-fp8,gs-fp16,ha-fp8(if exist),hs-fp8(if exist)`

> Note: For FP8, the build script will quantize the corresponding sub-ONNX (Q/DQ) and then build a **strongly-typed** TensorRT engine.

### 2) Define component boundaries (`config.json`)

After choosing the precision, you must define **graph boundaries** for each component in `boundaries.config`.
The build script uses these boundaries to **extract sub-graphs** (g_a / g_s / h_a / h_s) from the full FP32 ONNX before quantizing and building engines.

### What to provide for each component

For every component, specify:

* `inputs`: a list of **input tensor names** in the ONNX graph
* `outputs`: a list of **output tensor names** in the ONNX graph

Example (JSON):

```json
{
  "ga": { "inputs": ["input"], "outputs": ["/g_a/g_a.6/Conv_output_0"] },
  "ha": { "inputs": ["/g_a/g_a.6/Conv_output_0"], "outputs": ["<ha_out_tensor>"] },
  "hs": { "inputs": ["<hs_in_tensor>"], "outputs": ["/entropy_bottleneck/Transpose_1_output_0"] },
  "gs": { "inputs": ["/entropy_bottleneck/Transpose_1_output_0"], "outputs": ["output"] }
}
```

### Tips

* Tensor names must match **exactly** what appears in the exported ONNX (case-sensitive).
* You can inspect tensor names using **Netron** or by printing ONNX graph I/O names in Python.
* If you only plan to accelerate a subset of components, you can still define all boundaries now and only build engines for the components listed in your precision config.


### 3) Prepare calibration data and input shapes

For **FP8 components**, calibration data is required to determine quantization scales.

You must also specify the **exact input shape** used to build engines, since TensorRT engines are shape-specific.


In [2]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"  # 定位卡点，确认后可去掉

import torch
from compressai.zoo import bmshj2018_hyperprior, mbt2018_mean

quality = 1
project_dir = "/hwj"
device = "cuda:0"
model_name = "mbt2018-mean"

# load model (CPU -> GPU)
model = mbt2018_mean(quality=quality, pretrained=False)
state = torch.load(f"{project_dir}/data/model/{model_name}-{quality}.pth", map_location="cpu")
model.load_state_dict(state)
model = model.to(device).eval()

print("Load model done.")

class Wrap(torch.nn.Module):
    def __init__(self, m):
        super().__init__()
        self.m = m

    # # bmshj2018_hyperprior
    # def forward(self, x):
    #     y = self.m.g_a(x)
    #     z = self.m.h_a(torch.abs(y))
    #     z_hat, _ = self.m.entropy_bottleneck(z)
    #     scales_hat = self.m.h_s(z_hat)
    #     y_hat, _ = self.m.gaussian_conditional(y, scales_hat)
    #     x_hat = self.m.g_s(y_hat)
    #     return x_hat, scales_hat
    # mbt2018_mean
    def forward(self, x):
        y = self.m.g_a(x)
        z = self.m.h_a(y)
        z_hat, _ = self.m.entropy_bottleneck(z)
        gaussian_params = self.m.h_s(z_hat)
        scales_hat, means_hat = gaussian_params.chunk(2, 1)
        y_hat, _ = self.m.gaussian_conditional(y, scales_hat, means=means_hat)
        x_hat = self.m.g_s(y_hat)
        return x_hat, scales_hat

wrapped = Wrap(model).eval()

dummy_input = torch.randn((512, 3, 128, 128), device=device, dtype=torch.float32)
print("Input data prepared.")

# 先确认 forward 是否真的能跑完
with torch.no_grad():
    print("Running warmup forward...")
    _ = wrapped(dummy_input)
    torch.cuda.synchronize()
print("Warmup done.")

onnx_path = f"{project_dir}/data/model/onnx/{model_name}-{quality}-f32.onnx"
print("Exporting to ONNX...")

with torch.no_grad():
    torch.onnx.export(
        wrapped,
        dummy_input,
        onnx_path,
        input_names=["input"],
        output_names=["x_hat", "scales_hat"],
        opset_version=17,
        # do_constant_folding=False,
    )
    
model = model.eval().to(device).to(torch.float16)
dummy_input = torch.randn((512, 3, 128, 128), device=device, dtype=torch.float16)
onnx_path = f"{project_dir}/data/model/onnx/{model_name}-{quality}-f16.onnx"
wrapped = Wrap(model).eval()

with torch.no_grad():
    torch.onnx.export(
        wrapped,
        dummy_input,
        onnx_path,
        input_names=["input"],
        output_names=["x_hat", "scales_hat"],
        opset_version=17,
        # do_constant_folding=False,
    )

print("Export done:", onnx_path)




Load model done.
Input data prepared.
Running warmup forward...
Warmup done.
Exporting to ONNX...


  torch.onnx.export(
  torch.onnx.export(


Export done: /hwj/data/model/onnx/mbt2018-mean-1-f16.onnx


In [4]:
!python ./utils/build_engines.py \
--onnx_fp32 /hwj/data/model/onnx/mbt2018-mean-1-f32.onnx \
--onnx_fp16 /hwj/data/model/onnx/mbt2018-mean-1-f16.onnx \
--input_shape 512,3,128,128 \
--config ga-fp16,gs-fp16,ha-fp16,hs-fp16 \
--boundaries /hwj/project/CompressAI-Science/examples/config-mbt2018-mean-q1.json \
--calib_npy /hwj/project/aiz-accelerate/data/nyx-dark_matter_density.npy \
--out_dir /hwj/project/CompressAI-Science/examples/out_engines \
--model_tag mbt2018-mean-q1 \
--max_calib_samples 512 \
--prefer_cuda_ort

[OK] Extracted ga (fp16): /hwj/project/CompressAI-Science/examples/out_engines/subonnx/mbt2018-mean-q1/ga_fp16.onnx
[OK] Extracted ha (fp16): /hwj/project/CompressAI-Science/examples/out_engines/subonnx/mbt2018-mean-q1/ha_fp16.onnx
[OK] Extracted hs (fp16): /hwj/project/CompressAI-Science/examples/out_engines/subonnx/mbt2018-mean-q1/hs_fp16.onnx
[OK] Extracted gs (fp16): /hwj/project/CompressAI-Science/examples/out_engines/subonnx/mbt2018-mean-q1/gs_fp16.onnx
[Shape] ha fixed input shape = (512, 192, 8, 8) (reuse calib file)
[Shape] hs fixed input shape = (512, 128, 2, 2) (reuse calib file)
[Shape] gs fixed input shape = (512, 192, 8, 8) (reuse calib file)
[FixShape] ga_fp16.onnx input=input -> (512, 3, 128, 128)
[FixShape] ha_fp16.onnx input=/g_a/g_a.6/Conv_output_0 -> (512, 192, 8, 8)
[FixShape] hs_fp16.onnx input=/entropy_bottleneck/Transpose_1_output_0 -> (512, 128, 2, 2)
[FixShape] gs_fp16.onnx input=/gaussian_conditional/Add_output_0 -> (512, 192, 8, 8)

[Engine] Building ga engi

# Step 1 — Run Benchmark

In [1]:
import numpy as np
import torch
from compressai.zoo import bmshj2018_factorized, bmshj2018_hyperprior, mbt2018_mean
from compressai.runtime import build_runtime
from compressai.runtime.config import RuntimeConfig
from compressai.runtime.codecs import GpuPackedEntropyCodec
from compressai.runtime.utils.benchmark import run_e2e

device = "cuda:0"

# 1) load net
net = mbt2018_mean(quality=1, pretrained=False).to(device).eval()
state = torch.load("/hwj/data/model/mbt2018-mean-1.pth", map_location=device)
net.load_state_dict(state)

# 2) codec (in runtime)
codec = GpuPackedEntropyCodec(
    net.entropy_bottleneck,
    gaussian_conditional=net.gaussian_conditional,
    P=12
)

# 3) runtime (TRT, dtype auto-infer)
cfg = RuntimeConfig(
    model_name="mbt2018_mean",
    ga_input_dtype=torch.float16,
    gs_input_dtype=torch.float16,
    ha_input_dtype=torch.float16,
    hs_input_dtype=torch.float16,
    codec_input_dtype=torch.float32,
    trt_engines={
        "ga": "/hwj/project/CompressAI-Science/examples/out_engines/engines/mbt2018-mean-q1/ga/fp16.engine",
        "gs": "/hwj/project/CompressAI-Science/examples/out_engines/engines/mbt2018-mean-q1/gs/fp16.engine",
        "ha": "/hwj/project/CompressAI-Science/examples/out_engines/engines/mbt2018-mean-q1/ha/fp16.engine",
        "hs": "/hwj/project/CompressAI-Science/examples/out_engines/engines/mbt2018-mean-q1/hs/fp16.engine",
    },
)
engine = build_runtime(net, codec, cfg)

# 4) data
arr = np.load("/hwj/project/aiz-accelerate/data/nyx-dark_matter_density.npy")
x = torch.from_numpy(arr).float().to(device)

# 5) benchmark (auto stream)
stats, x_hat, x = run_e2e(engine, codec, x, warmup=5, iters=10)
stats


[01/14/2026-13:55:23] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
[01/14/2026-13:55:23] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
[01/14/2026-13:55:23] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
[01/14/2026-13:55:23] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
strings_bytes_list: [378056.0]


{'input_bytes': 100663296.0,
 'enc_ms': 13.131430435180665,
 'dec_ms': 15.05866231918335,
 'enc_GBps': 7.1393593000220745,
 'dec_GBps': 6.225652585394064,
 'strings_bytes': 378056.0,
 'state_bytes': 104.0,
 'total_bytes': 378160.0,
 'bpp_strings': 0.36054229736328125,
 'bpp_total': 0.3606414794921875,
 'cr_strings': 266.26556912203483,
 'cr_total': 266.19234186587687,
 'rmse': 0.09700117260217667,
 'nrmse': 0.09738306701183319,
 'maxe': 0.8133085370063782,
 'psnr': 20.230331420898438}

In [4]:
print("x:", x.min().item(), x.max().item(), torch.isnan(x).any().item(), torch.isinf(x).any().item())
pack = engine.compress(x)
x_hat = engine.decompress(pack)
print("x_hat:", x_hat.min().item(), x_hat.max().item(), torch.isnan(x_hat).any().item(), torch.isinf(x_hat).any().item())


x: 0.0 0.9960784316062927 False False
[01/13/2026-14:15:48] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
[01/13/2026-14:15:48] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
x_hat: -0.1768798828125 1.21875 False False


In [6]:
for i in range(5):
    pack = engine.compress(x)
    x_hat = engine.decompress(pack)
    
print("x_hat:", x_hat.min().item(), x_hat.max().item(), torch.isnan(x_hat).any().item(), torch.isinf(x_hat).any().item())

x_hat: -0.1768798828125 1.21875 False False
