From ff59eec137e2a062f2e8650e0927d48420f9352f Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 26 Sep 2025 23:23:41 +0000 Subject: [PATCH 1/6] add eora toggle to ci test Signed-off-by: Qubitium --- tests/models/model_test.py | 7 +++++-- tests/models/test_llama3_2.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/models/model_test.py b/tests/models/model_test.py index d8b4e09a4..585252afe 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -75,6 +75,7 @@ class ModelTest(unittest.TestCase): V2 = False ACT_GROUP_AWARE = True FAIL_SAFE = True + EORA = None SAVE_PATH = None # default is temp folder @@ -89,7 +90,6 @@ class ModelTest(unittest.TestCase): EXPECT_LM_HEAD_LOSS = None - def assertInference(self, model, tokenizer=None, keywords=None, prompt=INFERENCE_PROMPT): # gptqmodel can auto init tokenizer internally if keywords is None: @@ -181,7 +181,8 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne act_group_aware=self.ACT_GROUP_AWARE, fail_safe=self.FAIL_SAFE, sym=self.SYM, - v2=self.V2 + v2=self.V2, + adapter=self.EORA, ) log.info(f"Quant config: {quantize_config}") @@ -267,6 +268,7 @@ def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_pa backend=self.LOAD_BACKEND, device_map={"": "cpu"} if self.LOAD_BACKEND == BACKEND.TORCH_FUSED else "auto", debug=self.DEBUG, + adapter=self.EORA, **kargs ) @@ -296,6 +298,7 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del for framework, tasks in task_groups.items(): log.info(f"TEST: EVAL starting: backend = {self.LOAD_BACKEND}") + log.info(f"Inference from model path: {model.model_local_path}") results = GPTQModel.eval( model_or_id_or_path=model.model_local_path, llm_backend="vllm" if self.USE_VLLM else "gptqmodel", diff --git a/tests/models/test_llama3_2.py b/tests/models/test_llama3_2.py index f8175aebc..2b3adcd8b 100644 --- a/tests/models/test_llama3_2.py +++ b/tests/models/test_llama3_2.py @@ -4,6 +4,7 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from model_test import ModelTest +from gptqmodel.adapter.adapter import Lora # a100:0 @@ -23,6 +24,11 @@ class TestLlama3_2(ModelTest): DATASET_SIZE = 1024 DATASET_SORT = "desc" QUANT_BATCH_SIZE = 4 + # EORA = Lora( + # # for quant, path is save path. for load, it is loading path + # path="./eora_test", + # rank=128, + # ) # b1 = 0.315, b4 = 0.3106, b8 = 0.3148, b32 = 0.3148, b16 = 0.3234 def test_llama3_2(self): From d0dad38e8ecdf334468c726784b87ad2b0aa0a40 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 27 Sep 2025 01:06:38 +0000 Subject: [PATCH 2/6] memory tracker v1 Signed-off-by: Qubitium --- gptqmodel/utils/memory.py | 147 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 gptqmodel/utils/memory.py diff --git a/gptqmodel/utils/memory.py b/gptqmodel/utils/memory.py new file mode 100644 index 000000000..482fcbb1b --- /dev/null +++ b/gptqmodel/utils/memory.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations +import threading +from typing import Dict, Iterable, Tuple + +import torch +import torch.nn as nn + +Obj = nn.Module | torch.Tensor # Python 3.10+ union syntax + + +class MemTracker: + """ + Tracks memory attributed to Modules or Tensors by **device instance** (e.g., 'cuda:0', 'cuda:1', 'cpu'). + Query with a torch.device to aggregate by **type** (torch.device('cuda')) or a specific index + (torch.device('cuda:1')). + """ + + def __init__(self) -> None: + self._allocated_by_dev: Dict[str, int] = {} + self._freed_by_dev: Dict[str, int] = {} + self._lock = threading.Lock() + + # ---------- Public API ---------- + def allocate(self, ob: Obj) -> None: + sizes = self._sizes_by_device_instance(ob) + with self._lock: + for dev_key, b in sizes.items(): + self._allocated_by_dev[dev_key] = self._allocated_by_dev.get(dev_key, 0) + b + + def free(self, ob: Obj) -> None: + sizes = self._sizes_by_device_instance(ob) + with self._lock: + for dev_key, b in sizes.items(): + if b <= 0: + continue + self._allocated_by_dev[dev_key] = max(0, self._allocated_by_dev.get(dev_key, 0) - b) + self._freed_by_dev[dev_key] = self._freed_by_dev.get(dev_key, 0) + b + + def reset(self) -> None: + with self._lock: + self._allocated_by_dev.clear() + self._freed_by_dev.clear() + + def allocated(self, device: torch.device | None = None) -> Tuple[int, str]: + """Return (raw_bytes, formatted_string) for allocated memory.""" + with self._lock: + if device is None: + val = sum(self._allocated_by_dev.values()) + else: + val = _sum_for_device(self._allocated_by_dev, device) + return val, format_bytes(val) + + def freed(self, device: torch.device | None = None) -> Tuple[int, str]: + """Return (raw_bytes, formatted_string) for freed memory.""" + with self._lock: + if device is None: + val = sum(self._freed_by_dev.values()) + else: + val = _sum_for_device(self._freed_by_dev, device) + return val, format_bytes(val) + + # ---------- Helpers ---------- + def _sizes_by_device_instance(self, ob: Obj) -> Dict[str, int]: + tensors = list(self._gather_tensors(ob)) + return self._sum_by_devkey_dedup(tensors) + + def _gather_tensors(self, ob: Obj) -> Iterable[torch.Tensor]: + if isinstance(ob, torch.Tensor): + yield ob + return + for p in ob.parameters(recurse=True): + yield p.data + for b in ob.buffers(recurse=True): + yield b + + def _sum_by_devkey_dedup(self, tensors: Iterable[torch.Tensor]) -> Dict[str, int]: + seen_keys: set[tuple[int, int]] = set() + by_dev: Dict[str, int] = {} + + def _accumulate_dense(t: torch.Tensor) -> None: + if not isinstance(t, torch.Tensor): + return + dev = t.device + if dev.type == "meta": + return + dev_key = str(dev) # e.g., 'cuda:0', 'cpu' + try: + st = t.untyped_storage() + key = (st.data_ptr(), st.nbytes()) + if key in seen_keys: + return + seen_keys.add(key) + by_dev[dev_key] = by_dev.get(dev_key, 0) + int(st.nbytes()) + except RuntimeError: + nbytes = int(t.numel() * t.element_size()) + key = (t.data_ptr(), nbytes) + if key in seen_keys: + return + seen_keys.add(key) + by_dev[dev_key] = by_dev.get(dev_key, 0) + nbytes + + for t in tensors: + if t.is_sparse: + _accumulate_dense(t.indices()) + _accumulate_dense(t.values()) + elif t.layout == torch.sparse_csr: + _accumulate_dense(t.crow_indices()) + _accumulate_dense(t.col_indices()) + _accumulate_dense(t.values()) + elif getattr(torch, "sparse_csc", None) is not None and t.layout == torch.sparse_csc: + _accumulate_dense(t.ccol_indices()) + _accumulate_dense(t.row_indices()) + _accumulate_dense(t.values()) + else: + _accumulate_dense(t) + + return by_dev + + +def _sum_for_device(table: Dict[str, int], device: torch.device) -> int: + dev_type = device.type + idx = device.index + if idx is None: + total = 0 + prefix = f"{dev_type}:" + for k, v in table.items(): + if k == dev_type or k.startswith(prefix): + total += v + return total + else: + key = f"{dev_type}:{idx}" + return table.get(key, 0) + + +# ---------- Optional utility ---------- +def format_bytes(n: int) -> str: + units = ["B", "KiB", "MiB", "GiB", "TiB"] + x = float(n) + for u in units: + if x < 1024 or u == units[-1]: + return f"{x:.2f} {u}" + x /= 1024.0 From 4e5fe3ba6a0da3c6bb278272e3a45a79c3713c3f Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 27 Sep 2025 05:01:49 +0000 Subject: [PATCH 3/6] fix bad device check Signed-off-by: Qubitium --- gptqmodel/nn_modules/qlinear/marlin.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/marlin.py b/gptqmodel/nn_modules/qlinear/marlin.py index 3c9b05574..ced343537 100644 --- a/gptqmodel/nn_modules/qlinear/marlin.py +++ b/gptqmodel/nn_modules/qlinear/marlin.py @@ -207,16 +207,15 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: @classmethod def validate_device(cls, device: DEVICE): super().validate_device(device) - CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES") if device == DEVICE.CUDA: if IS_ROCM: raise NotImplementedError("Marlin kernel is not supported on ROCm.") - if CUDA_VISIBLE_DEVICES is None: - has_cuda_v8 = all(torch.cuda.get_device_capability(i)[0] >= 8 for i in range(torch.cuda.device_count())) - else: - has_cuda_v8 = all( - torch.cuda.get_device_capability(i)[0] >= 8 for i in range(len(CUDA_VISIBLE_DEVICES.split(",")))) + # Directly check capabilities of all currently visible CUDA devices + has_cuda_v8 = all( + torch.cuda.get_device_capability(i)[0] >= 8 + for i in range(torch.cuda.device_count()) + ) if not has_cuda_v8: raise NotImplementedError("Marlin kernel only supports compute capability >= 8.0.") From 5ee59c1a54ac40eeb61da8dcb57c387ee49e5cf9 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 27 Sep 2025 05:07:49 +0000 Subject: [PATCH 4/6] memory v3 Signed-off-by: Qubitium --- gptqmodel/utils/memory.py | 269 ++++++++++++++++++++++++++++++-------- 1 file changed, 217 insertions(+), 52 deletions(-) diff --git a/gptqmodel/utils/memory.py b/gptqmodel/utils/memory.py index 482fcbb1b..cffdfd0c1 100644 --- a/gptqmodel/utils/memory.py +++ b/gptqmodel/utils/memory.py @@ -1,73 +1,162 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-FileCopyrightText: 2025 ModelCloud.ai # SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium from __future__ import annotations import threading -from typing import Dict, Iterable, Tuple +from typing import Dict, Iterable, Tuple, Generator import torch import torch.nn as nn -Obj = nn.Module | torch.Tensor # Python 3.10+ union syntax +# ---------- ANSI COLORS ---------- +RESET = "\033[0m" +RED = "\033[91m" +GREEN = "\033[92m" +YELLOW = "\033[93m" +CYAN = "\033[96m" +MAGENTA = "\033[95m" + +# ---------- TYPE ALIASES ---------- +Obj = nn.Module | torch.Tensor +ObjOrTuple = Obj | tuple[Obj, ...] class MemTracker: """ - Tracks memory attributed to Modules or Tensors by **device instance** (e.g., 'cuda:0', 'cuda:1', 'cpu'). - Query with a torch.device to aggregate by **type** (torch.device('cuda')) or a specific index - (torch.device('cuda:1')). + Tracks memory per device instance (torch.device). + + - allocate(obj|tuple): track allocations (nn.Module or torch.Tensor). + - free(obj|tuple): track frees; optional auto-GC per device when freed bytes exceed threshold. + - allocated(device?): -> (raw_bytes, human_str). + - freed(device?): -> (raw_bytes, human_str). + - set_auto_gc(bytes|None): enable/disable/change threshold for auto-GC per device. + + Debug logging (ANSI colored): + - allocate -> red + - free -> green + - summaries -> cyan + - auto-GC -> yellow + - reset -> magenta + + Summary FIX: + - Summaries now list **all known device+index pairs** every time (not just affected devices), + then show per-type aggregates. """ - def __init__(self) -> None: - self._allocated_by_dev: Dict[str, int] = {} - self._freed_by_dev: Dict[str, int] = {} + def __init__(self, auto_gc_bytes: int | None = None) -> None: + self._allocated_by_dev: Dict[torch.device, int] = {} + self._freed_by_dev: Dict[torch.device, int] = {} + self._auto_gc_bytes: int | None = auto_gc_bytes + + # GC accounting + self._gc_count_by_dev: Dict[torch.device, int] = {} + self._gc_total_count: int = 0 + self._lock = threading.Lock() # ---------- Public API ---------- - def allocate(self, ob: Obj) -> None: - sizes = self._sizes_by_device_instance(ob) + def allocate(self, ob: ObjOrTuple) -> None: + sizes = self._sizes_for_many(ob) + affected: set[torch.device] = set() + with self._lock: - for dev_key, b in sizes.items(): - self._allocated_by_dev[dev_key] = self._allocated_by_dev.get(dev_key, 0) + b + # Apply deltas + for dev, b in sizes.items(): + self._allocated_by_dev[dev] = self._allocated_by_dev.get(dev, 0) + b + affected.add(dev) + print(f"{RED}[allocate]{RESET} +{format_bytes(b)} on {dev}") + + # Build **complete** device set and per-type aggregates + all_devs = self._all_known_devices_locked() + type_totals = self._totals_by_type_locked(self._allocated_by_dev) + + # Print summaries outside lock + self._print_full_device_summary( + header=f"{CYAN}[allocate-summary]{RESET}", + per_device_map=self._allocated_by_dev, + all_devices=all_devs, + type_totals=type_totals, + ) + + def free(self, ob: ObjOrTuple) -> None: + sizes = self._sizes_for_many(ob) + affected: set[torch.device] = set() - def free(self, ob: Obj) -> None: - sizes = self._sizes_by_device_instance(ob) with self._lock: - for dev_key, b in sizes.items(): + # Apply deltas + for dev, b in sizes.items(): if b <= 0: continue - self._allocated_by_dev[dev_key] = max(0, self._allocated_by_dev.get(dev_key, 0) - b) - self._freed_by_dev[dev_key] = self._freed_by_dev.get(dev_key, 0) + b + self._allocated_by_dev[dev] = max(0, self._allocated_by_dev.get(dev, 0) - b) + self._freed_by_dev[dev] = self._freed_by_dev.get(dev, 0) + b + affected.add(dev) + print(f"{GREEN}[free]{RESET} released {format_bytes(b)} on {dev}") + + # Build **complete** device set and per-type aggregates for freed + all_devs = self._all_known_devices_locked() + freed_type_totals = self._totals_by_type_locked(self._freed_by_dev) + + # Print summaries outside lock + self._print_full_device_summary( + header=f"{CYAN}[free-summary]{RESET}", + per_device_map=self._freed_by_dev, + all_devices=all_devs, + type_totals=freed_type_totals, + ) + + # Auto-GC checks (may reset a specific device's freed counter) + if self._auto_gc_bytes is not None and self._auto_gc_bytes > 0: + for dev in affected: + self._maybe_auto_gc(dev) def reset(self) -> None: with self._lock: self._allocated_by_dev.clear() self._freed_by_dev.clear() + self._gc_count_by_dev.clear() + self._gc_total_count = 0 + print(f"{MAGENTA}[reset]{RESET} counters cleared") def allocated(self, device: torch.device | None = None) -> Tuple[int, str]: - """Return (raw_bytes, formatted_string) for allocated memory.""" with self._lock: - if device is None: - val = sum(self._allocated_by_dev.values()) - else: - val = _sum_for_device(self._allocated_by_dev, device) + val = sum(self._allocated_by_dev.values()) if device is None else _sum_for_device(self._allocated_by_dev, device) + print(f"{CYAN}[allocated]{RESET} query={device}, result={format_bytes(val)}") return val, format_bytes(val) def freed(self, device: torch.device | None = None) -> Tuple[int, str]: - """Return (raw_bytes, formatted_string) for freed memory.""" with self._lock: - if device is None: - val = sum(self._freed_by_dev.values()) - else: - val = _sum_for_device(self._freed_by_dev, device) + val = sum(self._freed_by_dev.values()) if device is None else __sum_for_device(self._freed_by_dev, device) + print(f"{CYAN}[freed]{RESET} query={device}, result={format_bytes(val)}") return val, format_bytes(val) + def set_auto_gc(self, bytes_threshold: int | None) -> None: + with self._lock: + self._auto_gc_bytes = bytes_threshold + print(f"{YELLOW}[set_auto_gc]{RESET} threshold={bytes_threshold}") + # ---------- Helpers ---------- - def _sizes_by_device_instance(self, ob: Obj) -> Dict[str, int]: + def _sizes_for_many(self, ob: ObjOrTuple) -> Dict[torch.device, int]: + agg: Dict[torch.device, int] = {} + for item in self._iter_objs(ob): + for dev, b in self._sizes_by_device_instance(item).items(): + agg[dev] = agg.get(dev, 0) + b + return agg + + def _iter_objs(self, ob: ObjOrTuple) -> Generator[Obj, None, None]: + if isinstance(ob, tuple): + for x in ob: + if isinstance(x, (nn.Module, torch.Tensor)): + yield x + else: + raise TypeError(f"Unsupported type in tuple: {type(x)}") + elif isinstance(ob, (nn.Module, torch.Tensor)): + yield ob + else: + raise TypeError(f"Unsupported type: {type(ob)}") + + def _sizes_by_device_instance(self, ob: Obj) -> Dict[torch.device, int]: tensors = list(self._gather_tensors(ob)) - return self._sum_by_devkey_dedup(tensors) + return self._sum_by_dev_dedup(tensors) def _gather_tensors(self, ob: Obj) -> Iterable[torch.Tensor]: if isinstance(ob, torch.Tensor): @@ -78,31 +167,32 @@ def _gather_tensors(self, ob: Obj) -> Iterable[torch.Tensor]: for b in ob.buffers(recurse=True): yield b - def _sum_by_devkey_dedup(self, tensors: Iterable[torch.Tensor]) -> Dict[str, int]: + def _sum_by_dev_dedup(self, tensors: Iterable[torch.Tensor]) -> Dict[torch.device, int]: + """ + Sum bytes by device (torch.device), deduping shared storages via (data_ptr, nbytes). + Counts sparse indices/values; ignores 'meta'. + """ seen_keys: set[tuple[int, int]] = set() - by_dev: Dict[str, int] = {} + by_dev: Dict[torch.device, int] = {} def _accumulate_dense(t: torch.Tensor) -> None: - if not isinstance(t, torch.Tensor): - return dev = t.device if dev.type == "meta": return - dev_key = str(dev) # e.g., 'cuda:0', 'cpu' try: st = t.untyped_storage() key = (st.data_ptr(), st.nbytes()) if key in seen_keys: return seen_keys.add(key) - by_dev[dev_key] = by_dev.get(dev_key, 0) + int(st.nbytes()) + by_dev[dev] = by_dev.get(dev, 0) + int(st.nbytes()) except RuntimeError: nbytes = int(t.numel() * t.element_size()) key = (t.data_ptr(), nbytes) if key in seen_keys: return seen_keys.add(key) - by_dev[dev_key] = by_dev.get(dev_key, 0) + nbytes + by_dev[dev] = by_dev.get(dev, 0) + nbytes for t in tensors: if t.is_sparse: @@ -121,23 +211,93 @@ def _accumulate_dense(t: torch.Tensor) -> None: return by_dev + # ---- Device sets & summaries (BUGFIX area) ---- + def _all_known_devices_locked(self) -> list[torch.device]: + """ + Build a sorted list of **all** devices we know about from both allocated and freed maps. + """ + all_set = set(self._allocated_by_dev.keys()) | set(self._freed_by_dev.keys()) + # Sort by (type, index_or_-1) + return sorted(all_set, key=lambda d: (d.type, -1 if d.index is None else d.index)) + + def _totals_by_type_locked(self, table: Dict[torch.device, int]) -> Dict[str, int]: + out: Dict[str, int] = {} + for d, v in table.items(): + out[d.type] = out.get(d.type, 0) + v + return out + + def _print_full_device_summary( + self, + header: str, + per_device_map: Dict[torch.device, int], + all_devices: list[torch.device], + type_totals: Dict[str, int], + ) -> None: + """ + Print a full summary: + - one line per **device instance** (for every known device) + - one line per **device type** aggregate + """ + # Per-device lines (ALL known devices) + for dev in all_devices: + val = per_device_map.get(dev, 0) + print(f"{header} {dev}: {format_bytes(val)}") + + # Per-type aggregates + for dtype in sorted(type_totals.keys()): + print(f"{header} {dtype}: {format_bytes(type_totals[dtype])}") -def _sum_for_device(table: Dict[str, int], device: torch.device) -> int: - dev_type = device.type - idx = device.index - if idx is None: - total = 0 - prefix = f"{dev_type}:" - for k, v in table.items(): - if k == dev_type or k.startswith(prefix): - total += v - return total + # ---- Auto-GC ---- + def _maybe_auto_gc(self, dev: torch.device) -> None: + threshold = self._auto_gc_bytes + if threshold is None or threshold <= 0: + return + + with self._lock: + current_freed = self._freed_by_dev.get(dev, 0) + + if current_freed < threshold: + return + + if _run_backend_gc(dev): + with self._lock: + self._freed_by_dev[dev] = 0 + self._gc_count_by_dev[dev] = self._gc_count_by_dev.get(dev, 0) + 1 + self._gc_total_count += 1 + per_dev_count = self._gc_count_by_dev[dev] + total_count = self._gc_total_count + + print(f"{YELLOW}[auto_gc]{RESET} {dev}: ran GC (count={per_dev_count}), total across devices={total_count}") + + +def _run_backend_gc(dev: torch.device) -> bool: + try: + if dev.type == "cuda": + if dev.index is not None: + torch.cuda.set_device(dev.index) + torch.cuda.empty_cache() + return True + if dev.type == "xpu" and hasattr(torch, "xpu"): + torch.xpu.empty_cache() # type: ignore[attr-defined] + return True + if dev.type == "mps" and hasattr(torch, "mps"): + torch.mps.empty_cache() # type: ignore[attr-defined] + return True + if dev.type == "npu" and hasattr(torch, "npu"): + torch.npu.empty_cache() # type: ignore[attr-defined] + return True + return False + except Exception: + return False + + +def _sum_for_device(table: Dict[torch.device, int], query: torch.device) -> int: + if query.index is None: + return sum(v for d, v in table.items() if d.type == query.type) else: - key = f"{dev_type}:{idx}" - return table.get(key, 0) + return table.get(torch.device(query.type, query.index), 0) -# ---------- Optional utility ---------- def format_bytes(n: int) -> str: units = ["B", "KiB", "MiB", "GiB", "TiB"] x = float(n) @@ -145,3 +305,8 @@ def format_bytes(n: int) -> str: if x < 1024 or u == units[-1]: return f"{x:.2f} {u}" x /= 1024.0 + + + +# default to auto gc interval for every ~256GB of freed memory +MEM_LORD = MemTracker(auto_gc_bytes=1024 * 1024 * 1024) \ No newline at end of file From 943d3140d013ecad21fb8b728d3e981b0041e1ca Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 27 Sep 2025 05:22:30 +0000 Subject: [PATCH 5/6] auto threshold for gc Signed-off-by: Qubitium --- gptqmodel/utils/memory.py | 151 +++++++++++++++++++++----------------- 1 file changed, 84 insertions(+), 67 deletions(-) diff --git a/gptqmodel/utils/memory.py b/gptqmodel/utils/memory.py index cffdfd0c1..67ad93ec3 100644 --- a/gptqmodel/utils/memory.py +++ b/gptqmodel/utils/memory.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +import os import threading from typing import Dict, Iterable, Tuple, Generator @@ -16,37 +17,22 @@ CYAN = "\033[96m" MAGENTA = "\033[95m" +# ---------- DEBUG FLAG ---------- +DEBUG_MODE = os.environ.get("DEBUG", "0") == "1" + +def _log(msg: str) -> None: + if DEBUG_MODE: + print(msg) + # ---------- TYPE ALIASES ---------- Obj = nn.Module | torch.Tensor ObjOrTuple = Obj | tuple[Obj, ...] class MemTracker: - """ - Tracks memory per device instance (torch.device). - - - allocate(obj|tuple): track allocations (nn.Module or torch.Tensor). - - free(obj|tuple): track frees; optional auto-GC per device when freed bytes exceed threshold. - - allocated(device?): -> (raw_bytes, human_str). - - freed(device?): -> (raw_bytes, human_str). - - set_auto_gc(bytes|None): enable/disable/change threshold for auto-GC per device. - - Debug logging (ANSI colored): - - allocate -> red - - free -> green - - summaries -> cyan - - auto-GC -> yellow - - reset -> magenta - - Summary FIX: - - Summaries now list **all known device+index pairs** every time (not just affected devices), - then show per-type aggregates. - """ - - def __init__(self, auto_gc_bytes: int | None = None) -> None: + def __init__(self, auto_gc_bytes: int | str | None = "auto") -> None: self._allocated_by_dev: Dict[torch.device, int] = {} self._freed_by_dev: Dict[torch.device, int] = {} - self._auto_gc_bytes: int | None = auto_gc_bytes # GC accounting self._gc_count_by_dev: Dict[torch.device, int] = {} @@ -54,28 +40,28 @@ def __init__(self, auto_gc_bytes: int | None = None) -> None: self._lock = threading.Lock() + # Resolve threshold + self._auto_gc_bytes: int | None = None + self._resolve_and_set_auto_gc(auto_gc_bytes, context="init") + # ---------- Public API ---------- def allocate(self, ob: ObjOrTuple) -> None: sizes = self._sizes_for_many(ob) - affected: set[torch.device] = set() - with self._lock: - # Apply deltas for dev, b in sizes.items(): self._allocated_by_dev[dev] = self._allocated_by_dev.get(dev, 0) + b - affected.add(dev) - print(f"{RED}[allocate]{RESET} +{format_bytes(b)} on {dev}") + _log(f"{RED}[allocate]{RESET} +{format_bytes(b)} on {dev}") - # Build **complete** device set and per-type aggregates all_devs = self._all_known_devices_locked() type_totals = self._totals_by_type_locked(self._allocated_by_dev) + type_counts = self._counts_by_type_locked() - # Print summaries outside lock self._print_full_device_summary( header=f"{CYAN}[allocate-summary]{RESET}", per_device_map=self._allocated_by_dev, all_devices=all_devs, type_totals=type_totals, + type_counts=type_counts, ) def free(self, ob: ObjOrTuple) -> None: @@ -83,28 +69,26 @@ def free(self, ob: ObjOrTuple) -> None: affected: set[torch.device] = set() with self._lock: - # Apply deltas for dev, b in sizes.items(): if b <= 0: continue self._allocated_by_dev[dev] = max(0, self._allocated_by_dev.get(dev, 0) - b) self._freed_by_dev[dev] = self._freed_by_dev.get(dev, 0) + b affected.add(dev) - print(f"{GREEN}[free]{RESET} released {format_bytes(b)} on {dev}") + _log(f"{GREEN}[free]{RESET} released {format_bytes(b)} on {dev}") - # Build **complete** device set and per-type aggregates for freed all_devs = self._all_known_devices_locked() freed_type_totals = self._totals_by_type_locked(self._freed_by_dev) + type_counts = self._counts_by_type_locked() - # Print summaries outside lock self._print_full_device_summary( header=f"{CYAN}[free-summary]{RESET}", per_device_map=self._freed_by_dev, all_devices=all_devs, type_totals=freed_type_totals, + type_counts=type_counts, ) - # Auto-GC checks (may reset a specific device's freed counter) if self._auto_gc_bytes is not None and self._auto_gc_bytes > 0: for dev in affected: self._maybe_auto_gc(dev) @@ -115,26 +99,67 @@ def reset(self) -> None: self._freed_by_dev.clear() self._gc_count_by_dev.clear() self._gc_total_count = 0 - print(f"{MAGENTA}[reset]{RESET} counters cleared") + _log(f"{MAGENTA}[reset]{RESET} counters cleared") def allocated(self, device: torch.device | None = None) -> Tuple[int, str]: with self._lock: val = sum(self._allocated_by_dev.values()) if device is None else _sum_for_device(self._allocated_by_dev, device) - print(f"{CYAN}[allocated]{RESET} query={device}, result={format_bytes(val)}") + _log(f"{CYAN}[allocated]{RESET} query={device}, result={format_bytes(val)}") return val, format_bytes(val) def freed(self, device: torch.device | None = None) -> Tuple[int, str]: with self._lock: - val = sum(self._freed_by_dev.values()) if device is None else __sum_for_device(self._freed_by_dev, device) - print(f"{CYAN}[freed]{RESET} query={device}, result={format_bytes(val)}") + val = sum(self._freed_by_dev.values()) if device is None else _sum_for_device(self._freed_by_dev, device) + _log(f"{CYAN}[freed]{RESET} query={device}, result={format_bytes(val)}") return val, format_bytes(val) - def set_auto_gc(self, bytes_threshold: int | None) -> None: + def set_auto_gc(self, bytes_threshold: int | str | None) -> None: + self._resolve_and_set_auto_gc(bytes_threshold, context="set_auto_gc") + + # ---------- Auto threshold ---------- + def _resolve_and_set_auto_gc(self, val: int | str | None, context: str) -> None: + auto_requested = (val is None) or (isinstance(val, str) and val.lower() == "auto") + + if not auto_requested: + if not isinstance(val, int) or val < 0: + raise ValueError("auto_gc_bytes must be an int >= 0, 'auto', or None") + with self._lock: + self._auto_gc_bytes = val + _log(f"{YELLOW}[{context}]{RESET} auto_gc_bytes set to {format_bytes(val)} (explicit)") + return + + threshold, debug_msg = self._compute_auto_threshold() with self._lock: - self._auto_gc_bytes = bytes_threshold - print(f"{YELLOW}[set_auto_gc]{RESET} threshold={bytes_threshold}") + self._auto_gc_bytes = threshold - # ---------- Helpers ---------- + if threshold is None or threshold <= 0: + _log(f"{YELLOW}[{context}]{RESET} auto_gc_bytes: CUDA not available; auto-GC disabled. {debug_msg}") + else: + _log(f"{YELLOW}[{context}]{RESET} auto_gc_bytes (auto): {debug_msg} -> {format_bytes(threshold)}") + + def _compute_auto_threshold(self) -> tuple[int | None, str]: + try: + if not torch.cuda.is_available(): + return None, "torch.cuda.is_available() == False" + count = torch.cuda.device_count() + if count <= 0: + return None, "torch.cuda.device_count() == 0" + totals = [] + parts = [] + for i in range(count): + props = torch.cuda.get_device_properties(i) + total = int(getattr(props, "total_memory", 0)) + totals.append(total) + parts.append(f"{i}:{format_bytes(total)}") + if not totals: + return None, "No visible CUDA totals found" + min_total = min(totals) + threshold = min_total // 3 + return threshold, f"visible CUDA -> [{', '.join(parts)}]; min={format_bytes(min_total)}; min/3={format_bytes(threshold)}" + except Exception as e: + return None, f"auto detection error: {e}" + + # ---------- Memory accounting helpers ---------- def _sizes_for_many(self, ob: ObjOrTuple) -> Dict[torch.device, int]: agg: Dict[torch.device, int] = {} for item in self._iter_objs(ob): @@ -168,10 +193,6 @@ def _gather_tensors(self, ob: Obj) -> Iterable[torch.Tensor]: yield b def _sum_by_dev_dedup(self, tensors: Iterable[torch.Tensor]) -> Dict[torch.device, int]: - """ - Sum bytes by device (torch.device), deduping shared storages via (data_ptr, nbytes). - Counts sparse indices/values; ignores 'meta'. - """ seen_keys: set[tuple[int, int]] = set() by_dev: Dict[torch.device, int] = {} @@ -211,13 +232,9 @@ def _accumulate_dense(t: torch.Tensor) -> None: return by_dev - # ---- Device sets & summaries (BUGFIX area) ---- + # ---------- Summaries ---------- def _all_known_devices_locked(self) -> list[torch.device]: - """ - Build a sorted list of **all** devices we know about from both allocated and freed maps. - """ all_set = set(self._allocated_by_dev.keys()) | set(self._freed_by_dev.keys()) - # Sort by (type, index_or_-1) return sorted(all_set, key=lambda d: (d.type, -1 if d.index is None else d.index)) def _totals_by_type_locked(self, table: Dict[torch.device, int]) -> Dict[str, int]: @@ -226,28 +243,30 @@ def _totals_by_type_locked(self, table: Dict[torch.device, int]) -> Dict[str, in out[d.type] = out.get(d.type, 0) + v return out + def _counts_by_type_locked(self) -> Dict[str, int]: + counts: Dict[str, int] = {} + for d in set(self._allocated_by_dev.keys()) | set(self._freed_by_dev.keys()): + counts[d.type] = counts.get(d.type, 0) + 1 + return counts + def _print_full_device_summary( self, header: str, per_device_map: Dict[torch.device, int], all_devices: list[torch.device], type_totals: Dict[str, int], + type_counts: Dict[str, int], ) -> None: - """ - Print a full summary: - - one line per **device instance** (for every known device) - - one line per **device type** aggregate - """ - # Per-device lines (ALL known devices) + if not DEBUG_MODE: + return for dev in all_devices: val = per_device_map.get(dev, 0) print(f"{header} {dev}: {format_bytes(val)}") - - # Per-type aggregates for dtype in sorted(type_totals.keys()): - print(f"{header} {dtype}: {format_bytes(type_totals[dtype])}") + if type_counts.get(dtype, 0) > 1: + print(f"{header} {dtype}: {format_bytes(type_totals[dtype])}") - # ---- Auto-GC ---- + # ---------- Auto-GC ---------- def _maybe_auto_gc(self, dev: torch.device) -> None: threshold = self._auto_gc_bytes if threshold is None or threshold <= 0: @@ -267,7 +286,7 @@ def _maybe_auto_gc(self, dev: torch.device) -> None: per_dev_count = self._gc_count_by_dev[dev] total_count = self._gc_total_count - print(f"{YELLOW}[auto_gc]{RESET} {dev}: ran GC (count={per_dev_count}), total across devices={total_count}") + _log(f"{YELLOW}[auto_gc]{RESET} {dev}: ran GC (count={per_dev_count}), total across devices={total_count}") def _run_backend_gc(dev: torch.device) -> bool: @@ -306,7 +325,5 @@ def format_bytes(n: int) -> str: return f"{x:.2f} {u}" x /= 1024.0 - - -# default to auto gc interval for every ~256GB of freed memory -MEM_LORD = MemTracker(auto_gc_bytes=1024 * 1024 * 1024) \ No newline at end of file +# default to auto gc interval for every 8GB of freed memory +MEM_LORD = MemTracker(auto_gc_bytes="auto") \ No newline at end of file From a3c90db95e3556975404717ff5c467aa43922362 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 27 Sep 2025 05:22:50 +0000 Subject: [PATCH 6/6] cleanup Signed-off-by: Qubitium --- gptqmodel/looper/gptq_processor.py | 6 ++++++ gptqmodel/quantization/gptq.py | 2 ++ gptqmodel/utils/model.py | 7 ++++++- gptqmodel/utils/torch.py | 4 ++-- 4 files changed, 16 insertions(+), 3 deletions(-) diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index 973205a86..58fc5b40f 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -19,6 +19,7 @@ from ..quantization.config import METHOD, QuantizeConfig from ..utils.importer import select_quant_linear from ..utils.logger import setup_logger +from ..utils.memory import MEM_LORD from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module from ..utils.offload import undo_offload_to_disk from ..utils.torch import HAS_CUDA, torch_streamCtx, torch_sync @@ -126,6 +127,7 @@ def process(self, module: NamedModule): g = self.tasks[module.name] wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize() + MEM_LORD.free((q_scales, q_zeros, q_g_idx)) with self.lock: module.state.update({"q_scales": q_scales}) @@ -196,6 +198,7 @@ def process(self, module: NamedModule): "wq": wq, # fp16, quantized weight but not int4 (packed qweight) }) + MEM_LORD.free(module.weight) module.weight.data = wq # submodule_finalized is called in reverse after all next sequential processes are called @@ -248,6 +251,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): with self.lock: self.result_pop(module.full_name) + # MEM_LORD.free(module.weight) module.unregister_parameter("weight") def finalize(self, model: BaseQModel, **kwargs): @@ -256,6 +260,8 @@ def finalize(self, model: BaseQModel, **kwargs): torch_sync() model.model = undo_offload_to_disk(module=model.model, include_buffers=True, delete_offload_folders=True) + MEM_LORD.free(model.model) + # print("finalize") # print_module_tree(model.model) diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 2d3231ba0..61cda7533 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -25,6 +25,7 @@ from ..utils.torch import HAS_CUDA, HAS_XPU, device_next from .gar import compose_final_perm, compute_global_perm, compute_local_perms, invert_perm from .quantizer import HF_OPTIMUM, Quantizer +from ..utils.memory import MEM_LORD log = setup_logger() @@ -522,6 +523,7 @@ def quantize( avg_loss = 999999999 del Losses + MEM_LORD.free(self.H) del self.H group_size = self.qcfg.group_size if self.qcfg.group_size != -1 else self.columns diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 1987c2e8b..90f887a6f 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -48,6 +48,7 @@ from .importer import select_quant_linear from .logger import setup_logger from .torch import torch_empty_cache, torch_new_stream_ctx +from ..utils.memory import MEM_LORD log = setup_logger() @@ -79,6 +80,7 @@ def recurse_setattr(module, name, value): def move_to(obj: torch.Tensor | nn.Module, device: torch.device, dtype: torch.dtype = None, stream: bool = False): if get_device(obj) != device: + MEM_LORD.free(obj) if stream: # we cannot support changing dtype and stream at the same time assert dtype is None, f"streaming does not support changing dtype: actual = `{dtype}" @@ -584,12 +586,15 @@ def pack_module(name, qModules, q_scales, q_zeros, q_g_idx, layers, quant_linear layer = layers[name] module = qModules[name] + module = module.to(CPU) layer = layer.to(CPU) q_scales = q_scales.to(CPU) q_zeros = q_zeros.to(CPU) - q_g_idx = q_g_idx.to(CPU) if q_g_idx is not None else None + + if q_g_idx is not None: + q_g_idx = q_g_idx.to(CPU) with lock: layers[name] = layer diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index bc61ce941..9dbd46761 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -140,10 +140,10 @@ def torch_empty_cache(device: torch.device = None, gc: bool = True): # check all backends if device is None: if HAS_CUDA: - torch.cuda.synchronize() + # torch.cuda.synchronize() torch.cuda.empty_cache() if HAS_XPU: - torch.xpu.synchronize() + # torch.xpu.synchronize() torch.xpu.empty_cache() if HAS_MPS: torch.mps.empty_cache()