[Qwen3.5][Feature] support fla triton kernel for qwen3.5#7024
[Qwen3.5][Feature] support fla triton kernel for qwen3.5#7024wanderHZ wants to merge 4 commits intoPaddlePaddle:developfrom
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
该 PR 为 Qwen3.5 中的 GatedDeltaNet Attention 引入基于 Triton 的 FLA(Flash Linear Attention)推理内核实现,并补充对应单测,旨在为 Prefill/Decode 两条路径提供高性能的 GDN SSM + causal conv1d 计算。
Changes:
- 新增
fastdeploy/model_executor/ops/triton_ops/fla/FLA Triton kernel 包:涵盖 chunked prefill(WY 6-step)与 fused recurrent decode 两条路径的核心算子与索引/工具函数。 - 新增
causal_conv1d.py:提供 Prefill(varlen) 与 Decode(single-token, pool-index) 的 Triton causal conv1d 实现。 - 新增
tests/model_executor/ops/triton_ops/test_gdn_kernels.py:对 GDN recurrent/chunk 与 causal conv1d 的 kernel 输出做 baseline 对齐验证。
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/model_executor/ops/triton_ops/test_gdn_kernels.py | 新增 GDN/conv1d Triton kernel 的正确性对齐单测与纯 Paddle baseline。 |
| fastdeploy/model_executor/ops/triton_ops/fla/init.py | 导出 FLA kernel 包的 public API。 |
| fastdeploy/model_executor/ops/triton_ops/fla/utils.py | 提供 input_guard(contiguous)与简单 tensor_cache,以及 Triton 环境能力探测。 |
| fastdeploy/model_executor/ops/triton_ops/fla/op.py | Triton 侧基础数学/安全函数与 gather 能力适配。 |
| fastdeploy/model_executor/ops/triton_ops/fla/index.py | varlen chunk 索引/offset 生成工具。 |
| fastdeploy/model_executor/ops/triton_ops/fla/cumsum.py | chunk-local cumsum Triton kernel(标量/向量)。 |
| fastdeploy/model_executor/ops/triton_ops/fla/l2norm.py | L2Norm Triton kernel(推理用 forward)。 |
| fastdeploy/model_executor/ops/triton_ops/fla/chunk_scaled_dot_kkt.py | chunk 内 beta·K·Kᵀ 计算 Triton kernel。 |
| fastdeploy/model_executor/ops/triton_ops/fla/solve_tril.py | (I + A)^{-1} 下三角逆的 Triton 分块实现。 |
| fastdeploy/model_executor/ops/triton_ops/fla/wy_fast.py | WY 分解中 W/U 重计算 Triton kernel + wrapper。 |
| fastdeploy/model_executor/ops/triton_ops/fla/chunk_delta_h.py | chunk 间状态传播 Triton kernel + wrapper。 |
| fastdeploy/model_executor/ops/triton_ops/fla/chunk_o.py | chunk 输出计算 Triton kernel + wrapper。 |
| fastdeploy/model_executor/ops/triton_ops/fla/fused_recurrent.py | Decode 路径 fused recurrent kernel(标准接口与 pool-index 接口)。 |
| fastdeploy/model_executor/ops/triton_ops/fla/chunk.py | Prefill 路径 6-step chunked WY 算法编排与 public API。 |
| fastdeploy/model_executor/ops/triton_ops/causal_conv1d.py | causal conv1d 的 Triton Prefill/Decode 两个接口实现。 |
| beta: Optional[paddle.Tensor] = None, | ||
| scale: Optional[float] = None, | ||
| ssm_pool: Optional[paddle.Tensor] = None, | ||
| ssm_indices: Optional[paddle.Tensor] = None, | ||
| cu_seqlens: Optional[paddle.Tensor] = None, |
There was a problem hiding this comment.
fused_recurrent_gated_delta_rule_update 将 ssm_pool/ssm_indices 设为 Optional 且默认 None,但后续实现会把它们直接传入 Triton kernel(kernel 内会 tl.load(h0_indices) 并做指针写回)。若调用方未显式传这两个参数,会触发运行时崩溃而不是可读异常。建议在该 public API 入口显式检查 ssm_pool/ssm_indices 非空且形状匹配,否则 raise ValueError。
| Returns: | ||
| x: [dim, total_tokens] (channel-last layout) | ||
| weight: [dim, kernel_width] | ||
| bias: [dim,] | ||
| conv_pool: [max_seqs, dim, state_len] |
There was a problem hiding this comment.
_make_varlen_inputs 的文档写的是 “channel-last layout”,但实际 x 的形状是 [dim, total_tokens](dim 在前,更像 channel-first)。建议统一表述,避免后续按错误布局理解并在调用 causal_conv1d_fn 时传错 shape/stride。
| _initial_state = paddle.zeros([B, H, K, V], dtype=k.dtype) | ||
| _initial_state_indices = paddle.arange(B, dtype=paddle.int32) |
There was a problem hiding this comment.
在 varlen 模式(cu_seqlens != None)且 initial_state=None 时,这里用 B 来构造 dummy initial_state / initial_state_indices:_initial_state = zeros([B, H, K, V])、_initial_state_indices = arange(B)。但 chunk_gated_delta_rule_fwd_h 内部会按 N=cu_seqlens.shape[0]-1 启动 grid,并对 initial_state_indices[i_n] 做 tl.load;当 N>1 且 B==1 时会发生越界读取/写入,导致结果错误或非法内存访问。建议按 N 构造 dummy(shape=[N, H, K, V] 且 indices=arange(N)),或在 varlen+initial_state=None 时直接显式报错并要求 caller 传入 state/indices。
| _initial_state = paddle.zeros([B, H, K, V], dtype=k.dtype) | |
| _initial_state_indices = paddle.arange(B, dtype=paddle.int32) | |
| if cu_seqlens is not None: | |
| # varlen mode: grid size is N = cu_seqlens.shape[0] - 1, so | |
| # dummy initial_state/indices must be sized by N instead of B. | |
| N = cu_seqlens.shape[0] - 1 | |
| _initial_state = paddle.zeros([N, H, K, V], dtype=k.dtype) | |
| _initial_state_indices = paddle.arange(N, dtype=paddle.int32) | |
| else: | |
| _initial_state = paddle.zeros([B, H, K, V], dtype=k.dtype) | |
| _initial_state_indices = paddle.arange(B, dtype=paddle.int32) |
| if scale is None: | ||
| scale = k.shape[-1] ** -0.5 | ||
| if beta is None: | ||
| beta = paddle.ones(q.shape[:-1], dtype=q.dtype) # [B, T, HV] |
There was a problem hiding this comment.
beta is None 时默认值用的是 paddle.ones(q.shape[:-1]),其形状是 [B, T, H]。但该接口文档/Kernel 实际期望 beta 为 [B, T, HV](HV 可能大于 H,用于 GQA/GVA 等场景),Kernel 内部也按 HV 进行指针步进。若 HV!=H 会导致读取越界或计算错误。建议按 v 的 head 维生成默认 beta(例如 [B, T, HV]),并最好加上形状断言(beta.shape[2]==HV)。
| beta = paddle.ones(q.shape[:-1], dtype=q.dtype) # [B, T, HV] | |
| # When beta is not provided, create an all-ones tensor with shape [B, T, HV] | |
| # HV is derived from v to properly support HV != H (e.g. GQA/GVA scenarios). | |
| beta = paddle.ones(v.shape[:3], dtype=v.dtype) | |
| else: | |
| # Validate that beta matches [B, T, HV] derived from v to avoid kernel shape mismatch | |
| if ( | |
| beta.shape[0] != v.shape[0] | |
| or beta.shape[1] != v.shape[1] | |
| or beta.shape[2] != v.shape[2] | |
| ): | |
| raise ValueError( | |
| f"beta must have shape [B, T, HV] matching v, but got " | |
| f"beta.shape={beta.shape}, v.shape[:3]={v.shape[:3]}" | |
| ) |
| if scale is None: | ||
| scale = k.shape[-1] ** -0.5 | ||
| if beta is None: | ||
| beta = paddle.ones(q.shape[:-1], dtype=q.dtype) |
There was a problem hiding this comment.
beta is None 时默认值用的是 paddle.ones(q.shape[:-1])(形状 [B, T, H]),但 update kernel 期望 beta 为 [B, T, HV](按 HV 做指针步进)。在 HV!=H(如 GVA/GQA)时会导致读取越界或计算错误。建议用 v 的 head 维生成默认 beta,并加上 beta.shape[2]==HV 的断言。
| beta = paddle.ones(q.shape[:-1], dtype=q.dtype) | |
| # Default beta should match v's head dimension: [B, T, HV] | |
| beta = paddle.ones(v.shape[:-1], dtype=v.dtype) | |
| else: | |
| # Validate beta shape to prevent out-of-bounds access in the kernel | |
| assert beta.ndim == 3, "beta must be 3D tensor of shape [B, T, HV]" | |
| assert beta.shape == v.shape[:-1], ( | |
| f"beta shape {beta.shape} must match v.shape[:-1] {v.shape[:-1]} " | |
| "for fused_recurrent_gated_delta_rule_update" | |
| ) |
| def fused_recurrent_gated_delta_rule_update( | ||
| q: paddle.Tensor, | ||
| k: paddle.Tensor, | ||
| v: paddle.Tensor, | ||
| g: paddle.Tensor, | ||
| beta: Optional[paddle.Tensor] = None, | ||
| scale: Optional[float] = None, | ||
| ssm_pool: Optional[paddle.Tensor] = None, | ||
| ssm_indices: Optional[paddle.Tensor] = None, | ||
| cu_seqlens: Optional[paddle.Tensor] = None, | ||
| use_qk_l2norm_in_kernel: bool = False, | ||
| disable_state_update: bool = False, | ||
| ) -> paddle.Tensor: |
There was a problem hiding this comment.
fused_recurrent_gated_delta_rule_update 对外接口允许 ssm_pool/ssm_indices 传 None(默认值也是 None),但随即无条件传入 Triton kernel;这会在运行时报错且错误信息不直观。建议在进入 fused_recurrent_gated_delta_rule_update_fwd 之前显式校验 ssm_pool、ssm_indices 非空且 shape/dtype 合法(并在 varlen 模式下校验 N 与 cu_seqlens 对齐)。
| def setUp(self): | ||
| paddle.seed(42) | ||
| self.dtype = paddle.bfloat16 | ||
| self.B, self.T = 2, 8 | ||
| self.H, self.K, self.V = 4, 64, 64 | ||
|
|
There was a problem hiding this comment.
该测试文件直接运行 Triton kernel,但没有在 setUp 中检查 CUDA/Triton 环境并在不可用时 skip;在 CPU-only 或未安装 Triton 的测试环境会直接失败。仓库里其他 GPU 测试通常用 if not paddle.is_compiled_with_cuda(): self.skipTest(...) 并 paddle.set_device('gpu')。建议在各 TestCase 的 setUp 中补齐这些 guard,避免环境相关的误报失败。
| Returns: | ||
| x: [dim, total_tokens] (channel-last layout) | ||
| weight: [dim, kernel_width] | ||
| bias: [dim,] | ||
| conv_pool: [max_seqs, dim, state_len] | ||
| slot_ids: [N] | ||
| has_initial_state: [N] bool | ||
| query_start_loc: [N+1] | ||
| seq_lens_cpu: List[int] | ||
| """ | ||
| dim, width, state_len = self.dim, self.kernel_width, self.state_len | ||
| N = len(seq_lens) | ||
| total = sum(seq_lens) | ||
| # channel-last: (dim, total_tokens) | ||
| x = paddle.randn([dim, total], dtype=paddle.float32).cast(self.dtype) | ||
| weight = paddle.randn([dim, width], dtype=paddle.float32).cast(self.dtype) |
There was a problem hiding this comment.
注释/文档存在布局描述不一致:这里 x 实际构造为 [dim, total_tokens](channel-first),但注释写的是 “channel-last layout”。建议统一描述,避免后续调用方按错误布局传参。
|
|
||
| def input_guard(fn: Callable) -> Callable: | ||
| """ | ||
| Ensure all input Tensors are contiguous and run on the correct CUDA device. |
There was a problem hiding this comment.
input_guard 的 docstring 写到“run on the correct CUDA device”,但 wrapper 实际只做了 contiguous,并未做 device 校验/切换。建议更新 docstring 以匹配真实行为,或补充必要的设备一致性检查,避免误导调用方。
| Ensure all input Tensors are contiguous and run on the correct CUDA device. | |
| Ensure all input Tensors are contiguous. |
| # The kernel always loads initial_state_indices even when USE_INITIAL_STATE=False, | ||
| # so dummy values are needed to avoid NoneType errors when initial_state is None. | ||
| B, T, H, K = k.shape | ||
| V = u.shape[-1] | ||
| _initial_state = initial_state | ||
| _initial_state_indices = initial_state_indices | ||
| if _initial_state is None: | ||
| # dummy: zero state, indices pointing to slot 0 | ||
| _initial_state = paddle.zeros([B, H, K, V], dtype=k.dtype) | ||
| _initial_state_indices = paddle.arange(B, dtype=paddle.int32) | ||
| h, v_new = chunk_gated_delta_rule_fwd_h( |
There was a problem hiding this comment.
这里在 initial_state is None 时会创建 dummy state,并把 _initial_state 传给 chunk_gated_delta_rule_fwd_h,导致其内部 USE_INITIAL_STATE=initial_state is not None 恒为 True,从而多做一次从 dummy state 的读取(虽为全 0 但仍有带宽开销)。建议将“是否使用初始状态”的语义与“传入一个有效占位 Tensor 以满足 kernel 指针要求”解耦:始终传有效 Tensor,但用单独 flag 控制 USE_INITIAL_STATE,以避免不必要的内存读。
Motivation
为支持 Qwen3.5 模型中 GatedDeltaNet Attention 计算操作,新增了基于 FLA (Flash Linear Attention) 的 Triton kernel 实现。
GatedDeltaNet (GDN) 是 Qwen3.5 中使用的线性注意力机制,其 Prefill 阶段使用 6-step chunked WY 算法,Decode 阶段使用 fused recurrent 算法。本 PR 将 FLA 的 Triton kernel(经由 SGLang 验证)移植到 FastDeploy(PaddlePaddle)中,所有 Triton kernel 体与 SGLang/FLA 原版完全一致,仅 Python wrapper 从 PyTorch 适配为 PaddlePaddle。
此外,新增了
causal_conv1dTriton kernel,为 GDN 中的因果卷积操作提供 Decode (update) 和 Prefill (varlen fn) 两种路径。另外,新增了
fused_gdn_gatingTriton kernel,将 GDN 中的门控计算(softplus + exp + sigmoid)融合为单个 kernel launch,避免多次 GPU kernel launch 开销。Modifications
新增文件
FLA Triton Kernel 包 (
fastdeploy/model_executor/ops/triton_ops/fla/, 14 个文件):__init__.pyutils.pyis_nvidia_hopper,is_gather_supported)、input_guard装饰器、Triton kernel 编译缓存管理op.pyexp,log,safe_exp,gather)index.pyprepare_lens,prepare_chunk_indices,prepare_chunk_offsets)cumsum.pyl2norm.pychunk_scaled_dot_kkt.pysolve_tril.pywy_fast.pychunk_delta_h.pychunk_o.pyfused_recurrent.pychunk.pyfused_gdn_gating.pyg = -exp(A_log)*softplus(a+dt_bias),beta = sigmoid(b)Causal Conv1d Triton Kernel (
fastdeploy/model_executor/ops/triton_ops/causal_conv1d.py, 1 个文件):causal_conv1d_updatecausal_conv1d_fnquery_start_loc/has_initial_state参数单元测试 (
tests/model_executor/ops/triton_ops/test_gdn_kernels.py, 1 个文件):包含 14 个测试用例,覆盖 5 大类:
TestFusedRecurrentGDNTestChunkGDNTestCausalConv1dUpdateTestCausalConv1dFnTestFusedGDNGating移植策略
torch.Tensor→paddle.Tensortorch.empty/torch.empty_like→paddle.empty/paddle.empty_liketorch.float32→paddle.float32tensor.stride(i)→tensor.strides[i]torch.autograd.Function(推理无需反向传播)einops.rearrange依赖chunk.py:g_cumsum强制使用output_dtype=paddle.float32以避免 bf16 精度溢出chunk.py:initial_state=None时创建 dummy 零状态和索引张量,避免 Triton kernel 收到空指针Usage or Command
1. GDN SSM Kernel — Prefill (chunk algorithm)
2. GDN SSM Kernel — Decode (fused recurrent, pool-index)
3. Causal Conv1d — Decode (单 token 更新)
4. Causal Conv1d — Prefill (varlen)
5. Fused GDN Gating — 融合门控计算
运行单元测试
cd FastDeploy python -m pytest tests/model_executor/ops/triton_ops/test_gdn_kernels.py -vAccuracy Tests
基准参考实现(Pure-Paddle,从 HuggingFace Transformers 的 PyTorch 参考实现移植而来)覆盖 GDN recurrent / chunk / conv1d 三类操作。
测试精度(bf16 输入):
Checklist
[Qwen3.5][Feature]]pre-commitbefore commit.