Skip to content

Optimize Qwen Image VAE Performance Using TensorRT (Static & Multi-Profile)#909

Merged
helloyongyang merged 18 commits intoModelTC:mainfrom
fuheaven:qwen_text_encoder
Mar 5, 2026
Merged

Optimize Qwen Image VAE Performance Using TensorRT (Static & Multi-Profile)#909
helloyongyang merged 18 commits intoModelTC:mainfrom
fuheaven:qwen_text_encoder

Conversation

@fuheaven
Copy link
Contributor

@fuheaven fuheaven commented Feb 28, 2026

Description

通过引入 TensorRT 引擎来替换 PyTorch 原生算子,为 Qwen Image 模型的 VAE (Encoder / Decoder) 带来了显著的性能提升。针对 T2I(定尺寸)和 I2I(变尺寸)两种截然不同的任务类型,设计了双维度的 TRT 加速方案,并重构了底层的加载组件。

Key Features

  1. 统一底层的 TensorRT VAE 加载器 (vae_trt.py)

    • 使用统一的 trt_engine_path 入参与 vae_type: "tensorrt" 配置开关。
    • 支持完善的 PyTorch Fallback 机制:一旦环境探测失败、引擎文件缺失或预分配显存 OOM,会自动回退使用 PyTorch 的原生 VAE 算子执行,保证推理链路的业务安全与健壮性。
  2. T2I 场景:Static Shape 引擎 + 按需加载 (Lazy Load)

    • 因为 T2I 生成图像具有有限的固定比例,为每个分辨率预构建独立的静态引擎,完全消除动态执行开销。
    • 采用 按需加载 (Lazy Load) 策略:仅在当前分辨率首次请求时加载对应引擎对(~5GB 显存 / 对),切换分辨率时自动释放旧引擎、加载新引擎。相比全量加载(~25GB)大幅降低显存占用,兼容端到端推理场景。
  3. I2I 场景:Multi-Profile 动态引擎集成

    • 针对非受控的任意宽高输入,支持在一份引擎中包含 9 组经典的 Opt Shapes(包括 512x512, 1024x1024, 720p, 1080p 等)。
    • 推理时动态匹配最接近的 Profile 档位,确保 TensorRT 分配出最佳的内存布局与 Kernel 计算路径。
    • 引擎常驻显存,Encoder + Decoder 合计约 ~1.0-1.2 GB。
  4. 配套文档 (QwenImageVAETensorRT.md)

    • 新增 VAE TRT 优化的配置与最佳实践指南。
    • 含独立测试与端到端服务模式两组 benchmark 数据,以及性能差异的根因分析。

Performance Benchmark

实测数据来自 NVIDIA H100 (80GB) 单卡环境。

1. T2I Static Shape — 独立 VAE 测试

比例 PT Enc (ms) TRT Enc (ms) Enc 加速 PT Dec (ms) TRT Dec (ms) Dec 加速
16:9 66.53 32.70 2.03x 103.65 49.66 2.09x
9:16 65.72 32.22 2.04x 103.02 50.71 2.03x
1:1 78.16 41.95 1.86x 121.91 61.52 1.98x
4:3 73.99 37.23 1.99x 114.45 54.75 2.09x
3:4 31.74 17.33 1.83x 50.77 26.86 1.89x

Encoder ~1.95x, Decoder ~2.02x

2. T2I Static Shape — 端到端服务模式 (Qwen-Image-2512, 5 step, VAE Decoder)

T2I 无 VAE Encoder,仅统计 Decoder。

比例 PT Dec (ms) TRT Dec (ms) Dec 加速 首次加载 (ms)
16:9 189.3 88.4 2.14x 343.9
9:16 179.6 85.6 2.10x 226.4
1:1 157.6 106.2 1.48x 304.1
4:3 148.7 94.7 1.57x 238.0
3:4 70.4 46.1 1.53x 178.2

Decoder 平均 ~1.8x。「首次加载」为 Lazy Load 切换分辨率时的一次性开销,后续同分辨率请求不再产生。

