diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 314eb8579..917800ecb 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -47,7 +47,7 @@ def __init__(self, model: BaseGPTQModel, processors: List[LoopProcessor]): self.support_batch_quantize = model.support_batch_quantize self.lock = threading.Lock() - def cache_inputs(self, layers, auto_gc, calibration_data, calibration_enable_gpu_cache): + def cache_inputs(self, layers, auto_gc, calibration_data, calibration_enable_gpu_cache, use_cache): layer_inputs = [] attention_masks = [] position_ids = [] @@ -123,7 +123,7 @@ def store_input_hook(module, args, kwargs): if str(type(layers[0])) == "": self.gptq_model.model.generate(**example, return_audio=False) else: - self.gptq_model.model(**example) + self.gptq_model.model(**example, use_cache=use_cache) except ValueError: pass self.gptq_model.pre_quantize_generate_hook_end() @@ -182,7 +182,8 @@ def loop(self, auto_gc=True, calibration_enable_gpu_cache=True, buffered_fwd=Fal input_cache = self.cache_inputs(layers=layers, auto_gc=auto_gc, calibration_data=processor.calibration_dataset, - calibration_enable_gpu_cache=calibration_enable_gpu_cache) + calibration_enable_gpu_cache=calibration_enable_gpu_cache, + use_cache=False) processor.receive_input_cache(input_cache) # release calibration_dataset