Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ def skip(*args, **kwargs):

def build_layerwise_device_map(
model,
device,
layers: List[torch.nn.Module],
ignore_modules: List[torch.nn.Module],
num_gpus: Optional[int] = None,
Expand All @@ -581,10 +582,16 @@ def build_layerwise_device_map(
device_map: Dict[str, str] = {}
mod2name = {m: n for n, m in model.named_modules()}

if torch.cuda.is_available():
device_strs = [f"cuda:{i}" for i in range(num_gpus)]
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device_strs = [f"xpu:{i}" for i in range(num_gpus)]
if device == DEVICE.CUDA:
if torch.cuda.is_available():
device_strs = [f"cuda:{i}" for i in range(num_gpus)]
else:
raise RuntimeError("CUDA is not available")
elif device == DEVICE.XPU:
if hasattr(torch, "xpu") and torch.xpu.is_available():
device_strs = [f"xpu:{i}" for i in range(num_gpus)]
else:
raise RuntimeError("XPU is not available")
else:
device_strs = ["cpu"] * num_gpus

Expand Down Expand Up @@ -653,7 +660,7 @@ def assign(mod, device_id):
else:
owner = owner.rsplit(".", 1)[0]
if owner:
device_map.setdefault(owner, fallback_device)
device_map.setdefault(owner, device_strs[fallback_device])
else:
log.info(f"Loader: unable to map param '{param_name}' to a module; skipping fallback assignment.")

Expand Down Expand Up @@ -685,7 +692,7 @@ def assign(mod, device_id):
num_gpus = torch.cuda.device_count()
elif device is DEVICE.XPU:
num_gpus = torch.xpu.device_count()
device_map = build_layerwise_device_map(model, layers, ignore_modules, num_gpus)
device_map = build_layerwise_device_map(model, device, layers, ignore_modules, num_gpus)
log.info(f"Loader: device_map = {device_map}")

load_checkpoint_in_model = True
Expand Down