Skip to content
Merged

Cleanup #2185

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,9 +975,8 @@ def awq_get_modules_for_scaling(self, module, input_feat, module_kwargs):
last_module_name = None
last_module_root = None # self_attn.* has root == self_attn, mlp.* has root == mlp

num_experts = None
if self.model.config is not None and self.dynamic_expert_index is not None:
num_experts = self.get_num_experts(self.model.config)
self.get_num_experts(self.model.config)

def strip_non_quantize_flags(module_name):
for flag in NON_QUANTIZE_FLAGS:
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/models/definitions/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,4 @@ def before_model_load(self, load_quantized_model=False):
import transformers.models.gpt_oss.modeling_gpt_oss as gpt_oss_modeling

gpt_oss_modeling.GptOssExperts = GptOssExpertsNew
gpt_oss_modeling.GptOssTopKRouter = GptOssTopKRouterNew
gpt_oss_modeling.GptOssTopKRouter = GptOssTopKRouterNew
1 change: 0 additions & 1 deletion gptqmodel/models/definitions/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,3 @@ def forward(self, hidden_states: torch.Tensor):
return out, router_logits

llama4_modeling.Llama4TextMoe = SequentialLlama4TextMoe

5 changes: 3 additions & 2 deletions gptqmodel/nn_modules/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def convert_gpt_oss_expert_converter(module, config):
import torch.nn as nn
import transformers.models.gpt_oss.modeling_gpt_oss as gpt_oss_modeling
from transformers.integrations.hub_kernels import use_kernel_forward_from_hub

from ..models.definitions.gpt_oss import GptOssExpertsNew

@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
Expand All @@ -27,7 +28,7 @@ def forward(self, hidden_states):
for name, sub_module in module.named_modules():
if isinstance(sub_module, gpt_oss_modeling.GptOssMLP):
new_module = GptOssMLPNew(config=config, ori_mlp=sub_module)
setattr(module, name, new_module)
setattr(module, name, new_module)

return module

Expand Down Expand Up @@ -96,7 +97,7 @@ def forward(self, hidden_states: torch.Tensor):
for name, sub_module in module.named_modules():
if isinstance(sub_module, Llama4TextMoe):
new_module = SequentialLlama4TextMoe(config=config.get_text_config(), original=sub_module)
setattr(module, name, new_module)
setattr(module, name, new_module)

return module

Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def materialize_global_hessian(self, target_device: Optional[torch.device] = Non
except:
log.warn(f"Quantization: Module `{self.name}` -> Retry partial.to 1/2 in 0.25s")
time.sleep(0.25)
try:
try:
result_accum.add_(partial.to(device=result_accum.device, dtype=torch.float32))
except:
log.warn(f"Quantization: Module `{self.name}` -> Retry partial.to 2/2 in 0.75s")
Expand Down
244 changes: 243 additions & 1 deletion gptqmodel/utils/model_dequant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import shutil
from collections import defaultdict
from pathlib import Path
from typing import Dict, Iterable, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, Optional, Tuple

import torch
from safetensors import safe_open
Expand All @@ -24,6 +24,10 @@

LOG = logging.getLogger(__name__)

if TYPE_CHECKING:
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.quantization.quant_scheme import QuantizationScheme


def load_json(path: Path) -> dict:
if not path.exists():
Expand Down Expand Up @@ -76,6 +80,137 @@ def normalize_device(device: Optional[str]) -> Optional[str]:
return f"cuda:{dev.index}"


def _get_compressed_tensors_dependencies() -> dict:
try:
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.model_compressors.model_compressor import (
map_module_to_scheme,
)
from compressed_tensors.quantization import QuantizationConfig
from compressed_tensors.quantization.lifecycle.apply import apply_quantization_config
except ImportError as exc: # pragma: no cover - exercised when dependency missing
raise RuntimeError(
"Support for compressed-tensors quantized models requires the "
"'compressed-tensors' package. Install it with 'pip install compressed-tensors'."
) from exc

try:
from accelerate import init_empty_weights
except ImportError as exc: # pragma: no cover - exercised when dependency missing
raise RuntimeError(
"Support for compressed-tensors quantized models requires the "
"'accelerate' package. Install it with 'pip install accelerate'."
) from exc

try:
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
)
except ImportError as exc: # pragma: no cover - exercised when dependency missing
raise RuntimeError(
"Support for compressed-tensors quantized models requires the "
"'transformers' package."
) from exc

