Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
import transformers
from itertools import chain

from ..utils.structure import print_module_tree

Expand Down Expand Up @@ -561,9 +562,6 @@ def skip(*args, **kwargs):
)



import torch

def build_layerwise_device_map(
model,
device,
Expand Down Expand Up @@ -625,16 +623,40 @@ def assign(mod, device_id):
assign(in_emb, device_ids[0])

# Alternating layers
layer_name2devid: Dict[str, int] = {}

for i, layer in enumerate(layers):
gpu = device_ids[i % num_gpus]
assign(layer, gpu)
lname = mod2name.get(layer)
if lname is not None:
layer_name2devid[lname] = gpu

# Ignored modules - skip input embeddings to avoid overriding GPU 0 assignment
# Iterate over modules that should be ignored during default layer-wise mapping
for mod in ignore_modules:
# Preserve GPU-0 placement for the input embedding module if it exists
if in_emb is not None and mod is in_emb:
continue # Skip input embedding to preserve GPU 0 assignment
assign(mod, device_ids[-1])

# Retrieve the module’s fully-qualified name
name = mod2name.get(mod)
if name is None:
continue
# Walk up the module hierarchy to find the closest ancestor that already has a device assignment
owner = name
dev_id = None
while owner:
if owner in layer_name2devid:
dev_id = layer_name2devid[owner]
break
if "." not in owner:
break
owner = owner.rsplit(".", 1)[0]
# If no ancestor is found, fall back to the last GPU
if dev_id is None:
dev_id = device_ids[-1]
# Assign the current module to the determined device
assign(mod, dev_id)
# -------------------------------------------------------------
# 4. Handle lm_head / output projection explicitly
# -------------------------------------------------------------
Expand All @@ -660,7 +682,7 @@ def assign(mod, device_id):
# 5. Safety check: ensure all params are covered
# -------------------------------------------------------------
missing = [
n for n, _ in model.named_parameters()
n for n, _ in chain(model.named_parameters(), model.named_buffers())
if not any(n == k or n.startswith(k + ".") for k in device_map)
]
module_names = set(mod2name.values())
Expand Down Expand Up @@ -694,6 +716,27 @@ def assign(mod, device_id):
log.info(f"Loader: dropping parent '{name}' from device_map to preserve child placements.")
device_map.pop(name, None)

# Collect parameters/buffers that were not assigned to any device in the current device_map
missing_after_prune = [
n for n, _ in chain(model.named_parameters(), model.named_buffers())
if not any(n == k or n.startswith(k + ".") for k in device_map)
]
# If any tensors remain unmapped, assign them to the last GPU as a fallback
if missing_after_prune:
fallback_device = device_ids[-1]
for param_name in missing_after_prune:
# Walk up the module tree until we find a module name that exists in module_names
owner = param_name
while owner and owner not in module_names:
if "." not in owner:
owner = ""
else:
owner = owner.rsplit(".", 1)[0]
# Map the closest owning module to the fallback device
if owner:
device_map.setdefault(owner, device_strs[fallback_device])
else:
log.info(f"Loader: unable to map param '{param_name}' to a module; skipping fallback assignment.")
# optional logging for debug
log.info(f"Loader: Built map across {num_gpus} GPU(s), "
f"{len(device_map)} entries. First 8: {list(device_map.items())[:8]}")
Expand Down