3. I2I Multi-Profile — 独立 VAE 测试 (10 轮平均)

Encoder:

分辨率 PT Enc (ms) TRT Enc (ms) 加速
512x512 11.00 8.53 1.29x
1024x1024 42.85 27.56 1.55x
480p 16:9 17.25 12.00 1.44x
720p 16:9 38.00 25.35 1.50x
768p 4:3 31.98 21.76 1.47x

Encoder 平均 ~1.45x

Decoder:

分辨率 PT Dec (ms) TRT Dec (ms) 加速
512x512 17.60 12.78 1.38x
1024x1024 68.16 44.93 1.52x
480p 16:9 27.67 18.85 1.47x
720p 16:9 60.24 40.80 1.48x
768p 4:3 51.14 34.92 1.46x

Decoder 平均 ~1.46x。综合 ~1.45x

4. I2I Multi-Profile — 端到端服务模式 (qwen-image-edit-251130, 4 step)

分辨率 PT Enc → TRT Enc Enc 加速 PT Dec → TRT Dec Dec 加速
512x512 48.5 → 28.8 1.68x 138.4 → 134.0 1.03x
1024x1024 48.2 → 28.4 1.70x 152.7 → 133.3 1.15x
480p 16:9 48.7 → 29.6 1.64x 140.4 → 134.4 1.04x
720p 16:9 48.6 → 30.1 1.62x 139.0 → 134.2 1.04x
768p 4:3 49.2 → 29.8 1.65x 152.8 → 134.8 1.13x

Encoder ~1.66x, Decoder ~1.08x

Decoder 加速比低于独立测试是因为 postprocess(output_type="pil") 附加了 ~80-90ms 恒定 CPU 开销(tensor → PIL 转换),TRT 无法加速,数学上稀释了比值。TRT 引擎内核本身的加速效果应参考独立测试数据。


Changes Made

  • Refactored lightx2v/models/video_encoders/trt/qwen_image/vae_trt.py
    • Unified Static / Multi-Profile loading logic
    • Implemented Lazy Load for T2I static engines (auto load/release per resolution)
    • PyTorch fallback mechanism
  • Added T2I TRT config: configs/qwen_image/qwen_image_t2i_2512_trt.json
  • Added I2I TRT config: configs/qwen_image/qwen_image_i2i_2511_trt.json
  • Added shell scripts: scripts/qwen_image/qwen_image_t2i_2512_trt.sh, scripts/qwen_image/qwen_image_i2i_2511_trt.sh
  • Added Documentation: examples/BeginnerGuide/ZH_CN/QwenImageVAETensorRT.md, examples/BeginnerGuide/EN/QwenImageVAETensorRT.md

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

此拉取请求旨在通过为 Qwen Image 模型的 VAE 组件集成 TensorRT 加速,大幅提升其推理性能。它根据文本生成图像 (T2I) 和图像生成图像 (I2I) 任务的不同输入特性,采用了定制化的优化策略:T2I 使用静态 Shape 引擎以实现极致加速和显存优化,而 I2I 则采用 Multi-Profile 动态引擎以适应任意尺寸输入。此外,还引入了统一的 VAE 加载器,并内置了强大的 PyTorch 回退机制,确保系统在各种情况下都能稳定运行。所有这些改进都伴随着详细的性能基准测试和配置指南。

Highlights

  • VAE TensorRT 加速集成: 为 Qwen Image 模型的 VAE (Encoder / Decoder) 引入了 TensorRT 引擎,以显著提升性能。
  • 双维度加速策略: 针对 T2I(固定尺寸)和 I2I(可变尺寸)任务,分别设计了静态 Shape 引擎(按需加载)和 Multi-Profile 动态引擎(包含多个优化档位)。
  • 统一 VAE 加载器与 PyTorch 回退机制: 重构了底层的 VAE 加载组件 (vae_trt.py),支持统一的 trt_engine_path 入参,并实现了完善的 PyTorch 回退机制,确保业务的健壮性。
  • 性能基准测试: 在 NVIDIA H100 环境下,T2I 静态方案平均端到端综合加速比约 2.0x,I2I 动态方案整体平均端到端加速比约 1.6x。
  • 配套文档: 新增了中英文文档 (QwenImageVAETensorRT.md),详细说明了 VAE TRT 优化的配置与最佳实践。

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • configs/qwen_image/qwen_image_i2i_2511_trt.json
    • 新增了用于 I2I 任务的 TensorRT VAE Multi-Profile 配置。
  • configs/qwen_image/qwen_image_t2i_2512_trt.json
    • 新增了用于 T2I 任务的 TensorRT VAE 静态 Shape 配置。
  • examples/BeginnerGuide/EN/QwenImageVAETensorRT.md
    • 新增了 Qwen Image VAE TensorRT 加速的英文版详细文档。
  • examples/BeginnerGuide/ZH_CN/QwenImageVAETensorRT.md
    • 新增了 Qwen Image VAE TensorRT 加速的中文版详细文档。
  • lightx2v/models/runners/qwen_image/qwen_image_runner.py
    • 更新了 VAE 加载逻辑,以支持 TensorRT 集成和 PyTorch 回退。
    • 调整了 T2I 任务的图像 Shape 处理方式。
  • lightx2v/models/video_encoders/trt/qwen_image/vae_trt.py
    • 新增了 TensorRT VAE 的实现模块,支持静态和 Multi-Profile 引擎,并包含 PyTorch 回退逻辑。
  • lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py
    • 更新了 Flash Attention 实现,根据环境变量条件性地使用稀疏注意力。
    • apply 方法添加了 model_cls 参数。
  • scripts/qwen_image/qwen_image_i2i_2511_trt.sh
    • 新增了用于使用 TensorRT VAE 配置运行 I2I 推理的 shell 脚本。
  • scripts/qwen_image/qwen_image_t2i_2512_trt.sh
    • 新增了用于使用 TensorRT VAE 配置运行 T2I 推理的 shell 脚本。
Activity
  • 此拉取请求由 fuheaven 创建。
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

这个 PR 通过引入 TensorRT 引擎来优化 Qwen Image VAE 的性能,这是一个很棒的性能提升。针对 T2I 和 I2I 两种不同的任务场景,分别设计了 Static Shape 和 Multi-Profile 两种方案,考虑得很周全。代码重构和 fallback 机制的引入也增强了系统的健壮性。

审查中发现了一些可以改进的地方:

  • 文档中有几处关于分辨率的描述似乎与代码中的定义不一致,可能会误导用户。
  • vae_trt.py 中存在一些可以优化的性能点,比如在热点路径中重复创建张量。
  • flash_attn.py 中有一个小 bug,导致环境变量配置未生效。
  • t2i 的 shell 脚本与 i2i 的在环境变量导出上存在不一致。

具体的修改建议请见下面的评论。

q,
k,
v,
topk=0.4,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

此处 topk 参数被硬编码为 0.4,而没有使用从环境变量 SPARSE_ATTN_TOPK 中获取的 topk_value 值。这使得该环境变量配置无效。建议使用 topk_value

Suggested change
topk=0.4,
topk=topk_value,

Comment on lines +169 to +175
3. `16_9_480p` (480x848)
4. `16_9_720p` (720x1280)
5. `16_9_1080p` (1080x1920)
6. `9_16_720p` (1280x720)
7. `9_16_1080p` (1920x1080)
8. `4_3_768p` (768x1024)
9. `3_2_1080p` (1088x1620)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the Built-in Optimization Profiles list, the resolution (WxH) for some profiles appears to be swapped, and some values do not exactly match the code (in PROFILE_CONFIGS in lightx2v/models/video_encoders/trt/qwen_image/vae_trt.py). This could be confusing for users.

For example:

  • 16_9_720p is height=720, width=1280 in the code, so it should be (1280x720) in the documentation, but it is currently (720x1280).
  • 16_9_1080p is height=1088, width=1920 in the code, so it should be (1920x1088) in the documentation, but it is currently (1080x1920), and the height value is also incorrect.

