diff --git a/tests/test_inference_result_xpu.py b/tests/test_inference_result_xpu.py index 3498ea45c..73dd7faa3 100644 --- a/tests/test_inference_result_xpu.py +++ b/tests/test_inference_result_xpu.py @@ -25,16 +25,15 @@ class TestInferenceResultXPU(ModelTest): (BACKEND.TORCH, DEVICE.XPU, False), ] ) - def testTritonXPU(self, backend, device, template): + def testTorchXPU(self, backend, device, template): origin_model = GPTQModel.load( self.NATIVE_MODEL_ID, - quantize_config=QuantizeConfig(), + quantize_config=QuantizeConfig(device=device), backend=backend, - device=device, ) tokenizer = self.load_tokenizer(self.NATIVE_MODEL_ID) calibration_dataset = self.load_dataset(tokenizer, rows=128) - origin_model.quantize(calibration_dataset, backend=BACKEND.TRITON) + origin_model.quantize(calibration_dataset, backend=BACKEND.TORCH) with tempfile.TemporaryDirectory() as tmpdir: origin_model.save(tmpdir) diff --git a/tests/test_triton_xpu.py b/tests/test_torch_xpu.py similarity index 92% rename from tests/test_triton_xpu.py rename to tests/test_torch_xpu.py index bb761ab62..0547b79af 100644 --- a/tests/test_triton_xpu.py +++ b/tests/test_torch_xpu.py @@ -18,25 +18,25 @@ from gptqmodel.models._const import DEVICE # noqa: E402 -class TestTritonXPU(ModelTest): +class TestTorchXPU(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" def test(self): origin_model = GPTQModel.load( self.NATIVE_MODEL_ID, quantize_config=QuantizeConfig(), - backend=BACKEND.TRITON, + backend=BACKEND.TORCH, device=DEVICE.XPU, ) tokenizer = self.load_tokenizer(self.NATIVE_MODEL_ID) calibration_dataset = self.load_dataset(tokenizer, self.DATASET_SIZE) - origin_model.quantize(calibration_dataset, backend=BACKEND.TRITON) + origin_model.quantize(calibration_dataset, backend=BACKEND.TORCH) with tempfile.TemporaryDirectory() as tmpdir: origin_model.save(tmpdir) model = GPTQModel.load( tmpdir, - backend=BACKEND.TRITON, + backend=BACKEND.TORCH, device=DEVICE.XPU, ) generate_str = tokenizer.decode(model.generate(**tokenizer("The capital of France is is", return_tensors="pt").to(model.device), max_new_tokens=2)[0])