Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/models/core/qwen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,15 @@ mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickst

```

### NVFP4 quantization

TRTLLM supports NVFP4 precision with blocksize=16 for both activations and GEMM weights.
To run the Qwen3-Next model on NVFP4 precision, use the following command
```bash
mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickstart_advanced.py --model_dir <YOUR_MODEL_DIR> --kv_cache_fraction 0.6 --disable_kv_cache_reuse --max_batch_size 1 --tp_size 2 --trust_remote_code

```

## Notes and Troubleshooting

- **Model Directory:** Update `<YOUR_MODEL_DIR>` with the actual path where the model weights reside.
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,9 @@ def _attn_impl(

out_scale = None
out_scale_sf = None
if self.has_quant_scale:
if self.has_quant_scale and not self.attn_output_gate:
out_scale = self.o_proj.inv_input_scale
if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output:
if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output and not self.attn_output_gate:
out_scale_sf = self.o_proj.input_scale

kv_scales_sf = None
Expand Down
47 changes: 46 additions & 1 deletion tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...models.modeling_utils import QuantConfig
from ..cublaslt_utils import IS_CUBLASLT_AVAILABLE
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
from ..utils import Fp4QuantizedTensor
from ..utils import Fp4QuantizedTensor, unswizzle_sf


class WeightMode(str, enum.Enum):
Expand Down Expand Up @@ -824,6 +824,9 @@ def apply(self, module: Linear, input: torch.Tensor,
act_sf,
module.weight_scale,
module.alpha, module.dtype)
# Take the dim of out_features if padded.
if output.shape[-1] > module.out_features:
output = output[..., :module.out_features]

if bias is not None:
output = output + bias
Expand Down Expand Up @@ -957,6 +960,48 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
copy_weight(module.alpha, alpha)
module.scalar_alpha = alpha.item()

def post_load_weights(self, module: Linear):
super().post_load_weights(module)
"""
Pad weight and weight_scale tensors to meet torch trtllm NVFP4 GEMM alignment requirements.

Args:
row_alignment: Required row alignment (default: 32)
col_alignment: Required column alignment (default: 16)
"""
row_alignment, col_alignment = 32, 16
row_pad_size = (row_alignment - module.weight.size(0)) % row_alignment
col_pad_size = (col_alignment - module.weight.size(1)) % col_alignment
if row_pad_size != 0 or col_pad_size != 0:
# Pad weight to meet NVFP4 GEMM kernel alignment requirements
module.weight = Parameter(F.pad(module.weight,
(0, col_pad_size, 0, row_pad_size),
mode='constant',
value=0),
requires_grad=False)
weight_col_size = module.weight.size(1)
assert (
weight_col_size * 2
) % module.scaling_vector_size == 0, f"weight column size after padding {weight_col_size} must be divisible by scaling_vector_size {module.scaling_vector_size}"
# Pad weight_scale to match padded weight dimensions
# Padding should be performed on unswizzled weight_scale tensor
scale_rows = fp4_utils.pad_up(module.out_features, 128)
scale_cols = fp4_utils.pad_up(
module.in_features // module.scaling_vector_size, 4)
weight_scale_unswizzle = unswizzle_sf(module.weight_scale.data,
scale_rows, scale_cols,
module.scaling_vector_size)
weight_scale_unswizzle_pad = F.pad(
weight_scale_unswizzle,
(0, (col_pad_size * 2) // module.scaling_vector_size, 0,
row_pad_size),
mode='constant',
value=0)
module.weight_scale = Parameter(
torch.ops.trtllm.block_scale_interleave(
weight_scale_unswizzle_pad),
requires_grad=False)


class W4A8NVFP4FP8LinearMethod(LinearMethodBase):

Expand Down