return {
"QuantizationConfig": QuantizationConfig,
"apply_quantization_config": apply_quantization_config,
"BaseCompressor": BaseCompressor,
"map_module_to_scheme": map_module_to_scheme,
"init_empty_weights": init_empty_weights,
"AutoConfig": AutoConfig,
"AutoModel": AutoModel,
"AutoModelForCausalLM": AutoModelForCausalLM,
"AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM,
}


def _discover_compressed_tensors_module_schemes(
model_path: Path,
quant_config,
*,
deps: dict,
) -> Dict[str, "QuantizationScheme"]:
AutoConfig = deps["AutoConfig"]
init_empty_weights = deps["init_empty_weights"]
apply_quantization_config = deps["apply_quantization_config"]
map_module_to_scheme = deps["map_module_to_scheme"]

config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)

loader_candidates = (
deps["AutoModelForCausalLM"],
deps["AutoModelForSeq2SeqLM"],
deps["AutoModel"],
)

loader_errors: list[tuple[str, Exception]] = []
model = None
for loader in loader_candidates:
if loader is None:
continue

with init_empty_weights(include_buffers=True):
try:
model = loader.from_config(
config,
trust_remote_code=True,
torch_dtype=torch.float32,
)
except Exception as exc: # pragma: no cover - depends on available loaders
loader_errors.append((loader.__name__, exc))
continue
else:
break

if model is None:
if loader_errors:
names = ", ".join(name for name, _ in loader_errors)
last_error = loader_errors[-1][1]
raise RuntimeError(
f"Failed to instantiate model from '{model_path}' while inspecting "
f"compressed-tensors modules. Loaders attempted: {names}."
) from last_error
raise RuntimeError(
f"Failed to instantiate model from '{model_path}' while inspecting "
"compressed-tensors modules."
)

try:
apply_quantization_config(model, quant_config, run_compressed=False)
module_to_scheme = map_module_to_scheme(model)
finally:
del model

return dict(module_to_scheme)


def _prepare_compressed_tensors_context(
model_path: Path, quant_cfg: dict
) -> tuple["QuantizationConfig", Dict[str, "QuantizationScheme"], "BaseCompressor"]: # noqa
deps = _get_compressed_tensors_dependencies()

QuantizationConfig = deps["QuantizationConfig"]
quant_config = QuantizationConfig.model_validate(quant_cfg)
quant_format = (quant_config.format or "").lower()
if quant_format != "pack-quantized":
raise ValueError(
f"Unsupported compressed-tensors format '{quant_config.format}'. "
"Only 'pack-quantized' is currently supported."
)

module_to_scheme = _discover_compressed_tensors_module_schemes(
model_path, quant_config, deps=deps
)

BaseCompressor = deps["BaseCompressor"]
compressor = BaseCompressor.load_from_registry(quant_format, config=quant_config)
return quant_config, module_to_scheme, compressor


def resolve_block_size(config: dict) -> Optional[Tuple[int, int]]:
quant_cfg = config.get("quantization_config", {}) or {}
block_size = quant_cfg.get("weight_block_size")
Expand Down Expand Up @@ -199,6 +334,13 @@ def detect_format(model_path: Path, config: dict) -> str:
if tensor.dtype == torch.uint8 and (key + "_scale") in keys:
LOG.debug("Detected NVFP4 weights via dtype on tensor '%s'", key)
return "nvfp4"
if any(k.endswith(".weight_packed") for k in keys):
LOG.debug(
"Detected compressed-tensors pack-quantized format via '.weight_packed' "
"metadata in shard '%s'",
files[0],
)
return "compressed-pack"
if any(k.endswith(".weight_scale") for k in keys):
LOG.debug("Detected NVFP4 format via '.weight_scale' metadata in shard '%s'", files[0])
return "nvfp4"
Expand All @@ -223,6 +365,15 @@ def detect_format(model_path: Path, config: dict) -> str:
if method == "awq":
LOG.debug("Detected AWQ format via quant_method=%s", method)
return "awq"
if method == "compressed-tensors":
fmt_name = (quant_cfg.get("format") or "").lower()
if fmt_name == "pack-quantized":
LOG.debug(
"Detected compressed-tensors format via quant_method=%s and format=%s",
method,
fmt_name,
)
return "compressed-pack"

