diff --git a/gptqmodel/looper/dequantize_processor.py b/gptqmodel/looper/dequantize_processor.py index b6951cbdd..96d3cc1cf 100644 --- a/gptqmodel/looper/dequantize_processor.py +++ b/gptqmodel/looper/dequantize_processor.py @@ -9,6 +9,7 @@ from ..looper.loop_processor import LoopProcessor from ..looper.named_module import NamedModule +from ..models import BaseQModel from ..nn_modules.qlinear.torch import TorchQuantLinear from ..utils.logger import setup_logger @@ -47,7 +48,7 @@ def process(self, module: NamedModule): "wq": wq, }) - def submodule_finalize(self, module: NamedModule): + def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): module.state.pop("w", None) # no need for these weights now module.state.pop("wq", None) # no need for these weights now diff --git a/gptqmodel/looper/native_processor.py b/gptqmodel/looper/native_processor.py index 0c8de5bb0..01142929b 100644 --- a/gptqmodel/looper/native_processor.py +++ b/gptqmodel/looper/native_processor.py @@ -83,7 +83,7 @@ def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): def process(self, module: NamedModule): module.state[NATIVE_INPUTS_STATE_KEY] = self.native_inp_caches.pop(module.name) - def submodule_finalize(self, module: NamedModule): + def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): module.state.pop(NATIVE_INPUTS_STATE_KEY, None) def finalize(self, model: BaseQModel, **kwargs): diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index bc88cdfd9..e9d70bd70 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -922,7 +922,13 @@ def quantize( if self.quantize_config.v2 is True: from ..looper.native_processor import NativeProcessor - args_clone = copy.deepcopy(args) + + # During the deepcopy process, self.prepare_dataset will be deeply copied along with self. However, + # self has a threading.RLock() , which is not serializable. + args_to_copy = {k: v for k, v in args.items() if k != "prepare_dataset_func"} + args_clone = copy.deepcopy(args_to_copy) + args_clone["prepare_dataset_func"] = args["prepare_dataset_func"] + args_clone.pop("calculate_w_wq_diff", None) quantize_processor.insert(0, NativeProcessor(**args_clone))