Skip to content
415 changes: 376 additions & 39 deletions gptqmodel/looper/module_looper.py

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from ..nn_modules.qlinear.lookahead import configure_default_lookahead
from ..nn_modules.qlinear.torch import TorchQuantLinear
from ..quantization import QuantizeConfig
from ..quantization.config import FORMAT, METHOD, QUANTIZE_BLACK_LIST, dynamic_get
from ..quantization.config import FORMAT, METHOD, QUANTIZE_BLACK_LIST, VRAMStrategy, dynamic_get
from ..quantization.rotation.rotation import fuse_layer_norms, rotate_model
from ..utils.backend import BACKEND
from ..utils.data import collate_data
Expand Down Expand Up @@ -180,6 +180,9 @@ class BaseQModel(nn.Module):
# monkey patch api for trust_remote_code=True models that have broken transformer compat
require_monkeypatch = False

# VRAM strategy support list
supported_vram_strategies: List[VRAMStrategy] = [VRAMStrategy.EXCLUSIVE, VRAMStrategy.BALANCED]

# some models have broken attention mask codes so we need to only use batch 1 with no masks
support_batch_quantize = True

Expand Down
1 change: 1 addition & 0 deletions gptqmodel/models/definitions/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from ..base import BaseQModel
from ...quantization.config import VRAMStrategy


class GLM4MoEGPTQ(BaseQModel):
Expand Down
3 changes: 3 additions & 0 deletions gptqmodel/models/definitions/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from ...quantization import METHOD
from ...quantization.config import VRAMStrategy
from ..base import BaseQModel


class Qwen3MoeQModel(BaseQModel):
require_monkeypatch = False

# allow dynamic expert index for layer_modules so we don't need to write out 64 layers here
# config.num_experts contains the actual expert count used for index
dynamic_expert_index = "num_experts"
Expand Down
14 changes: 13 additions & 1 deletion gptqmodel/models/definitions/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from ...quantization import METHOD
from ...quantization.config import VRAMStrategy
from ..base import BaseQModel


Expand Down Expand Up @@ -37,7 +39,7 @@ class Qwen3NextGPTQ(BaseQModel):
# MLP / MoE
"mlp": {
# MoE router + shared expert (Qwen3NextSparseMoeBlock)
"gate": ("gate",), # router gate linear
"gate": ("gate:!",), # router gate linear
"shared_expert_gate": ("shared_expert_gate:!",), # <-- single (1, N) logic projections should not be quantized
"shared_expert": ("gate_proj", "up_proj", "down_proj"),

Expand All @@ -48,3 +50,13 @@ class Qwen3NextGPTQ(BaseQModel):
},
},
]

module_tree_overrides = {
METHOD.AWQ: [
{
"mlp": {
"gate": ("gate",),
}
}
]
}
36 changes: 31 additions & 5 deletions gptqmodel/nn_modules/hooked_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,19 @@ def from_conv1d(m: transformers.Conv1D):

@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(device=self.weight.data.device)
original_device = input.device
target_device = self.weight.data.device
if original_device != target_device:
input = input.to(device=target_device)
output = super().forward(input)

if self.forward_hook:
self.forward_hook(self, (input,), output)
if self.forward_hook_last:
raise STOP_FORWARD_EXCEPTION.with_traceback(None)

if output.device != original_device:
output = output.to(device=original_device)
return output

class HookedConv1d(torch.nn.Conv1d):
Expand Down Expand Up @@ -98,12 +104,17 @@ def from_conv1d(m: torch.nn.Conv1d):

@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(device=self.weight.data.device)
original_device = input.device
target_device = self.weight.data.device
if original_device != target_device:
input = input.to(device=target_device)
output = super().forward(input)
if self.forward_hook:
self.forward_hook(self, (input,), output)
if self.forward_hook_last:
raise STOP_FORWARD_EXCEPTION.with_traceback(None)
if output.device != original_device:
output = output.to(device=original_device)
return output

# Models using conv2d: ovis
Expand Down Expand Up @@ -156,12 +167,17 @@ def from_conv2d(m: torch.nn.Conv2d):

@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(device=self.weight.data.device)
original_device = input.device
target_device = self.weight.data.device
if original_device != target_device:
input = input.to(device=target_device)
output = super().forward(input)
if self.forward_hook:
self.forward_hook(self, (input,), output)
if self.forward_hook_last:
raise STOP_FORWARD_EXCEPTION.with_traceback(None)
if output.device != original_device:
output = output.to(device=original_device)
return output

# Models using transformers.conv1d: gpt2
Expand All @@ -182,12 +198,17 @@ def from_conv1d(conv1d: transformers.Conv1D):

@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(device=self.weight.data.device)
original_device = input.device
target_device = self.weight.data.device
if original_device != target_device:
input = input.to(device=target_device)
output = super().forward(input)
if self.forward_hook:
self.forward_hook(self, (input,), output)
if self.forward_hook_last:
raise STOP_FORWARD_EXCEPTION.with_traceback(None)
if output.device != original_device:
output = output.to(device=original_device)
return output

class HookedLinear(torch.nn.Linear):
Expand All @@ -209,12 +230,17 @@ def from_linear(linear: torch.nn.Linear):

@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(device=self.weight.data.device)
original_device = input.device
target_device = self.weight.data.device
if original_device != target_device:
input = input.to(device=target_device)
output = super().forward(input)
if self.forward_hook:
self.forward_hook(self, (input,), output)
if self.forward_hook_last:
raise STOP_FORWARD_EXCEPTION.with_traceback(None)
if output.device != original_device:
output = output.to(device=original_device)
return output


Expand Down
20 changes: 20 additions & 0 deletions gptqmodel/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class METHOD(str, Enum):
AWQ = "awq"


class VRAMStrategy(str, Enum):
EXCLUSIVE = "exclusive"
BALANCED = "balanced"


QUANT_METHOD_FORMAT_MAPPING = {
METHOD.GPTQ: {
FORMAT.GPTQ,
Expand Down Expand Up @@ -234,6 +239,9 @@ class QuantizeConfig():
hessian_chunk_bytes: Optional[int] = field(default=None, metadata={"help": "Memory budget (in bytes) for Hessian chunk staging"})
hessian_use_bfloat16_staging: bool = field(default=False, metadata={"help": "Stage Hessian chunks in bfloat16 when supported"})

# VRAM allocation strategy for MoE-heavy subsets
vram_strategy: VRAMStrategy = field(default=VRAMStrategy.EXCLUSIVE)

def __post_init__(self):
fields_info = fields(self)

Expand Down Expand Up @@ -368,6 +376,18 @@ def __post_init__(self):
self.offload_to_disk_path = f"./gptqmodel_offload/{path_key}/"
log.info(f"QuantizeConfig: offload_to_disk_path auto set to `{self.offload_to_disk_path}`")

if isinstance(self.vram_strategy, str):
try:
self.vram_strategy = VRAMStrategy(self.vram_strategy.lower())
except ValueError as exc:
raise ValueError(
f"QuantizeConfig: `vram_strategy` must be one of {[v.value for v in VRAMStrategy]}."
) from exc
elif not isinstance(self.vram_strategy, VRAMStrategy):
raise ValueError(
f"QuantizeConfig: `vram_strategy` must be one of {[v.value for v in VRAMStrategy]}."
)

def extension_set(self, key: str, value: Any):
if self.adapter is None:
self.adapter = {}
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

__version__ = "5.0.0"
__version__ = "5.1.0-dev"
4 changes: 3 additions & 1 deletion tests/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def is_flash_attn_2_available(): # type: ignore
from gptqmodel.models.base import BaseQModel # noqa: E402
from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402
from gptqmodel.quantization import FORMAT, METHOD # noqa: E402
from gptqmodel.quantization.config import QuantizeConfig # noqa: E402
from gptqmodel.quantization.config import QuantizeConfig, VRAMStrategy # noqa: E402
from gptqmodel.utils.eval import EVAL # noqa: E402
from gptqmodel.utils.model import MODALITY # noqa: E402
from gptqmodel.utils.torch import torch_empty_cache # noqa: E402
Expand All @@ -79,6 +79,7 @@ def is_flash_attn_2_available(): # type: ignore
class ModelTest(unittest.TestCase):
DEBUG = True # enable extra debug output

VRAM_STRATEGY = VRAMStrategy.EXCLUSIVE
TRUST_REMOTE_CODE = False
APPLY_CHAT_TEMPLATE = False
TORCH_DTYPE = "auto"
Expand Down Expand Up @@ -717,6 +718,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne
v2=self.V2,
adapter=self.EORA,
pack_impl="cpu",
vram_strategy=self.VRAM_STRATEGY,
)

log.info(f"Quant config: {quantize_config}")
Expand Down
4 changes: 3 additions & 1 deletion tests/models/test_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from gptqmodel.quantization.config import VRAMStrategy
from model_test import ModelTest

from gptqmodel.utils.eval import EVAL
Expand All @@ -20,6 +20,8 @@ class TestQwen3Moe(ModelTest):
"acc_norm": {"value": 0.5486, "floor_pct": 0.04},
},
}

VRAM_STRATEGY = VRAMStrategy.BALANCED
# TRUST_REMOTE_CODE = False
# APPLY_CHAT_TEMPLATE = True
# EVAL_BATCH_SIZE = 6
Expand Down
4 changes: 4 additions & 0 deletions tests/models/test_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from gptqmodel.quantization.config import VRAMStrategy
from model_test import ModelTest

from gptqmodel.utils.eval import EVAL
Expand All @@ -24,6 +25,9 @@ class TestQwen3Next(ModelTest):
"acc": {"value": 0.8403, "floor_pct": 0.04},
},
}

VRAM_STRATEGY = VRAMStrategy.BALANCED
# DATASET_SIZE = 2048
# TRUST_REMOTE_CODE = True
# APPLY_CHAT_TEMPLATE = True
# EVAL_BATCH_SIZE = 4
Expand Down