raise ValueError("Unable to detect quantization format for model")

Expand Down Expand Up @@ -440,6 +591,72 @@ def convert_gptq_file(path: Path, target_dtype: torch.dtype, config: dict, devic
return tensors


def convert_compressed_pack_file(
path: Path,
target_dtype: torch.dtype,
*,
device: str,
module_to_scheme: Dict[str, "QuantizationScheme"],
compressor: "BaseCompressor",
) -> Dict[str, torch.Tensor]:
tensors: Dict[str, torch.Tensor] = {}
module_buffers: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict)

with safe_open(path, framework="pt", device=device) as reader:
for key in reader.keys():
tensor = reader.get_tensor(key)
if key.endswith(".weight_packed"):
prefix = key[: -len(".weight_packed")]
module_buffers[prefix]["weight_packed"] = tensor
LOG.debug("Collected compressed weight tensor '%s'", key)
elif key.endswith(".weight_scale"):
prefix = key[: -len(".weight_scale")]
module_buffers[prefix]["weight_scale"] = tensor
LOG.debug("Collected compressed scale tensor '%s'", key)
elif key.endswith(".weight_zero_point"):
prefix = key[: -len(".weight_zero_point")]
module_buffers[prefix]["weight_zero_point"] = tensor
LOG.debug("Collected compressed zero-point tensor '%s'", key)
elif key.endswith(".weight_g_idx"):
prefix = key[: -len(".weight_g_idx")]
module_buffers[prefix]["weight_g_idx"] = tensor
LOG.debug("Collected compressed group-index tensor '%s'", key)
elif key.endswith(".weight_shape"):
prefix = key[: -len(".weight_shape")]
module_buffers[prefix]["weight_shape"] = tensor
LOG.debug("Collected compressed shape tensor '%s'", key)
else:
tensors[key] = finalize_for_save(tensor, target_dtype)

for prefix, buf in module_buffers.items():
scheme = module_to_scheme.get(prefix)
if scheme is None:
raise KeyError(
f"No quantization scheme registered for compressed module '{prefix}'."
)

if scheme.weights is None:
raise ValueError(
f"Module '{prefix}' does not define weight quantization parameters."
)

required_fields = {"weight_packed", "weight_scale", "weight_shape"}
missing = required_fields.difference(buf)
if missing:
raise KeyError(
f"Compressed tensors for module '{prefix}' in shard '{path.name}' "
f"are missing required fields: {sorted(missing)}"
)

weight = compressor.decompress_weight(
compressed_data=buf,
quantization_args=scheme.weights,
)
tensors[prefix + ".weight"] = finalize_for_save(weight, target_dtype)

return tensors


def copy_aux_files(model_path: Path, output_path: Path, skip: Iterable[str]) -> None:
for item in model_path.iterdir():
if item.name in skip:
Expand Down Expand Up @@ -488,6 +705,21 @@ def dequantize_model(
else:
LOG.debug("No explicit FP8 block size found; will infer from scale tensors if needed")

compressed_module_to_scheme: Dict[str, "QuantizationScheme"] = {}
compressed_compressor: Optional["BaseCompressor"] = None
if fmt == "compressed-pack":
if not quant_cfg:
raise ValueError(
"compressed-tensors model requires a populated 'quantization_config' entry."
)
_, compressed_module_to_scheme, compressed_compressor = _prepare_compressed_tensors_context(
model_path, quant_cfg
)
LOG.debug(
"Prepared compressed-tensors context with %d modules",
len(compressed_module_to_scheme),
)

log = setup_logger()
LOG.debug(
"Starting dequantization for model '%s' (format=%s, target_dtype=%s, device=%s)",
Expand Down Expand Up @@ -516,6 +748,16 @@ def dequantize_model(
tensors = convert_awq_file(path, target_dtype, open_device)
elif fmt == "gptq":
tensors = convert_gptq_file(path, target_dtype, quant_cfg, open_device)
elif fmt == "compressed-pack":
if compressed_compressor is None:
raise RuntimeError("Compressed-tensors compressor was not initialized")
tensors = convert_compressed_pack_file(
path,
target_dtype,
device=open_device,
module_to_scheme=compressed_module_to_scheme,
compressor=compressed_compressor,
)
else:
raise ValueError(f"Unsupported format {fmt}")

Expand Down
Loading
Loading