diff --git a/gptqmodel/quantization/gptqv2.py b/gptqmodel/quantization/gptqv2.py index b40969c9f..77deea81f 100644 --- a/gptqmodel/quantization/gptqv2.py +++ b/gptqmodel/quantization/gptqv2.py @@ -33,6 +33,11 @@ def __init__(self, module: NamedModule, qcfg: Optional[QuantizeConfig] = None): self.native_inps = module.state.pop(NATIVE_INPUTS_STATE_KEY) + def add_batch(self, inp: torch.Tensor, out: torch.Tensor, batch_index: Optional[int] = None): + with self.lock: + self.fwd_counter += 1 + self.process_batch(inp) + # TODO FIXME: using v1 new process_batch kills v2 quantization quality, use original process_batch # sample counter based on batch request # instead of batched token #. # def process_batch(self, inp):