diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 6e4a44268..dbc18b746 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -576,6 +576,9 @@ def cache_inputs(self, layers, calibration_data, use_cache): cur_layer_device = get_device(layers[0]) data_device = cur_layer_device + # make sure turtle is ready for lias + self.gptq_model.wait_for_turtle_reload() + # TODO HookLinear add register_forward_pre_hook() def store_input_hook(module, args, kwargs): # Positional arguments.