Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion gptqmodel/looper/dequantize_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ..looper.loop_processor import LoopProcessor
from ..looper.named_module import NamedModule
from ..models import BaseQModel
from ..nn_modules.qlinear.torch import TorchQuantLinear
from ..utils.logger import setup_logger

Expand Down Expand Up @@ -47,7 +48,7 @@ def process(self, module: NamedModule):
"wq": wq,
})

def submodule_finalize(self, module: NamedModule):
def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
module.state.pop("w", None) # no need for these weights now
module.state.pop("wq", None) # no need for these weights now

Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/looper/native_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
def process(self, module: NamedModule):
module.state[NATIVE_INPUTS_STATE_KEY] = self.native_inp_caches.pop(module.name)

def submodule_finalize(self, module: NamedModule):
def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
module.state.pop(NATIVE_INPUTS_STATE_KEY, None)

def finalize(self, model: BaseQModel, **kwargs):
Expand Down
8 changes: 7 additions & 1 deletion gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,13 @@ def quantize(

if self.quantize_config.v2 is True:
from ..looper.native_processor import NativeProcessor
args_clone = copy.deepcopy(args)

# During the deepcopy process, self.prepare_dataset will be deeply copied along with self. However,
# self has a threading.RLock() , which is not serializable.
args_to_copy = {k: v for k, v in args.items() if k != "prepare_dataset_func"}
args_clone = copy.deepcopy(args_to_copy)
args_clone["prepare_dataset_func"] = args["prepare_dataset_func"]

args_clone.pop("calculate_w_wq_diff", None)
quantize_processor.insert(0, NativeProcessor(**args_clone))

Expand Down