It is recommended to update this list according to the PROFILE_CONFIGS definition in vae_trt.py to ensure consistency between the documentation and the code implementation.

Suggested change
3. `16_9_480p` (480x848)
4. `16_9_720p` (720x1280)
5. `16_9_1080p` (1080x1920)
6. `9_16_720p` (1280x720)
7. `9_16_1080p` (1920x1080)
8. `4_3_768p` (768x1024)
9. `3_2_1080p` (1088x1620)
3. `16_9_480p` (848x480)
4. `16_9_720p` (1280x720)
5. `16_9_1080p` (1920x1088)
6. `9_16_720p` (720x1280)
7. `9_16_1080p` (1088x1920)
8. `4_3_768p` (1024x768)
9. `3_2_1080p` (1620x1088)

Comment on lines +169 to +175
3. `16_9_480p` (480x848)
4. `16_9_720p` (720x1280)
5. `16_9_1080p` (1080x1920)
6. `9_16_720p` (1280x720)
7. `9_16_1080p` (1920x1080)
8. `4_3_768p` (768x1024)
9. `3_2_1080p` (1088x1620)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

你好,中文文档的 内置包含的 Optimization Profile 档位与尺寸 列表这里,和英文文档一样,存在分辨率 (WxH) 写反了以及数值与代码 (lightx2v/models/video_encoders/trt/qwen_image/vae_trt.py 中的 PROFILE_CONFIGS) 不一致的问题。

为了避免用户混淆,建议根据代码实现来更正此处的尺寸说明。

Suggested change
3. `16_9_480p` (480x848)
4. `16_9_720p` (720x1280)
5. `16_9_1080p` (1080x1920)
6. `9_16_720p` (1280x720)
7. `9_16_1080p` (1920x1080)
8. `4_3_768p` (768x1024)
9. `3_2_1080p` (1088x1620)
3. `16_9_480p` (848x480)
4. `16_9_720p` (1280x720)
5. `16_9_1080p` (1920x1088)
6. `9_16_720p` (720x1280)
7. `9_16_1080p` (1088x1920)
8. `4_3_768p` (1024x768)
9. `3_2_1080p` (1620x1088)

Comment on lines +67 to +68
self.vae_latents_mean = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921]
self.vae_latents_std = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.916]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

为了提升性能,建议将这些归一化常量在 __init__ 中一次性初始化为张量,而不是在 encode_vae_imagedecode 方法中每次调用时都从 Python 列表重新创建。这样可以避免不必要的开销。

Suggested change
self.vae_latents_mean = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921]
self.vae_latents_std = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.916]
self.vae_latents_mean = torch.tensor([-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921], device=self.device).view(1, self.latent_channels, 1, 1, 1)
self.vae_latents_std = torch.tensor([2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.916], device=self.device).view(1, self.latent_channels, 1, 1, 1)

Comment on lines +338 to +340
latents_mean = torch.tensor(self.vae_latents_mean).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_std = torch.tensor(self.vae_latents_std).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
image_latents = (image_latents - latents_mean) / latents_std
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

使用在 __init__ 中预先创建的张量进行归一化操作,以提升性能。

Suggested change
latents_mean = torch.tensor(self.vae_latents_mean).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_std = torch.tensor(self.vae_latents_std).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
image_latents = (image_latents - latents_mean) / latents_std
image_latents = (image_latents - self.vae_latents_mean.to(image_latents.device, image_latents.dtype)) / self.vae_latents_std.to(image_latents.device, image_latents.dtype)

