diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 03ad2392a..891e4512b 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -450,6 +450,7 @@ def loop(self, fail_safe: bool = False, **kwargs): with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] + @torch.inference_mode() def process_module(name, m): # prevent cuda sync memory ctx bugs m_device = get_device(m) @@ -541,6 +542,7 @@ def process_module(name, m): torch_sync() for reverse_p in reversed(self.processors): for name in processed_subset: + @torch.inference_mode() def finalize_module(module): # prevent cuda sync memory ctx bugs m_device = get_device(module) diff --git a/tests/test_qqq_inference.py b/tests/test_qqq_inference.py index 283788c72..4bf5e1870 100644 --- a/tests/test_qqq_inference.py +++ b/tests/test_qqq_inference.py @@ -5,4 +5,4 @@ framework=EVAL.LM_EVAL, tasks=[EVAL.LM_EVAL.ARC_CHALLENGE]) -print(f"{eval_results}") \ No newline at end of file +print(f"{eval_results}")