diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index ea34b286d..da850dd45 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -76,6 +76,7 @@ def __init__( require_fwd=require_fwd, fwd_after_process=True, subset_forward_early_stop=True, + enable_activation_capture_flag=True, ) self.calculate_w_wq_diff = calculate_w_wq_diff diff --git a/gptqmodel/looper/loop_processor.py b/gptqmodel/looper/loop_processor.py index 7c2f536da..6a6c6fd3e 100644 --- a/gptqmodel/looper/loop_processor.py +++ b/gptqmodel/looper/loop_processor.py @@ -63,6 +63,7 @@ def __init__( fwd_after_process: bool = True, fwd_all_modules_in_single_pass: bool = False, subset_forward_early_stop: bool = False, + enable_activation_capture_flag: bool = False, ): # process level lock self.lock = threading.Lock() @@ -91,6 +92,8 @@ def __init__( self.fwd_all_modules_in_single_pass = fwd_all_modules_in_single_pass # default False # when True, stop the layer forward immediately after the final module in a subset fires self.subset_forward_early_stop = subset_forward_early_stop + # enable capture-only hooks (e.g. ':?') for processors that require activations + self.enable_activation_capture = enable_activation_capture_flag self.inputs_cache: InputCache = InputCache(None, None, None, None) self.tasks = {} diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index c9129086a..69f061b97 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -1129,10 +1129,14 @@ def _loop_impl(self, fail_safe: bool = False, **kwargs): region_timer.flush() is_awq_quantize = any(isinstance(proc, AWQProcessor) for proc in self.processors) + requires_activation_capture = any( + getattr(proc, "enable_activation_capture", False) for proc in self.processors + ) layer_modules = self.gptq_model.simple_layer_modules( model_config=self.gptq_model.model.config, quantize_config=self.gptq_model.quantize_config, is_awq_quantize=is_awq_quantize, + include_capture_only=requires_activation_capture, ) # true-sequential will replay the quantized activations after each subset has been quantized to be used for next subset quantization diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index b02d8f9d6..0fda18350 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -127,6 +127,8 @@ def apply_module_tree_override(module_tree, override): NOT_QUANTIZE_FLAG = ":!" +CAPTURE_ONLY_FLAG = ":?" +NON_QUANTIZE_FLAGS = (NOT_QUANTIZE_FLAG, CAPTURE_ONLY_FLAG) # Fix cpu memory leak. @@ -362,7 +364,11 @@ def get_num_experts(cls, model_config): @classmethod def filter_not_quantize_module(cls, layer_modules, quantize_config): layer_modules = [ - [name for name in block if NOT_QUANTIZE_FLAG not in name] + [ + name + for name in block + if all(flag not in name for flag in NON_QUANTIZE_FLAGS) + ] for block in layer_modules ] layer_modules = [block for block in layer_modules if block] # 去掉空 block @@ -384,8 +390,8 @@ def filter_not_quantize_module(cls, layer_modules, quantize_config): # List them in the order executed in model forward() code # Many models have same execution order of: attention (q_k_v) projection, attention (output) projection, mlp (n) projections @classmethod - def simple_layer_modules(cls, model_config, quantize_config, is_awq_quantize: bool = False): - layer_modules = cls.build_layer_modules(cls.module_tree) + def simple_layer_modules(cls, model_config, quantize_config, is_awq_quantize: bool = False, include_capture_only: bool = False): + layer_modules = cls.build_layer_modules(cls.module_tree, include_capture_only=include_capture_only) layer_modules = cls.build_moe_modules_if_need(model_config, layer_modules, is_awq_quantize) @@ -395,8 +401,8 @@ def simple_layer_modules(cls, model_config, quantize_config, is_awq_quantize: bo return layer_modules @classmethod - def full_layer_modules(cls, model_config=None, is_awq_quantize: bool = False): - full = cls.build_layer_modules(cls.module_tree) + def full_layer_modules(cls, model_config=None, is_awq_quantize: bool = False, include_capture_only: bool = False): + full = cls.build_layer_modules(cls.module_tree, include_capture_only=include_capture_only) full = cls.build_moe_modules_if_need(model_config, full, is_awq_quantize) # print(f"full layer_modules: {full}") return full @@ -960,19 +966,19 @@ def awq_get_modules_for_scaling(self, module, input_feat, module_kwargs): if self.model.config is not None and self.dynamic_expert_index is not None: num_experts = self.get_num_experts(self.model.config) - def strip_not_quantize_flag(module_name): - if NOT_QUANTIZE_FLAG in module_name: - return module_name[:module_name.find(NOT_QUANTIZE_FLAG)] - else: - return module_name + def strip_non_quantize_flags(module_name): + for flag in NON_QUANTIZE_FLAGS: + if flag in module_name: + module_name = module_name.replace(flag, "") + return module_name def _select_feature_name(names): """Return the first quantized child that has captured activations.""" for raw in names: - stripped = strip_not_quantize_flag(raw) + stripped = strip_non_quantize_flags(raw) if stripped in input_feat: return stripped - return strip_not_quantize_flag(names[0]) if names else None + return strip_non_quantize_flags(names[0]) if names else None def _try_update_last_module(candidate_name: str) -> bool: nonlocal last_module, last_module_name, last_module_root @@ -992,18 +998,22 @@ def _try_update_last_module(candidate_name: str) -> bool: last_module_root = candidate_name.split(".", 1)[0] return True - full_layer_modules = self.full_layer_modules(self.model.config, is_awq_quantize=True) + full_layer_modules = self.full_layer_modules( + self.model.config, + is_awq_quantize=True, + include_capture_only=True, + ) for i, block in enumerate(full_layer_modules): - not_quantized = all(NOT_QUANTIZE_FLAG in name for name in block) + not_quantized = all(any(flag in name for flag in NON_QUANTIZE_FLAGS) for name in block) if not_quantized: # If both the current block and the previous one are marked as not quantized, # skip remembering the current block. This ensures that when two consecutive # blocks are not quantized, only the first one is remembered as last_module. - if i > 0 and all(NOT_QUANTIZE_FLAG in name for name in full_layer_modules[i - 1]): + if i > 0 and all(any(flag in name for flag in NON_QUANTIZE_FLAGS) for name in full_layer_modules[i - 1]): continue # Remember the latest norm (use the last entry if multiple are present) - candidate_name = strip_not_quantize_flag(block[-1]) + candidate_name = strip_non_quantize_flags(block[-1]) _try_update_last_module(candidate_name) continue @@ -1034,7 +1044,7 @@ def _try_update_last_module(candidate_name: str) -> bool: subset = [] # preserve execution order while collecting quantizable modules skip = False for name in block: - if NOT_QUANTIZE_FLAG not in name: + if all(flag not in name for flag in NON_QUANTIZE_FLAGS): m, _ = get_module_by_name_prefix(module, name) # If the Model uses GQA (Grouped Query Attention), attention out will be skipped. # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 @@ -1060,7 +1070,7 @@ def _try_update_last_module(candidate_name: str) -> bool: continue # Match the activation bucket to the first quantized child in this block - feature_name = _select_feature_name(block) or strip_not_quantize_flag(block[0]) + feature_name = _select_feature_name(block) or strip_non_quantize_flags(block[0]) root_split = feature_name.split(".") module2inspect = None if len(root_split) >= 2: @@ -1091,7 +1101,7 @@ def _try_update_last_module(candidate_name: str) -> bool: nodes.append(n) # Update tracker to the LAST item of this block - candidate_name = strip_not_quantize_flag(block[-1]) + candidate_name = strip_non_quantize_flags(block[-1]) _try_update_last_module(candidate_name) import torch @@ -1305,12 +1315,13 @@ def shell_module_materialize( # else: # log.info(f"{self.__class__.__name__}: `MODEL switching to eval mode.") @classmethod - def build_layer_modules(cls, tree): + def build_layer_modules(cls, tree, include_capture_only: bool = False): """ tree format: [, , "#", { parent_module: ( "child[:!][:grp]", ... ), ... }] Rules: - ':!' means participates in inference but is NOT quantized; keep this marker in output. + - ':?' marks capture-only nodes; activations are recorded but the module is not quantized. - ':' means grouping; children with the same group id are emitted in the same block. - Both can appear together, e.g. 'module_name:!:2'. - Supports nested dict structures for MoE models with experts. @@ -1329,57 +1340,81 @@ def build_layer_modules(cls, tree): out_blocks = [] - def process_entries(parent, entries, parent_group_offset=0): + def _parse_token(token: str) -> tuple[str, List[str]]: + parts = token.split(":") + name = parts[0] + flags = [p for p in parts[1:] if p] + return name, flags + + def _group_from_flags(flags: List[str]) -> int: + for flag in flags: + if flag.isdigit(): + return int(flag) + return 0 + + def process_entries(parent_token: str, entries, parent_group_offset: int = 0): """Process entries recursively to handle nested dict structures for MoE""" - groups = defaultdict(list) + groups: defaultdict[int, List[tuple[str, bool, bool]]] = defaultdict(list) + + parent_name, parent_flags = _parse_token(parent_token) + parent_group = parent_group_offset + _group_from_flags(parent_flags) + parent_has_bang = "!" in parent_flags + parent_capture_only = "?" in parent_flags + + child_group_offset = parent_group_offset + add_parent = parent_has_bang or (parent_capture_only and include_capture_only) + if add_parent: + groups[parent_group].append((parent_name, parent_has_bang, parent_capture_only)) + child_group_offset = max(child_group_offset, parent_group + 1) # Handle tuple/list of strings (traditional format) if isinstance(entries, (tuple, list)): for ent in entries: - parts = ent.split(':') - child = parts[0] + child_name, child_flags = _parse_token(ent) - flags = parts[1:] - has_bang = ('!' in flags) + has_bang = "!" in child_flags + capture_only = "?" in child_flags # first numeric tag is the group id; default 0 - grp = next((int(p) for p in flags if p.isdigit()), 0) + grp = child_group_offset + _group_from_flags(child_flags) # Apply parent group offset to avoid conflicts between different nesting levels - grp += parent_group_offset - # Store the full path including parent for later use - # Special case: if parent ends with the same name as child, don't duplicate - if parent.endswith(f".{child}"): - full_path = parent + if parent_name.endswith(f".{child_name}") or parent_name == child_name: + full_path = parent_name + elif parent_name: + full_path = f"{parent_name}.{child_name}" else: - full_path = f"{parent}.{child}" if parent != child else child - groups[grp].append((full_path, has_bang)) + full_path = child_name + + if capture_only and not include_capture_only: + continue + groups[grp].append((full_path, has_bang, capture_only)) - # Handle nested dict structure (MoE format) elif isinstance(entries, dict): # Calculate max group number used at current level to avoid conflicts max_current_group = 0 for sub_parent, sub_entries in entries.items(): if isinstance(sub_entries, (tuple, list)): for ent in sub_entries: - parts = ent.split(':') - flags = parts[1:] - grp = next((int(p) for p in flags if p.isdigit()), 0) - max_current_group = max(max_current_group, grp) + _, ent_flags = _parse_token(ent) + max_current_group = max(max_current_group, _group_from_flags(ent_flags)) # Process nested entries with appropriate group offset - current_offset = parent_group_offset + current_offset = child_group_offset for sub_parent, sub_entries in entries.items(): if sub_parent == "#": # Special case: "#" means expert index placeholder # Create a template path that will be expanded later by simple_layer_modules - template_parent = f"{parent}.{EXPERT_INDEX_PLACEHOLDER}" + template_parent = ( + f"{parent_name}.{EXPERT_INDEX_PLACEHOLDER}" + if parent_name else EXPERT_INDEX_PLACEHOLDER + ) # Use a higher offset for expert modules to avoid conflicts with parent level expert_offset = current_offset + max_current_group + 100 # Large offset to avoid conflicts # Handle special case where sub_entries is ("#",) or "#" - this means use the parent path directly if (isinstance(sub_entries, (tuple, list)) and len(sub_entries) == 1 and sub_entries[0] == "#") or sub_entries == "#": # For ("#",) or "#" format, use the template_parent directly with default group 0 - groups[expert_offset].append((template_parent, False)) + groups[expert_offset].append((template_parent, False, False)) else: sub_groups = process_entries(template_parent, sub_entries, expert_offset) for grp, items in sub_groups.items(): @@ -1388,9 +1423,12 @@ def process_entries(parent, entries, parent_group_offset=0): # Nested structure: process recursively with full path # Special case: empty string key means use parent path directly if sub_parent == "": - full_sub_parent = parent + full_sub_parent = parent_name else: - full_sub_parent = f"{parent}.{sub_parent}" + full_sub_parent = ( + f"{parent_name}.{sub_parent}" + if parent_name else sub_parent + ) sub_groups = process_entries(full_sub_parent, sub_entries, current_offset) for grp, items in sub_groups.items(): groups[grp].extend(items) @@ -1412,11 +1450,14 @@ def process_entries(parent, entries, parent_group_offset=0): # continue block = [] - for full_path, has_bang in items: + for full_path, has_bang, capture_only in items: # The full path is already constructed in process_entries + name = full_path if has_bang: - full_path += NOT_QUANTIZE_FLAG - block.append(full_path) + name += NOT_QUANTIZE_FLAG + if capture_only and include_capture_only: + name += CAPTURE_ONLY_FLAG + block.append(name) out_blocks.append(block) @@ -1458,8 +1499,8 @@ def get_base_modules(cls, model): def generate_layers_modules_tree_simple(self, node): """ Recursively walk a nested list/dict structure and: - 1. Drop dict entries where *all* values are ':!' flagged. - 2. Remove ':!' and ':' markers from strings. + 1. Drop dict entries where *all* values are ':!' or ':?' flagged. + 2. Remove ':!' / ':?' and ':' markers from strings. """ # If it's a list, recurse into each element @@ -1473,7 +1514,7 @@ def generate_layers_modules_tree_simple(self, node): # Expand tuple-of-strings blocks (special handling) if isinstance(v, (tuple, list)) and all(isinstance(x, str) for x in v): # Rule 1: check if ALL entries are :! - if all(any(p == "!" for p in x.split(":")[1:]) for x in v): + if all(any(p in {"!", "?"} for p in x.split(":")[1:]) for x in v): continue # skip this parent entirely # Rule 2: strip :! and :digit markers diff --git a/gptqmodel/models/definitions/qwen3_moe.py b/gptqmodel/models/definitions/qwen3_moe.py index 6e07fab60..87f6610ff 100644 --- a/gptqmodel/models/definitions/qwen3_moe.py +++ b/gptqmodel/models/definitions/qwen3_moe.py @@ -27,7 +27,7 @@ class Qwen3MoeQModel(BaseQModel): "input_layernorm": ("input_layernorm:!",), "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), "post_attention_layernorm": ("post_attention_layernorm:!",), - "mlp": { + "mlp:?": { "gate": ("gate:!",), # <-- 0.5MB per layer. Not worth quantizing "experts": { "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), diff --git a/tests/test_subset_parsing.py b/tests/test_subset_parsing.py new file mode 100644 index 000000000..614547e62 --- /dev/null +++ b/tests/test_subset_parsing.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium +import os +from types import SimpleNamespace +from typing import Callable, Dict, List, Optional + +import torch +from transformers import Qwen3MoeConfig, Qwen3MoeForCausalLM +from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock + +from gptqmodel.looper.awq_processor import AWQProcessor +from gptqmodel.looper.loop_processor import LoopProcessor +from gptqmodel.looper.module_looper import ModuleLooper +from gptqmodel.looper.stage_subset import run_subset_stage +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.quantization import FORMAT, METHOD +from gptqmodel.quantization.config import QuantizeConfig, VRAMStrategy +from gptqmodel.nn_modules.hooked_linear import replace_module_with_hooked_legacy +from gptqmodel.utils.model import find_modules, get_module_by_name_prefix +from gptqmodel.models.definitions.qwen3_moe import Qwen3MoeQModel + + +# honour the request to bind the test harness to GPU index 5 when CUDA is available +os.environ.setdefault("CUDA_VISIBLE_DEVICES", "5") + + +def _prepare_dataset_func(**kwargs): + return kwargs["calibration_dataset"] + + +def _make_quant_config(device: torch.device | str = "cpu") -> QuantizeConfig: + return QuantizeConfig( + bits=4, + group_size=128, + quant_method=METHOD.AWQ, + format=FORMAT.GEMM, + device=device, + vram_strategy=VRAMStrategy.EXCLUSIVE, + ) + + +def test_mlp_capture_flag_propagates_to_layer_modules(): + cfg = Qwen3MoeConfig( + hidden_size=16, + intermediate_size=32, + num_attention_heads=2, + num_hidden_layers=1, + num_key_value_heads=2, + per_token_num_experts=1, + num_experts=2, + ) + model = Qwen3MoeForCausalLM(cfg) + + model_config = cfg + quant_cfg = SimpleNamespace(dynamic=None) + + tree = Qwen3MoeQModel.build_layer_modules( + Qwen3MoeQModel.module_tree, + include_capture_only=True, + ) + assert any("mlp:?" in group for group in tree) + + baseline_tree = Qwen3MoeQModel.build_layer_modules(Qwen3MoeQModel.module_tree) + assert all(":?" not in name for block in baseline_tree for name in block) + + full = Qwen3MoeQModel.full_layer_modules( + model_config=model_config, + is_awq_quantize=True, + include_capture_only=True, + ) + capture_blocks = [block for block in full if any(":?" in name for name in block)] + assert capture_blocks and capture_blocks[0] == ["mlp:?"] + + simple = Qwen3MoeQModel.simple_layer_modules( + model_config=model_config, + quantize_config=quant_cfg, + is_awq_quantize=True, + ) + assert all(":?" not in name for block in simple for name in block) + + layer = model.model.layers[0] + mlp_module, _ = get_module_by_name_prefix(layer, "mlp") + assert isinstance(mlp_module, Qwen3MoeSparseMoeBlock) + + +def test_awq_processor_enables_subset_early_stop(): + calibration = [{"input_ids": torch.tensor([1, 2, 3])}] + qcfg = _make_quant_config() + dummy_gptq_model = SimpleNamespace() + dummy_model = torch.nn.Linear(3, 3) + + processor = AWQProcessor( + tokenizer=None, + qcfg=qcfg, + calibration=calibration, + prepare_dataset_func=_prepare_dataset_func, + calibration_concat_size=None, + calibration_sort=None, + calibration_concat_separator=None, + batch_size=1, + gptq_model=dummy_gptq_model, + model=dummy_model, + ) + + assert processor.subset_forward_early_stop is True + + +def test_module_looper_subset_callback_invoked(): + quant_cfg = SimpleNamespace( + device=torch.device("cpu"), + vram_strategy=VRAMStrategy.EXCLUSIVE, + true_sequential=True, + lm_head=False, + ) + dummy_model = SimpleNamespace( + support_batch_quantize=False, + quantize_config=quant_cfg, + layer_callback=None, + subset_callback=None, + supported_vram_strategies=[VRAMStrategy.EXCLUSIVE], + ) + + looper = ModuleLooper(model=dummy_model, processors=[]) + + events: List[Dict[str, object]] = [] + looper.register_subset_callback(lambda **payload: events.append(payload)) + + looper._subset_event_dispatch( + stage="forward_start", + layer_idx=0, + subset_index=0, + subset_total=1, + module_names=["self_attn.q_proj"], + processor="stub", + ) + + assert events and events[0]["module_names"] == ["self_attn.q_proj"] + + +class _DummyProgress: + def title(self, *_args, **_kwargs): + return self + + def subtitle(self, *_args, **_kwargs): + return self + + def draw(self): + return self + + +class _SubsetRecorder: + def __init__(self): + self.events: List[Dict[str, object]] = [] + + def __call__( + self, + *, + stage: str, + layer_idx: int, + subset_index: int, + subset_total: int, + module_names: List[str], + processor: str, + ): + self.events.append( + dict( + stage=stage, + layer_idx=layer_idx, + subset_index=subset_index, + subset_total=subset_total, + module_names=module_names, + processor=processor, + ) + ) + + +class _StubAWQProcessor(LoopProcessor): + def __init__(self, qcfg: QuantizeConfig): + calibration = [{"input_ids": torch.tensor([1, 2, 3])}] + super().__init__( + tokenizer=None, + qcfg=qcfg, + calibration=calibration, + prepare_dataset_func=_prepare_dataset_func, + batch_size=1, + require_fwd=True, + fwd_after_process=False, + subset_forward_early_stop=True, + ) + self.hook_calls: List[str] = [] + self.process_calls: List[str] = [] + + @classmethod + def name(cls) -> str: + return "stub-awq" + + def preprocess(self, module: NamedModule, fail_safe: Optional[bool] = None): + self.tasks[module.name] = {"inputs": []} + + def pre_process_fwd_hook(self, name: str) -> Callable[[torch.nn.Module, tuple, torch.Tensor], None]: + def _hook(_module, _inp, _out): + self.hook_calls.append(name) + return _hook + + def process( + self, + module: NamedModule, + device: torch.device = None, + subset: Optional[Dict[str, NamedModule]] = None, + previous_subset: Optional[Dict[str, NamedModule]] = None, + subset_index: Optional[int] = None, + subset_total: Optional[int] = None, + ): + self.process_calls.append(module.name) + + def verify_calibration_dataset(self, processor_index: int) -> bool: + return True + + +class _MiniSelfAttn(torch.nn.Module): + def __init__(self): + super().__init__() + self.q_proj = torch.nn.Linear(4, 4) + self.k_proj = torch.nn.Linear(4, 4) + self.v_proj = torch.nn.Linear(4, 4) + self.o_proj = torch.nn.Linear(4, 4) + + +class _MiniLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.self_attn = _MiniSelfAttn() + self.after_o_proj_called = False + + def forward(self, hidden_states, **kwargs): + x = self.self_attn.q_proj(hidden_states) + x = self.self_attn.k_proj(x) + x = self.self_attn.v_proj(x) + x = self.self_attn.o_proj(x) + self.after_o_proj_called = True + return (x,) + + +def test_stage_subset_early_stop_and_callbacks(): + quant_cfg = _make_quant_config() + mini_layer = _MiniLayer() + replace_module_with_hooked_legacy(mini_layer) + + dummy_quant_cfg = SimpleNamespace( + device=torch.device("cpu"), + vram_strategy=VRAMStrategy.EXCLUSIVE, + true_sequential=True, + lm_head=False, + ) + dummy_model = SimpleNamespace( + support_batch_quantize=False, + quantize_config=dummy_quant_cfg, + layer_callback=None, + subset_callback=None, + supported_vram_strategies=[VRAMStrategy.EXCLUSIVE, VRAMStrategy.BALANCED], + layer_modules_strict=True, + lm_head="lm_head", + ) + + processor = _StubAWQProcessor(quant_cfg) + looper = ModuleLooper(model=dummy_model, processors=[processor]) + + recorder = _SubsetRecorder() + looper.register_subset_callback(recorder) + + layer_inputs = [[torch.randn(2, 4)]] + layer_input_kwargs = [{}] + position_ids: List[Optional[torch.Tensor]] = [None] + attention_masks: List[Optional[torch.Tensor]] = [None] + shared_kv_cache_dict: Dict[int, torch.Tensor] = {} + + full_modules = find_modules(mini_layer) + subset_names = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] + + run_subset_stage( + looper=looper, + processor=processor, + module=mini_layer, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=torch.device("cpu"), + is_lm_head_module=False, + layer_descriptor="layers.0", + layer_title="subset-check", + layer_index=0, + layers_prefix="layers", + subset_names=subset_names, + subset_index=0, + subset_total=2, + full=full_modules, + fail_safe=False, + shared_kv_cache_dict=shared_kv_cache_dict, + pb=_DummyProgress(), + log=None, + region_timer=None, + previous_processed_subset=None, + subset_event_cb=looper._subset_event_dispatch, + ) + + assert mini_layer.after_o_proj_called is False + assert recorder.events and [evt["stage"] for evt in recorder.events] == [ + "forward_start", + "forward_end", + "quant_start", + "quant_complete", + ] + assert recorder.events[0]["module_names"] == subset_names + assert processor.hook_calls and processor.hook_calls[-1] == subset_names[-1] + assert set(processor.process_calls) == set(subset_names) + assert len(processor.process_calls) == len(subset_names)