Comment on lines +363 to +365
latents_mean = torch.tensor(self.vae_latents_mean).view(1, self.latent_channels, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / torch.tensor(self.vae_latents_std).view(1, self.latent_channels, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

使用在 __init__ 中预先创建的张量进行反归一化操作,以提升性能。

Suggested change
latents_mean = torch.tensor(self.vae_latents_mean).view(1, self.latent_channels, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / torch.tensor(self.vae_latents_std).view(1, self.latent_channels, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
latents = latents * self.vae_latents_std.to(latents.device, latents.dtype) + self.vae_latents_mean.to(latents.device, latents.dtype)

Comment on lines +4 to +5
lightx2v_path=
model_path=
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

为了与 qwen_image_i2i_2511_trt.sh 脚本保持一致,并确保环境变量能被子进程正确继承,建议在此处使用 export 关键字导出 lightx2v_pathmodel_path

Suggested change
lightx2v_path=
model_path=
export lightx2v_path=
export model_path=

@helloyongyang helloyongyang merged commit e06f5fa into ModelTC:main Mar 5, 2026
1 check passed
helloyongyang pushed a commit that referenced this pull request Mar 6, 2026
…ofile) (#909)

## Description

通过引入 TensorRT 引擎来替换 PyTorch 原生算子,为 Qwen Image 模型的 VAE (Encoder /
Decoder) 带来了显著的性能提升。针对 T2I(定尺寸)和 I2I(变尺寸)两种截然不同的任务类型,设计了双维度的 TRT
加速方案,并重构了底层的加载组件。

### Key Features

1. **统一底层的 TensorRT VAE 加载器 (`vae_trt.py`)**:
   - 使用统一的 `trt_engine_path` 入参与 `vae_type: "tensorrt"` 配置开关。
- 支持完善的 **PyTorch Fallback 机制**:一旦环境探测失败、引擎文件缺失或预分配显存 OOM,会自动回退使用
PyTorch 的原生 VAE 算子执行,保证推理链路的业务安全与健壮性。

2. **T2I 场景:Static Shape 引擎 + 按需加载 (Lazy Load)**
   - 因为 T2I 生成图像具有有限的固定比例,为每个分辨率预构建独立的静态引擎,完全消除动态执行开销。
- 采用 **按需加载 (Lazy Load)** 策略:仅在当前分辨率首次请求时加载对应引擎对(~5GB 显存 /
对),切换分辨率时自动释放旧引擎、加载新引擎。相比全量加载(~25GB)大幅降低显存占用,兼容端到端推理场景。

3. **I2I 场景:Multi-Profile 动态引擎集成**
- 针对非受控的任意宽高输入,支持在一份引擎中包含 9 组经典的 Opt Shapes(包括 512x512, 1024x1024, 720p,
1080p 等)。
   - 推理时动态匹配最接近的 Profile 档位,确保 TensorRT 分配出最佳的内存布局与 Kernel 计算路径。
   - 引擎常驻显存,Encoder + Decoder 合计约 ~1.0-1.2 GB。

4. **配套文档 (`QwenImageVAETensorRT.md`)**
   - 新增 VAE TRT 优化的配置与最佳实践指南。
   - 含独立测试与端到端服务模式两组 benchmark 数据,以及性能差异的根因分析。

---

## Performance Benchmark

实测数据来自 NVIDIA H100 (80GB) 单卡环境。

### 1. T2I Static Shape — 独立 VAE 测试

| 比例 | PT Enc (ms) | TRT Enc (ms) | Enc 加速 | PT Dec (ms) | TRT Dec (ms)
| Dec 加速 |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| 16:9 | 66.53 | **32.70** | **2.03x** | 103.65 | **49.66** | **2.09x**
|
| 9:16 | 65.72 | **32.22** | **2.04x** | 103.02 | **50.71** | **2.03x**
|
| 1:1 | 78.16 | **41.95** | **1.86x** | 121.91 | **61.52** | **1.98x** |
| 4:3 | 73.99 | **37.23** | **1.99x** | 114.45 | **54.75** | **2.09x** |
| 3:4 | 31.74 | **17.33** | **1.83x** | 50.77 | **26.86** | **1.89x** |

> **Encoder ~1.95x, Decoder ~2.02x**

### 2. T2I Static Shape — 端到端服务模式 (Qwen-Image-2512, 5 step, VAE Decoder)

> T2I 无 VAE Encoder,仅统计 Decoder。

| 比例 | PT Dec (ms) | TRT Dec (ms) | Dec 加速 | 首次加载 (ms) |
| :---: | :---: | :---: | :---: | :---: |
| 16:9 | 189.3 | **88.4** | **2.14x** | 343.9 |
| 9:16 | 179.6 | **85.6** | **2.10x** | 226.4 |
| 1:1 | 157.6 | **106.2** | **1.48x** | 304.1 |
| 4:3 | 148.7 | **94.7** | **1.57x** | 238.0 |
| 3:4 | 70.4 | **46.1** | **1.53x** | 178.2 |

> **Decoder 平均 ~1.8x**。「首次加载」为 Lazy Load 切换分辨率时的一次性开销,后续同分辨率请求不再产生。

### 3. I2I Multi-Profile — 独立 VAE 测试 (10 轮平均)

**Encoder**:

| 分辨率 | PT Enc (ms) | TRT Enc (ms) | 加速 |
| :---: | :---: | :---: | :---: |
| 512x512 | 11.00 | **8.53** | **1.29x** |
| 1024x1024 | 42.85 | **27.56** | **1.55x** |
| 480p 16:9 | 17.25 | **12.00** | **1.44x** |
| 720p 16:9 | 38.00 | **25.35** | **1.50x** |
| 768p 4:3 | 31.98 | **21.76** | **1.47x** |

> **Encoder 平均 ~1.45x**

**Decoder**:

| 分辨率 | PT Dec (ms) | TRT Dec (ms) | 加速 |
| :---: | :---: | :---: | :---: |
| 512x512 | 17.60 | **12.78** | **1.38x** |
| 1024x1024 | 68.16 | **44.93** | **1.52x** |
| 480p 16:9 | 27.67 | **18.85** | **1.47x** |
| 720p 16:9 | 60.24 | **40.80** | **1.48x** |
| 768p 4:3 | 51.14 | **34.92** | **1.46x** |

> **Decoder 平均 ~1.46x。综合 ~1.45x**

### 4. I2I Multi-Profile — 端到端服务模式 (qwen-image-edit-251130, 4 step)

| 分辨率 | PT Enc → TRT Enc | Enc 加速 | PT Dec → TRT Dec | Dec 加速 |
| :---: | :---: | :---: | :---: | :---: |
| 512x512 | 48.5 → **28.8** | **1.68x** | 138.4 → **134.0** | **1.03x**
|
| 1024x1024 | 48.2 → **28.4** | **1.70x** | 152.7 → **133.3** |
**1.15x** |
| 480p 16:9 | 48.7 → **29.6** | **1.64x** | 140.4 → **134.4** |
**1.04x** |
| 720p 16:9 | 48.6 → **30.1** | **1.62x** | 139.0 → **134.2** |
**1.04x** |
| 768p 4:3 | 49.2 → **29.8** | **1.65x** | 152.8 → **134.8** | **1.13x**
|

> **Encoder ~1.66x, Decoder ~1.08x**
>
> Decoder 加速比低于独立测试是因为 `postprocess(output_type="pil")` 附加了 ~80-90ms 恒定
CPU 开销(tensor → PIL 转换),TRT 无法加速,数学上稀释了比值。TRT 引擎内核本身的加速效果应参考独立测试数据。

---

## Changes Made

- Refactored `lightx2v/models/video_encoders/trt/qwen_image/vae_trt.py`
  - Unified Static / Multi-Profile loading logic
- Implemented Lazy Load for T2I static engines (auto load/release per
resolution)
  - PyTorch fallback mechanism
- Added T2I TRT config:
`configs/qwen_image/qwen_image_t2i_2512_trt.json`
- Added I2I TRT config:
`configs/qwen_image/qwen_image_i2i_2511_trt.json`
- Added shell scripts: `scripts/qwen_image/qwen_image_t2i_2512_trt.sh`,
`scripts/qwen_image/qwen_image_i2i_2511_trt.sh`
- Added Documentation:
`examples/BeginnerGuide/ZH_CN/QwenImageVAETensorRT.md`,
`examples/BeginnerGuide/EN/QwenImageVAETensorRT.md`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants