From a31d47cd4be025ea1c7bc00d11d19bc537309c9c Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 25 Sep 2025 12:39:22 +0000 Subject: [PATCH 1/2] sync shell model with turtle before save Signed-off-by: Qubitium --- gptqmodel/models/writer.py | 9 +-- gptqmodel/utils/offload.py | 4 +- gptqmodel/utils/structure.py | 103 +++++++++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 8 deletions(-) diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index 0b197e767..5261d0b81 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -26,7 +26,6 @@ from ..adapter.adapter import HF_ADAPTER_FILE_NAME, HF_ADAPTER_WEIGHT_KEY_PREFIX, Lora from ..adapter.peft import LoraConfig -from ..nn_modules.hooked_linear import HookedLinear from ..quantization.config import (FORMAT, META_FIELD_ACT_GROUP_AWARE, META_FIELD_DAMP_AUTO_INCREMENT, META_FIELD_DAMP_PERCENT, META_FIELD_MSE, META_FIELD_QUANTIZER, META_FIELD_STATIC_GROUPS, META_FIELD_TRUE_SEQUENTIAL, META_FIELD_URI, @@ -36,6 +35,7 @@ from ..utils.logger import setup_logger from ..utils.model import (convert_gptq_v2_to_v1_format, copy_py_files, find_modules, get_model_files_size, get_state_dict_for_save, load_checkpoint_in_model_then_tie_weights, make_quant) +from ..utils.structure import alias_all_from_turtle_if_meta from ..utils.torch import torch_empty_cache from ..version import __version__ from ._const import DEFAULT_MAX_SHARD_SIZE @@ -257,12 +257,9 @@ def debug_saved_config(path): self.processor.save_pretrained(save_dir) # --- end config save block --- - # TODO FIX ME..remove this ugly patch and find core issue why output_embedding is not retied after offload/undo_offload - output_embed = model.get_output_embeddings() - if isinstance(output_embed, HookedLinear): - model.set_output_embeddings(model.get_input_embeddings()) + # Due to shell/turtle state, we need to sync the modules from turtle to shell + alias_all_from_turtle_if_meta(shell_model=model, turtle_model=self.turtle_model) - # model.to(CPU) <-- do we need to do this? state_dict = get_state_dict_for_save(model) model_base_name = "model" diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index 949ad1ae4..42613b0b8 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -77,8 +77,8 @@ def offload_to_disk(module: List[str] | nn.Module, model: nn.Module, disk_path: _offload_disk(module=module, name=full_name, disk_path=disk_path) - if hasattr(module, "config") and hasattr(module.config, - "tie_word_embeddings") and module.config.tie_word_embeddings: + if hasattr(module, "config") and getattr(module.config, + "tie_word_embeddings", False): module.tie_weights() # makes lm_head.weight point to embed_tokens.weight again after offload # print("offload_disk: list item tree") diff --git a/gptqmodel/utils/structure.py b/gptqmodel/utils/structure.py index fae668732..d8af66c95 100644 --- a/gptqmodel/utils/structure.py +++ b/gptqmodel/utils/structure.py @@ -30,6 +30,8 @@ import torch from torch import nn +from ..utils.logger import setup_logger + # ========================= # ANSI color helpers # ========================= @@ -39,6 +41,8 @@ FG_CYAN = "\033[36m" FG_YELLOW = "\033[33m" +log = setup_logger() + def _maybe(s: str, code: str, *, color: bool) -> str: return f"{code}{s}{RESET}" if color else s @@ -368,3 +372,102 @@ def alias_from_turtle_for_submodule( # return the *target* submodule, which is the injected result return target_submodule + +def _is_meta_tensor(t: torch.Tensor) -> bool: + return bool(getattr(t, "is_meta", False)) or (hasattr(t, "device") and t.device.type == "meta") + +def _module_all_meta(mod: nn.Module) -> bool: + """True if the module has at least one tensor and *all* its params/buffers are meta.""" + saw_any = False + for _, p in mod.named_parameters(recurse=False): + saw_any = True + if not _is_meta_tensor(p): + return False + for _, b in mod.named_buffers(recurse=False): + saw_any = True + if not _is_meta_tensor(b): + return False + return saw_any # modules with no tensors aren't considered 'meta' targets + +def _is_leaf(mod: nn.Module) -> bool: + return next(mod.named_children(), None) is None + +def alias_all_from_turtle_if_meta( + shell_model: nn.Module, + turtle_model: nn.Module, + *, + require_class_match: bool = True, + verify_shapes: bool = True, + tie_after: bool = True, +) -> int: + """ + Replace (alias) leaf submodules in `shell_model` with the corresponding submodules + from `turtle_model` when the shell submodule's tensors are on meta. + + Logs each swap via log.info(). + """ + turtle_map = dict(turtle_model.named_modules()) + swapped = 0 + + for qname, shell_sub in list(shell_model.named_modules()): + if not qname: # skip root + continue + if not _is_leaf(shell_sub): + continue + if not _module_all_meta(shell_sub): + continue + + turtle_sub = turtle_map.get(qname, None) + if turtle_sub is None: + # log.info(f"Module: Skipped {qname}: not found in turtle model") + continue + + if require_class_match and (shell_sub.__class__ is not turtle_sub.__class__): + # log.info( + # f"Module: Skipped {qname}: class mismatch " + # f"(shell={shell_sub.__class__.__name__}, turtle={turtle_sub.__class__.__name__})" + # ) + continue + + if verify_shapes: + shell_ps = dict(shell_sub.named_parameters(recurse=False)) + turtle_ps = dict(turtle_sub.named_parameters(recurse=False)) + for n in set(shell_ps.keys()) & set(turtle_ps.keys()): + if shell_ps[n].shape != turtle_ps[n].shape: + # log.info( + # f"Module: Skipped {qname}: parameter shape mismatch at '{n}' " + # f"(shell={tuple(shell_ps[n].shape)}, turtle={tuple(turtle_ps[n].shape)})" + # ) + break + else: + shell_bs = dict(shell_sub.named_buffers(recurse=False)) + turtle_bs = dict(turtle_sub.named_buffers(recurse=False)) + for n in set(shell_bs.keys()) & set(turtle_bs.keys()): + if shell_bs[n].shape != turtle_bs[n].shape: + # log.info( + # f"Module: Skipped {qname}: buffer shape mismatch at '{n}' " + # f"(shell={tuple(shell_bs[n].shape)}, turtle={tuple(turtle_bs[n].shape)})" + # ) + break + else: + parent, leaf = _get_parent_and_leaf_by_path(shell_model, qname) + setattr(parent, leaf, turtle_sub) + swapped += 1 + log.info(f"Module:: Sync {qname} with ({turtle_sub.__class__.__name__})") + continue + continue + + parent, leaf = _get_parent_and_leaf_by_path(shell_model, qname) + setattr(parent, leaf, turtle_sub) + swapped += 1 + log.info(f"Module:: Sync {qname} with ({turtle_sub.__class__.__name__})") + + if tie_after and hasattr(shell_model, "tie_weights") and getattr(shell_model.config, "tie_word_embeddings", False): + try: + shell_model.tie_weights() + log.info("Module: Re-tied embedding weights on shell model after full sync") + except Exception as e: + log.info(f"Module: tie_weights failed: {e}") + + log.info(f"Module: Total synced modules: {swapped}") + return swapped From b665c30a3138a6fd266d6ddefab393cdc180a79e Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 25 Sep 2025 12:53:34 +0000 Subject: [PATCH 2/2] cleanup Signed-off-by: Qubitium --- gptqmodel/utils/structure.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gptqmodel/utils/structure.py b/gptqmodel/utils/structure.py index d8af66c95..331f8237b 100644 --- a/gptqmodel/utils/structure.py +++ b/gptqmodel/utils/structure.py @@ -453,14 +453,14 @@ def alias_all_from_turtle_if_meta( parent, leaf = _get_parent_and_leaf_by_path(shell_model, qname) setattr(parent, leaf, turtle_sub) swapped += 1 - log.info(f"Module:: Sync {qname} with ({turtle_sub.__class__.__name__})") + log.info(f"Module: Sync {qname} <- from turtle ({turtle_sub.__class__.__name__})") continue continue parent, leaf = _get_parent_and_leaf_by_path(shell_model, qname) setattr(parent, leaf, turtle_sub) swapped += 1 - log.info(f"Module:: Sync {qname} with ({turtle_sub.__class__.__name__})") + log.info(f"Module:: Sync {qname} <- from turtle ({turtle_sub.__class__.__name__})") if tie_after and hasattr(shell_model, "tie_weights") and getattr(shell_model.config, "tie_word_embeddings", False): try: