diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index 2d961205d..db0100133 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -185,6 +185,9 @@ def process( wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize() + workspace_summary = getattr(g, "_borrow_workspace_last_summary", None) + workspace_totals = getattr(g, "_borrow_workspace_totals", None) + module.stream_state_payload_to_cpu( { "q_scales": q_scales, @@ -235,6 +238,25 @@ def process( PROCESS_USED_MEMORY: self.device_memory_report(), } + if workspace_summary: + requests = int(workspace_summary.get("requests", 0) or 0) + if requests: + hit_rate = float(workspace_summary.get("hit_rate", 0.0) or 0.0) + chunk_rows = workspace_summary.get("chunk_rows") + stat["workspace_cache_requests"] = str(requests) + stat["workspace_cache_hit_rate"] = f"{hit_rate:.1%}" + stat["workspace_stage_dtype"] = workspace_summary.get("staging_dtype", "") + if chunk_rows is not None: + stat["workspace_chunk_rows"] = str(chunk_rows) + if workspace_totals: + total_requests = int(workspace_totals.get("requests", 0) or 0) + if total_requests: + cumulative_hit_rate = ( + float(workspace_totals.get("materialized_hits", 0) or 0.0) / total_requests + ) + stat["workspace_total_requests"] = str(total_requests) + stat["workspace_total_hit_rate"] = f"{cumulative_hit_rate:.1%}" + if self.qcfg.dynamic is not None: stat["dynamic"] = self.qcfg.dynamic_get(layer_name=module.full_name) @@ -244,6 +266,8 @@ def process( # Log the new row self.log_new_row(stat) + g.log_workspace_stats(context="gptq_process") + if self.calculate_w_wq_diff: # diff in float32 w_wq_diff = module.weight.data.to(dtype=torch.float32) - wq.to(dtype=torch.float32) diff --git a/gptqmodel/looper/loop_processor.py b/gptqmodel/looper/loop_processor.py index 87a552488..fe52bb053 100644 --- a/gptqmodel/looper/loop_processor.py +++ b/gptqmodel/looper/loop_processor.py @@ -439,9 +439,11 @@ def device_memory_report(self) -> str: return "n/a" def _format_gib(value: float) -> str: - text = f"{value:.1f}" - if text.endswith(".0"): + text = f"{value:.2f}" + if text.endswith("00"): text = text[:-2] + elif text.endswith("0"): + text = text[:-1] return f"{text}G" grouped: Dict[str, List[Tuple[str, float, int]]] = {} diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 018bc2d3a..e787d8e29 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -71,16 +71,27 @@ def _needs_workspace_resize( @contextlib.contextmanager -def _lease_workspace(device: torch.device, dtype: torch.dtype, cols: int, required_rows: int): +def _lease_workspace( + device: torch.device, + dtype: torch.dtype, + cols: int, + required_rows: int, +) -> Tuple[torch.Tensor, bool]: key = _workspace_cache_key(device) lock = _WORKSPACE_LOCKS.setdefault(key, threading.Lock()) with lock: workspace = _WORKSPACE_CACHE.pop(key, None) - if _needs_workspace_resize(workspace, dtype, required_rows, cols): + reused = workspace is not None and not _needs_workspace_resize( + workspace, + dtype, + required_rows, + cols, + ) + if not reused: rows = max(required_rows, 1) workspace = torch.empty((rows, cols), dtype=dtype, device=device) try: - yield workspace + yield workspace, reused finally: with lock: _WORKSPACE_CACHE[key] = workspace @@ -172,7 +183,7 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): else: self._final_hessian_device_hint = torch.device(module_device) - self._validate_module(self.module) + self.validate_module(self.module) self.qcfg = qcfg if qcfg else QuantizeConfig() # HF compat will not pass qcfg @@ -196,8 +207,28 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): self._device_sample_counts: Dict[torch.device, int] = {} self._hessian_dirty: bool = False + self._borrow_workspace_stats = { + "requests": 0, + "staging_requests": 0, + "staging_hits": 0, + "staging_misses": 0, + "materialized_requests": 0, + "materialized_hits": 0, + "materialized_misses": 0, + } + self._borrow_workspace_totals = { + "requests": 0, + "materialized_hits": 0, + "materialized_misses": 0, + "staging_hits": 0, + "staging_misses": 0, + } + self._borrow_workspace_last_summary: Optional[Dict[str, object]] = None + self._borrow_workspace_stage_dtype: Optional[torch.dtype] = None + self._borrow_workspace_last_chunk_rows: Optional[int] = None + @staticmethod - def _validate_module(module): + def validate_module(module): assert isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, transformers.Conv1D)), f"We supports only linear and convolutional layers. actual = `{module}`" @@ -213,14 +244,14 @@ def shape(self): else: return (0, 0) - def _mock_hessian_inverse(self, H: torch.Tensor): + def mock_hessian_inverse(self, H: torch.Tensor): """Mock hessian inverse for fast testing""" damp = self.qcfg.damp_percent # Return identity matrix instead of complex inversion identity = torch.eye(H.shape[0], dtype=torch.float32, device=H.device) return identity, damp - def _clone_module(self, copy=True, device: torch.device = None): + def clone_module(self, copy=True, device: torch.device = None): if not device: device = self.module.weight.data.device @@ -243,7 +274,7 @@ def _clone_module(self, copy=True, device: torch.device = None): return clone.float() @staticmethod - def _truncate_last_dim(tensor: torch.Tensor, length: int) -> torch.Tensor: + def truncate_last_dim(tensor: torch.Tensor, length: int) -> torch.Tensor: if tensor.dim() == 0: return tensor @@ -274,7 +305,7 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor, batch_index: Optional[ self.nsamples += batch_token_size self._hessian_dirty = True - def _preferred_staging_dtype(self, input_dtype: torch.dtype, device: torch.device) -> torch.dtype: + def preferred_staging_dtype(self, input_dtype: torch.dtype, device: torch.device) -> torch.dtype: device = torch.device(device) if not self.qcfg.hessian_use_bfloat16_staging: @@ -288,7 +319,7 @@ def _preferred_staging_dtype(self, input_dtype: torch.dtype, device: torch.devic return torch.bfloat16 - def _resolve_hessian_chunk_size(self, rows: int, stage_dtype: torch.dtype) -> Optional[int]: + def resolve_hessian_chunk_size(self, rows: int, stage_dtype: torch.dtype) -> Optional[int]: if rows == 0: return None @@ -308,7 +339,7 @@ def _resolve_hessian_chunk_size(self, rows: int, stage_dtype: torch.dtype) -> Op return None @contextlib.contextmanager - def _borrow_materialized_chunk_fp32( + def borrow_materialized_chunk_fp32( self, chunk: torch.Tensor, rows: int, @@ -318,20 +349,52 @@ def _borrow_materialized_chunk_fp32( return device = chunk.device - stage_dtype = self._preferred_staging_dtype(chunk.dtype, device) + stage_dtype = self.preferred_staging_dtype(chunk.dtype, device) + + stats = self._borrow_workspace_stats + stats["requests"] += 1 + + with _lease_workspace(device, stage_dtype, self.columns, rows) as ( + staging_workspace, + staging_reused, + ): + stats["staging_requests"] += 1 + if staging_reused: + stats["staging_hits"] += 1 + else: + stats["staging_misses"] += 1 - with _lease_workspace(device, stage_dtype, self.columns, rows) as staging_workspace: staging_view = staging_workspace[:rows, :] staging_view.copy_(chunk.to(dtype=stage_dtype)) if stage_dtype == torch.float32: + stats["materialized_requests"] += 1 + if staging_reused: + stats["materialized_hits"] += 1 + else: + stats["materialized_misses"] += 1 + try: yield staging_view finally: if device.type == "cuda": torch.cuda.current_stream(device).synchronize() else: - with _lease_workspace(device, torch.float32, self.columns, rows) as fp32_workspace: + with _lease_workspace( + device, + torch.float32, + self.columns, + rows, + ) as ( + fp32_workspace, + fp32_reused, + ): + stats["materialized_requests"] += 1 + if fp32_reused: + stats["materialized_hits"] += 1 + else: + stats["materialized_misses"] += 1 + try: fp32_view = fp32_workspace[:rows, :] fp32_view.copy_(staging_view.to(torch.float32)) @@ -340,13 +403,15 @@ def _borrow_materialized_chunk_fp32( if device.type == "cuda": torch.cuda.current_stream(device).synchronize() - def _compute_hessian_xtx(self, matrix: torch.Tensor) -> torch.Tensor: + def compute_hessian_xtx(self, matrix: torch.Tensor) -> torch.Tensor: rows = matrix.shape[0] if rows == 0: return torch.zeros((self.columns, self.columns), dtype=torch.float32, device=matrix.device) - stage_dtype = self._preferred_staging_dtype(matrix.dtype, matrix.device) - chunk_size = self._resolve_hessian_chunk_size(rows, stage_dtype) + stage_dtype = self.preferred_staging_dtype(matrix.dtype, matrix.device) + chunk_size = self.resolve_hessian_chunk_size(rows, stage_dtype) + self._borrow_workspace_stage_dtype = stage_dtype + self._borrow_workspace_last_chunk_rows = chunk_size if chunk_size is not None else rows if chunk_size is None: mat32 = matrix.to(dtype=torch.float32) @@ -359,7 +424,7 @@ def _compute_hessian_xtx(self, matrix: torch.Tensor) -> torch.Tensor: for start in range(0, rows, chunk_size): rows_this = min(chunk_size, rows - start) source = matrix[start:start + rows_this] - with self._borrow_materialized_chunk_fp32(source, rows_this) as materialized: + with self.borrow_materialized_chunk_fp32(source, rows_this) as materialized: materialized32 = materialized xtx_accum.add_(torch.matmul(materialized32.T, materialized32)) @@ -423,7 +488,7 @@ def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], return 0, None, canonical_device try: - xtx = self._compute_hessian_xtx(reshaped_inp).to(dtype=torch.float32) + xtx = self.compute_hessian_xtx(reshaped_inp).to(dtype=torch.float32) except RuntimeError as exc: if ( torch.device(inp_device).type == "cuda" @@ -438,7 +503,7 @@ def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], if torch.cuda.is_available(): torch.cuda.empty_cache() canonical_device = torch.device("cpu") - xtx = self._compute_hessian_xtx(reshaped_inp_cpu).to(dtype=torch.float32) + xtx = self.compute_hessian_xtx(reshaped_inp_cpu).to(dtype=torch.float32) xtx = xtx.detach() del reshaped_inp_cpu else: @@ -448,6 +513,7 @@ def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], xtx = xtx.detach() del reshaped_inp + self._snapshot_borrow_workspace_stats(context="process_batch") return batch_token_size, xtx, canonical_device def _select_hessian_target_device(self, requested: Optional[torch.device]) -> torch.device: @@ -665,7 +731,7 @@ def quantize( if self.qcfg.mock_quantization: # Use simplified hessian inverse (identity matrix) - self.hessian_inverse = self._mock_hessian_inverse + self.hessian_inverse = self.mock_hessian_inverse # if self.device.type not in ["mps", "cpu"]: # self.module.weight.data = self.module.weight.data.cpu() @@ -677,7 +743,7 @@ def quantize( if self.module_copy is None: # log.info("copy W to cuda_1") - W = self._clone_module(device=self.H.device) + W = self.clone_module(device=self.H.device) else: W = self.module_copy.to(device=self.H.device) del self.module_copy @@ -952,8 +1018,8 @@ def quantize( if self._tp_pad_cols: valid_cols = self._original_columns - scale = self._truncate_last_dim(scale, valid_cols) - zero = self._truncate_last_dim(zero, valid_cols) + scale = self.truncate_last_dim(scale, valid_cols) + zero = self.truncate_last_dim(zero, valid_cols) Q = Q.to(device=self.module.weight.data.device, non_blocking=False) @@ -961,6 +1027,98 @@ def quantize( return Q, scale, zero, g_idx, duration, avg_loss, damp, self.nsamples + def borrow_materialized_chunk_stats(self, reset: bool = False) -> Dict[str, int]: + stats = dict(self._borrow_workspace_stats) + if reset: + for key in self._borrow_workspace_stats: + self._borrow_workspace_stats[key] = 0 + return stats + + def _snapshot_borrow_workspace_stats(self, *, context: str) -> None: + stats = self.borrow_materialized_chunk_stats(reset=True) + total_requests = int(stats.get("requests", 0) or 0) + if total_requests == 0: + return + + materialized_hits = int(stats.get("materialized_hits", 0) or 0) + materialized_misses = int(stats.get("materialized_misses", 0) or 0) + staging_hits = int(stats.get("staging_hits", 0) or 0) + staging_misses = int(stats.get("staging_misses", 0) or 0) + chunk_rows = self._borrow_workspace_last_chunk_rows + stage_dtype = self._borrow_workspace_stage_dtype + stage_dtype_str = str(stage_dtype) if stage_dtype is not None else "n/a" + hit_rate = materialized_hits / total_requests if total_requests else 0.0 + + summary = { + "context": context, + "requests": total_requests, + "materialized_hits": materialized_hits, + "materialized_misses": materialized_misses, + "staging_hits": staging_hits, + "staging_misses": staging_misses, + "chunk_rows": chunk_rows, + "staging_dtype": stage_dtype_str, + "hit_rate": hit_rate, + } + self._borrow_workspace_last_summary = summary + + totals = self._borrow_workspace_totals + totals["requests"] += total_requests + totals["materialized_hits"] += materialized_hits + totals["materialized_misses"] += materialized_misses + totals["staging_hits"] += staging_hits + totals["staging_misses"] += staging_misses + + def log_workspace_stats(self, *, context: str, reset: bool = True) -> None: + totals = self._borrow_workspace_totals + total_requests = int(totals.get("requests", 0) or 0) + if total_requests == 0: + if reset: + self.reset_workspace_stats() + return + + total_hits = int(totals.get("materialized_hits", 0) or 0) + total_misses = int(totals.get("materialized_misses", 0) or 0) + total_hit_rate = total_hits / total_requests if total_requests else 0.0 + + last = self._borrow_workspace_last_summary or {} + last_requests = int(last.get("requests", 0) or 0) + last_hits = int(last.get("materialized_hits", 0) or 0) + last_misses = int(last.get("materialized_misses", 0) or 0) + last_hit_rate = float(last.get("hit_rate", 0.0) or 0.0) + rows_label = last.get("chunk_rows", "n/a") + stage_dtype = last.get("staging_dtype", "n/a") + + log.info( + "GPTQ workspace cache [%s]: module=%s rows=%s staging_dtype=%s " + "requests=%d hits=%d misses=%d hit_rate=%.2f total_requests=%d " + "total_hits=%d total_misses=%d total_hit_rate=%.2f", + context, + getattr(self, "name", ""), + rows_label, + stage_dtype, + last_requests, + last_hits, + last_misses, + last_hit_rate, + total_requests, + total_hits, + total_misses, + total_hit_rate, + ) + + if reset: + self.reset_workspace_stats() + + def reset_workspace_stats(self) -> None: + for key in self._borrow_workspace_stats: + self._borrow_workspace_stats[key] = 0 + for key in self._borrow_workspace_totals: + self._borrow_workspace_totals[key] = 0 + self._borrow_workspace_last_summary = None + self._borrow_workspace_stage_dtype = None + self._borrow_workspace_last_chunk_rows = None + def free(self): if hasattr(self, "H"): del self.H diff --git a/gptqmodel/quantization/gptqv2.py b/gptqmodel/quantization/gptqv2.py index 640b708b0..af9167894 100644 --- a/gptqmodel/quantization/gptqv2.py +++ b/gptqmodel/quantization/gptqv2.py @@ -135,7 +135,7 @@ def quantize( if self.module_copy is None: # log.info("copy W to cuda_1") - W = self._clone_module(device=self.H.device) + W = self.clone_module(device=self.H.device) else: W = self.module_copy self.module_copy = None diff --git a/tests/models/model_test.py b/tests/models/model_test.py index cae7b2fce..b4612b723 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -114,6 +114,7 @@ class ModelTest(unittest.TestCase): DAMP_PERCENT = 0.05 MSE = 0.0 DYNAMIC = None + HESSIAN_CHUNK_SIZE = None SAVE_PATH = None # default is temp folder @@ -790,6 +791,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne damp_percent=self.DAMP_PERCENT, mse=self.MSE, dynamic=self.DYNAMIC, + hessian_chunk_size=self.HESSIAN_CHUNK_SIZE, ) log.info(f"Quant config: {quantize_config}") diff --git a/tests/models/test_qwen2_5.py b/tests/models/test_qwen2_5.py index 4c5e2c072..695d757d1 100644 --- a/tests/models/test_qwen2_5.py +++ b/tests/models/test_qwen2_5.py @@ -16,6 +16,7 @@ # | gsm8k_plat :: exact,flexible | 0.2963 | class TestQwen2_5(ModelTest): GROUP_SIZE = 32 + HESSIAN_CHUNK_SIZE = 256 * 1024 * 1024 NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct" EVAL_BATCH_SIZE = 64 DATASET_CONCAT_SIZE = 2048 diff --git a/tests/models/test_qwen3_moe.py b/tests/models/test_qwen3_moe.py index d77262830..3d5846a22 100644 --- a/tests/models/test_qwen3_moe.py +++ b/tests/models/test_qwen3_moe.py @@ -13,6 +13,7 @@ # | arc_challenge :: acc,none | 0.5094 | # | arc_challenge :: acc_norm,none | 0.5486 | class TestQwen3Moe(ModelTest): + # HESSIAN_CHUNK_SIZE = 256 * 1024 * 1024 NATIVE_MODEL_ID = "/monster/data/model/Qwen3-30B-A3B" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { diff --git a/tests/test_gptq_hessian_chunking.py b/tests/test_gptq_hessian_chunking.py index 11ba9b409..8eac0923b 100644 --- a/tests/test_gptq_hessian_chunking.py +++ b/tests/test_gptq_hessian_chunking.py @@ -67,7 +67,7 @@ def _one_batch(idx: int): activation_mb = (batch_size * seq_len * hidden_dim * 2) / (1024**2) peak_delta_mb = max(0.0, (peak_alloc - baseline_alloc) / (1024**2)) - chunk_rows = gptq._resolve_hessian_chunk_size(batch_size * seq_len, torch.float32) + chunk_rows = gptq.resolve_hessian_chunk_size(batch_size * seq_len, torch.float32) return { "chunk_bytes": chunk_bytes, diff --git a/tests/test_hessian_chunk.py b/tests/test_hessian_chunk.py index 5149f0845..3e5c1d69b 100644 --- a/tests/test_hessian_chunk.py +++ b/tests/test_hessian_chunk.py @@ -40,7 +40,7 @@ def _clone_module(module: torch.nn.Module) -> torch.nn.Module: def _instrument_chunks(gptq: GPTQ) -> None: - original = gptq._borrow_materialized_chunk_fp32 + original = gptq.borrow_materialized_chunk_fp32 @contextlib.contextmanager def wrapped(self, chunk, rows): @@ -49,7 +49,7 @@ def wrapped(self, chunk, rows): yield materialized gptq._chunk_invocations = 0 - gptq._borrow_materialized_chunk_fp32 = types.MethodType(wrapped, gptq) + gptq.borrow_materialized_chunk_fp32 = types.MethodType(wrapped, gptq) def test_hessian_chunk_consistency_matches_full_precision(): @@ -100,10 +100,37 @@ def test_hessian_chunk_invocations_and_workspace_shape(): large_gptq.process_batch(calib.clone()) assert large_gptq._chunk_invocations == 1 + large_summary = getattr(large_gptq, "_borrow_workspace_last_summary", None) + assert large_summary is not None + assert large_summary["requests"] == 1 + assert large_summary["materialized_hits"] == 0 + assert large_summary["materialized_misses"] == 1 + assert large_summary["staging_misses"] == 1 + large_totals = getattr(large_gptq, "_borrow_workspace_totals", {}) + assert large_totals.get("requests") == 1 + assert large_totals.get("materialized_misses") == 1 + large_gptq.log_workspace_stats(context="test_hessian_chunk", reset=True) + assert getattr(large_gptq, "_borrow_workspace_totals", {}).get("requests") == 0 + small_gptq.process_batch(calib.clone()) expected_chunks = math.ceil(calib.shape[0] / small_cfg.hessian_chunk_size) assert small_gptq._chunk_invocations == expected_chunks + small_summary = getattr(small_gptq, "_borrow_workspace_last_summary", None) + assert small_summary is not None + assert small_summary["requests"] == expected_chunks + assert small_summary["materialized_hits"] + small_summary["materialized_misses"] == expected_chunks + assert small_summary["materialized_hits"] >= expected_chunks - 1 + assert pytest.approx( + small_summary["hit_rate"], + rel=1e-6, + ) == small_summary["materialized_hits"] / expected_chunks + small_totals = getattr(small_gptq, "_borrow_workspace_totals", {}) + assert small_totals.get("requests") == expected_chunks + assert small_totals.get("materialized_hits") >= expected_chunks - 1 + small_gptq.log_workspace_stats(context="test_hessian_chunk", reset=True) + assert getattr(small_gptq, "_borrow_workspace_totals", {}).get("requests") == 0 + device = torch.device(base.weight.device) cache_key = gptq_impl._workspace_cache_key(device) @@ -116,7 +143,7 @@ def test_hessian_chunk_invocations_and_workspace_shape(): small_workspace = gptq_impl._WORKSPACE_CACHE[cache_key] assert small_workspace is large_workspace - staging_dtype = small_gptq._preferred_staging_dtype(calib.dtype, device) + staging_dtype = small_gptq.preferred_staging_dtype(calib.dtype, device) if staging_dtype == torch.bfloat16: staged_workspace = gptq_impl._WORKSPACE_CACHE[cache_key] assert staged_workspace.dtype == torch.bfloat16 @@ -188,7 +215,7 @@ def worker(task_id: int) -> None: assert cached_workspace.shape[0] >= expected_rows assert cached_workspace.shape[1] == cols - stage_dtype = gptq_workers[0]._preferred_staging_dtype(torch.float16, device) + stage_dtype = gptq_workers[0].preferred_staging_dtype(torch.float16, device) if stage_dtype == torch.bfloat16: assert cached_workspace.dtype == torch.bfloat16 else: