From 541b47e317da805a586599a809594dfd9044eadc Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 3 Oct 2025 13:38:08 +0000 Subject: [PATCH 1/5] fix awq_gemm not passing register_buffers Signed-off-by: Qubitium --- gptqmodel/nn_modules/qlinear/__init__.py | 3 +- gptqmodel/nn_modules/qlinear/awq_gemm.py | 2 + gptqmodel/nn_modules/qlinear/awq_gemv.py | 55 ++++++------ gptqmodel/nn_modules/qlinear/awq_gemv_fast.py | 55 ++++++------ gptqmodel/nn_modules/qlinear/awq_marlin.py | 84 ++++++++++--------- tests/models/model_test.py | 3 +- tests/models/test_llama3_2_awq.py | 38 +++++++++ tests/test_awq.py | 4 +- 8 files changed, 146 insertions(+), 98 deletions(-) create mode 100644 tests/models/test_llama3_2_awq.py diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 69a5dbb8f..9379f9fee 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -789,7 +789,6 @@ class AWQuantLinear(BaseQuantLinear): def __init__(self, bias: bool = False, use_bf16: bool = False, - register_awq_buffers: bool = True, register_buffers: bool = False, **kwargs): super().__init__(bias=bias, register_buffers=False, **kwargs) @@ -799,7 +798,7 @@ def __init__(self, in_features = self.in_features out_features = self.out_features - if register_awq_buffers: + if register_buffers: self.register_buffer( "qweight", t.zeros((in_features, out_features // (self.pack_dtype_bits // self.bits)), dtype=self.pack_dtype), diff --git a/gptqmodel/nn_modules/qlinear/awq_gemm.py b/gptqmodel/nn_modules/qlinear/awq_gemm.py index 58f7943d4..913c69568 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemm.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemm.py @@ -49,6 +49,7 @@ def __init__( bias: bool = False, pack_dtype: torch.dtype = torch.int32, adapter: Adapter = None, + register_buffers: bool = False, **kwargs, ): super().__init__( @@ -62,6 +63,7 @@ def __init__( pack_dtype=pack_dtype, backend=kwargs.pop("backend", BACKEND.GEMM), adapter=adapter, + register_buffers=register_buffers, **kwargs) def post_init(self): diff --git a/gptqmodel/nn_modules/qlinear/awq_gemv.py b/gptqmodel/nn_modules/qlinear/awq_gemv.py index 46fbbe41f..d754aa4cb 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemv.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemv.py @@ -50,8 +50,10 @@ def __init__( bias: bool = False, pack_dtype: torch.dtype = torch.int32, adapter: Adapter = None, + register_buffers: bool = False, **kwargs, ): + backend = kwargs.pop("backend", BACKEND.GEMV) super().__init__( bits=bits, group_size=group_size, @@ -61,38 +63,39 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.GEMV), + backend=backend, adapter=adapter, - register_awq_buffers=False, + register_buffers=False, **kwargs) self.split_k_iters = 8 - self.register_buffer( - "qweight", - torch.zeros((out_features, in_features // self.pack_factor), dtype=self.pack_dtype), - ) - self.register_buffer( - "qzeros", - torch.zeros( - out_features, - calculate_zeros_width(in_features, self.group_size), - dtype=self.pack_dtype, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - out_features, - calculate_zeros_width(in_features, self.group_size) * self.pack_factor, - dtype=torch.float16, - ), - ) + self.bias = None - if bias: - self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float16)) - else: - self.bias = None + if register_buffers: + self.register_buffer( + "qweight", + torch.zeros((out_features, in_features // self.pack_factor), dtype=self.pack_dtype), + ) + self.register_buffer( + "qzeros", + torch.zeros( + out_features, + calculate_zeros_width(in_features, self.group_size), + dtype=self.pack_dtype, + ), + ) + self.register_buffer( + "scales", + torch.zeros( + out_features, + calculate_zeros_width(in_features, self.group_size) * self.pack_factor, + dtype=torch.float16, + ), + ) + + if bias: + self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float16)) def post_init(self): # if self.padded_infeatures != self.in_features: diff --git a/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py b/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py index 3649b6900..79548e56b 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py @@ -50,8 +50,10 @@ def __init__( bias: bool = False, pack_dtype: torch.dtype = torch.int32, adapter: Adapter = None, + register_buffers: bool = False, **kwargs, ): + backend = kwargs.pop("backend", BACKEND.GEMV_FAST) super().__init__( bits=bits, group_size=group_size, @@ -61,9 +63,9 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.GEMV_FAST), + backend=backend, adapter=adapter, - register_awq_buffers=False, + register_buffers=False, **kwargs) self.split_k_iters = 8 @@ -71,31 +73,32 @@ def __init__( int32_pack_factor = 32 // self.bits - self.register_buffer( - "qweight", - torch.zeros((out_features // self.interleave, in_features // self.pack_factor * self.interleave), dtype=self.pack_dtype), - ) - self.register_buffer( - "qzeros", - torch.zeros( - calculate_zeros_width(in_features, self.group_size) * int32_pack_factor, - out_features, - dtype=torch.float16, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - calculate_zeros_width(in_features, self.group_size) * int32_pack_factor, - out_features, - dtype=torch.float16, - ), - ) + self.bias = None - if bias: - self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float16)) - else: - self.bias = None + if register_buffers: + self.register_buffer( + "qweight", + torch.zeros((out_features // self.interleave, in_features // self.pack_factor * self.interleave), dtype=self.pack_dtype), + ) + self.register_buffer( + "qzeros", + torch.zeros( + calculate_zeros_width(in_features, self.group_size) * int32_pack_factor, + out_features, + dtype=torch.float16, + ), + ) + self.register_buffer( + "scales", + torch.zeros( + calculate_zeros_width(in_features, self.group_size) * int32_pack_factor, + out_features, + dtype=torch.float16, + ), + ) + + if bias: + self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float16)) def post_init(self): # if self.padded_infeatures != self.in_features: diff --git a/gptqmodel/nn_modules/qlinear/awq_marlin.py b/gptqmodel/nn_modules/qlinear/awq_marlin.py index 586f98436..c964ea98c 100644 --- a/gptqmodel/nn_modules/qlinear/awq_marlin.py +++ b/gptqmodel/nn_modules/qlinear/awq_marlin.py @@ -76,6 +76,7 @@ def __init__( bias: bool = False, pack_dtype: torch.dtype = torch.int32, adapter: Adapter = None, + register_buffers=False, **kwargs): if marlin_import_exception is not None: raise ValueError( @@ -95,54 +96,55 @@ def __init__( pack_dtype=pack_dtype, backend=kwargs.pop("backend", BACKEND.MARLIN), adapter=adapter, - register_awq_buffers=False, + register_buffers=False, **kwargs) - self.register_parameter( - "qweight", - torch.nn.Parameter( - torch.empty( - self.in_features, - self.out_features // self.pack_factor, - dtype=torch.int32, + if register_buffers: + self.register_parameter( + "qweight", + torch.nn.Parameter( + torch.empty( + self.in_features, + self.out_features // self.pack_factor, + dtype=torch.int32, + ), + requires_grad=False ), - requires_grad=False - ), - ) - self.register_parameter( - "qzeros", - torch.nn.Parameter( - torch.empty( - self.in_features // self.group_size, - self.out_features // self.pack_factor, - dtype=torch.int32, - ), - requires_grad=False ) - ) - - self.register_parameter( - "scales", - torch.nn.Parameter( - torch.empty( - self.in_features // self.group_size, - self.out_features, - dtype=torch.float16, - ), - requires_grad=False + self.register_parameter( + "qzeros", + torch.nn.Parameter( + torch.empty( + self.in_features // self.group_size, + self.out_features // self.pack_factor, + dtype=torch.int32, + ), + requires_grad=False + ) ) - ) - if bias: - self.register_buffer( - "bias", - torch.zeros( - (out_features), - dtype=torch.float16, - ), + self.register_parameter( + "scales", + torch.nn.Parameter( + torch.empty( + self.in_features // self.group_size, + self.out_features, + dtype=torch.float16, + ), + requires_grad=False + ) ) - else: - self.bias = None + + if bias: + self.register_buffer( + "bias", + torch.zeros( + (out_features), + dtype=torch.float16, + ), + ) + else: + self.bias = None self.is_lm_head = False if kwargs.get("name") is not None and kwargs.get("lm_head_name") is not None: diff --git a/tests/models/model_test.py b/tests/models/model_test.py index e88cdc6cf..d3345ade9 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -222,7 +222,8 @@ def run_arc_challenge_eval(self, model, backend, trust_remote_code=False): def perform_post_quant_validation(self, model_path, trust_remote_code=False): inference_records = {} arc_records = {} - for backend in (BACKEND.MARLIN, BACKEND.TORCH): + compare_backends = (BACKEND.MARLIN, BACKEND.TORCH) if self.FORMAT is FORMAT.GPTQ else (BACKEND.MARLIN, BACKEND.GEMM) + for backend in compare_backends: log.info(f"Loading post-quant model with backend `{backend.name}`") model = self.loadQuantModel( model_path, diff --git a/tests/models/test_llama3_2_awq.py b/tests/models/test_llama3_2_awq.py new file mode 100644 index 000000000..35a07728a --- /dev/null +++ b/tests/models/test_llama3_2_awq.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from model_test import ModelTest +from gptqmodel.quantization import METHOD, FORMAT + + +# a100:0 +# desc_act = False, act_group_aware = False 0.2500/0.2841 +# desc_act = False, act_group_aware = True 0.3063/0.3456 +# desc_act = True, 0.3089/0.3328 +class TestLlama3_2(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" + NATIVE_ARC_CHALLENGE_ACC = 0.3234 + NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3524 + QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.36 + APPLY_CHAT_TEMPLATE = True + V2 = False + DEBUG = True + ACT_GROUP_AWARE = False + DESC_ACT = False + DATASET_SIZE = 1024 + DATASET_SORT = "desc" + QUANT_BATCH_SIZE = 4 + FORMAT = FORMAT.GEMM + METHOD = METHOD.AWQ + # USE_FLASH_ATTN = False + # EORA = Lora( + # # for quant, path is save path. for load, it is loading path + # path="./eora_test", + # rank=128, + # ) + # b1 = 0.315, b4 = 0.3106, b8 = 0.3148, b32 = 0.3148, b16 = 0.3234 + + def test_llama3_2(self): + self.quant_lm_eval() diff --git a/tests/test_awq.py b/tests/test_awq.py index 42a612690..ed65b8506 100644 --- a/tests/test_awq.py +++ b/tests/test_awq.py @@ -40,7 +40,7 @@ def setUpClass(self): traindata = load_dataset("json", data_files="/monster/data/model/dataset/c4-train.00000-of-01024.json.gz", split="train") - self.calibration_dataset = [self.tokenizer(example["text"]) for example in traindata.select(range(1024))] + self.calibration_dataset = traindata.select(range(1024)) # def test_load_group_128(self): # model = GPTQModel.load( @@ -71,7 +71,7 @@ def test_quant_and_inference(self, checkpoint_format, backend, group_size: int): quantize_config=quantize_config, debug=True, ) - model.quantize(self.calibration_dataset, batch_size=1, calibration_concat_size=2048) + model.quantize(self.calibration_dataset, batch_size=1, calibration_concat_size=0) with tempfile.TemporaryDirectory() as tmp_dir_name: model.save(tmp_dir_name) From 2c73146079b6644daeb19b199da71617d41d5321 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 3 Oct 2025 14:04:48 +0000 Subject: [PATCH 2/5] patch Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 10 +++++++++- gptqmodel/nn_modules/qlinear/awq_gemm.py | 3 ++- gptqmodel/utils/model.py | 2 +- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index c718870f5..a7bfdd073 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -17,6 +17,7 @@ import threading import time +from contextlib import nullcontext from typing import Dict, List, Optional import torch @@ -711,7 +712,14 @@ def loop(self, fail_safe: bool = False, **kwargs): processor=processor, fail_safe=fail_safe) named_childs.update(named_modules) - processor.layer_quantize(module, cur_layer_device, named_childs) + + lock_ctx = nullcontext() + device_for_ctx = cur_layer_device if getattr(cur_layer_device, 'type', None) != 'meta' else None + if device_for_ctx is not None: + lock_ctx = self.pool.read_lock(cur_layer_device) + with lock_ctx: + with device_ctx(device_for_ctx): + processor.layer_quantize(module, cur_layer_device, named_childs) continue layer_inputs = processor.inputs_cache.layer_inputs diff --git a/gptqmodel/nn_modules/qlinear/awq_gemm.py b/gptqmodel/nn_modules/qlinear/awq_gemm.py index 913c69568..fac0bedb5 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemm.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemm.py @@ -78,7 +78,8 @@ def post_init(self): # device=self.g_idx.device) # awq only accepts float16 - self.scales = self.scales.to(dtype=torch.float16) + if self.scales: + self.scales = self.scales.to(dtype=torch.float16) super().post_init() diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 3dd7275ea..5a3b4da90 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -973,7 +973,7 @@ def gptqmodel_post_init(model, use_act_order: bool, quantize_config: QuantizeCon # The buffers need to have been initialized first before calling make_q4. for _, submodule in model.named_modules(): - if isinstance(submodule, ExllamaV2QuantLinear) or isinstance(submodule, AwqExllamaV2QuantLinear): + if isinstance(submodule, (ExllamaV2QuantLinear, AwqExllamaV2QuantLinear)): device = submodule.qweight.device submodule.post_init(scratch_space=model.device_tensors[device]) elif isinstance(submodule, BaseQuantLinear): From 236822995d82658cf6a5274205c938198ddbfdba Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 3 Oct 2025 14:10:23 +0000 Subject: [PATCH 3/5] use new multi ctx helper Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 6 +- gptqmodel/utils/ctx.py | 34 +++++++++ tests/test_ctx.py | 119 ++++++++++++++++++++++++++++++ 3 files changed, 156 insertions(+), 3 deletions(-) create mode 100644 gptqmodel/utils/ctx.py create mode 100644 tests/test_ctx.py diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index a7bfdd073..b534f7416 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -33,6 +33,7 @@ from ..nn_modules.hooked_linear import (STOP_FORWARD_EXCEPTION, HookedLinear, StopForward, replace_module_with_hooked_legacy) from ..utils.attn_mask import apply_keep_mask_bt, normalize_seq_mask +from ..utils.ctx import ctx from ..utils.device import get_device, get_device_new from ..utils.logger import setup_logger from ..utils.looper_helpers import ( @@ -717,9 +718,8 @@ def loop(self, fail_safe: bool = False, **kwargs): device_for_ctx = cur_layer_device if getattr(cur_layer_device, 'type', None) != 'meta' else None if device_for_ctx is not None: lock_ctx = self.pool.read_lock(cur_layer_device) - with lock_ctx: - with device_ctx(device_for_ctx): - processor.layer_quantize(module, cur_layer_device, named_childs) + with ctx(lock_ctx, device_ctx(device_for_ctx)): + processor.layer_quantize(module, cur_layer_device, named_childs) continue layer_inputs = processor.inputs_cache.layer_inputs diff --git a/gptqmodel/utils/ctx.py b/gptqmodel/utils/ctx.py new file mode 100644 index 000000000..b06abf9ef --- /dev/null +++ b/gptqmodel/utils/ctx.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +from contextlib import AbstractContextManager, ExitStack, contextmanager +from typing import Any, Iterator + +ContextArg = AbstractContextManager[Any] | None + + +@contextmanager +def ctx(*contexts: ContextArg) -> Iterator[Any | tuple[Any, ...] | None]: + """Enter each context in ``contexts`` and yield their ``__enter__`` values. + + The helper lets callers replace nested ``with`` blocks with ``with ctx(...)`` + while gracefully ignoring ``None`` entries. + """ + with ExitStack() as stack: + entered: list[Any] = [] + for context in contexts: + if context is None: + continue + entered.append(stack.enter_context(context)) + + if not entered: + yield None + elif len(entered) == 1: + yield entered[0] + else: + yield tuple(entered) + diff --git a/tests/test_ctx.py b/tests/test_ctx.py new file mode 100644 index 000000000..0de1f9114 --- /dev/null +++ b/tests/test_ctx.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +import threading +from concurrent.futures import ThreadPoolExecutor +from contextlib import AbstractContextManager +from typing import Any, List, Tuple + +import pytest + +from gptqmodel.utils.ctx import ctx + + +class RecordingContext(AbstractContextManager[Any]): + def __init__(self, name: str, log: List[Tuple[str, str, int]]): + self.name = name + self.log = log + + def __enter__(self) -> str: + self.log.append(("enter", self.name, threading.get_ident())) + return self.name + + def __exit__(self, exc_type, exc, tb) -> None: + self.log.append(("exit", self.name, threading.get_ident())) + return False + + +def test_ctx_enters_in_order_and_exits_in_reverse() -> None: + log: List[Tuple[str, str, int]] = [] + contexts = [RecordingContext("first", log), RecordingContext("second", log), RecordingContext("third", log)] + + with ctx(*contexts) as values: + assert values == ("first", "second", "third") + assert log == [ + ("enter", "first", threading.get_ident()), + ("enter", "second", threading.get_ident()), + ("enter", "third", threading.get_ident()), + ] + + assert log[3:] == [ + ("exit", "third", threading.get_ident()), + ("exit", "second", threading.get_ident()), + ("exit", "first", threading.get_ident()), + ] + + +def test_ctx_skips_none_and_returns_single_value() -> None: + log: List[Tuple[str, str, int]] = [] + single = RecordingContext("solo", log) + + with ctx(None, single, None) as value: + assert value == "solo" + assert log == [("enter", "solo", threading.get_ident())] + + assert log[1:] == [("exit", "solo", threading.get_ident())] + + +def test_ctx_returns_none_when_no_contexts() -> None: + with ctx() as value: + assert value is None + + +def test_ctx_exits_all_contexts_when_exception_raised() -> None: + log: List[Tuple[str, str, int]] = [] + contexts = [RecordingContext("first", log), RecordingContext("second", log)] + + class ExpectedError(RuntimeError): + pass + + with pytest.raises(ExpectedError): + with ctx(*contexts): + raise ExpectedError() + + assert log == [ + ("enter", "first", threading.get_ident()), + ("enter", "second", threading.get_ident()), + ("exit", "second", threading.get_ident()), + ("exit", "first", threading.get_ident()), + ] + + +def test_ctx_threaded_execution_respects_order(monkeypatch) -> None: + monkeypatch.setenv("PYTHON_GIL", "0") + + def worker(thread_index: int) -> List[Tuple[str, str, int]]: + log: List[Tuple[str, str, int]] = [] + contexts = [ + RecordingContext(f"c{thread_index}-1", log), + None, + RecordingContext(f"c{thread_index}-2", log), + ] + with ctx(*contexts) as values: + assert values == (f"c{thread_index}-1", f"c{thread_index}-2") + return log + + with ThreadPoolExecutor(max_workers=4) as executor: + logs = list(executor.map(worker, range(4))) + + for idx, log in enumerate(logs): + entries = [item for item in log if item[0] == "enter"] + exits = [item for item in log if item[0] == "exit"] + assert len(entries) == 2 + assert len(exits) == 2 + enter_thread = entries[0][2] + exit_thread = exits[0][2] + assert enter_thread == exit_thread + assert entries == [ + ("enter", f"c{idx}-1", enter_thread), + ("enter", f"c{idx}-2", enter_thread), + ] + assert exits == [ + ("exit", f"c{idx}-2", exit_thread), + ("exit", f"c{idx}-1", exit_thread), + ] + From 52fdf5b83e23a5dae889910414cbea1f649cc549 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 3 Oct 2025 14:15:29 +0000 Subject: [PATCH 4/5] none check Signed-off-by: Qubitium --- gptqmodel/nn_modules/qlinear/awq_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/nn_modules/qlinear/awq_gemm.py b/gptqmodel/nn_modules/qlinear/awq_gemm.py index fac0bedb5..af88aa237 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemm.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemm.py @@ -78,7 +78,7 @@ def post_init(self): # device=self.g_idx.device) # awq only accepts float16 - if self.scales: + if self.scales is not None: self.scales = self.scales.to(dtype=torch.float16) super().post_init() From 03b786753b7730d1053096e4eeb9c437ff457904 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 3 Oct 2025 14:18:12 +0000 Subject: [PATCH 5/5] use ctx Signed-off-by: Qubitium --- gptqmodel/looper/awq_processor.py | 18 +++---- gptqmodel/looper/module_looper.py | 14 +++--- gptqmodel/utils/model.py | 83 +++++++++++++++---------------- gptqmodel/utils/threadx.py | 16 +++--- 4 files changed, 65 insertions(+), 66 deletions(-) diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index 0f856996f..56de99ef5 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -28,6 +28,7 @@ from ..quantization.awq.utils.utils import get_best_device from ..quantization.config import FORMAT, METHOD, QuantizeConfig from ..utils.logger import setup_logger +from ..utils.ctx import ctx from ..utils.model import get_module_by_name_prefix, move_to from ..utils.torch import CPU, tf32_disable_guard, tf32_enable_guard, torch_sync @@ -295,17 +296,16 @@ def _search_best_scale( # [STEP 3]: Compute output of module module_kwargs = self._sanitize_kwargs(kwargs, module2inspect) - with torch.inference_mode(): - with tf32_enable_guard(): - fp16_output = self._module_forward(inp, module2inspect, module_kwargs) + with ctx(torch.inference_mode(), tf32_enable_guard()): + fp16_output = self._module_forward(inp, module2inspect, module_kwargs) - with tf32_disable_guard(): - fp16_output = fp16_output.clip(torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max) + with tf32_disable_guard(): + fp16_output = fp16_output.clip(torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max) - # [STEP 4]: Compute loss - best_scales, loss = self._compute_best_scale( - inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs - ) + # [STEP 4]: Compute loss + best_scales, loss = self._compute_best_scale( + inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs + ) return ( get_op_name(module, prev_op), diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index b534f7416..d1ef44fad 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -566,12 +566,14 @@ def store_input_hook(module, args, kwargs): example["attention_mask"] = example["attention_mask"].long() # Ensure initial caches (like RoPE) are created on the quant device - with self.pool.read_lock(self.gptq_model.quantize_config.device): - with device_ctx(self.gptq_model.quantize_config.device): - if self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS: - self.gptq_model.model.generate(**example, **self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS) - else: - self.gptq_model.model(**example, use_cache=use_cache) + with ctx( + self.pool.read_lock(self.gptq_model.quantize_config.device), + device_ctx(self.gptq_model.quantize_config.device), + ): + if self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS: + self.gptq_model.model.generate(**example, **self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS) + else: + self.gptq_model.model(**example, use_cache=use_cache) except StopForward: pass diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 5a3b4da90..5a31382c7 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -18,6 +18,7 @@ import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass +from .ctx import ctx from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type, Union @@ -687,24 +688,22 @@ def pack_module( quant_result: Optional[Dict[str, Any]] = None, ): # Limit pack() thread usage to avoid auto-parallizataion regression - with tctl.threadpool_limits(limits=1): - with lock: - layer = layers[name] - module = qModules[name] - + with ctx(tctl.threadpool_limits(limits=1), lock): + layer = layers[name] + module = qModules[name] - module = module.to(CPU) + module = module.to(CPU) - layer = layer.to(CPU) - q_scales = q_scales.to(CPU) - q_zeros = q_zeros.to(CPU) + layer = layer.to(CPU) + q_scales = q_scales.to(CPU) + q_zeros = q_zeros.to(CPU) - if q_g_idx is not None: - q_g_idx = q_g_idx.to(CPU) + if q_g_idx is not None: + q_g_idx = q_g_idx.to(CPU) - with lock: - layers[name] = layer - qModules[name] = module + with lock: + layers[name] = layer + qModules[name] = module # TODO FIX ME..remove hard coded qqq pack if quant_linear_cls.QUANT_TYPE == "qqq": @@ -784,24 +783,23 @@ def pack_model( else: max_packers = 1 # due to gil, there is no point packing with more than 1 thread - with ThreadPoolExecutor(max_workers=max_packers) as executor: - with log.pb(names).manual() as pb: - def wrapper(name): - # TODO FIX, thread pool executor does not advance iterator - pb.next() - pb.title(f"Packing {name}").draw() - pack_module( - name=name, - qModules=qModules, - quant_result=quant_result, - layers=modules, - quant_linear_cls=quant_linear_cls, - lock=lock, - quantize_config=qcfg, - ) + with ctx(ThreadPoolExecutor(max_workers=max_packers), log.pb(names).manual()) as (executor, pb): + def wrapper(name): + # TODO FIX, thread pool executor does not advance iterator + pb.next() + pb.title(f"Packing {name}").draw() + pack_module( + name=name, + qModules=qModules, + quant_result=quant_result, + layers=modules, + quant_linear_cls=quant_linear_cls, + lock=lock, + quantize_config=qcfg, + ) - for _ in executor.map(wrapper, names): - pass + for _ in executor.map(wrapper, names): + pass log.info("Model packed.") return quant_linear_cls @@ -1394,18 +1392,17 @@ def load_checkpoint_in_model_then_tie_weights(model, *args, **kwargs): _STREAM_BUFFER_LOCK = threading.Lock() def _copy_file_stream(src_path: str, dst_fh, length: int, *, offset: int = 0) -> None: - with open(src_path, "rb", buffering=0) as src: - with _STREAM_BUFFER_LOCK: - if offset: - src.seek(offset) - remaining = length - while remaining > 0: - chunk_size = min(_STREAM_BUFFER_SIZE, remaining) - read = src.readinto(_STREAM_BUFFER[:chunk_size]) - if not read: - raise IOError(f"Unexpected EOF while copying from {src_path}") - dst_fh.write(_STREAM_BUFFER[:read]) - remaining -= read + with ctx(open(src_path, "rb", buffering=0), _STREAM_BUFFER_LOCK) as (src, _): + if offset: + src.seek(offset) + remaining = length + while remaining > 0: + chunk_size = min(_STREAM_BUFFER_SIZE, remaining) + read = src.readinto(_STREAM_BUFFER[:chunk_size]) + if not read: + raise IOError(f"Unexpected EOF while copying from {src_path}") + dst_fh.write(_STREAM_BUFFER[:read]) + remaining -= read def _write_tensor_bytes(out, tensor: torch.Tensor, dtype: torch.dtype) -> None: diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index a1b20440e..dc3757938 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -16,6 +16,7 @@ from .. import DEBUG_ON from ..utils.logger import setup_logger +from ..utils.ctx import ctx log = setup_logger() @@ -356,15 +357,14 @@ def _run(self): use_inference = self._inference_mode if override_inference is None else bool(override_inference) # Tasks take a **read** lock so janitor's write lock can't interleave - with self.rwlock.reader(): - with _device_ctx(self.device): - inference_ctx = torch.inference_mode() if use_inference else contextlib.nullcontext() - with inference_ctx: - if stream is not None and self.device.type == "cuda": - with torch.cuda.stream(stream): - result = fn(*args, **kwargs) - else: + with ctx(self.rwlock.reader(), _device_ctx(self.device)): + inference_ctx = torch.inference_mode() if use_inference else contextlib.nullcontext() + with inference_ctx: + if stream is not None and self.device.type == "cuda": + with torch.cuda.stream(stream): result = fn(*args, **kwargs) + else: + result = fn(*args, **kwargs) # Counters must be updated before resolving futures to prevent # tests reading stats mid-transition and seeing stale totals. self._on_task_finished(self.key)