Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gptqmodel/looper/awq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
require_fwd=require_fwd,
fwd_after_process=True,
subset_forward_early_stop=True,
enable_activation_capture_flag=True,
)

self.calculate_w_wq_diff = calculate_w_wq_diff
Expand Down
3 changes: 3 additions & 0 deletions gptqmodel/looper/loop_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
fwd_after_process: bool = True,
fwd_all_modules_in_single_pass: bool = False,
subset_forward_early_stop: bool = False,
enable_activation_capture_flag: bool = False,
):
# process level lock
self.lock = threading.Lock()
Expand Down Expand Up @@ -91,6 +92,8 @@ def __init__(
self.fwd_all_modules_in_single_pass = fwd_all_modules_in_single_pass # default False
# when True, stop the layer forward immediately after the final module in a subset fires
self.subset_forward_early_stop = subset_forward_early_stop
# enable capture-only hooks (e.g. ':?') for processors that require activations
self.enable_activation_capture = enable_activation_capture_flag

self.inputs_cache: InputCache = InputCache(None, None, None, None)
self.tasks = {}
Expand Down
4 changes: 4 additions & 0 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,10 +1129,14 @@ def _loop_impl(self, fail_safe: bool = False, **kwargs):
region_timer.flush()

is_awq_quantize = any(isinstance(proc, AWQProcessor) for proc in self.processors)
requires_activation_capture = any(
getattr(proc, "enable_activation_capture", False) for proc in self.processors
)
layer_modules = self.gptq_model.simple_layer_modules(
model_config=self.gptq_model.model.config,
quantize_config=self.gptq_model.quantize_config,
is_awq_quantize=is_awq_quantize,
include_capture_only=requires_activation_capture,
)

# true-sequential will replay the quantized activations after each subset has been quantized to be used for next subset quantization
Expand Down
141 changes: 91 additions & 50 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def apply_module_tree_override(module_tree, override):


NOT_QUANTIZE_FLAG = ":!"
CAPTURE_ONLY_FLAG = ":?"
NON_QUANTIZE_FLAGS = (NOT_QUANTIZE_FLAG, CAPTURE_ONLY_FLAG)


# Fix cpu memory leak.
Expand Down Expand Up @@ -362,7 +364,11 @@ def get_num_experts(cls, model_config):
@classmethod
def filter_not_quantize_module(cls, layer_modules, quantize_config):
layer_modules = [
[name for name in block if NOT_QUANTIZE_FLAG not in name]
[
name
for name in block
if all(flag not in name for flag in NON_QUANTIZE_FLAGS)
]
for block in layer_modules
]
layer_modules = [block for block in layer_modules if block] # 去掉空 block
Expand All @@ -384,8 +390,8 @@ def filter_not_quantize_module(cls, layer_modules, quantize_config):
# List them in the order executed in model forward() code
# Many models have same execution order of: attention (q_k_v) projection, attention (output) projection, mlp (n) projections
@classmethod
def simple_layer_modules(cls, model_config, quantize_config, is_awq_quantize: bool = False):
layer_modules = cls.build_layer_modules(cls.module_tree)
def simple_layer_modules(cls, model_config, quantize_config, is_awq_quantize: bool = False, include_capture_only: bool = False):
layer_modules = cls.build_layer_modules(cls.module_tree, include_capture_only=include_capture_only)

layer_modules = cls.build_moe_modules_if_need(model_config, layer_modules, is_awq_quantize)

Expand All @@ -395,8 +401,8 @@ def simple_layer_modules(cls, model_config, quantize_config, is_awq_quantize: bo
return layer_modules

@classmethod
def full_layer_modules(cls, model_config=None, is_awq_quantize: bool = False):
full = cls.build_layer_modules(cls.module_tree)
def full_layer_modules(cls, model_config=None, is_awq_quantize: bool = False, include_capture_only: bool = False):
full = cls.build_layer_modules(cls.module_tree, include_capture_only=include_capture_only)
full = cls.build_moe_modules_if_need(model_config, full, is_awq_quantize)
# print(f"full layer_modules: {full}")
return full
Expand Down Expand Up @@ -960,19 +966,19 @@ def awq_get_modules_for_scaling(self, module, input_feat, module_kwargs):
if self.model.config is not None and self.dynamic_expert_index is not None:
num_experts = self.get_num_experts(self.model.config)

def strip_not_quantize_flag(module_name):
if NOT_QUANTIZE_FLAG in module_name:
return module_name[:module_name.find(NOT_QUANTIZE_FLAG)]
else:
return module_name
def strip_non_quantize_flags(module_name):
for flag in NON_QUANTIZE_FLAGS:
if flag in module_name:
module_name = module_name.replace(flag, "")
return module_name

def _select_feature_name(names):
"""Return the first quantized child that has captured activations."""
for raw in names:
stripped = strip_not_quantize_flag(raw)
stripped = strip_non_quantize_flags(raw)
if stripped in input_feat:
return stripped
return strip_not_quantize_flag(names[0]) if names else None
return strip_non_quantize_flags(names[0]) if names else None

def _try_update_last_module(candidate_name: str) -> bool:
nonlocal last_module, last_module_name, last_module_root
Expand All @@ -992,18 +998,22 @@ def _try_update_last_module(candidate_name: str) -> bool:
last_module_root = candidate_name.split(".", 1)[0]
return True

full_layer_modules = self.full_layer_modules(self.model.config, is_awq_quantize=True)
full_layer_modules = self.full_layer_modules(
self.model.config,
is_awq_quantize=True,
include_capture_only=True,
)
for i, block in enumerate(full_layer_modules):
not_quantized = all(NOT_QUANTIZE_FLAG in name for name in block)
not_quantized = all(any(flag in name for flag in NON_QUANTIZE_FLAGS) for name in block)
if not_quantized:
# If both the current block and the previous one are marked as not quantized,
# skip remembering the current block. This ensures that when two consecutive
# blocks are not quantized, only the first one is remembered as last_module.
if i > 0 and all(NOT_QUANTIZE_FLAG in name for name in full_layer_modules[i - 1]):
if i > 0 and all(any(flag in name for flag in NON_QUANTIZE_FLAGS) for name in full_layer_modules[i - 1]):
continue

# Remember the latest norm (use the last entry if multiple are present)
candidate_name = strip_not_quantize_flag(block[-1])
candidate_name = strip_non_quantize_flags(block[-1])
_try_update_last_module(candidate_name)
continue

Expand Down Expand Up @@ -1034,7 +1044,7 @@ def _try_update_last_module(candidate_name: str) -> bool:
subset = [] # preserve execution order while collecting quantizable modules
skip = False
for name in block:
if NOT_QUANTIZE_FLAG not in name:
if all(flag not in name for flag in NON_QUANTIZE_FLAGS):
m, _ = get_module_by_name_prefix(module, name)
# If the Model uses GQA (Grouped Query Attention), attention out will be skipped.
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
Expand All @@ -1060,7 +1070,7 @@ def _try_update_last_module(candidate_name: str) -> bool:
continue

# Match the activation bucket to the first quantized child in this block
feature_name = _select_feature_name(block) or strip_not_quantize_flag(block[0])
feature_name = _select_feature_name(block) or strip_non_quantize_flags(block[0])
root_split = feature_name.split(".")
module2inspect = None
if len(root_split) >= 2:
Expand Down Expand Up @@ -1091,7 +1101,7 @@ def _try_update_last_module(candidate_name: str) -> bool:
nodes.append(n)

# Update tracker to the LAST item of this block
candidate_name = strip_not_quantize_flag(block[-1])
candidate_name = strip_non_quantize_flags(block[-1])
_try_update_last_module(candidate_name)

import torch
Expand Down Expand Up @@ -1305,12 +1315,13 @@ def shell_module_materialize(
# else:
# log.info(f"{self.__class__.__name__}: `MODEL switching to eval mode.")
@classmethod
def build_layer_modules(cls, tree):
def build_layer_modules(cls, tree, include_capture_only: bool = False):
"""
tree format:
[<model_name>, <submodule>, "#", { parent_module: ( "child[:!][:grp]", ... ), ... }]
Rules:
- ':!' means participates in inference but is NOT quantized; keep this marker in output.
- ':?' marks capture-only nodes; activations are recorded but the module is not quantized.
- ':<digit>' means grouping; children with the same group id are emitted in the same block.
- Both can appear together, e.g. 'module_name:!:2'.
- Supports nested dict structures for MoE models with experts.
Expand All @@ -1329,57 +1340,81 @@ def build_layer_modules(cls, tree):

out_blocks = []

def process_entries(parent, entries, parent_group_offset=0):
def _parse_token(token: str) -> tuple[str, List[str]]:
parts = token.split(":")
name = parts[0]
flags = [p for p in parts[1:] if p]
return name, flags

def _group_from_flags(flags: List[str]) -> int:
for flag in flags:
if flag.isdigit():
return int(flag)
return 0

def process_entries(parent_token: str, entries, parent_group_offset: int = 0):
"""Process entries recursively to handle nested dict structures for MoE"""
groups = defaultdict(list)
groups: defaultdict[int, List[tuple[str, bool, bool]]] = defaultdict(list)

parent_name, parent_flags = _parse_token(parent_token)
parent_group = parent_group_offset + _group_from_flags(parent_flags)
parent_has_bang = "!" in parent_flags
parent_capture_only = "?" in parent_flags

child_group_offset = parent_group_offset
add_parent = parent_has_bang or (parent_capture_only and include_capture_only)
if add_parent:
groups[parent_group].append((parent_name, parent_has_bang, parent_capture_only))
child_group_offset = max(child_group_offset, parent_group + 1)

# Handle tuple/list of strings (traditional format)
if isinstance(entries, (tuple, list)):
for ent in entries:
parts = ent.split(':')
child = parts[0]
child_name, child_flags = _parse_token(ent)

flags = parts[1:]
has_bang = ('!' in flags)
has_bang = "!" in child_flags
capture_only = "?" in child_flags
# first numeric tag is the group id; default 0
grp = next((int(p) for p in flags if p.isdigit()), 0)
grp = child_group_offset + _group_from_flags(child_flags)
# Apply parent group offset to avoid conflicts between different nesting levels
grp += parent_group_offset

# Store the full path including parent for later use
# Special case: if parent ends with the same name as child, don't duplicate
if parent.endswith(f".{child}"):
full_path = parent
if parent_name.endswith(f".{child_name}") or parent_name == child_name:
full_path = parent_name
elif parent_name:
full_path = f"{parent_name}.{child_name}"
else:
full_path = f"{parent}.{child}" if parent != child else child
groups[grp].append((full_path, has_bang))
full_path = child_name

if capture_only and not include_capture_only:
continue
groups[grp].append((full_path, has_bang, capture_only))

# Handle nested dict structure (MoE format)
elif isinstance(entries, dict):
# Calculate max group number used at current level to avoid conflicts
max_current_group = 0
for sub_parent, sub_entries in entries.items():
if isinstance(sub_entries, (tuple, list)):
for ent in sub_entries:
parts = ent.split(':')
flags = parts[1:]
grp = next((int(p) for p in flags if p.isdigit()), 0)
max_current_group = max(max_current_group, grp)
_, ent_flags = _parse_token(ent)
max_current_group = max(max_current_group, _group_from_flags(ent_flags))

# Process nested entries with appropriate group offset
current_offset = parent_group_offset
current_offset = child_group_offset
for sub_parent, sub_entries in entries.items():
if sub_parent == "#":
# Special case: "#" means expert index placeholder
# Create a template path that will be expanded later by simple_layer_modules
template_parent = f"{parent}.{EXPERT_INDEX_PLACEHOLDER}"
template_parent = (
f"{parent_name}.{EXPERT_INDEX_PLACEHOLDER}"
if parent_name else EXPERT_INDEX_PLACEHOLDER
)
# Use a higher offset for expert modules to avoid conflicts with parent level
expert_offset = current_offset + max_current_group + 100 # Large offset to avoid conflicts

# Handle special case where sub_entries is ("#",) or "#" - this means use the parent path directly
if (isinstance(sub_entries, (tuple, list)) and len(sub_entries) == 1 and sub_entries[0] == "#") or sub_entries == "#":
# For ("#",) or "#" format, use the template_parent directly with default group 0
groups[expert_offset].append((template_parent, False))
groups[expert_offset].append((template_parent, False, False))
else:
sub_groups = process_entries(template_parent, sub_entries, expert_offset)
for grp, items in sub_groups.items():
Expand All @@ -1388,9 +1423,12 @@ def process_entries(parent, entries, parent_group_offset=0):
# Nested structure: process recursively with full path
# Special case: empty string key means use parent path directly
if sub_parent == "":
full_sub_parent = parent
full_sub_parent = parent_name
else:
full_sub_parent = f"{parent}.{sub_parent}"
full_sub_parent = (
f"{parent_name}.{sub_parent}"
if parent_name else sub_parent
)
sub_groups = process_entries(full_sub_parent, sub_entries, current_offset)
for grp, items in sub_groups.items():
groups[grp].extend(items)
Expand All @@ -1412,11 +1450,14 @@ def process_entries(parent, entries, parent_group_offset=0):
# continue

block = []
for full_path, has_bang in items:
for full_path, has_bang, capture_only in items:
# The full path is already constructed in process_entries
name = full_path
if has_bang:
full_path += NOT_QUANTIZE_FLAG
block.append(full_path)
name += NOT_QUANTIZE_FLAG
if capture_only and include_capture_only:
name += CAPTURE_ONLY_FLAG
block.append(name)

out_blocks.append(block)

Expand Down Expand Up @@ -1458,8 +1499,8 @@ def get_base_modules(cls, model):
def generate_layers_modules_tree_simple(self, node):
"""
Recursively walk a nested list/dict structure and:
1. Drop dict entries where *all* values are ':!' flagged.
2. Remove ':!' and ':<digit>' markers from strings.
1. Drop dict entries where *all* values are ':!' or ':?' flagged.
2. Remove ':!' / ':?' and ':<digit>' markers from strings.
"""

# If it's a list, recurse into each element
Expand All @@ -1473,7 +1514,7 @@ def generate_layers_modules_tree_simple(self, node):
# Expand tuple-of-strings blocks (special handling)
if isinstance(v, (tuple, list)) and all(isinstance(x, str) for x in v):
# Rule 1: check if ALL entries are :!
if all(any(p == "!" for p in x.split(":")[1:]) for x in v):
if all(any(p in {"!", "?"} for p in x.split(":")[1:]) for x in v):
continue # skip this parent entirely

# Rule 2: strip :! and :digit markers
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/models/definitions/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Qwen3MoeQModel(BaseQModel):
"input_layernorm": ("input_layernorm:!",),
"self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"),
"post_attention_layernorm": ("post_attention_layernorm:!",),
"mlp": {
"mlp:?": {
"gate": ("gate:!",), # <-- 0.5MB per layer. Not worth quantizing
"experts": {
"#": ("gate_proj:0", "up_proj:0", "down_proj:1"),
Expand Down
Loading