From c6dd79cf4e873b58ce41348b200ff89b140a56ba Mon Sep 17 00:00:00 2001 From: LRL2-ModelCloud Date: Fri, 31 Oct 2025 13:54:00 +0800 Subject: [PATCH 1/2] fix device_map --- gptqmodel/models/loader.py | 46 +++++++++++++++++++++++++++++++++----- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 017d0719f..1901baa54 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -12,6 +12,7 @@ import torch import transformers +from itertools import chain from ..utils.structure import print_module_tree @@ -561,9 +562,6 @@ def skip(*args, **kwargs): ) - - import torch - def build_layerwise_device_map( model, device, @@ -625,15 +623,35 @@ def assign(mod, device_id): assign(in_emb, device_ids[0]) # Alternating layers + layer_name2devid: Dict[str, int] = {} + for i, layer in enumerate(layers): gpu = device_ids[i % num_gpus] assign(layer, gpu) + lname = mod2name.get(layer) + if lname is not None: + layer_name2devid[lname] = gpu # Ignored modules - skip input embeddings to avoid overriding GPU 0 assignment for mod in ignore_modules: if in_emb is not None and mod is in_emb: continue # Skip input embedding to preserve GPU 0 assignment - assign(mod, device_ids[-1]) + name = mod2name.get(mod) + if name is None: + continue + # walk up ancestors to find the nearest repeating layer + owner = name + dev_id = None + while owner: + if owner in layer_name2devid: + dev_id = layer_name2devid[owner] + break + if "." not in owner: + break + owner = owner.rsplit(".", 1)[0] + if dev_id is None: + dev_id = device_ids[-1] + assign(mod, dev_id) # ------------------------------------------------------------- # 4. Handle lm_head / output projection explicitly @@ -660,7 +678,7 @@ def assign(mod, device_id): # 5. Safety check: ensure all params are covered # ------------------------------------------------------------- missing = [ - n for n, _ in model.named_parameters() + n for n, _ in chain(model.named_parameters(), model.named_buffers()) if not any(n == k or n.startswith(k + ".") for k in device_map) ] module_names = set(mod2name.values()) @@ -694,6 +712,24 @@ def assign(mod, device_id): log.info(f"Loader: dropping parent '{name}' from device_map to preserve child placements.") device_map.pop(name, None) + missing_after_prune = [ + n for n, _ in chain(model.named_parameters(), model.named_buffers()) + if not any(n == k or n.startswith(k + ".") for k in device_map) + ] + if missing_after_prune: + fallback_device = device_ids[-1] + for param_name in missing_after_prune: + owner = param_name + while owner and owner not in module_names: + if "." not in owner: + owner = "" + else: + owner = owner.rsplit(".", 1)[0] + if owner: + device_map.setdefault(owner, device_strs[fallback_device]) + else: + log.info(f"Loader: unable to map param '{param_name}' to a module; skipping fallback assignment.") + # optional logging for debug log.info(f"Loader: Built map across {num_gpus} GPU(s), " f"{len(device_map)} entries. First 8: {list(device_map.items())[:8]}") From e753d188864a133177fb661b6b637660a37a673b Mon Sep 17 00:00:00 2001 From: LRL2-ModelCloud Date: Fri, 31 Oct 2025 15:21:19 +0800 Subject: [PATCH 2/2] add comments --- gptqmodel/models/loader.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 1901baa54..e1dd85b98 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -633,13 +633,16 @@ def assign(mod, device_id): layer_name2devid[lname] = gpu # Ignored modules - skip input embeddings to avoid overriding GPU 0 assignment + # Iterate over modules that should be ignored during default layer-wise mapping for mod in ignore_modules: + # Preserve GPU-0 placement for the input embedding module if it exists if in_emb is not None and mod is in_emb: continue # Skip input embedding to preserve GPU 0 assignment + # Retrieve the module’s fully-qualified name name = mod2name.get(mod) if name is None: continue - # walk up ancestors to find the nearest repeating layer + # Walk up the module hierarchy to find the closest ancestor that already has a device assignment owner = name dev_id = None while owner: @@ -649,10 +652,11 @@ def assign(mod, device_id): if "." not in owner: break owner = owner.rsplit(".", 1)[0] + # If no ancestor is found, fall back to the last GPU if dev_id is None: dev_id = device_ids[-1] + # Assign the current module to the determined device assign(mod, dev_id) - # ------------------------------------------------------------- # 4. Handle lm_head / output projection explicitly # ------------------------------------------------------------- @@ -712,24 +716,27 @@ def assign(mod, device_id): log.info(f"Loader: dropping parent '{name}' from device_map to preserve child placements.") device_map.pop(name, None) + # Collect parameters/buffers that were not assigned to any device in the current device_map missing_after_prune = [ n for n, _ in chain(model.named_parameters(), model.named_buffers()) if not any(n == k or n.startswith(k + ".") for k in device_map) ] + # If any tensors remain unmapped, assign them to the last GPU as a fallback if missing_after_prune: fallback_device = device_ids[-1] for param_name in missing_after_prune: + # Walk up the module tree until we find a module name that exists in module_names owner = param_name while owner and owner not in module_names: if "." not in owner: owner = "" else: owner = owner.rsplit(".", 1)[0] + # Map the closest owning module to the fallback device if owner: device_map.setdefault(owner, device_strs[fallback_device]) else: log.info(f"Loader: unable to map param '{param_name}' to a module; skipping fallback assignment.") - # optional logging for debug log.info(f"Loader: Built map across {num_gpus} GPU(s), " f"{len(device_map)} entries. First 8: {list(device_map.items())[:8]}")