Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion format/format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
cd "$(dirname "$0")" || exit

# force ruff/isort to be same version as setup.py
pip install -U ruff==0.13.0
pip install -U ruff==0.14.2
#isort==6.0.1

ruff check ../gptqmodel/models ../gptqmodel/nn_modules ../gptqmodel/quantization ../gptqmodel/utils ../gptqmodel/__init__.py ../examples ../tests ../setup.py --fix --unsafe-fixes
Expand Down
91 changes: 91 additions & 0 deletions gptqmodel/quantization/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@

import torch


try:
from torchao.prototype.mx_formats.kernels import f4_unpacked_to_f32, unpack_uint4
except Exception:
unpack_uint4 = None
f4_unpacked_to_f32 = None

__all__ = [
"device_supports_native_fp8",
"dequantize_f8_e4m3",
"dequantize_f4_e2m1",
]


Expand Down Expand Up @@ -162,3 +170,86 @@ def _expand_scale(scale_tensor: torch.Tensor, *, axis_hint: Optional[int]) -> to
result = result / scale_tensor

return result


def dequantize_f4_e2m1(
tensor: torch.Tensor,
*,
scale: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
axis: Optional[int] = 0,
target_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
"""Dequantize FP4 (E2M1) values packed as two nibbles per byte."""

if unpack_uint4 is None or f4_unpacked_to_f32 is None:
raise RuntimeError("torchao with nvfp4 support is required for FP4 dequantization")

if scale is not None and scale_inv is not None:
raise ValueError("Provide either scale or scale_inv, not both")

if tensor.dtype is not torch.uint8:
raise ValueError("FP4 packed tensors must use torch.uint8 storage")

orig_shape = list(tensor.shape)
if not orig_shape:
raise ValueError("Tensor must have at least one dimension")

unpacked = unpack_uint4(tensor.reshape(-1))
expanded_shape = orig_shape[:-1] + [orig_shape[-1] * 2]
unpacked = unpacked.view(*expanded_shape)

result = f4_unpacked_to_f32(unpacked).to(target_dtype)

def _expand_scale_fp4(scale_tensor: torch.Tensor, *, axis_hint: Optional[int]) -> torch.Tensor:
if scale_tensor.ndim == 0:
return scale_tensor

target_shape = result.shape

if scale_tensor.shape == target_shape:
return scale_tensor

if scale_tensor.ndim == 2 and len(target_shape) == 2:
blocks_r, blocks_c = scale_tensor.shape
rows, cols = target_shape
if rows % blocks_r == 0 and cols % blocks_c == 0:
repeat_r = rows // blocks_r
repeat_c = cols // blocks_c
expanded = scale_tensor.repeat_interleave(repeat_r, dim=0)
expanded = expanded.repeat_interleave(repeat_c, dim=1)
return expanded

if scale_tensor.ndim == result.ndim:
expanded = scale_tensor
for dim, (target_size, current_size) in enumerate(zip(result.shape, expanded.shape)):
if target_size == current_size:
continue
if current_size == 1:
expanded = expanded.expand(*[
target_size if i == dim else expanded.shape[i]
for i in range(expanded.ndim)
])
continue
if target_size % current_size != 0:
raise ValueError(
f"Cannot broadcast scale dimension {current_size} to target {target_size}"
)
repeat = target_size // current_size
expanded = expanded.repeat_interleave(repeat, dim=dim)
return expanded

reshaped = _reshape_for_axis(scale_tensor, axis_hint, result.ndim)
return reshaped.expand(result.shape)

if scale is not None:
scale_tensor = _expand_scale_fp4(scale.to(result.dtype), axis_hint=axis)
result = result * scale_tensor
elif scale_inv is not None:
scale_tensor = _expand_scale_fp4(scale_inv.to(result.dtype), axis_hint=axis)
if torch.max(torch.abs(scale_tensor)) <= 1:
result = result * scale_tensor
else:
result = result / scale_tensor

return result
Loading