Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,8 @@ def post_init(self, **kwargs):
dtype=t.int32,
).reshape(1, 3, 12).to(device=self.g_idx.device)

self.wf_unsqueeze_zero = wf.unsqueeze(0).to(device=self.g_idx.device)
self.wf_unsqueeze_neg_one = wf.unsqueeze(-1).to(device=self.g_idx.device)
self.register_buffer("wf_unsqueeze_zero", wf.unsqueeze(0).to(device=self.g_idx.device), persistent=False)
self.register_buffer("wf_unsqueeze_neg_one", wf.unsqueeze(-1).to(device=self.g_idx.device), persistent=False)

super().post_init(**kwargs)

Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/nn_modules/qlinear/torch_fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class TorchFusedQuantLinear(PackableQuantLinear):
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1]

# optimized for XPU but should run on all
SUPPORTS_DEVICES = [DEVICE.XPU] # change this to XPU to limit to Intel XPU
SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.XPU] # change this to XPU to limit to Intel XPU
SUPPORTS_PLATFORM = [PLATFORM.ALL]
SUPPORTS_PACK_DTYPES = [torch.int32]
SUPPORTS_ADAPTERS = [Lora]
Expand Down Expand Up @@ -174,7 +174,7 @@ def forward(self, x: torch.Tensor):
def _forward(self, x, out_shape):
num_itr = self.g_idx.shape[0] // x.shape[-1]

if not self.training and not self.transformed and TORCH_HAS_XPU_FUSED_OPS:
if not self.training and not self.transformed and TORCH_HAS_XPU_FUSED_OPS and "xpu" == x.device.type:
# one-time transform per module for xpu aten fused ops
self.transform(x.dtype)
self.transformed = True
Expand Down
10 changes: 5 additions & 5 deletions tests/test_kernel_output_torch_fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TestKernelOutput(unittest.TestCase):
BACKEND.TORCH_FUSED: TorchFusedQuantLinear,
}
target = 'model.layers.6.self_attn.v_proj'
device_map = "cpu"
device = "cpu"
m = [1, 16, 64, 256, 1024]
k = 2048
dtype = torch.float16
Expand All @@ -29,7 +29,7 @@ class TestKernelOutput(unittest.TestCase):

@classmethod
def setUp(self):
self.torch_model = GPTQModel.load(self.model_path, backend=BACKEND.TORCH, device_map=self.device_map, dtype=self.dtype)
self.torch_model = GPTQModel.load(self.model_path, backend=BACKEND.TORCH, device=self.device, dtype=self.dtype)
self.x = []
self.torch_kernel_outs = []
for dim_0 in self.m:
Expand Down Expand Up @@ -61,8 +61,8 @@ def assert_on_mismatch(self, a: Tensor, b: Tensor, rtol=0.00005, atol=0.00005):
(BACKEND.TORCH_FUSED, r_tolerance, a_tolerance),
])
def test_kernel_output(self, backend: BACKEND, r_tolerance: float, a_tolerance: float):
model = GPTQModel.load(self.model_path, backend=backend, device_map=self.device_map, dtype=self.dtype)
log.info(f"device_map: {self.device_map} ")
model = GPTQModel.load(self.model_path, backend=backend, device=self.device, dtype=self.dtype)
log.info(f"device: {self.device} ")
log.info(f"backend: {backend} ")
for i in range(len(self.x)):
out = self.forward(model, self.x[i], backend=backend)
Expand All @@ -75,7 +75,7 @@ class TestKernelOutputBFloat16(TestKernelOutput):

@unittest.skipUnless(hasattr(torch, "xpu") and torch.xpu.is_available(), reason="Test requires XPU")
class TestKernelOutputXPU(TestKernelOutput):
device_map = "xpu:0"
device = "xpu:0"
a_tolerance = 0.0005


Expand Down