diff --git a/gptqmodel/__init__.py b/gptqmodel/__init__.py index f3f2c0504..32253c8db 100644 --- a/gptqmodel/__init__.py +++ b/gptqmodel/__init__.py @@ -48,6 +48,7 @@ from .utils.exllama import exllama_set_max_input_length from .version import __version__ + setup_logger().info("\n%s", ASCII_LOGO) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 22c2b6cb6..85ccea470 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -55,6 +55,8 @@ from .. import DEVICE_THREAD_POOL from .awq_processor import AWQProcessor from .qqq_processor import QQQProcessor +from .stage_inputs_capture import StageInputsCapture +from .stage_layer import run_layer_stage log = setup_logger() @@ -962,195 +964,13 @@ def hook(module, inputs, output): return hook def cache_inputs(self, layers, calibration_data, use_cache): - layer_inputs = [] - attention_masks = [] - position_ids = [] - layer_input_kwargs = [] - - timer = getattr(self.gptq_model, "quant_region_timer", None) - layer_label = None - if layers: - first_layer = layers[0] - layer_label = getattr(first_layer, "full_name", None) - if layer_label is None: - layer_label = getattr(getattr(first_layer, "__class__", None), "__name__", None) - if layer_label is None: - layer_label = type(first_layer).__name__ - capture_source = f"cache_inputs:{layer_label}" - else: - capture_source = "cache_inputs" - start_time = time.perf_counter() if timer else None - - try: - calibration_batches = len(calibration_data) - except (TypeError, AttributeError): - calibration_batches = None - - if calibration_batches is None: - log.info("ModuleLooper: capturing layer inputs (batch count unknown)") - else: - log.info( - f"ModuleLooper: capturing layer inputs from {calibration_batches} calibration batches" - ) - - cur_layer_device = get_device(layers[0]) - data_device = cur_layer_device - - cache_forward_pb: ProgressBar = None - processed_rows = 0 - cache_total_batches = None - if calibration_batches is not None and calibration_batches > 0: - cache_total_batches = int(calibration_batches) - cache_forward_pb = ( - log.pb(range(max(cache_total_batches, 1))) - .manual() - .set(show_left_steps=False) - ) - cache_title = ( - f"Forward cached inputs (Pre {layer_label})" - if layer_label - else "Forward cached inputs" - ) - cache_forward_pb.title(cache_title).subtitle( - f"Batch 0/{cache_total_batches}" - ).draw() - - # TODO HookLinear add register_forward_pre_hook() - def store_input_hook(module, args, kwargs): - # Positional arguments. - layer_input = [] - if kwargs.get("hidden_states") is not None: - layer_input.append(move_to(kwargs["hidden_states"], device=data_device)) - else: - # If hidden_states is not in kwargs, get it from the first positional argument - # If error occurs here, check the model's modeling code - layer_input.append(move_to(args[0], device=data_device)) - - layer_inputs.append(layer_input) - - # Keyword arguments. - # Always capture attention_mask so downstream masking can drop padded tokens - if kwargs.get("attention_mask") is not None: - attention_masks.append(kwargs["attention_mask"].to(device=data_device)) - else: - attention_masks.append(None) - - pos_ids = kwargs.get("position_ids", None) - if pos_ids is not None: - position_ids.append(move_to(pos_ids, device=data_device)) - one_kwargs = {} - for (k, v) in kwargs.items(): # make sure other arguments also be captured - if k not in ["hidden_states", "attention_mask", "position_ids"]: - one_kwargs[k] = nested_move_to(v, device=data_device) - layer_input_kwargs.append(one_kwargs) - - raise STOP_FORWARD_EXCEPTION - - # move layer to target device - if cur_layer_device == META: - layers[0] = self.gptq_model.shell_module_materialize( - target_submodule=layers[0], - device=self.gptq_model.quantize_config.device, - ) - cur_layer_device = self.gptq_model.quantize_config.device - else: - layers[0] = layers[0].to(self.gptq_model.quantize_config.device) - - ori_outside_layer_module_devices = {} - for module_name in self.gptq_model.get_base_modules(self.gptq_model.model): - module, _ = get_module_by_name_prefix(self.gptq_model.model, [module_name]) - - if module is None: - continue - - m_device = get_device(module) - ori_outside_layer_module_devices[module_name] = CPU if m_device == META else m_device - if module is not None: - self.gptq_model.shell_module_materialize( - target_submodule=module, - device=cur_layer_device, - ) - - handle = layers[0].register_forward_pre_hook(store_input_hook, with_kwargs=True) - - # TODO FIX ME.. remove hard coded Ovis code - is_ovis = self.gptq_model.__class__.__name__ == "OvisGPTQ" - - # LifeCycle: start pre-first layer embedding hook - self.gptq_model.pre_quantize_generate_hook_start() - - try: - for batch_index, example in enumerate(calibration_data, start=1): - for k, v in example.items(): - if self.gptq_model.ATTENTION_MASKS_REQUIRED_FOR_INPUT: - data_device = self.gptq_model.quantize_config.device - else: - data_device = self.gptq_model.quantize_config.device if k == "pixel_values" else cur_layer_device - if isinstance(v, list): - for index in range(len(v)): - if len(v[index].shape) == 1: - v[index] = v[index].unsqueeze(0) - v[index] = move_to( - v[index].to(self.gptq_model.model.visual_tokenizer.dtype) if is_ovis else v[index], - device=data_device, - ) - else: - if len(v.shape) == 1: - v = v.unsqueeze(0) - example[k] = move_to(v, device=data_device) - try: - if self.gptq_model.ATTENTION_MASKS_DTYPE is torch.long: - example["attention_mask"] = example["attention_mask"].long() - - # Ensure initial caches (like RoPE) are created on the quant device - with ctx( - DEVICE_THREAD_POOL.read_lock(self.gptq_model.quantize_config.device), - device_ctx(self.gptq_model.quantize_config.device), - ): - if self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS: - self.gptq_model.model.generate(**example, **self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS) - else: - self.gptq_model.model(**example, use_cache=use_cache) - except StopForward: - pass - finally: - processed_batches = batch_index - if cache_forward_pb is not None: - rows_for_batch = 0 - if batch_index <= len(layer_inputs): - rows_for_batch = self._batch_row_count(layer_inputs[batch_index - 1]) - if rows_for_batch <= 0: - rows_for_batch = 1 - processed_rows += rows_for_batch - cache_forward_pb.current_iter_step = processed_batches - subtitle = f"Batch {processed_batches}/{cache_total_batches}" - if processed_rows > 0: - subtitle += f" rows {processed_rows}" - cache_forward_pb.subtitle(subtitle).draw() - finally: - if cache_forward_pb is not None: - cache_forward_pb.close() - - # LifeCycle: pre-first layer embedding hook - self.gptq_model.pre_quantize_generate_hook_end() - handle.remove() - - result = InputCache( - layer_inputs=layer_inputs, - layer_input_kwargs=layer_input_kwargs, - position_ids=position_ids, - attention_masks=attention_masks, + capture_stage = StageInputsCapture(self, logger=log) + return capture_stage.cache_inputs( + layers=layers, + calibration_data=calibration_data, + use_cache=use_cache, ) - if timer is not None and start_time is not None: - timer.record( - "capture_inputs", - time.perf_counter() - start_time, - source=capture_source, - ) - - return result - def loop(self, fail_safe: bool = False, **kwargs): with tf32_high_precision_guard(): return self._loop_impl(fail_safe=fail_safe, **kwargs) @@ -1240,727 +1060,19 @@ def _loop_impl(self, fail_safe: bool = False, **kwargs): parent = getattr(parent, part) setattr(parent, module_path[-1], hooked_lm_head) - for layer_index in pb: - if self._check_loop_stop(): - break - is_lm_head_module = layer_index >= layer_count - - if is_lm_head_module: - layer_title = "Quantizing lm_head" - module = get_module(self.gptq_model.model, key=self.gptq_model.lm_head) - else: - layer_title = f"Quantizing layer {layer_index} of {layer_count - 1}" - module = layers[layer_index] - - pb.title(layer_title).subtitle("").draw() - - if module.__class__.__name__.lower() == "MllamaCrossAttentionDecoderLayer".lower(): - # TODO FIXME: currently we not support quantizing cross attention layer (pixel_values) - continue - - module = self.gptq_model.pre_quantize(module) - - if is_lm_head_module: - layer_descriptor = self.gptq_model.lm_head - elif layers_prefix: - layer_descriptor = f"{layers_prefix}.{layer_index}" - else: - layer_descriptor = str(layer_index) - - cur_layer_device = get_device(module) - full = find_modules(module, name=self.gptq_model.lm_head if is_lm_head_module else "") - - for p_index, processor in enumerate(self.processors): - processor.log_call_count = 0 # reset - processor.collect_memory_info(layer_index) - - modules = [[self.gptq_model.lm_head]] if is_lm_head_module else layer_modules - - # for NativeProcessor we process one time forward on all grouped module subsets - if processor.fwd_all_modules_in_single_pass: - # merge all subsets into one - modules = [sum(modules, [])] - - # AWQ does per-layer itself; skip here - if isinstance(processor, AWQProcessor): - named_childs = dict() - for index, names in enumerate(modules): - named_modules = self.crate_named_modules(full=full, - is_lm_head_module=is_lm_head_module, - layer_index=layer_index, layers_prefix=layers_prefix, - names=names, - processor=processor, - fail_safe=fail_safe) - named_childs.update(named_modules) - - lock_ctx = nullcontext() - device_for_ctx = cur_layer_device if getattr(cur_layer_device, 'type', None) != 'meta' else None - if device_for_ctx is not None: - lock_ctx = DEVICE_THREAD_POOL.read_lock(cur_layer_device) - with ctx(lock_ctx, device_ctx(device_for_ctx)): - processor.layer_quantize(module, cur_layer_device, named_childs) - if p_index == len(self.processors) - 1: - self._emit_layer_complete( - layer_idx=layer_index, - submodule_finalized=False, - raise_in_place=True, - ) - self._emit_layer_complete( - layer_idx=layer_index, - submodule_finalized=True, - raise_in_place=True, - ) - continue - - layer_inputs = processor.inputs_cache.layer_inputs - if is_lm_head_module: - layer_inputs = self.gptq_model.lm_head_pre_quantize_generate_hook(layer_inputs) - layer_input_kwargs = processor.inputs_cache.layer_input_kwargs - position_ids = processor.inputs_cache.position_ids - attention_masks = processor.inputs_cache.attention_masks - - processed_subset = {} - - for index, names in enumerate(modules): - subset = self.crate_named_modules(full=full, is_lm_head_module=is_lm_head_module, - layer_index=layer_index, layers_prefix=layers_prefix, - names=names, - processor=processor, - fail_safe=fail_safe) - - if len(subset) == 0: - continue - - moe_group_keys_all: List[str] = [] - forward_device_map: Dict[str, torch.device] = {} - subset_forward_serial = False - - attention_subset = bool(subset) and all( - self._is_attention_module_name(name) for name in subset - ) - - moe_group_key_by_name: Dict[str, Optional[str]] = { - name: self._extract_moe_group_key(name) - for name in subset - } - moe_module_names = [ - name for name, group_key in moe_group_key_by_name.items() - if group_key is not None - ] - moe_modules_set = set(moe_module_names) - is_moe_subset = len(moe_module_names) >= self._moe_subset_threshold - - if is_moe_subset: - expert_groups: Dict[str, List[str]] = {} - combined_names: List[str] = list(subset.keys()) - if full is not None: - for candidate in full.keys(): - if candidate not in subset: - combined_names.append(candidate) - - for sub_name in combined_names: - group_key = self._extract_moe_group_key(sub_name) - if group_key is None: - continue - expert_groups.setdefault(group_key, []).append(sub_name) - - moe_group_keys_all = list(expert_groups.keys()) - - for name, named_module in subset.items(): - setattr(named_module, "moe_enabled", name in moe_modules_set) - - if self._vram_strategy == VRAMStrategy.BALANCED: - devices = [ - dev for dev in self._quant_devices - if dev is not None and getattr(dev, "type", None) != "cpu" - ] - if len(devices) > 1 and expert_groups: - assignable_group_keys: List[str] = [] - for group_key, module_names in expert_groups.items(): - suffixes = {name.rsplit(".", 1)[-1] for name in module_names} - if {"gate_proj", "up_proj"}.issubset(suffixes): - assignable_group_keys.append(group_key) - - if assignable_group_keys: - groups_per_device = max( - math.ceil(len(assignable_group_keys) / len(devices)), 1 - ) - for group_index, group_key in enumerate(assignable_group_keys): - device_idx = min(group_index // groups_per_device, len(devices) - 1) - target_device = devices[device_idx] - for module_name in expert_groups[group_key]: - forward_device_map[module_name] = target_device - - subset_forward_serial = self._vram_strategy == VRAMStrategy.BALANCED - if subset_forward_serial: - active_group_count = len(moe_group_keys_all) - if active_group_count == 0: - subset_forward_serial = False - elif attention_subset and active_group_count <= self._moe_subset_threshold: - subset_forward_serial = False - else: - for named_module in subset.values(): - setattr(named_module, "moe_enabled", False) - - handle = [] - subset_total = len(modules) - batch_count = self._resolve_batch_total( - getattr(processor, "num_batches", None), - layer_inputs, - ) - forward_row_counts = list(self._collect_row_counts(layer_inputs)) - if not forward_row_counts and batch_count > 0: - forward_row_counts = [1] * batch_count - if len(forward_row_counts) > batch_count: - forward_row_counts = forward_row_counts[:batch_count] - forward_total_rows = sum(forward_row_counts) if forward_row_counts else batch_count - forward_total_rows = max(forward_total_rows, 1) - if len(forward_row_counts) < batch_count: - forward_row_counts.extend([1] * (batch_count - len(forward_row_counts))) - - subset_size = len(subset) - for idx, (name, m) in enumerate(subset.items()): - is_last = (idx == subset_size - 1) - hook_source = getattr(m, "full_name", None) - if hook_source is None: - hook_source = getattr(m, "name", name) - if hook_source is None: - hook_source = str(name) - - # Wrap the processor hook with masking - if hasattr(subset[name], 'forward_hook'): - original_hook = processor.pre_process_fwd_hook(name) - subset[name].forward_hook = self._masked_hook_wrapper(processor, original_hook, hook_source) - if is_last and processor.fwd_after_process: - subset[name].forward_hook_last = True - else: - # Older registration path - original_hook = processor.pre_process_fwd_hook(name) - handle.append(subset[name].register_forward_hook( - self._masked_hook_wrapper(processor, original_hook, hook_source) - )) - - # ---- Start Pre-Quantized Forward ---- - fwd_start = time.perf_counter() - forward_source = f"{layer_descriptor}:subset{index + 1}/{subset_total}" - - need_outputs = not processor.fwd_after_process - reuse_kv = bool(getattr(module, "reuse_kv", False)) - forward_msg = ( - "Forward: " - f"Layer=`{layer_descriptor}`, subset={index + 1}/{subset_total}, " - f"batches={batch_count}" - ) - forward_pb = ( - log.pb(range(forward_total_rows)) - .manual() - .set(show_left_steps=False) - ) - forward_pb.title(forward_msg).subtitle( - f"Row 0/{forward_total_rows}" - ).draw() - # Drain any background work so the forward spike does not race pooled tasks. - # DEVICE_THREAD_POOL.wait() - # try to cleanup recent objects before forward - #timed_gc_collect(1) - - previous_forward_devices: Dict[str, torch.device] = {} - preserve_devices = bool(forward_device_map) - if forward_device_map: - previous_forward_devices = self._apply_forward_device_overrides( - subset, - forward_device_map, - fallback_modules=full, - ) - - # if log.isEnabledFor(logging.DEBUG): - # device_snapshot = [] - # for name, named_module in subset.items(): - # target_device = getattr(named_module, "target_device", None) - # if target_device is None: - # try: - # target_device = get_device(named_module.module) - # except Exception: - # target_device = None - # target_device_str = str(target_device) if target_device is not None else "unknown" - # device_snapshot.append(f"{name}:{target_device_str}") - # log.debug( - # "ModuleLooper: Forward subset device snapshot (layer=`%s`, subset=%d/%d, serial=%s) %s", - # layer_descriptor, - # index + 1, - # subset_total, - # subset_forward_serial, - # ", ".join(device_snapshot), - # ) - - try: - forward_outputs = self._run_forward_batches( - module=module, - processor=processor, - layer_inputs=layer_inputs, - layer_input_kwargs=layer_input_kwargs, - position_ids=position_ids, - attention_masks=attention_masks, - cur_layer_device=cur_layer_device, - is_lm_head_module=is_lm_head_module, - shared_kv_cache_dict=shared_kv_cache_dict, - layer_index=layer_index, - need_outputs=need_outputs, - reuse_kv=reuse_kv, - progress_pb=forward_pb, - progress_title=forward_msg, - progress_stage="Forward", - progress_rows_per_batch=forward_row_counts, - progress_total_rows=forward_total_rows, - force_serial=subset_forward_serial, - preserve_module_devices=preserve_devices, - ) - finally: - if forward_device_map: - self._restore_forward_device_overrides( - subset, - previous_forward_devices, - fallback_modules=full, - ) - if forward_pb is not None: - forward_pb.close() - if need_outputs: - processor.receive_layer_inputs(forward_outputs) - layer_inputs = processor.inputs_cache.layer_inputs - del forward_outputs - - fwd_time = time.perf_counter() - fwd_start - processor.set_fwd_time(fwd_time) - if region_timer is not None: - region_timer.record( - "pre_quant_forward", - fwd_time, - source=forward_source, - ) - - pb.title(layer_title).subtitle("").draw() - - for h in handle: - h.remove() - - for name in subset: - if hasattr(subset[name], 'forward_hook'): - subset[name].forward_hook = None - subset[name].forward_hook_last = False - - # MoE coverage check for GPTQ - moe_skip_modules = [] - if isinstance(processor, GPTQProcessor): - for name in subset: - if processor.tasks[name].fwd_counter == 0: - log.error(f"`{name}` was not invoked, if it is a MoE module, it may lack sufficient calibration data routed to it.") - moe_skip_modules.append(name) - - if not fail_safe: - for name in moe_skip_modules: - subset.pop(name) - task_map = getattr(processor, "tasks", None) - if task_map is not None: - task_map.pop(name, None) - - # ---- Start Process Hook (via DeviceThreadPool) ---- - quant_target_devices: Dict[str, torch.device] = {} - for name, named_module in subset.items(): - task_map = getattr(processor, "tasks", None) - has_task = bool(task_map and task_map.get(name) is not None) - - if has_task: - target_device = self._prepare_named_module_for_quantization( - processor=processor, - named_module=named_module, - fallback_device=cur_layer_device, - ) - else: - target_device = get_device(named_module.module) - setattr(named_module, "target_device", target_device) - setattr(named_module.module, "target_device", target_device) - - quant_target_devices[name] = target_device - - futures = [] - - @torch.inference_mode() - def _process_on_worker( - proc: LoopProcessor, - nm: NamedModule, - expected_device: torch.device, - ): - module_label = getattr(nm, "full_name", getattr(nm, "name", repr(nm))) - module_ref = nm.module if isinstance(nm, NamedModule) else nm - module_weight = getattr(module_ref, "weight", None) - if module_weight is not None and expected_device is not None: - target_device = expected_device if isinstance(expected_device, torch.device) else torch.device(expected_device) - actual_device = get_device(module_weight) - assert actual_device == target_device, ( - f"Device mismatch for '{module_label}' process task: " - f"module weight on {actual_device}, thread target {target_device}." - ) - - # Run processor.process for this NamedModule - timer = getattr(self.gptq_model, "quant_region_timer", None) - start = time.perf_counter() if timer else None - try: - proc.process(module=nm) - finally: - if timer is not None and start is not None: - timer.record( - "process_quant", - time.perf_counter() - start, - source=module_label, - ) - return nm.name, nm - - for name, m in subset.items(): - tgt_dev = quant_target_devices.get(name, cur_layer_device) - futures.append( - DEVICE_THREAD_POOL.submit(tgt_dev, _process_on_worker, processor, m, tgt_dev) - ) - - for fut in futures: - name, m = fut.result() - processed_subset[name] = m - torch_sync() - # ---- End Process Hook ---- - - is_last_module = layer_index == len(pb) - 1 - layer_outputs: List[List[torch.Tensor]] = [] - # second forward after process() - if not is_last_module and processor.fwd_after_process: - replay_batch_count = self._resolve_batch_total( - getattr(processor, "num_batches", None), - layer_inputs, - ) - replay_row_counts = list(self._collect_row_counts(layer_inputs)) - if not replay_row_counts and replay_batch_count > 0: - replay_row_counts = [1] * replay_batch_count - if len(replay_row_counts) > replay_batch_count: - replay_row_counts = replay_row_counts[:replay_batch_count] - replay_total_rows = sum(replay_row_counts) if replay_row_counts else replay_batch_count - replay_total_rows = max(replay_total_rows, 1) - if len(replay_row_counts) < replay_batch_count: - replay_row_counts.extend([1] * (replay_batch_count - len(replay_row_counts))) - replay_msg = ( - "Forward replay " - f"(layer=`{layer_descriptor}`, batches={replay_batch_count}, rows={replay_total_rows})" - ) - replay_pb = ( - log.pb(range(replay_total_rows)) - .manual() - .set(show_left_steps=False) - ) - replay_pb.title(replay_msg).subtitle( - f"Forward replay Row 0/{replay_total_rows}" - ).draw() - # Forward replay shares the same VRAM spike; block until the pool drains first. - # DEVICE_THREAD_POOL.wait() - # try to cleanup recent objects before forward - #timed_gc_collect(1) - - replay_start = time.perf_counter() - replay_source = f"{layer_descriptor}:subset{index + 1}/{subset_total}" - - replay_prev_devices: Dict[str, torch.device] = {} - if forward_device_map: - replay_prev_devices = self._apply_forward_device_overrides( - subset, - forward_device_map, - fallback_modules=full, - ) - - # if log.isEnabledFor(logging.DEBUG): - # replay_snapshot = [] - # for name, named_module in subset.items(): - # target_device = getattr(named_module, "target_device", None) - # if target_device is None: - # try: - # target_device = get_device(named_module.module) - # except Exception: - # target_device = None - # target_device_str = str(target_device) if target_device is not None else "unknown" - # replay_snapshot.append(f"{name}:{target_device_str}") - # log.debug( - # "ModuleLooper: Forward replay device snapshot (layer=`%s`, subset=%d/%d, serial=%s) %s", - # layer_descriptor, - # index + 1, - # subset_total, - # subset_forward_serial, - # ", ".join(replay_snapshot), - # ) - - try: - layer_outputs = self._run_forward_batches( - module=module, - processor=processor, - layer_inputs=layer_inputs, - layer_input_kwargs=layer_input_kwargs, - position_ids=position_ids, - attention_masks=attention_masks, - cur_layer_device=cur_layer_device, - is_lm_head_module=is_lm_head_module, - shared_kv_cache_dict=shared_kv_cache_dict, - layer_index=layer_index, - need_outputs=True, - reuse_kv=False, - progress_pb=replay_pb, - progress_title=replay_msg, - progress_stage="Forward replay", - progress_rows_per_batch=replay_row_counts, - progress_total_rows=replay_total_rows, - force_serial=subset_forward_serial, - preserve_module_devices=preserve_devices, - ) - finally: - if forward_device_map: - self._restore_forward_device_overrides( - subset, - replay_prev_devices, - fallback_modules=full, - ) - if replay_pb is not None: - replay_pb.close() - if region_timer is not None: - region_timer.record( - "post_quant_forward", - time.perf_counter() - replay_start, - source=replay_source, - ) - - # Finalize module after last processor - if p_index == len(self.processors) - 1: - torch_sync() - - if not is_lm_head_module: - layers[layer_index] = self.gptq_model.post_quantize(module) - else: - self.gptq_model.post_quantize(module) - - for finalized in processed_subset.values(): - if isinstance(finalized, NamedModule): - setattr(finalized, "target_device", CPU) - inner_module = getattr(finalized, "module", None) - else: - inner_module = finalized - - if inner_module is not None and hasattr(inner_module, "target_device"): - setattr(inner_module, "target_device", CPU) - - if region_timer is not None: - region_timer.flush() - - if processor.fwd_after_process: - processor.clear_cache_data() - processor.receive_layer_inputs(layer_outputs) - layer_inputs = processor.inputs_cache.layer_inputs - - pb.title(layer_title).subtitle("").draw() - - if p_index == len(self.processors) - 1: - torch_sync() - - # Gather finalize tasks (can offload to disk); run them via the pool - finalize_tasks = [] - - for reverse_p in reversed(self.processors): - for module in processed_subset.values(): - actual_module = module.module if isinstance(module, NamedModule) else module - - get_device_new( - actual_module, - recursive=True, - assert_mode=True, - expected=CPU, - ) - with self._quant_device_lock: - key = getattr(module, "full_name", getattr(module, "name", None)) - if key is not None: - self._module_device_map[key] = CPU - - target_dev = CPU - module_label = getattr(module, "full_name", getattr(module, "name", "")) - layer_idx = getattr(module, "layer_index", None) - finalize_tasks.append((reverse_p, module, module_label, target_dev, layer_idx)) - - finalize_count = len(finalize_tasks) - finalize_futures = [] - finalize_pb = log.pb(range(finalize_count)).manual().set(show_left_steps=False) - - @torch.inference_mode() - def _finalize_on_worker(process, module, idx, total, module_label, layer_idx): - resolved_label = module_label or getattr(module, "full_name", getattr(module, "name", "")) - start = time.perf_counter() if region_timer is not None else None - try: - with log_time_block( - "submodule_finalize", - logger=log, - module_name=resolved_label, - ): - process.submodule_finalize(module, self.gptq_model) - - # Disk offload (lifecycle TODO note preserved) - if isinstance(process, (GPTQProcessor, QQQProcessor, AWQProcessor)): - quant_config = getattr(self.gptq_model, "quantize_config", None) - if quant_config and getattr(quant_config, "offload_to_disk", False): - offload_path = getattr(quant_config, "offload_to_disk_path", None) - if offload_path: - module_full_name = getattr(module, "full_name", None) - target_module = ( - self.gptq_model.model.get_submodule(module_full_name) - if module_full_name - else module - ) - offload_start = time.perf_counter() if region_timer is not None else None - with log_time_block( - "disk_offload", - logger=log, - module_name=resolved_label, - ): - offload_to_disk( - model=self.gptq_model.model, - module=target_module, - disk_path=offload_path, - ) - if region_timer is not None and offload_start is not None: - region_timer.record( - "submodule_finalize_offload", - time.perf_counter() - offload_start, - source=resolved_label, - ) - else: - log.warning( - "Skipping disk offload for %s: no offload path configured", - module_label, - ) - finally: - if region_timer is not None and start is not None: - region_timer.record( - "submodule_finalize", - time.perf_counter() - start, - source=resolved_label, - ) - process_name = process.name() if process is not None else "" - return FinalizeProgressInfo(module_label, process_name, layer_idx) - - # pb.subtitle( - # f"{process.name()}: layer:{layer_idx} Finalized {idx}/{total} {module_label}" - # ).draw() - - for index, (process, module, module_label, target_dev, layer_idx) in enumerate(finalize_tasks, start=1): - future = DEVICE_THREAD_POOL.submit( - target_dev, - _finalize_on_worker, - process, - module, - index, - finalize_count, - module_label, - layer_idx, - ) - finalize_futures.append((future, index, module_label, process, layer_idx)) - - finalize_futures_snapshot = list(finalize_futures) - - self._emit_layer_complete( - layer_idx=layer_index, - submodule_finalized=False, - raise_in_place=True, - ) - - if finalize_futures_snapshot: - known_layers = sorted( - { - layer_idx - for _, _, _, _, layer_idx in finalize_futures_snapshot - if layer_idx is not None - } - ) - includes_unknown = any( - layer_idx is None - for _, _, _, _, layer_idx in finalize_futures_snapshot - ) - - layer_heading = "Layer ?" - if known_layers: - sample_layers = ", ".join(str(idx) for idx in known_layers[:3]) - if len(known_layers) > 3: - sample_layers += ", …" - suffix = ", ?" if includes_unknown else "" - prefix = "Layer" if len(known_layers) == 1 else "Layers" - layer_heading = f"{prefix} {sample_layers}{suffix}" - elif includes_unknown: - layer_heading = "Layer ?" - - finalize_pb.title( - f"{layer_heading} Submodule finalize 0/{finalize_count}" - ).subtitle("Waiting for completions...").draw() - - def _drain_finalize_futures( - futures, - finalize_pb_local, - finalize_count_local, - layer_idx_for_callback, - ): - completed_local = 0 - try: - for future in as_completed(futures): - try: - result = future.result() - except BaseException as exc: - log.exception("Submodule finalize task raised an exception") - self._request_loop_stop(exc) - return - - if isinstance(result, FinalizeProgressInfo): - module_label = result.module_label - process_name = result.process_name - layer_idx = result.layer_idx - elif isinstance(result, tuple) and len(result) == 3: - module_label, process_name, layer_idx = result - else: - module_label = None - process_name = "" - layer_idx = None - - layer_label = f"Layer {layer_idx}" if layer_idx is not None else "Layer ?" - display_module = module_label or "" - subtitle = f"{process_name}: {display_module}" - - completed_local += 1 - finalize_pb_local.next() - finalize_pb_local.title( - f"{layer_label} Finalize {completed_local}/{finalize_count_local}" - ).subtitle(subtitle).draw() - finally: - finalize_pb_local.close() - self._emit_layer_complete( - layer_idx=layer_idx_for_callback, - submodule_finalized=True, - raise_in_place=False, - ) - - if finalize_futures_snapshot: - # Drain finalize futures asynchronously so the main loop can continue scheduling work. - threading.Thread( - target=_drain_finalize_futures, - args=( - [future for future, *_ in finalize_futures_snapshot], - finalize_pb, - finalize_count, - layer_index, - ), - name="SubmoduleFinalizeWatcher", - daemon=True, - ).start() - else: - self._emit_layer_complete( - layer_idx=layer_index, - submodule_finalized=True, - raise_in_place=True, - ) + run_layer_stage( + self, + layers=layers, + layer_modules=layer_modules, + layers_prefix=layers_prefix, + fail_safe=fail_safe, + shared_kv_cache_dict=shared_kv_cache_dict, + pb=pb, + layer_count=layer_count, + region_timer=region_timer, + finalize_progress_cls=FinalizeProgressInfo, + logger=log, + ) # LifeCycle: All sub-modules have finalized meaning quantization work is complete self._check_loop_stop() diff --git a/gptqmodel/looper/stage_inputs_capture.py b/gptqmodel/looper/stage_inputs_capture.py new file mode 100644 index 000000000..5a8c36614 --- /dev/null +++ b/gptqmodel/looper/stage_inputs_capture.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +"""Isolated stage for capturing calibration inputs prior to quantization.""" + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Sequence + +import torch + +from .. import DEVICE_THREAD_POOL +from ..looper.input_cache import InputCache +from ..nn_modules.hooked_linear import STOP_FORWARD_EXCEPTION, StopForward +from ..utils.ctx import ctx +from ..utils.device import get_device +from ..utils.looper_helpers import device_ctx +from ..utils.logger import setup_logger +from ..utils.model import get_module_by_name_prefix, move_to, nested_move_to +from ..utils.torch import CPU, META + +if TYPE_CHECKING: # pragma: no cover - import for typing only + from .module_looper import ModuleLooper + + +class StageInputsCapture: + """Capture layer inputs so processors can reuse cached activations.""" + + def __init__(self, looper: ModuleLooper, logger=None) -> None: + self.looper = looper + self.gptq_model = looper.gptq_model + self.logger = logger or setup_logger() + + def cache_inputs( + self, + layers: Sequence[torch.nn.Module], + calibration_data: Iterable[Dict[str, torch.Tensor]], + use_cache: bool, + ) -> InputCache: + layer_inputs: List[List[torch.Tensor]] = [] + attention_masks: List[torch.Tensor | None] = [] + position_ids: List[torch.Tensor] = [] + layer_input_kwargs: List[Dict[str, Any]] = [] + + timer = getattr(self.gptq_model, "quant_region_timer", None) + layer_label = None + if layers: + first_layer = layers[0] + layer_label = getattr(first_layer, "full_name", None) + if layer_label is None: + layer_label = getattr(getattr(first_layer, "__class__", None), "__name__", None) + if layer_label is None: + layer_label = type(first_layer).__name__ + capture_source = f"cache_inputs:{layer_label}" + else: + capture_source = "cache_inputs" + start_time = time.perf_counter() if timer else None + + try: + calibration_batches = len(calibration_data) # type: ignore[arg-type] + except (TypeError, AttributeError): + calibration_batches = None + + if calibration_batches is None: + self.logger.info("ModuleLooper: capturing layer inputs (batch count unknown)") + else: + self.logger.info( + "ModuleLooper: capturing layer inputs from %s calibration batches", + calibration_batches, + ) + + cur_layer_device = get_device(layers[0]) + data_device = cur_layer_device + + cache_forward_pb = None + processed_rows = 0 + cache_total_batches = None + if calibration_batches is not None and calibration_batches > 0: + cache_total_batches = int(calibration_batches) + cache_forward_pb = ( + self.logger.pb(range(max(cache_total_batches, 1))) + .manual() + .set(show_left_steps=False) + ) + cache_title = ( + f"Forward cached inputs (Pre {layer_label})" + if layer_label + else "Forward cached inputs" + ) + cache_forward_pb.title(cache_title).subtitle( + f"Batch 0/{cache_total_batches}" + ).draw() + + def store_input_hook(module, args, kwargs): + layer_input: List[torch.Tensor] = [] + if kwargs.get("hidden_states") is not None: + layer_input.append(move_to(kwargs["hidden_states"], device=data_device)) + else: + layer_input.append(move_to(args[0], device=data_device)) + + layer_inputs.append(layer_input) + + if kwargs.get("attention_mask") is not None: + attention_masks.append(kwargs["attention_mask"].to(device=data_device)) + else: + attention_masks.append(None) + + pos_ids = kwargs.get("position_ids", None) + if pos_ids is not None: + position_ids.append(move_to(pos_ids, device=data_device)) + one_kwargs: Dict[str, Any] = {} + for (k, v) in kwargs.items(): + if k not in ["hidden_states", "attention_mask", "position_ids"]: + one_kwargs[k] = nested_move_to(v, device=data_device) + layer_input_kwargs.append(one_kwargs) + + raise STOP_FORWARD_EXCEPTION + + if cur_layer_device == META: + layers[0] = self.gptq_model.shell_module_materialize( + target_submodule=layers[0], + device=self.gptq_model.quantize_config.device, + ) + cur_layer_device = self.gptq_model.quantize_config.device + else: + layers[0] = layers[0].to(self.gptq_model.quantize_config.device) + + ori_outside_layer_module_devices: Dict[str, torch.device] = {} + for module_name in self.gptq_model.get_base_modules(self.gptq_model.model): + module, _ = get_module_by_name_prefix(self.gptq_model.model, [module_name]) + + if module is None: + continue + + m_device = get_device(module) + ori_outside_layer_module_devices[module_name] = CPU if m_device == META else m_device + if module is not None: + self.gptq_model.shell_module_materialize( + target_submodule=module, + device=cur_layer_device, + ) + + handle = layers[0].register_forward_pre_hook(store_input_hook, with_kwargs=True) + + is_ovis = self.gptq_model.__class__.__name__ == "OvisGPTQ" + + self.gptq_model.pre_quantize_generate_hook_start() + + try: + for batch_index, example in enumerate(calibration_data, start=1): + for k, v in example.items(): + if self.gptq_model.ATTENTION_MASKS_REQUIRED_FOR_INPUT: + data_device = self.gptq_model.quantize_config.device + else: + data_device = ( + self.gptq_model.quantize_config.device + if k == "pixel_values" + else cur_layer_device + ) + if isinstance(v, list): + for index in range(len(v)): + if len(v[index].shape) == 1: + v[index] = v[index].unsqueeze(0) + v[index] = move_to( + v[index].to(self.gptq_model.model.visual_tokenizer.dtype) + if is_ovis + else v[index], + device=data_device, + ) + else: + if len(v.shape) == 1: + v = v.unsqueeze(0) + example[k] = move_to(v, device=data_device) + try: + if self.gptq_model.ATTENTION_MASKS_DTYPE is torch.long: + example["attention_mask"] = example["attention_mask"].long() + + with ctx( + DEVICE_THREAD_POOL.read_lock(self.gptq_model.quantize_config.device), + device_ctx(self.gptq_model.quantize_config.device), + ): + if self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS: + self.gptq_model.model.generate( + **example, + **self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS, + ) + else: + self.gptq_model.model(**example, use_cache=use_cache) + except StopForward: + pass + finally: + processed_batches = batch_index + if cache_forward_pb is not None: + rows_for_batch = 0 + if batch_index <= len(layer_inputs): + rows_for_batch = self.looper._batch_row_count( + layer_inputs[batch_index - 1] + ) + if rows_for_batch <= 0: + rows_for_batch = 1 + processed_rows += rows_for_batch + cache_forward_pb.current_iter_step = processed_batches + subtitle = f"Batch {processed_batches}/{cache_total_batches}" + if processed_rows > 0: + subtitle += f" rows {processed_rows}" + cache_forward_pb.subtitle(subtitle).draw() + finally: + if cache_forward_pb is not None: + cache_forward_pb.close() + + self.gptq_model.pre_quantize_generate_hook_end() + handle.remove() + + result = InputCache( + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + ) + + if timer is not None and start_time is not None: + timer.record( + "capture_inputs", + time.perf_counter() - start_time, + source=capture_source, + ) + + return result diff --git a/gptqmodel/looper/stage_layer.py b/gptqmodel/looper/stage_layer.py new file mode 100644 index 000000000..49dc14df8 --- /dev/null +++ b/gptqmodel/looper/stage_layer.py @@ -0,0 +1,507 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +"""Layer execution stage extracted from ModuleLooper.""" + +from __future__ import annotations + +import threading +import time +from concurrent.futures import as_completed +from contextlib import nullcontext +from typing import TYPE_CHECKING, Dict, List, Optional + +import torch + +from .. import DEVICE_THREAD_POOL +from ..looper.awq_processor import AWQProcessor +from ..looper.gptq_processor import GPTQProcessor +from ..looper.named_module import NamedModule +from ..looper.qqq_processor import QQQProcessor +from ..utils.ctx import ctx +from ..utils.device import get_device, get_device_new +from ..utils.logger import log_time_block, setup_logger +from ..utils.looper_helpers import device_ctx +from ..utils.model import find_modules, get_module +from ..utils.offload import offload_to_disk +from ..utils.torch import CPU, torch_sync +from .stage_subset import SubsetForwardContext, run_subset_stage + +if TYPE_CHECKING: # pragma: no cover - type hints only + from .module_looper import ModuleLooper + + +def run_layer_stage( + looper: 'ModuleLooper', + *, + layers: List[torch.nn.Module], + layer_modules: List[List[str]], + layers_prefix: Optional[str], + fail_safe: bool, + shared_kv_cache_dict: Dict[int, torch.Tensor], + pb, + layer_count: int, + region_timer, + finalize_progress_cls, + logger=None, +) -> None: + """Execute the main per-layer quantization loop.""" + log = logger or setup_logger() + for layer_index in pb: + if looper._check_loop_stop(): + break + is_lm_head_module = layer_index >= layer_count + + if is_lm_head_module: + layer_title = "Quantizing lm_head" + module = get_module(looper.gptq_model.model, key=looper.gptq_model.lm_head) + else: + layer_title = f"Quantizing layer {layer_index} of {layer_count - 1}" + module = layers[layer_index] + + pb.title(layer_title).subtitle("").draw() + + if module.__class__.__name__.lower() == "MllamaCrossAttentionDecoderLayer".lower(): + # TODO FIXME: currently we not support quantizing cross attention layer (pixel_values) + continue + + module = looper.gptq_model.pre_quantize(module) + + if is_lm_head_module: + layer_descriptor = looper.gptq_model.lm_head + elif layers_prefix: + layer_descriptor = f"{layers_prefix}.{layer_index}" + else: + layer_descriptor = str(layer_index) + + cur_layer_device = get_device(module) + full = find_modules(module, name=looper.gptq_model.lm_head if is_lm_head_module else "") + + for p_index, processor in enumerate(looper.processors): + processor.log_call_count = 0 # reset + processor.collect_memory_info(layer_index) + + modules = [[looper.gptq_model.lm_head]] if is_lm_head_module else layer_modules + + # for NativeProcessor we process one time forward on all grouped module subsets + if processor.fwd_all_modules_in_single_pass: + # merge all subsets into one + modules = [sum(modules, [])] + + # AWQ does per-layer itself; skip here + if isinstance(processor, AWQProcessor): + named_childs = dict() + for index, names in enumerate(modules): + named_modules = looper.crate_named_modules(full=full, + is_lm_head_module=is_lm_head_module, + layer_index=layer_index, layers_prefix=layers_prefix, + names=names, + processor=processor, + fail_safe=fail_safe) + named_childs.update(named_modules) + + lock_ctx = nullcontext() + device_for_ctx = cur_layer_device if getattr(cur_layer_device, 'type', None) != 'meta' else None + if device_for_ctx is not None: + lock_ctx = DEVICE_THREAD_POOL.read_lock(cur_layer_device) + with ctx(lock_ctx, device_ctx(device_for_ctx)): + processor.layer_quantize(module, cur_layer_device, named_childs) + if p_index == len(looper.processors) - 1: + looper._emit_layer_complete( + layer_idx=layer_index, + submodule_finalized=False, + raise_in_place=True, + ) + looper._emit_layer_complete( + layer_idx=layer_index, + submodule_finalized=True, + raise_in_place=True, + ) + continue + + layer_inputs = processor.inputs_cache.layer_inputs + if is_lm_head_module: + layer_inputs = looper.gptq_model.lm_head_pre_quantize_generate_hook(layer_inputs) + layer_input_kwargs = processor.inputs_cache.layer_input_kwargs + position_ids = processor.inputs_cache.position_ids + attention_masks = processor.inputs_cache.attention_masks + + processed_subset: Dict[str, NamedModule] = {} + last_subset_context: Optional[SubsetForwardContext] = None + subset_total = len(modules) + + for index, names in enumerate(modules): + subset_result = run_subset_stage( + looper, + processor=processor, + module=module, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + layer_descriptor=layer_descriptor, + layer_title=layer_title, + layer_index=layer_index, + layers_prefix=layers_prefix, + subset_names=names, + subset_index=index, + subset_total=subset_total, + full=full, + fail_safe=fail_safe, + shared_kv_cache_dict=shared_kv_cache_dict, + pb=pb, + log=log, + region_timer=region_timer, + ) + + layer_inputs = subset_result.layer_inputs + processed_subset.update(subset_result.processed_subset) + if subset_result.forward_context is not None: + last_subset_context = subset_result.forward_context + + is_last_module = layer_index == len(pb) - 1 + layer_outputs: List[List[torch.Tensor]] = [] + subset_context = last_subset_context + forward_device_map = subset_context.forward_device_map if subset_context else {} + subset_forward_serial = subset_context.subset_forward_serial if subset_context else False + subset_reference_total = subset_context.subset_total if subset_context else subset_total + subset_reference_index = subset_context.subset_index if subset_context else max(subset_total - 1, 0) + subset_for_overrides = subset_context.subset if subset_context else {} + preserve_devices = bool(forward_device_map) + + # second forward after process() + if not is_last_module and processor.fwd_after_process and subset_context is not None: + replay_batch_count = looper._resolve_batch_total( + getattr(processor, "num_batches", None), + layer_inputs, + ) + replay_row_counts = list(looper._collect_row_counts(layer_inputs)) + if not replay_row_counts and replay_batch_count > 0: + replay_row_counts = [1] * replay_batch_count + if len(replay_row_counts) > replay_batch_count: + replay_row_counts = replay_row_counts[:replay_batch_count] + replay_total_rows = sum(replay_row_counts) if replay_row_counts else replay_batch_count + replay_total_rows = max(replay_total_rows, 1) + if len(replay_row_counts) < replay_batch_count: + replay_row_counts.extend([1] * (replay_batch_count - len(replay_row_counts))) + replay_msg = ( + "Forward replay " + f"(layer=`{layer_descriptor}`, batches={replay_batch_count}, rows={replay_total_rows})" + ) + replay_pb = ( + log.pb(range(replay_total_rows)) + .manual() + .set(show_left_steps=False) + ) + replay_pb.title(replay_msg).subtitle( + f"Forward replay Row 0/{replay_total_rows}" + ).draw() + # Forward replay shares the same VRAM spike; block until the pool drains first. + # DEVICE_THREAD_POOL.wait() + # try to cleanup recent objects before forward + #timed_gc_collect(1) + + replay_start = time.perf_counter() + replay_source = f"{layer_descriptor}:subset{subset_reference_index + 1}/{subset_reference_total}" + + replay_prev_devices: Dict[str, torch.device] = {} + if forward_device_map: + replay_prev_devices = looper._apply_forward_device_overrides( + subset_for_overrides, + forward_device_map, + fallback_modules=full, + ) + + # if log.isEnabledFor(logging.DEBUG): + # replay_snapshot = [] + # for name, named_module in subset.items(): + # target_device = getattr(named_module, "target_device", None) + # if target_device is None: + # try: + # target_device = get_device(named_module.module) + # except Exception: + # target_device = None + # target_device_str = str(target_device) if target_device is not None else "unknown" + # replay_snapshot.append(f"{name}:{target_device_str}") + # log.debug( + # "ModuleLooper: Forward replay device snapshot (layer=`%s`, subset=%d/%d, serial=%s) %s", + # layer_descriptor, + # index + 1, + # subset_total, + # subset_forward_serial, + # ", ".join(replay_snapshot), + # ) + + try: + layer_outputs = looper._run_forward_batches( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + need_outputs=True, + reuse_kv=False, + progress_pb=replay_pb, + progress_title=replay_msg, + progress_stage="Forward replay", + progress_rows_per_batch=replay_row_counts, + progress_total_rows=replay_total_rows, + force_serial=subset_forward_serial, + preserve_module_devices=preserve_devices, + ) + finally: + if forward_device_map: + looper._restore_forward_device_overrides( + subset_for_overrides, + replay_prev_devices, + fallback_modules=full, + ) + if replay_pb is not None: + replay_pb.close() + if region_timer is not None: + region_timer.record( + "post_quant_forward", + time.perf_counter() - replay_start, + source=replay_source, + ) + + # Finalize module after last processor + if p_index == len(looper.processors) - 1: + torch_sync() + + if not is_lm_head_module: + layers[layer_index] = looper.gptq_model.post_quantize(module) + else: + looper.gptq_model.post_quantize(module) + + for finalized in processed_subset.values(): + if isinstance(finalized, NamedModule): + setattr(finalized, "target_device", CPU) + inner_module = getattr(finalized, "module", None) + else: + inner_module = finalized + + if inner_module is not None and hasattr(inner_module, "target_device"): + setattr(inner_module, "target_device", CPU) + + if region_timer is not None: + region_timer.flush() + + if processor.fwd_after_process: + processor.clear_cache_data() + processor.receive_layer_inputs(layer_outputs) + layer_inputs = processor.inputs_cache.layer_inputs + + pb.title(layer_title).subtitle("").draw() + + if p_index == len(looper.processors) - 1: + torch_sync() + + # Gather finalize tasks (can offload to disk); run them via the pool + finalize_tasks = [] + + for reverse_p in reversed(looper.processors): + for module in processed_subset.values(): + actual_module = module.module if isinstance(module, NamedModule) else module + + get_device_new( + actual_module, + recursive=True, + assert_mode=True, + expected=CPU, + ) + with looper._quant_device_lock: + key = getattr(module, "full_name", getattr(module, "name", None)) + if key is not None: + looper._module_device_map[key] = CPU + + target_dev = CPU + module_label = getattr(module, "full_name", getattr(module, "name", "")) + layer_idx = getattr(module, "layer_index", None) + finalize_tasks.append((reverse_p, module, module_label, target_dev, layer_idx)) + + finalize_count = len(finalize_tasks) + finalize_futures = [] + finalize_pb = log.pb(range(finalize_count)).manual().set(show_left_steps=False) + + @torch.inference_mode() + def _finalize_on_worker(process, module, idx, total, module_label, layer_idx): + resolved_label = module_label or getattr(module, "full_name", getattr(module, "name", "")) + start = time.perf_counter() if region_timer is not None else None + try: + with log_time_block( + "submodule_finalize", + logger=log, + module_name=resolved_label, + ): + process.submodule_finalize(module, looper.gptq_model) + + # Disk offload (lifecycle TODO note preserved) + if isinstance(process, (GPTQProcessor, QQQProcessor, AWQProcessor)): + quant_config = getattr(looper.gptq_model, "quantize_config", None) + if quant_config and getattr(quant_config, "offload_to_disk", False): + offload_path = getattr(quant_config, "offload_to_disk_path", None) + if offload_path: + module_full_name = getattr(module, "full_name", None) + target_module = ( + looper.gptq_model.model.get_submodule(module_full_name) + if module_full_name + else module + ) + offload_start = time.perf_counter() if region_timer is not None else None + with log_time_block( + "disk_offload", + logger=log, + module_name=resolved_label, + ): + offload_to_disk( + model=looper.gptq_model.model, + module=target_module, + disk_path=offload_path, + ) + if region_timer is not None and offload_start is not None: + region_timer.record( + "submodule_finalize_offload", + time.perf_counter() - offload_start, + source=resolved_label, + ) + else: + log.warning( + "Skipping disk offload for %s: no offload path configured", + module_label, + ) + finally: + if region_timer is not None and start is not None: + region_timer.record( + "submodule_finalize", + time.perf_counter() - start, + source=resolved_label, + ) + process_name = process.name() if process is not None else "" + return finalize_progress_cls(module_label, process_name, layer_idx) + + # pb.subtitle( + # f"{process.name()}: layer:{layer_idx} Finalized {idx}/{total} {module_label}" + # ).draw() + + for index, (process, module, module_label, target_dev, layer_idx) in enumerate(finalize_tasks, start=1): + future = DEVICE_THREAD_POOL.submit( + target_dev, + _finalize_on_worker, + process, + module, + index, + finalize_count, + module_label, + layer_idx, + ) + finalize_futures.append((future, index, module_label, process, layer_idx)) + + finalize_futures_snapshot = list(finalize_futures) + + looper._emit_layer_complete( + layer_idx=layer_index, + submodule_finalized=False, + raise_in_place=True, + ) + + if finalize_futures_snapshot: + known_layers = sorted( + { + layer_idx + for _, _, _, _, layer_idx in finalize_futures_snapshot + if layer_idx is not None + } + ) + includes_unknown = any( + layer_idx is None + for _, _, _, _, layer_idx in finalize_futures_snapshot + ) + + layer_heading = "Layer ?" + if known_layers: + sample_layers = ", ".join(str(idx) for idx in known_layers[:3]) + if len(known_layers) > 3: + sample_layers += ", …" + suffix = ", ?" if includes_unknown else "" + prefix = "Layer" if len(known_layers) == 1 else "Layers" + layer_heading = f"{prefix} {sample_layers}{suffix}" + elif includes_unknown: + layer_heading = "Layer ?" + + finalize_pb.title( + f"{layer_heading} Submodule finalize 0/{finalize_count}" + ).subtitle("Waiting for completions...").draw() + + def _drain_finalize_futures( + futures, + finalize_pb_local, + finalize_count_local, + layer_idx_for_callback, + ): + completed_local = 0 + try: + for future in as_completed(futures): + try: + result = future.result() + except BaseException as exc: + log.exception("Submodule finalize task raised an exception") + looper._request_loop_stop(exc) + return + + if isinstance(result, finalize_progress_cls): + module_label = result.module_label + process_name = result.process_name + layer_idx = result.layer_idx + elif isinstance(result, tuple) and len(result) == 3: + module_label, process_name, layer_idx = result + else: + module_label = None + process_name = "" + layer_idx = None + + layer_label = f"Layer {layer_idx}" if layer_idx is not None else "Layer ?" + display_module = module_label or "" + subtitle = f"{process_name}: {display_module}" + + completed_local += 1 + finalize_pb_local.next() + finalize_pb_local.title( + f"{layer_label} Finalize {completed_local}/{finalize_count_local}" + ).subtitle(subtitle).draw() + finally: + finalize_pb_local.close() + looper._emit_layer_complete( + layer_idx=layer_idx_for_callback, + submodule_finalized=True, + raise_in_place=False, + ) + + if finalize_futures_snapshot: + # Drain finalize futures asynchronously so the main loop can continue scheduling work. + threading.Thread( + target=_drain_finalize_futures, + args=( + [future for future, *_ in finalize_futures_snapshot], + finalize_pb, + finalize_count, + layer_index, + ), + name="SubmoduleFinalizeWatcher", + daemon=True, + ).start() + else: + looper._emit_layer_complete( + layer_idx=layer_index, + submodule_finalized=True, + raise_in_place=True, + ) diff --git a/gptqmodel/looper/stage_subset.py b/gptqmodel/looper/stage_subset.py new file mode 100644 index 000000000..befcc6dd1 --- /dev/null +++ b/gptqmodel/looper/stage_subset.py @@ -0,0 +1,364 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +"""Subset-level processing stage extracted from ModuleLooper.""" + +from __future__ import annotations + +import math +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional + +import torch + +from .. import DEVICE_THREAD_POOL +from ..looper.gptq_processor import GPTQProcessor +from ..looper.loop_processor import LoopProcessor +from ..looper.named_module import NamedModule +from ..quantization.config import VRAMStrategy +from ..utils.device import get_device +from ..utils.logger import setup_logger +from ..utils.torch import torch_sync + +if TYPE_CHECKING: # pragma: no cover - typing only + from .module_looper import ModuleLooper + + +@dataclass +class SubsetForwardContext: + subset: Dict[str, NamedModule] + forward_device_map: Dict[str, torch.device] + subset_forward_serial: bool + subset_total: int + subset_index: int + + +@dataclass +class SubsetStageResult: + processed_subset: Dict[str, NamedModule] + layer_inputs: List[List[torch.Tensor]] + forward_context: Optional[SubsetForwardContext] + + +def run_subset_stage( + looper: 'ModuleLooper', + *, + processor: LoopProcessor, + module: torch.nn.Module, + layer_inputs: List[List[torch.Tensor]], + layer_input_kwargs: List[Dict[str, torch.Tensor]], + position_ids: List[torch.Tensor], + attention_masks: List[torch.Tensor], + cur_layer_device: torch.device, + is_lm_head_module: bool, + layer_descriptor: str, + layer_title: str, + layer_index: int, + layers_prefix: Optional[str], + subset_names: List[str], + subset_index: int, + subset_total: int, + full, + fail_safe: bool, + shared_kv_cache_dict: Dict[int, torch.Tensor], + pb, + log=None, + region_timer=None, +) -> SubsetStageResult: + """Process a single subset of modules within the layer quantization loop.""" + logger = log or setup_logger() + + subset = looper.crate_named_modules( + full=full, + is_lm_head_module=is_lm_head_module, + layer_index=layer_index, + layers_prefix=layers_prefix, + names=subset_names, + processor=processor, + fail_safe=fail_safe, + ) + + if len(subset) == 0: + return SubsetStageResult(processed_subset={}, layer_inputs=layer_inputs, forward_context=None) + + moe_group_keys_all: List[str] = [] + forward_device_map: Dict[str, torch.device] = {} + subset_forward_serial = False + + attention_subset = bool(subset) and all( + looper._is_attention_module_name(name) for name in subset + ) + + moe_group_key_by_name: Dict[str, Optional[str]] = { + name: looper._extract_moe_group_key(name) + for name in subset + } + moe_module_names = [ + name for name, group_key in moe_group_key_by_name.items() + if group_key is not None + ] + moe_modules_set = set(moe_module_names) + is_moe_subset = len(moe_module_names) >= looper._moe_subset_threshold + + if is_moe_subset: + expert_groups: Dict[str, List[str]] = {} + combined_names: List[str] = list(subset.keys()) + if full is not None: + for candidate in full.keys(): + if candidate not in subset: + combined_names.append(candidate) + + for sub_name in combined_names: + group_key = looper._extract_moe_group_key(sub_name) + if group_key is None: + continue + expert_groups.setdefault(group_key, []).append(sub_name) + + moe_group_keys_all = list(expert_groups.keys()) + + for name, named_module in subset.items(): + setattr(named_module, "moe_enabled", name in moe_modules_set) + + if looper._vram_strategy == VRAMStrategy.BALANCED: + devices = [ + dev for dev in looper._quant_devices + if dev is not None and getattr(dev, "type", None) != "cpu" + ] + if len(devices) > 1 and expert_groups: + assignable_group_keys: List[str] = [] + for group_key, module_names in expert_groups.items(): + suffixes = {name.rsplit(".", 1)[-1] for name in module_names} + if {"gate_proj", "up_proj"}.issubset(suffixes): + assignable_group_keys.append(group_key) + + if assignable_group_keys: + groups_per_device = max( + math.ceil(len(assignable_group_keys) / len(devices)), 1 + ) + for group_index, group_key in enumerate(assignable_group_keys): + device_idx = min(group_index // groups_per_device, len(devices) - 1) + target_device = devices[device_idx] + for module_name in expert_groups[group_key]: + forward_device_map[module_name] = target_device + + subset_forward_serial = looper._vram_strategy == VRAMStrategy.BALANCED + if subset_forward_serial: + active_group_count = len(moe_group_keys_all) + if active_group_count == 0: + subset_forward_serial = False + elif attention_subset and active_group_count <= looper._moe_subset_threshold: + subset_forward_serial = False + else: + for named_module in subset.values(): + setattr(named_module, "moe_enabled", False) + + handle = [] + batch_count = looper._resolve_batch_total( + getattr(processor, "num_batches", None), + layer_inputs, + ) + forward_row_counts = list(looper._collect_row_counts(layer_inputs)) + if not forward_row_counts and batch_count > 0: + forward_row_counts = [1] * batch_count + if len(forward_row_counts) > batch_count: + forward_row_counts = forward_row_counts[:batch_count] + forward_total_rows = sum(forward_row_counts) if forward_row_counts else batch_count + forward_total_rows = max(forward_total_rows, 1) + if len(forward_row_counts) < batch_count: + forward_row_counts.extend([1] * (batch_count - len(forward_row_counts))) + + subset_size = len(subset) + for idx, (name, m) in enumerate(subset.items()): + is_last = (idx == subset_size - 1) + hook_source = getattr(m, "full_name", None) + if hook_source is None: + hook_source = getattr(m, "name", name) + if hook_source is None: + hook_source = str(name) + + if hasattr(subset[name], 'forward_hook'): + original_hook = processor.pre_process_fwd_hook(name) + subset[name].forward_hook = looper._masked_hook_wrapper(processor, original_hook, hook_source) + if is_last and processor.fwd_after_process: + subset[name].forward_hook_last = True + else: + original_hook = processor.pre_process_fwd_hook(name) + handle.append(subset[name].register_forward_hook( + looper._masked_hook_wrapper(processor, original_hook, hook_source) + )) + + fwd_start = time.perf_counter() + forward_source = f"{layer_descriptor}:subset{subset_index + 1}/{subset_total}" + + need_outputs = not processor.fwd_after_process + reuse_kv = bool(getattr(module, "reuse_kv", False)) + forward_msg = ( + "Forward: " + f"Layer=`{layer_descriptor}`, subset={subset_index + 1}/{subset_total}, " + f"batches={batch_count}" + ) + forward_pb = ( + logger.pb(range(forward_total_rows)) + .manual() + .set(show_left_steps=False) + ) + forward_pb.title(forward_msg).subtitle( + f"Row 0/{forward_total_rows}" + ).draw() + + previous_forward_devices: Dict[str, torch.device] = {} + preserve_devices = bool(forward_device_map) + if forward_device_map: + previous_forward_devices = looper._apply_forward_device_overrides( + subset, + forward_device_map, + fallback_modules=full, + ) + + try: + forward_outputs = looper._run_forward_batches( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + need_outputs=need_outputs, + reuse_kv=reuse_kv, + progress_pb=forward_pb, + progress_title=forward_msg, + progress_stage="Forward", + progress_rows_per_batch=forward_row_counts, + progress_total_rows=forward_total_rows, + force_serial=subset_forward_serial, + preserve_module_devices=preserve_devices, + ) + finally: + if forward_device_map: + looper._restore_forward_device_overrides( + subset, + previous_forward_devices, + fallback_modules=full, + ) + if forward_pb is not None: + forward_pb.close() + if need_outputs: + processor.receive_layer_inputs(forward_outputs) + layer_inputs = processor.inputs_cache.layer_inputs + del forward_outputs + + fwd_time = time.perf_counter() - fwd_start + processor.set_fwd_time(fwd_time) + if region_timer is not None: + region_timer.record( + "pre_quant_forward", + fwd_time, + source=forward_source, + ) + + pb.title(layer_title).subtitle("").draw() + + for h in handle: + h.remove() + + for name in subset: + if hasattr(subset[name], 'forward_hook'): + subset[name].forward_hook = None + subset[name].forward_hook_last = False + + moe_skip_modules = [] + if isinstance(processor, GPTQProcessor): + for name in subset: + if processor.tasks[name].fwd_counter == 0: + logger.error(f"`{name}` was not invoked, if it is a MoE module, it may lack sufficient calibration data routed to it.") + moe_skip_modules.append(name) + + if not fail_safe: + for name in moe_skip_modules: + subset.pop(name) + task_map = getattr(processor, "tasks", None) + if task_map is not None: + task_map.pop(name, None) + + quant_target_devices: Dict[str, torch.device] = {} + for name, named_module in subset.items(): + task_map = getattr(processor, "tasks", None) + has_task = bool(task_map and task_map.get(name) is not None) + + if has_task: + target_device = looper._prepare_named_module_for_quantization( + processor=processor, + named_module=named_module, + fallback_device=cur_layer_device, + ) + else: + target_device = get_device(named_module.module) + setattr(named_module, "target_device", target_device) + setattr(named_module.module, "target_device", target_device) + + quant_target_devices[name] = target_device + + processed_subset: Dict[str, NamedModule] = {} + futures = [] + + @torch.inference_mode() + def _process_on_worker( + proc: LoopProcessor, + nm: NamedModule, + expected_device: torch.device, + ): + module_label = getattr(nm, "full_name", getattr(nm, "name", repr(nm))) + module_ref = nm.module if isinstance(nm, NamedModule) else nm + module_weight = getattr(module_ref, "weight", None) + if module_weight is not None and expected_device is not None: + target_device = expected_device if isinstance(expected_device, torch.device) else torch.device(expected_device) + actual_device = get_device(module_weight) + assert actual_device == target_device, ( + f"Device mismatch for '{module_label}' process task: " + f"module weight on {actual_device}, thread target {target_device}." + ) + + timer = getattr(looper.gptq_model, "quant_region_timer", None) + start = time.perf_counter() if timer else None + try: + proc.process(module=nm) + finally: + if timer is not None and start is not None: + timer.record( + "process_quant", + time.perf_counter() - start, + source=module_label, + ) + return nm.name, nm + + for name, named_module in subset.items(): + tgt_dev = quant_target_devices.get(name, cur_layer_device) + futures.append( + DEVICE_THREAD_POOL.submit(tgt_dev, _process_on_worker, processor, named_module, tgt_dev) + ) + + for fut in futures: + name, named_module = fut.result() + processed_subset[name] = named_module + torch_sync() + + context = SubsetForwardContext( + subset=subset, + forward_device_map=forward_device_map, + subset_forward_serial=subset_forward_serial, + subset_total=subset_total, + subset_index=subset_index, + ) + + return SubsetStageResult( + processed_subset=processed_subset, + layer_inputs=layer_inputs, + forward_context=context, + ) diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 9af1cca4c..faa6793b2 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -13,15 +13,15 @@ log = setup_logger() ASCII_LOGO = r""" -_____/\\\\\\\\\\\\__/\\\\\\\\\\\\\____/\\\\\\\\\\\\\\\______________________/\\\________/\\\\____________/\\\\_______________________/\\\__________________/\\\\\\____ - ___/\\\//////////__\/\\\/////////\\\_\///////\\\/////____________________/\\\\/\\\\____\/\\\\\\________/\\\\\\______________________\/\\\_________________\////\\\____ - __/\\\_____________\/\\\_______\/\\\_______\/\\\_______________________/\\\//\////\\\__\/\\\//\\\____/\\\//\\\______________________\/\\\____________________\/\\\____ - _\/\\\____/\\\\\\\_\/\\\\\\\\\\\\\/________\/\\\________/\\\\\\\\\\\__/\\\______\//\\\_\/\\\\///\\\/\\\/_\/\\\_____/\\\\\___________\/\\\______/\\\\\\\\_____\/\\\____ - _\/\\\___\/////\\\_\/\\\/////////__________\/\\\_______\///////////__\//\\\______/\\\__\/\\\__\///\\\/___\/\\\___/\\\///\\\____/\\\\\\\\\____/\\\/////\\\____\/\\\____ - _\/\\\_______\/\\\_\/\\\___________________\/\\\______________________\///\\\\/\\\\/___\/\\\____\///_____\/\\\__/\\\__\//\\\__/\\\////\\\___/\\\\\\\\\\\_____\/\\\____ - _\/\\\_______\/\\\_\/\\\___________________\/\\\________________________\////\\\//_____\/\\\_____________\/\\\_\//\\\__/\\\__\/\\\__\/\\\__\//\\///////______\/\\\____ - _\//\\\\\\\\\\\\/__\/\\\___________________\/\\\___________________________\///\\\\\\__\/\\\_____________\/\\\__\///\\\\\/___\//\\\\\\\/\\__\//\\\\\\\\\\__/\\\\\\\\\_ - __\////////////____\///____________________\///______________________________\//////___\///______________\///_____\/////______\///////\//____\//////////__\/////////__ +_____/\\\\\\\\\\\\__/\\\\\\\\\\\\\____/\\\\\\\\\\\\\\\______________________/\\\________/\\\\____________/\\\\_______________________/\\\__________________/\\\\\\____ + ___/\\\//////////__\/\\\/////////\\\_\///////\\\/////____________________/\\\\/\\\\____\/\\\\\\________/\\\\\\______________________\/\\\_________________\////\\\____ + __/\\\_____________\/\\\_______\/\\\_______\/\\\_______________________/\\\//\////\\\__\/\\\//\\\____/\\\//\\\______________________\/\\\____________________\/\\\____ + _\/\\\____/\\\\\\\_\/\\\\\\\\\\\\\/________\/\\\________/\\\\\\\\\\\__/\\\______\//\\\_\/\\\\///\\\/\\\/_\/\\\_____/\\\\\___________\/\\\______/\\\\\\\\_____\/\\\____ + _\/\\\___\/////\\\_\/\\\/////////__________\/\\\_______\///////////__\//\\\______/\\\__\/\\\__\///\\\/___\/\\\___/\\\///\\\____/\\\\\\\\\____/\\\/////\\\____\/\\\____ + _\/\\\_______\/\\\_\/\\\___________________\/\\\______________________\///\\\\/\\\\/___\/\\\____\///_____\/\\\__/\\\__\//\\\__/\\\////\\\___/\\\\\\\\\\\_____\/\\\____ + _\/\\\_______\/\\\_\/\\\___________________\/\\\________________________\////\\\//_____\/\\\_____________\/\\\_\//\\\__/\\\__\/\\\__\/\\\__\//\\///////______\/\\\____ + _\//\\\\\\\\\\\\/__\/\\\___________________\/\\\___________________________\///\\\\\\__\/\\\_____________\/\\\__\///\\\\\/___\//\\\\\\\/\\__\//\\\\\\\\\\__/\\\\\\\\\_ + __\////////////____\///____________________\///______________________________\//////___\///______________\///_____\/////______\///////\//____\//////////__\/////////__ """ # if not os.environ.get("PYTHON_GIL", None): diff --git a/gptqmodel/models/definitions/glm4_moe.py b/gptqmodel/models/definitions/glm4_moe.py index 96c876b8f..9cc91d359 100644 --- a/gptqmodel/models/definitions/glm4_moe.py +++ b/gptqmodel/models/definitions/glm4_moe.py @@ -4,7 +4,6 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from ..base import BaseQModel -from ...quantization.config import VRAMStrategy class GLM4MoEGPTQ(BaseQModel): diff --git a/gptqmodel/models/definitions/qwen3_moe.py b/gptqmodel/models/definitions/qwen3_moe.py index ca8d59149..750f5f62e 100644 --- a/gptqmodel/models/definitions/qwen3_moe.py +++ b/gptqmodel/models/definitions/qwen3_moe.py @@ -4,7 +4,6 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from ...quantization import METHOD -from ...quantization.config import VRAMStrategy from ..base import BaseQModel diff --git a/gptqmodel/models/definitions/qwen3_next.py b/gptqmodel/models/definitions/qwen3_next.py index b2202d9d8..c1e571ee4 100644 --- a/gptqmodel/models/definitions/qwen3_next.py +++ b/gptqmodel/models/definitions/qwen3_next.py @@ -4,7 +4,6 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from ...quantization import METHOD -from ...quantization.config import VRAMStrategy from ..base import BaseQModel diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 1dfbb9946..58cad6730 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -586,7 +586,7 @@ def build_layerwise_device_map( device_ids = list(range(num_gpus)) device_map: Dict[str, str] = {} mod2name = {m: n for n, m in model.named_modules()} - + if device == DEVICE.CUDA: if torch.cuda.is_available(): device_strs = [f"cuda:{i}" for i in range(num_gpus)] @@ -599,7 +599,7 @@ def build_layerwise_device_map( raise RuntimeError("XPU is not available") else: device_strs = ["cpu"] * num_gpus - + def assign(mod, device_id): if mod is None: return diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index 53b815b20..6ad7aa6c8 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -16,8 +16,8 @@ import pcre as re import torch import transformers -from safetensors.torch import save_file from safetensors import safe_open +from safetensors.torch import save_file from transformers import AutoConfig, PreTrainedTokenizerFast, ProcessorMixin from transformers.modeling_utils import no_init_weights from transformers.models.auto.tokenization_auto import get_tokenizer_config diff --git a/gptqmodel/nn_modules/qlinear/bitblas.py b/gptqmodel/nn_modules/qlinear/bitblas.py index db9adadae..f5f455c18 100644 --- a/gptqmodel/nn_modules/qlinear/bitblas.py +++ b/gptqmodel/nn_modules/qlinear/bitblas.py @@ -6,9 +6,9 @@ import ctypes import os +import sys from dataclasses import dataclass from pathlib import Path -import sys from typing import List, Optional, Tuple, Union import torch diff --git a/gptqmodel/utils/colors.py b/gptqmodel/utils/colors.py index f5822bd18..d6f43ad75 100644 --- a/gptqmodel/utils/colors.py +++ b/gptqmodel/utils/colors.py @@ -10,6 +10,7 @@ from enum import Enum from typing import Optional, Union + ANSI_RESET = "\033[0m" diff --git a/gptqmodel/utils/eval.py b/gptqmodel/utils/eval.py index 2a6fa3325..fe3bdf1b5 100644 --- a/gptqmodel/utils/eval.py +++ b/gptqmodel/utils/eval.py @@ -24,6 +24,7 @@ class LM_EVAL(str, Enum): GSM8K_PLATINUM_COT = "gsm8k_platinum_cot" HELLASWAG = "hellaswag" MMLU = "mmlu" + MMLU_STEM = "mmlu_stem" GPQA = "gpqa" ARC_EASY = "arc_easy" BOOLQ = "boolq" @@ -179,4 +180,3 @@ def evalplus_make_table(results): print("|-------------|------------|--------------------|") for task, metrics in results.items(): print(f"| {task} | {metrics['base tests']} | {metrics['base + extra tests']} |") - diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index 2fb476871..cf5d33e65 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -1631,7 +1631,7 @@ def _on_task_finished(self, key: str) -> None: if not hasattr(self, "_gc_pending_physical"): self._gc_pending_physical = {} elif not isinstance(self._gc_pending_physical, dict): - self._gc_pending_physical = {k: 1 for k in self._gc_pending_physical} + self._gc_pending_physical = dict.fromkeys(self._gc_pending_physical, 1) if not hasattr(self, "_last_gc_done_physical"): self._last_gc_done_physical = {} if not hasattr(self, "_physical_children"): @@ -2092,7 +2092,7 @@ def _janitor_loop(self): with self._stats_lock: pending_map = self._gc_pending_physical if not isinstance(pending_map, dict): - pending_map = {k: 1 for k in pending_map} + pending_map = dict.fromkeys(pending_map, 1) self._gc_pending_physical = pending_map for key in processed_devices: pending_map.pop(key, None) diff --git a/setup.py b/setup.py index e3589f4ec..10c8ac90b 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ from setuptools import find_namespace_packages, find_packages, setup from setuptools.command.bdist_wheel import bdist_wheel as _bdist_wheel + CUTLASS_VERSION = "3.5.0" CUTLASS_RELEASE_URL = f"https://github.com/NVIDIA/cutlass/archive/refs/tags/v{CUTLASS_VERSION}.tar.gz" diff --git a/tests/models/model_test.py b/tests/models/model_test.py index da834c62b..6ee99e2ad 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -81,7 +81,6 @@ class ModelTest(unittest.TestCase): VRAM_STRATEGY = VRAMStrategy.EXCLUSIVE TRUST_REMOTE_CODE = False - APPLY_CHAT_TEMPLATE = False TORCH_DTYPE = "auto" EVAL_BATCH_SIZE = "auto" QUANT_BATCH_SIZE = 1 @@ -191,6 +190,9 @@ def _legacy_arc_tasks(self): lookup = getattr(self, "_resolved_task_lookup", None) if isinstance(lookup, dict): lookup[normalized] = EVAL.LM_EVAL.ARC_CHALLENGE + chat_lookup = getattr(self, "_task_chat_template", None) + if isinstance(chat_lookup, dict): + chat_lookup[normalized] = False return baselines def _normalize_metric_spec(self, spec): @@ -226,18 +228,30 @@ def _normalize_metric_spec(self, spec): def get_eval_tasks(self): self._resolved_task_lookup = {} + self._task_chat_template = {} if self.EVAL_TASKS: baselines = {} for task, metrics in self.EVAL_TASKS.items(): resolved_task = self._resolve_task_enum(task) normalized_task = self._normalize_task_identifier(resolved_task) self._resolved_task_lookup[normalized_task] = resolved_task + + metrics_dict = dict(metrics or {}) + chat_template = bool(metrics_dict.pop("chat_template", False)) + self._task_chat_template[normalized_task] = chat_template + baselines[normalized_task] = { metric_name: self._normalize_metric_spec(spec) - for metric_name, spec in metrics.items() + for metric_name, spec in metrics_dict.items() } return baselines - return self._legacy_arc_tasks() + + baselines = self._legacy_arc_tasks() + if isinstance(baselines, dict): + for task_name in baselines.keys(): + if task_name not in self._task_chat_template: + self._task_chat_template[task_name] = False + return baselines @staticmethod def _flatten_task_metrics(task_results): @@ -349,7 +363,6 @@ def run_eval_tasks(self, model, backend, trust_remote_code=False): try: task_results = self.lm_eval( model=model, - apply_chat_template=self.APPLY_CHAT_TEMPLATE, trust_remote_code=self.TRUST_REMOTE_CODE, delete_quantized_model=False, ) @@ -920,7 +933,7 @@ def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_pa return model - def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, delete_quantized_model=False, extra_args:dict=None): + def lm_eval(self, model, trust_remote_code=False, delete_quantized_model=False, extra_args:dict=None): try: task_names = self._normalize_task_list() aggregated_results = {} @@ -948,6 +961,8 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del task_groups = EVAL.get_task_groups_from_tasks(task_names) + chat_template_lookup = getattr(self, "_task_chat_template", {}) or {} + for framework, tasks in task_groups.items(): active_backend = self._current_load_backend() log.info(f"TEST: EVAL starting: backend = {active_backend.name}") @@ -973,37 +988,45 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del resolved_lookup[normalized_task] = original_task eval_tasks.append(original_task) - results = GPTQModel.eval( - model_or_id_or_path=eval_target, - llm_backend="vllm" if self.USE_VLLM else "gptqmodel", - model_args=model_args, - output_path=tmp_dir, - backend=active_backend, - framework=framework, - tasks=eval_tasks, - apply_chat_template=apply_chat_template, - trust_remote_code=trust_remote_code, - batch_size=self.EVAL_BATCH_SIZE, - gen_kwargs="temperature=0.0,top_k=50", - random_seed=RAND_SEED, - task_manager=TaskManager(include_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "../tasks"), include_defaults=False) - ) - - print('--------Eval Result---------') - print(make_table(results)) - if "groups" in results: - print(make_table(results, "groups")) - print('--------Eval Result End---------') - for task_name in eval_tasks: - normalized_task_name = self._normalize_task_identifier(task_name) - metrics = results["results"].get(normalized_task_name, {}) - filtered_metrics = { - metric: value - for metric, value in metrics.items() - if metric != "alias" and "stderr" not in metric - } - aggregated_results[normalized_task_name] = filtered_metrics - print({normalized_task_name: filtered_metrics}) + grouped_tasks: Dict[bool, List] = {} + for task in eval_tasks: + normalized_name = self._normalize_task_identifier(task) + apply_chat = bool(chat_template_lookup.get(normalized_name, False)) + grouped_tasks.setdefault(apply_chat, []).append(task) + + for apply_chat_template, grouped in grouped_tasks.items(): + results = GPTQModel.eval( + model_or_id_or_path=eval_target, + llm_backend="vllm" if self.USE_VLLM else "gptqmodel", + model_args=model_args, + output_path=tmp_dir, + backend=active_backend, + framework=framework, + tasks=grouped, + apply_chat_template=apply_chat_template, + trust_remote_code=trust_remote_code, + batch_size=self.EVAL_BATCH_SIZE, + gen_kwargs="temperature=0.0,top_k=50", + random_seed=RAND_SEED, + task_manager=TaskManager(include_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "../tasks"), include_defaults=False) + ) + + print('--------Eval Result---------') + print(make_table(results)) + if "groups" in results: + print(make_table(results, "groups")) + print('--------Eval Result End---------') + + for task_name in grouped: + normalized_task_name = self._normalize_task_identifier(task_name) + metrics = results["results"].get(normalized_task_name, {}) + filtered_metrics = { + metric: value + for metric, value in metrics.items() + if metric != "alias" and "stderr" not in metric + } + aggregated_results[normalized_task_name] = filtered_metrics + print({normalized_task_name: filtered_metrics}) self._cleanup_quantized_model(model, enabled=delete_quantized_model) return aggregated_results @@ -1020,9 +1043,9 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del if int(self.EVAL_BATCH_SIZE) > 0: self.lm_eval(model=model, - apply_chat_template=apply_chat_template, trust_remote_code=trust_remote_code, - delete_quantized_model=delete_quantized_model) + delete_quantized_model=delete_quantized_model, + extra_args=extra_args) print(f"set batch size to {self.EVAL_BATCH_SIZE}, passed") else: print(f"set batch size to {self.EVAL_BATCH_SIZE}, failed") @@ -1054,7 +1077,6 @@ def quant_lm_eval(self): else: task_results = self.lm_eval( model=self.SAVE_PATH if self.SAVE_PATH else self.model, - apply_chat_template=self.APPLY_CHAT_TEMPLATE, trust_remote_code=self.TRUST_REMOTE_CODE, delete_quantized_model=self.DELETE_QUANTIZED_MODEL, ) diff --git a/tests/models/test_act_group_aware.py b/tests/models/test_act_group_aware.py index 6263b2882..dffcc77a0 100644 --- a/tests/models/test_act_group_aware.py +++ b/tests/models/test_act_group_aware.py @@ -12,11 +12,11 @@ class TestHybridActOrder(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.3140, "floor_pct": 0.05}, "acc_norm": {"value": 0.3439, "floor_pct": 0.05}, }, } - APPLY_CHAT_TEMPLATE = True V2 = False ACT_GROUP_AWARE = True diff --git a/tests/models/test_apertus.py b/tests/models/test_apertus.py index a842fced4..9554d94bb 100644 --- a/tests/models/test_apertus.py +++ b/tests/models/test_apertus.py @@ -13,12 +13,12 @@ class TestApertus(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Apertus-8B-Instruct-2509/" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.5145, "floor_pct": 0.2}, "acc_norm": {"value": 0.5256, "floor_pct": 0.2}, }, } TRUST_REMOTE_CODE = False - APPLY_CHAT_TEMPLATE = True EVAL_BATCH_SIZE = 6 LOAD_BACKEND = BACKEND.TORCH diff --git a/tests/models/test_deepseekv2_lite.py b/tests/models/test_deepseekv2_lite.py index e6f715476..5a7a7f957 100644 --- a/tests/models/test_deepseekv2_lite.py +++ b/tests/models/test_deepseekv2_lite.py @@ -5,15 +5,22 @@ from model_test import ModelTest +from gptqmodel.utils.eval import EVAL + class TestDeepseekV2Lite(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/DeepSeek-Coder-V2-Lite-Instruct" # "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" NATIVE_ARC_CHALLENGE_ACC = 0.4753 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4855 - APPLY_CHAT_TEMPLATE = True TRUST_REMOTE_CODE = True + EVAL_TASKS = { + EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, + "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, + "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, + }, + } def test_deepseekv2lite(self): self.quant_lm_eval() - diff --git a/tests/models/test_dream.py b/tests/models/test_dream.py index 7af5f7ebc..b53e1e600 100644 --- a/tests/models/test_dream.py +++ b/tests/models/test_dream.py @@ -12,11 +12,11 @@ class TestDream(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Dream-v0-Instruct-7B" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.3567, "floor_pct": 0.36}, "acc_norm": {"value": 0.3805, "floor_pct": 0.36}, }, } - APPLY_CHAT_TEMPLATE = True TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 1 BITS = 8 diff --git a/tests/models/test_falcon.py b/tests/models/test_falcon.py index 07c3392f8..3745ebf15 100644 --- a/tests/models/test_falcon.py +++ b/tests/models/test_falcon.py @@ -11,11 +11,11 @@ class TestFalcon(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/falcon-7b-instruct" # "tiiuae/falcon-7b-instruct" - APPLY_CHAT_TEMPLATE = True TRUST_REMOTE_CODE = False TORCH_DTYPE = torch.float16 EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.3993, "floor_pct": 0.52}, "acc_norm": {"value": 0.4292, "floor_pct": 0.52}, }, diff --git a/tests/models/test_glm.py b/tests/models/test_glm.py index 27d8139c7..db9aaddc0 100644 --- a/tests/models/test_glm.py +++ b/tests/models/test_glm.py @@ -12,7 +12,7 @@ # |--------------------------------|----------| # | arc_challenge :: acc,none | 0.5154 | # | arc_challenge :: acc_norm,none | 0.535 | -# | mmlu :: acc,none | 0.6325 | +# | mmlu_stem :: acc,none | 0.6325 | class TestGlm(ModelTest): GROUP_SIZE = 32 # real: THUDM/glm-4-9b-chat-hf @@ -22,7 +22,7 @@ class TestGlm(ModelTest): "acc": {"value": 0.5154, "floor_pct": 0.04}, "acc_norm": {"value": 0.5350, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU: { + EVAL.LM_EVAL.MMLU_STEM: { "acc": {"value": 0.6325, "floor_pct": 0.04}, }, } diff --git a/tests/models/test_glm4_moe.py b/tests/models/test_glm4_moe.py index bf2e32561..02870bb89 100644 --- a/tests/models/test_glm4_moe.py +++ b/tests/models/test_glm4_moe.py @@ -18,10 +18,9 @@ class TestGlm4Moe(ModelTest): "acc": {"value": 0.5026, "floor_pct": 0.04}, "acc_norm": {"value": 0.5171, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU: { + EVAL.LM_EVAL.MMLU_STEM: { "acc": {"value": 0.6362, "floor_pct": 0.04}, }, } def test_glm4moe(self): self.quant_lm_eval() - diff --git a/tests/models/test_gpt_oss.py b/tests/models/test_gpt_oss.py index d7e3f99f2..75ea1c4d2 100644 --- a/tests/models/test_gpt_oss.py +++ b/tests/models/test_gpt_oss.py @@ -12,12 +12,12 @@ class TestGPTOSS(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/gpt-oss-20b-BF16/" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": False, "acc": {"value": 0.4411, "floor_pct": 0.2}, "acc_norm": {"value": 0.4718, "floor_pct": 0.2}, }, } TRUST_REMOTE_CODE = False - APPLY_CHAT_TEMPLATE = False EVAL_BATCH_SIZE = 6 USE_VLLM = False diff --git a/tests/models/test_granite.py b/tests/models/test_granite.py index 2b22089eb..89e932da3 100644 --- a/tests/models/test_granite.py +++ b/tests/models/test_granite.py @@ -10,10 +10,10 @@ class TestGranite(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/granite-3.0-2b-instruct" # "ibm-granite/granite-3.0-2b-instruct" - APPLY_CHAT_TEMPLATE = True TRUST_REMOTE_CODE = True EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.4505, "floor_pct": 0.2}, "acc_norm": {"value": 0.4770, "floor_pct": 0.2}, }, diff --git a/tests/models/test_hymba.py b/tests/models/test_hymba.py index e4f1d51df..455386d14 100644 --- a/tests/models/test_hymba.py +++ b/tests/models/test_hymba.py @@ -12,13 +12,13 @@ class TestHymba(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Hymba-1.5B-Instruct/" # "baichuan-inc/Baichuan2-7B-Chat" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.2073, "floor_pct": 0.75}, "acc_norm": {"value": 0.2713, "floor_pct": 0.75}, }, } MODEL_MAX_LEN = 8192 TRUST_REMOTE_CODE = True - APPLY_CHAT_TEMPLATE = True # Hymba currently only supports a batch size of 1. # See https://huggingface.co/nvidia/Hymba-1.5B-Instruct EVAL_BATCH_SIZE = 1 diff --git a/tests/models/test_internlm2_5.py b/tests/models/test_internlm2_5.py index a1da5dc24..2cef74a66 100644 --- a/tests/models/test_internlm2_5.py +++ b/tests/models/test_internlm2_5.py @@ -5,15 +5,23 @@ from model_test import ModelTest +from gptqmodel.utils.eval import EVAL + class TestInternlm2_5(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/internlm2_5-1_8b-chat" # "internlm/internlm2_5-1_8b-chat" NATIVE_ARC_CHALLENGE_ACC = 0.3217 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3575 - APPLY_CHAT_TEMPLATE = True TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 6 USE_VLLM = False + EVAL_TASKS = { + EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, + "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, + "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, + }, + } def test_internlm2_5(self): @@ -21,4 +29,3 @@ def test_internlm2_5(self): self.quant_lm_eval() - diff --git a/tests/models/test_ling.py b/tests/models/test_ling.py index 1af8547b4..5303f5e06 100644 --- a/tests/models/test_ling.py +++ b/tests/models/test_ling.py @@ -12,12 +12,12 @@ class TestLing(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Ling-mini-2.0/" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.5009, "floor_pct": 0.2}, "acc_norm": {"value": 0.5137, "floor_pct": 0.2}, }, } TRUST_REMOTE_CODE = True - APPLY_CHAT_TEMPLATE = True # EVAL_BATCH_SIZE = 6 V2 = False DEBUG = True diff --git a/tests/models/test_llama3_2.py b/tests/models/test_llama3_2.py index 5860f3133..c4a315c42 100644 --- a/tests/models/test_llama3_2.py +++ b/tests/models/test_llama3_2.py @@ -18,32 +18,38 @@ # |--------------------------------|----------| # | arc_challenge :: acc,none | 0.3174 | # | arc_challenge :: acc_norm,none | 0.3601 | -# | mmlu :: acc,none | 0.3186 | +# | mmlu_stem :: acc,none | 0.3186 | class TestLlama3_2(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" + EVAL_BATCH_SIZE = 64 EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": { - "value": 0.3174, + "value": 0.3191, "floor_pct": 0.04, "ceil_pct": 0.10, }, "acc_norm": { - "value": 0.3601, + "value": 0.3507, "floor_pct": 0.04, "ceil_pct": 0.10, }, }, - EVAL.LM_EVAL.MMLU: { + EVAL.LM_EVAL.MMLU_STEM: { + "chat_template": False, "acc": { - "value": 0.3186, + "value": 0.2978, "floor_pct": 0.04, "ceil_pct": 0.10, }, }, } - APPLY_CHAT_TEMPLATE = True - QUANT_BATCH_SIZE = 4 + + # llama 3.2 Instruct requires chat = true to have normal ARC scores + # mmlu requires chat = false + # APPLY_CHAT_TEMPLATE = True + # QUANT_BATCH_SIZE = 4 # EORA = Lora( # # for quant, path is save path. for load, it is loading path diff --git a/tests/models/test_llama3_2_awq.py b/tests/models/test_llama3_2_awq.py index 14e371d58..380ed3066 100644 --- a/tests/models/test_llama3_2_awq.py +++ b/tests/models/test_llama3_2_awq.py @@ -17,11 +17,11 @@ class TestLlama3_2(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.3234, "floor_pct": 0.36}, "acc_norm": {"value": 0.3524, "floor_pct": 0.36}, }, } - APPLY_CHAT_TEMPLATE = True V2 = False DEBUG = True ACT_GROUP_AWARE = False diff --git a/tests/models/test_llama4.py b/tests/models/test_llama4.py index 26cd5c0ef..8a7a8b1e5 100644 --- a/tests/models/test_llama4.py +++ b/tests/models/test_llama4.py @@ -12,11 +12,11 @@ class TestLlama4(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-4-Scout-17B-16E-Instruct" # "meta-llama/Llama-4-Scout-17B-16E-Instruct" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.3567, "floor_pct": 0.36}, "acc_norm": {"value": 0.3805, "floor_pct": 0.36}, }, } - APPLY_CHAT_TEMPLATE = True TRUST_REMOTE_CODE = False def test_llama4(self): diff --git a/tests/models/test_mimo.py b/tests/models/test_mimo.py index 1485cbfc4..3f90d68ea 100644 --- a/tests/models/test_mimo.py +++ b/tests/models/test_mimo.py @@ -12,12 +12,12 @@ class TestMimo(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/MiMo-7B-RL" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.2739, "floor_pct": 0.2}, "acc_norm": {"value": 0.3055, "floor_pct": 0.2}, }, } TRUST_REMOTE_CODE = True - APPLY_CHAT_TEMPLATE = True EVAL_BATCH_SIZE = 6 def test_mimo(self): diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 6e06003ca..016d350d2 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -5,14 +5,22 @@ from model_test import ModelTest +from gptqmodel.utils.eval import EVAL + class TestMistral(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Mistral-7B-Instruct-v0.2" # "mistralai/Mistral-7B-Instruct-v0.2" NATIVE_ARC_CHALLENGE_ACC = 0.5427 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5597 - APPLY_CHAT_TEMPLATE = True TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 6 + EVAL_TASKS = { + EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, + "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, + "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, + }, + } def test_mistral(self): self.quant_lm_eval() diff --git a/tests/models/test_mixtral.py b/tests/models/test_mixtral.py index 9e21dadcc..178a0b080 100644 --- a/tests/models/test_mixtral.py +++ b/tests/models/test_mixtral.py @@ -5,14 +5,22 @@ from model_test import ModelTest +from gptqmodel.utils.eval import EVAL + class TestMixtral(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Mixtral-8x7B-Instruct-v0.1" # "mistralai/Mixtral-8x7B-Instruct-v0.1" NATIVE_ARC_CHALLENGE_ACC = 0.5213 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5247 - APPLY_CHAT_TEMPLATE = True TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 6 + EVAL_TASKS = { + EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, + "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, + "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, + }, + } def test_mixtral(self): self.quant_lm_eval() diff --git a/tests/models/test_mpt.py b/tests/models/test_mpt.py index c940bc5b0..3f2c7b29b 100644 --- a/tests/models/test_mpt.py +++ b/tests/models/test_mpt.py @@ -5,15 +5,23 @@ from model_test import ModelTest +from gptqmodel.utils.eval import EVAL + class TestMpt(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/mpt-7b-instruct" # "mosaicml/mpt-7b-instruct" NATIVE_ARC_CHALLENGE_ACC = 0.4275 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4454 - APPLY_CHAT_TEMPLATE = False TRUST_REMOTE_CODE = False EVAL_BATCH_SIZE = 6 USE_FLASH_ATTN = False + EVAL_TASKS = { + EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": False, + "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, + "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, + }, + } def test_mpt(self): self.quant_lm_eval() diff --git a/tests/models/test_multi_vs_single_gpu.py b/tests/models/test_multi_vs_single_gpu.py index 2a8d4254d..8394ea610 100644 --- a/tests/models/test_multi_vs_single_gpu.py +++ b/tests/models/test_multi_vs_single_gpu.py @@ -48,11 +48,11 @@ class TestMultiVsSingleGPU(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.3311, "floor_pct": 0.05}, "acc_norm": {"value": 0.3549, "floor_pct": 0.05}, }, } - APPLY_CHAT_TEMPLATE = True V2 = False DEBUG = True ACT_GROUP_AWARE = False diff --git a/tests/models/test_nemotron_ultra.py b/tests/models/test_nemotron_ultra.py index 023a68bc1..926daef59 100644 --- a/tests/models/test_nemotron_ultra.py +++ b/tests/models/test_nemotron_ultra.py @@ -12,11 +12,11 @@ class TestNemotronUltra(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3_1-Nemotron-Ultra-253B-v1" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.3567, "floor_pct": 0.36}, "acc_norm": {"value": 0.3805, "floor_pct": 0.36}, }, } - APPLY_CHAT_TEMPLATE = True TRUST_REMOTE_CODE = True def test_nemotron_ultra(self): diff --git a/tests/models/test_ovis2.py b/tests/models/test_ovis2.py index 324e1ccbb..9b223be39 100644 --- a/tests/models/test_ovis2.py +++ b/tests/models/test_ovis2.py @@ -14,7 +14,6 @@ class Test(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Ovis2-1B" TRUST_REMOTE_CODE = True - APPLY_CHAT_TEMPLATE = False EVAL_BATCH_SIZE = 1 def test_ovis(self): diff --git a/tests/models/test_ovis_1_6_llama.py b/tests/models/test_ovis_1_6_llama.py index b899f6704..d681c773a 100644 --- a/tests/models/test_ovis_1_6_llama.py +++ b/tests/models/test_ovis_1_6_llama.py @@ -14,7 +14,6 @@ class TestOvis1_6_Llama(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Ovis1.6-Llama3.2-3B" TRUST_REMOTE_CODE = True - APPLY_CHAT_TEMPLATE = False EVAL_BATCH_SIZE = 1 USE_FLASH_ATTN = False diff --git a/tests/models/test_phi_3.py b/tests/models/test_phi_3.py index a23a1c410..dcbac750f 100644 --- a/tests/models/test_phi_3.py +++ b/tests/models/test_phi_3.py @@ -5,13 +5,21 @@ from model_test import ModelTest +from gptqmodel.utils.eval import EVAL + class TestPhi_3(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Phi-3-mini-4k-instruct" # "microsoft/Phi-3-mini-4k-instruct" NATIVE_ARC_CHALLENGE_ACC = 0.5401 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5674 - APPLY_CHAT_TEMPLATE = True TRUST_REMOTE_CODE = True + EVAL_TASKS = { + EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, + "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, + "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, + }, + } def test_phi_3(self): self.quant_lm_eval() diff --git a/tests/models/test_phi_3_moe.py b/tests/models/test_phi_3_moe.py index fb53e3669..ceb16af9c 100644 --- a/tests/models/test_phi_3_moe.py +++ b/tests/models/test_phi_3_moe.py @@ -5,13 +5,21 @@ from model_test import ModelTest +from gptqmodel.utils.eval import EVAL + class TestPhi_3(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Phi-3.5-MoE-instruct" # microsoft/Phi-3.5-MoE-instruct NATIVE_ARC_CHALLENGE_ACC = 0.5401 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5674 - APPLY_CHAT_TEMPLATE = True TRUST_REMOTE_CODE = True + EVAL_TASKS = { + EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, + "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, + "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, + }, + } def test_phi_3(self): self.quant_lm_eval() diff --git a/tests/models/test_phi_4.py b/tests/models/test_phi_4.py index 6084e3d43..0c6a11a13 100644 --- a/tests/models/test_phi_4.py +++ b/tests/models/test_phi_4.py @@ -5,15 +5,23 @@ from model_test import ModelTest +from gptqmodel.utils.eval import EVAL + class TestPhi_4(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Phi-4-multimodal-instruct" # "microsoft/Phi-3-mini-4k-instruct" NATIVE_ARC_CHALLENGE_ACC = 0.5401 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5674 - APPLY_CHAT_TEMPLATE = True TRUST_REMOTE_CODE = True USE_FLASH_ATTN = False BATCH_SIZE = 1 + EVAL_TASKS = { + EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, + "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, + "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, + }, + } def test_phi_4(self): self.quant_lm_eval() diff --git a/tests/models/test_qwen2_5.py b/tests/models/test_qwen2_5.py index a96f00777..e2ef3c0b8 100644 --- a/tests/models/test_qwen2_5.py +++ b/tests/models/test_qwen2_5.py @@ -12,16 +12,18 @@ # |--------------------------------|----------| # | arc_challenge :: acc,none | 0.2892 | # | arc_challenge :: acc_norm,none | 0.3302 | -# | mmlu :: acc,none | 0.4351 | +# | mmlu_stem :: acc,none | 0.4351 | class TestQwen2_5(ModelTest): GROUP_SIZE = 32 NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct" + EVAL_BATCH_SIZE = 64 EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.2892, "floor_pct": 0.04}, "acc_norm": {"value": 0.3302, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU: { + EVAL.LM_EVAL.MMLU_STEM: { "acc": {"value": 0.4351, "floor_pct": 0.04}, }, } diff --git a/tests/models/test_qwen2_5_omni.py b/tests/models/test_qwen2_5_omni.py index 2da963cf4..12b010154 100644 --- a/tests/models/test_qwen2_5_omni.py +++ b/tests/models/test_qwen2_5_omni.py @@ -15,12 +15,12 @@ class TestQwen2_5_Omni(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-Omni-3B" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.2329, "floor_pct": 0.2}, "acc_norm": {"value": 0.2765, "floor_pct": 0.2}, }, } TRUST_REMOTE_CODE = False - APPLY_CHAT_TEMPLATE = True EVAL_BATCH_SIZE = 6 def test_qwen2_5_omni(self): diff --git a/tests/models/test_qwen2_5_vl.py b/tests/models/test_qwen2_5_vl.py index 65fc80364..60b5d5c51 100644 --- a/tests/models/test_qwen2_5_vl.py +++ b/tests/models/test_qwen2_5_vl.py @@ -13,12 +13,12 @@ class TestQwen2_VL(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-VL-3B-Instruct" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.4309, "floor_pct": 0.2}, "acc_norm": {"value": 0.4113, "floor_pct": 0.2}, }, } TRUST_REMOTE_CODE = False - APPLY_CHAT_TEMPLATE = True EVAL_BATCH_SIZE = 6 def test_qwen2_vl(self): @@ -70,7 +70,6 @@ def test_qwen2_vl(self): self.check_kernel(model, self.KERNEL_INFERENCE) task_results = self.lm_eval(model=model, - apply_chat_template=self.APPLY_CHAT_TEMPLATE, trust_remote_code=self.TRUST_REMOTE_CODE, delete_quantized_model=self.DELETE_QUANTIZED_MODEL) self.check_results(task_results) diff --git a/tests/models/test_qwen2_moe_quant.py b/tests/models/test_qwen2_moe_quant.py index 8b81c84e5..316594e6d 100644 --- a/tests/models/test_qwen2_moe_quant.py +++ b/tests/models/test_qwen2_moe_quant.py @@ -12,12 +12,12 @@ class TestQwen2_5_Moe(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen1.5-MoE-A2.7B" # Qwen/Qwen1.5-MoE-A2.7B EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.2739, "floor_pct": 0.2}, "acc_norm": {"value": 0.3055, "floor_pct": 0.2}, }, } TRUST_REMOTE_CODE = False - APPLY_CHAT_TEMPLATE = True EVAL_BATCH_SIZE = 6 def test_qwen2_5(self): diff --git a/tests/models/test_qwen2_vl.py b/tests/models/test_qwen2_vl.py index 9a20ee182..e9d234321 100644 --- a/tests/models/test_qwen2_vl.py +++ b/tests/models/test_qwen2_vl.py @@ -13,12 +13,12 @@ class TestQwen2_VL(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2-VL-2B-Instruct" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.3524, "floor_pct": 0.2}, "acc_norm": {"value": 0.3763, "floor_pct": 0.2}, }, } TRUST_REMOTE_CODE = False - APPLY_CHAT_TEMPLATE = True EVAL_BATCH_SIZE = 6 def test_qwen2_vl(self): @@ -70,7 +70,6 @@ def test_qwen2_vl(self): self.check_kernel(model, self.KERNEL_INFERENCE) task_results = self.lm_eval(model=model, - apply_chat_template=self.APPLY_CHAT_TEMPLATE, trust_remote_code=self.TRUST_REMOTE_CODE, delete_quantized_model=self.DELETE_QUANTIZED_MODEL) self.check_results(task_results) diff --git a/tests/models/test_qwen3_moe.py b/tests/models/test_qwen3_moe.py index 051a0d747..d77262830 100644 --- a/tests/models/test_qwen3_moe.py +++ b/tests/models/test_qwen3_moe.py @@ -2,9 +2,9 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from gptqmodel.quantization.config import VRAMStrategy from model_test import ModelTest +from gptqmodel.quantization.config import VRAMStrategy from gptqmodel.utils.eval import EVAL diff --git a/tests/models/test_qwen3_next.py b/tests/models/test_qwen3_next.py index 96dee55ca..3d6f75f86 100644 --- a/tests/models/test_qwen3_next.py +++ b/tests/models/test_qwen3_next.py @@ -3,9 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from gptqmodel.quantization.config import VRAMStrategy from model_test import ModelTest +from gptqmodel.quantization.config import VRAMStrategy from gptqmodel.utils.eval import EVAL @@ -13,7 +13,7 @@ # |--------------------------------|----------| # | arc_challenge :: acc,none | 0.6271 | # | arc_challenge :: acc_norm,none | 0.6613 | -# | mmlu :: acc,none | 0.8403 | +# | mmlu_stem :: acc,none | 0.8403 | class TestQwen3Next(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen3-Next-80B-A3B-Instruct" EVAL_TASKS = { @@ -21,7 +21,7 @@ class TestQwen3Next(ModelTest): "acc": {"value": 0.6271, "floor_pct": 0.04}, "acc_norm": {"value": 0.6613, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU: { + EVAL.LM_EVAL.MMLU_STEM: { "acc": {"value": 0.8403, "floor_pct": 0.04}, }, } diff --git a/tests/models/test_seed_oss.py b/tests/models/test_seed_oss.py index fe9933cd4..eb231b621 100644 --- a/tests/models/test_seed_oss.py +++ b/tests/models/test_seed_oss.py @@ -12,12 +12,12 @@ class TestSeedOSS(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Seed-OSS-36B-Instruct/" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.2739, "floor_pct": 0.2}, "acc_norm": {"value": 0.3055, "floor_pct": 0.2}, }, } TRUST_REMOTE_CODE = False - APPLY_CHAT_TEMPLATE = True EVAL_BATCH_SIZE = 6 def test_seed_oss(self): diff --git a/tests/models/test_telechat2.py b/tests/models/test_telechat2.py index ca7c396da..0f38add01 100644 --- a/tests/models/test_telechat2.py +++ b/tests/models/test_telechat2.py @@ -4,16 +4,24 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from model_test import ModelTest +from gptqmodel.utils.eval import EVAL + class TestTeleChat_2(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/TeleChat2-7B/" # "Tele-AI/TeleChat2-7B" NATIVE_ARC_CHALLENGE_ACC = 0.3677 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3831 - APPLY_CHAT_TEMPLATE = True TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 6 USE_VLLM = False USE_FLASH_ATTN = False + EVAL_TASKS = { + EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, + "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, + "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, + }, + } def test_telechat2(self): diff --git a/tests/models/test_xverse.py b/tests/models/test_xverse.py index 7d77ed9df..0f1cff63d 100644 --- a/tests/models/test_xverse.py +++ b/tests/models/test_xverse.py @@ -12,12 +12,12 @@ class TestXVerse(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/XVERSE-7B-Chat" # "xverse/XVERSE-7B-Chat" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.4198, "floor_pct": 0.2}, "acc_norm": {"value": 0.4044, "floor_pct": 0.2}, }, } TRUST_REMOTE_CODE = True - APPLY_CHAT_TEMPLATE = True EVAL_BATCH_SIZE = 6 USE_VLLM = False USE_FLASH_ATTN = False diff --git a/tests/models/test_yi.py b/tests/models/test_yi.py index 2eda7be3e..4be697f35 100644 --- a/tests/models/test_yi.py +++ b/tests/models/test_yi.py @@ -5,14 +5,22 @@ from model_test import ModelTest +from gptqmodel.utils.eval import EVAL + class TestYi(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Yi-Coder-1.5B-Chat" # "01-ai/Yi-Coder-1.5B-Chat" NATIVE_ARC_CHALLENGE_ACC = 0.2679 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.2986 TRUST_REMOTE_CODE = True - APPLY_CHAT_TEMPLATE = True EVAL_BATCH_SIZE = 4 + EVAL_TASKS = { + EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, + "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, + "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, + }, + } def test_yi(self): self.quant_lm_eval() diff --git a/tests/test_awq.py b/tests/test_awq.py index 1afa83198..38e6d46ea 100644 --- a/tests/test_awq.py +++ b/tests/test_awq.py @@ -22,7 +22,6 @@ from gptqmodel.nn_modules.qlinear.awq_marlin import AwqMarlinQuantLinear from gptqmodel.quantization import FORMAT, METHOD, QUANT_CONFIG_FILENAME from gptqmodel.utils.machete import _validate_machete_device_support, machete_import_exception -from gptqmodel.utils.torch import torch_empty_cache os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" diff --git a/tests/test_benchmark_gar.py b/tests/test_benchmark_gar.py index 4a2f74e51..616a1c5e9 100644 --- a/tests/test_benchmark_gar.py +++ b/tests/test_benchmark_gar.py @@ -6,8 +6,7 @@ import torch from tabulate import tabulate -from gptqmodel.quantization import gar -from gptqmodel.quantization import gar_ref +from gptqmodel.quantization import gar, gar_ref def _benchmark_fn(label, fn, device, warmup_runs=3, measured_runs=10): diff --git a/tests/test_bits_new.py b/tests/test_bits_new.py index 40c9eb4f9..b9951f383 100644 --- a/tests/test_bits_new.py +++ b/tests/test_bits_new.py @@ -54,7 +54,7 @@ def bench(path: str, backend: BACKEND, adapter: Optional[Lora]): bench_result = GPTQModel.eval( model_or_id_or_path=model, framework=EVAL.LM_EVAL, - tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.MMLU], + tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.MMLU_STEM], batch_size=16, ) diff --git a/tests/test_gptqv2.py b/tests/test_gptqv2.py index 284c1dfea..7aed4695d 100644 --- a/tests/test_gptqv2.py +++ b/tests/test_gptqv2.py @@ -12,12 +12,12 @@ class TestQwen2_5_GPTQv2(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.2739, "floor_pct": 0.2}, "acc_norm": {"value": 0.3055, "floor_pct": 0.2}, }, } TRUST_REMOTE_CODE = False - APPLY_CHAT_TEMPLATE = True EVAL_BATCH_SIZE = 6 V2 = True diff --git a/tests/test_lm_head.py b/tests/test_lm_head.py index 98bada2ca..d34eea421 100644 --- a/tests/test_lm_head.py +++ b/tests/test_lm_head.py @@ -40,7 +40,6 @@ def test_eval(self): class TestLmHeadQuant(ModelTest): - APPLY_CHAT_TEMPLATE = True EXPECT_LM_HEAD_LOSS = 0.0094 sample_length = 1024 @@ -57,6 +56,7 @@ def setUpClass(cls): def test_quant_lm_head(self): self.EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.3148464163822526, "floor_pct": 0.2}, "acc_norm": {"value": 0.3310580204778157, "floor_pct": 0.2}, }, @@ -83,7 +83,6 @@ def test_quant_lm_head(self): ) task_results = self.lm_eval(model=model, - apply_chat_template=self.APPLY_CHAT_TEMPLATE, trust_remote_code=self.TRUST_REMOTE_CODE, delete_quantized_model=self.DELETE_QUANTIZED_MODEL) self.check_results(task_results) diff --git a/tests/test_out_of_model_tensor_files.py b/tests/test_out_of_model_tensor_files.py index e0fbfec76..a8241dfc0 100644 --- a/tests/test_out_of_model_tensor_files.py +++ b/tests/test_out_of_model_tensor_files.py @@ -132,7 +132,7 @@ def _fake_get_state_dict_for_save(*_args, **_kwargs): def _fake_streaming_state_dict_to_shards(state_dict, save_dir, model_base_name, single_file_name, metadata, *_args, **_kwargs): file_path = os.path.join(save_dir, single_file_name) save_file(state_dict, file_path, metadata=metadata) - tensor_to_filename = {name: single_file_name for name in state_dict.keys()} + tensor_to_filename = dict.fromkeys(state_dict.keys(), single_file_name) total_size = os.path.getsize(file_path) return [single_file_name], tensor_to_filename, total_size diff --git a/tests/test_post_quant_eora.py b/tests/test_post_quant_eora.py index 6c18300ab..57b379685 100644 --- a/tests/test_post_quant_eora.py +++ b/tests/test_post_quant_eora.py @@ -53,7 +53,7 @@ def bench(path: str, backend: BACKEND, adapter: Optional[Lora]): bench_result = GPTQModel.eval( model_or_id_or_path=model, framework=EVAL.LM_EVAL, - tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.MMLU] + tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.MMLU_STEM] ) del model diff --git a/tests/test_quant_and_eora.py b/tests/test_quant_and_eora.py index 1f49550dc..fb02249bf 100644 --- a/tests/test_quant_and_eora.py +++ b/tests/test_quant_and_eora.py @@ -40,12 +40,12 @@ class Test(ModelTest): EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, "acc": {"value": 0.3183, "floor_pct": 0.05}, "acc_norm": {"value": 0.3404, "floor_pct": 0.05}, }, } - APPLY_CHAT_TEMPLATE = True QUANT_BATCH_SIZE = 4 # V2 = False # DEBUG = True @@ -149,7 +149,7 @@ def bench(self, path: str, backend: BACKEND, adapter: Optional[Lora]): tasks=[EVAL.LM_EVAL.ARC_CHALLENGE], apply_chat_template=True, # MMLU is too slow for ci test - # EVAL.LM_EVAL.MMLU + # EVAL.LM_EVAL.MMLU_STEM ) del model diff --git a/tests/test_quant_and_eora_transformers.py b/tests/test_quant_and_eora_transformers.py index 3bf333ae7..fd2a987aa 100644 --- a/tests/test_quant_and_eora_transformers.py +++ b/tests/test_quant_and_eora_transformers.py @@ -197,7 +197,7 @@ def bench(self, path: str, backend: BACKEND, adapter: Optional[Lora]): bench_result = GPTQModel.eval( model_or_id_or_path=model, framework=EVAL.LM_EVAL, - tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.MMLU], + tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.MMLU_STEM], ) del model diff --git a/tests/test_stage_modules.py b/tests/test_stage_modules.py new file mode 100644 index 000000000..ae17fd8ad --- /dev/null +++ b/tests/test_stage_modules.py @@ -0,0 +1,315 @@ +import threading +import types + +import torch + +from gptqmodel.looper.module_looper import FinalizeProgressInfo, ModuleLooper +from gptqmodel.looper.stage_inputs_capture import StageInputsCapture +from gptqmodel.looper.stage_layer import run_layer_stage +from gptqmodel.looper.stage_subset import SubsetForwardContext, SubsetStageResult + + +class _DummyQModel: + def __init__(self): + self.support_batch_quantize = False + self.quantize_config = types.SimpleNamespace(device=None, vram_strategy="exclusive") + self.layer_callback = None + + +def _make_looper(): + processors = [types.SimpleNamespace(layer_count=0, pb=None)] + return ModuleLooper(model=_DummyQModel(), processors=processors) + + +def test_cache_inputs_delegates_to_stage_capture(monkeypatch): + looper = _make_looper() + sentinel = object() + captured = {} + + class FakeStage: + def __init__(self, looper_arg, logger): + captured["looper"] = looper_arg + captured["logger"] = logger + + def cache_inputs(self, **kwargs): + captured["kwargs"] = kwargs + return sentinel + + monkeypatch.setattr( + "gptqmodel.looper.module_looper.StageInputsCapture", + FakeStage, + ) + + layers = [object()] + data = [{"hidden_states": torch.zeros(1, 2, 2)}] + result = looper.cache_inputs(layers=layers, calibration_data=data, use_cache=False) + + assert result is sentinel + assert captured["looper"] is looper + assert captured["kwargs"]["layers"] == layers + assert captured["kwargs"]["calibration_data"] is data + + +class _TinyLayer(torch.nn.Module): + def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs): + return hidden_states + + +class _TinyModel(torch.nn.Module): + def __init__(self, layer): + super().__init__() + self.layer = layer + self.visual_tokenizer = types.SimpleNamespace(dtype=torch.float32) + + def forward(self, *, hidden_states, attention_mask=None, position_ids=None, use_cache=False, **kwargs): + return self.layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + **kwargs, + ) + + +class _TinyGptqModel: + ATTENTION_MASKS_REQUIRED_FOR_INPUT = False + ATTENTION_MASKS_DTYPE = torch.long + INPUT_EMBEDDING_EXTRA_ARGS = {} + + def __init__(self): + self.layer = _TinyLayer() + self.model = _TinyModel(self.layer) + self.quantize_config = types.SimpleNamespace(device=torch.device("cpu")) + self._hook_started = False + self._hook_finished = False + + def shell_module_materialize(self, target_submodule, device): + target_submodule.to(device) + return target_submodule + + def get_base_modules(self, model): + return [] + + def pre_quantize_generate_hook_start(self): + self._hook_started = True + + def pre_quantize_generate_hook_end(self): + self._hook_finished = True + + +class _TinyLooper: + def __init__(self, gptq_model): + self.gptq_model = gptq_model + + def _batch_row_count(self, batch_inputs): + if not batch_inputs: + return 0 + tensor = batch_inputs[0] + return int(tensor.shape[0]) if tensor.ndim > 0 else int(tensor.numel()) + + +def test_stage_inputs_capture_collects_real_inputs(): + gptq_model = _TinyGptqModel() + looper = _TinyLooper(gptq_model) + stage = StageInputsCapture(looper, logger=None) + + hidden = torch.ones(1, 2, 3) + attention = torch.ones(1, 2) + position_ids = torch.arange(2).unsqueeze(0) + extra = torch.tensor([5.0]) + + dataset = [ + { + "hidden_states": hidden.clone(), + "attention_mask": attention.clone(), + "position_ids": position_ids.clone(), + "extra": extra.clone(), + } + ] + + cache = stage.cache_inputs(layers=[gptq_model.layer], calibration_data=dataset, use_cache=False) + + assert len(cache.layer_inputs) == 1 + assert torch.equal(cache.layer_inputs[0][0], hidden) + assert torch.equal(cache.attention_masks[0], attention.long()) + assert torch.equal(cache.position_ids[0], position_ids) + assert torch.equal(cache.layer_input_kwargs[0]["extra"], extra.unsqueeze(0)) + assert gptq_model._hook_started is True + assert gptq_model._hook_finished is True + + +def test_run_layer_stage_invokes_subset_stage(monkeypatch): + calls = [] + + def fake_run_subset_stage(looper, **kwargs): + calls.append(kwargs["subset_index"]) + return SubsetStageResult( + processed_subset={}, + layer_inputs=kwargs["layer_inputs"], + forward_context=SubsetForwardContext( + subset={}, + forward_device_map={}, + subset_forward_serial=False, + subset_total=kwargs["subset_total"], + subset_index=kwargs["subset_index"], + ), + ) + + monkeypatch.setattr("gptqmodel.looper.stage_layer.run_subset_stage", fake_run_subset_stage) + monkeypatch.setattr("gptqmodel.looper.stage_layer.find_modules", lambda *_, **__: {}) + + class DummyPB: + def __init__(self, iterable): + self._iterable = list(iterable) + self.current_iter_step = 0 + + def __iter__(self): + return iter(self._iterable) + + def __len__(self): + return len(self._iterable) + + def manual(self): + return self + + def set(self, **kwargs): + return self + + def title(self, *_): + return self + + def subtitle(self, *_): + return self + + def draw(self): + return self + + def next(self): + return self + + def close(self): + return self + + class DummyLogger: + def pb(self, iterable): + return DummyPB(iterable) + + def info(self, *_, **__): + return None + + def debug(self, *_, **__): + return None + + def warning(self, *_, **__): + return None + + warn = warning + + def error(self, *_, **__): + return None + + class DummyProcessor: + fwd_all_modules_in_single_pass = False + fwd_after_process = False + + def __init__(self): + tensor = torch.zeros(1, 1, 1) + self.inputs_cache = types.SimpleNamespace( + layer_inputs=[[tensor]], + layer_input_kwargs=[{}], + position_ids=[], + attention_masks=[], + ) + self.calibration_dataset = [] + self.log = [] + self.tasks = {} + + def collect_memory_info(self, *_): + return None + + def pre_process_fwd_hook(self, *_): + return lambda *a, **k: None + + def process(self, *_, **__): + return None + + def clear_cache_data(self): + return None + + def receive_layer_inputs(self, inputs): + self.inputs_cache.layer_inputs = inputs + + def set_fwd_time(self, *_): + return None + + def name(self): + return "dummy" + + def submodule_finalize(self, *_, **__): + return None + + def finalize(self, *_, **__): + return None + + def log_plotly(self): + return None + + class DummyGptqModel: + def __init__(self): + self.model = torch.nn.Module() + self.quantize_config = types.SimpleNamespace(lm_head=False) + self.lm_head = None + + def pre_quantize(self, module): + return module + + def post_quantize(self, module): + return module + + def lm_head_pre_quantize_generate_hook(self, value): + return value + + class DummyLooper: + def __init__(self): + self.gptq_model = DummyGptqModel() + self.processors = [DummyProcessor()] + self._quant_devices = [torch.device("cpu")] + self._module_device_map = {} + self._quant_device_lock = threading.Lock() + self._moe_subset_threshold = 16 + self._vram_strategy = types.SimpleNamespace() + self._layer_events = [] + + def _check_loop_stop(self): + return False + + def _emit_layer_complete(self, *, layer_idx, submodule_finalized, raise_in_place): + self._layer_events.append((layer_idx, submodule_finalized, raise_in_place)) + + def _request_loop_stop(self, exc): + self._stop_exc = exc + + looper = DummyLooper() + processor = looper.processors[0] + pb = DummyPB(range(1)) + processor.layer_count = 1 + processor.pb = pb + + layers = [torch.nn.Identity()] + layer_modules = [["foo"]] + logger = DummyLogger() + + run_layer_stage( + looper, + layers=layers, + layer_modules=layer_modules, + layers_prefix="model.layers", + fail_safe=True, + shared_kv_cache_dict={}, + pb=pb, + layer_count=1, + region_timer=None, + finalize_progress_cls=FinalizeProgressInfo, + logger=logger, + ) + + assert calls == [0]