diff --git a/tests/test_simple_quant.py b/tests/test_simple_quant.py index 2d7871d1c..944b0ac15 100644 --- a/tests/test_simple_quant.py +++ b/tests/test_simple_quant.py @@ -81,7 +81,7 @@ def get_calib_data(tokenizer, rows: int): with tempfile.TemporaryDirectory() as tmp_dir: results = GPTQModel.eval( QUANT_SAVE_PATH, - tasks=[EVAL.LM_EVAL.GSM8K_PLATINUM_COT], #, EVAL.LM_EVAL.GSM8K_PLATINUM_COT], + tasks=[EVAL.LM_EVAL.GSM8K_COT], #, EVAL.LM_EVAL.GSM8K_PLATINUM_COT], apply_chat_template=True, random_seed=898, output_path= tmp_dir,