Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from ..adapter.adapter import HF_ADAPTER_FILE_NAME, HF_ADAPTER_WEIGHT_KEY_PREFIX, Lora
from ..adapter.peft import LoraConfig
from ..nn_modules.hooked_linear import HookedLinear
from ..quantization.config import (FORMAT, META_FIELD_ACT_GROUP_AWARE, META_FIELD_DAMP_AUTO_INCREMENT,
META_FIELD_DAMP_PERCENT, META_FIELD_MSE, META_FIELD_QUANTIZER,
META_FIELD_STATIC_GROUPS, META_FIELD_TRUE_SEQUENTIAL, META_FIELD_URI,
Expand All @@ -36,6 +35,7 @@
from ..utils.logger import setup_logger
from ..utils.model import (convert_gptq_v2_to_v1_format, copy_py_files, find_modules, get_model_files_size,
get_state_dict_for_save, load_checkpoint_in_model_then_tie_weights, make_quant)
from ..utils.structure import alias_all_from_turtle_if_meta
from ..utils.torch import torch_empty_cache
from ..version import __version__
from ._const import DEFAULT_MAX_SHARD_SIZE
Expand Down Expand Up @@ -257,12 +257,9 @@ def debug_saved_config(path):
self.processor.save_pretrained(save_dir)
# --- end config save block ---

# TODO FIX ME..remove this ugly patch and find core issue why output_embedding is not retied after offload/undo_offload
output_embed = model.get_output_embeddings()
if isinstance(output_embed, HookedLinear):
model.set_output_embeddings(model.get_input_embeddings())
# Due to shell/turtle state, we need to sync the modules from turtle to shell
alias_all_from_turtle_if_meta(shell_model=model, turtle_model=self.turtle_model)

# model.to(CPU) <-- do we need to do this?
state_dict = get_state_dict_for_save(model)

model_base_name = "model"
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def offload_to_disk(module: List[str] | nn.Module, model: nn.Module, disk_path:

_offload_disk(module=module, name=full_name, disk_path=disk_path)

if hasattr(module, "config") and hasattr(module.config,
"tie_word_embeddings") and module.config.tie_word_embeddings:
if hasattr(module, "config") and getattr(module.config,
"tie_word_embeddings", False):
module.tie_weights() # makes lm_head.weight point to embed_tokens.weight again after offload

# print("offload_disk: list item tree")
Expand Down
103 changes: 103 additions & 0 deletions gptqmodel/utils/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import torch
from torch import nn

from ..utils.logger import setup_logger

# =========================
# ANSI color helpers
# =========================
Expand All @@ -39,6 +41,8 @@
FG_CYAN = "\033[36m"
FG_YELLOW = "\033[33m"

log = setup_logger()

def _maybe(s: str, code: str, *, color: bool) -> str:
return f"{code}{s}{RESET}" if color else s

Expand Down Expand Up @@ -368,3 +372,102 @@ def alias_from_turtle_for_submodule(

# return the *target* submodule, which is the injected result
return target_submodule

def _is_meta_tensor(t: torch.Tensor) -> bool:
return bool(getattr(t, "is_meta", False)) or (hasattr(t, "device") and t.device.type == "meta")

def _module_all_meta(mod: nn.Module) -> bool:
"""True if the module has at least one tensor and *all* its params/buffers are meta."""
saw_any = False
for _, p in mod.named_parameters(recurse=False):
saw_any = True
if not _is_meta_tensor(p):
return False
for _, b in mod.named_buffers(recurse=False):
saw_any = True
if not _is_meta_tensor(b):
return False
return saw_any # modules with no tensors aren't considered 'meta' targets

def _is_leaf(mod: nn.Module) -> bool:
return next(mod.named_children(), None) is None

def alias_all_from_turtle_if_meta(
shell_model: nn.Module,
turtle_model: nn.Module,
*,
require_class_match: bool = True,
verify_shapes: bool = True,
tie_after: bool = True,
) -> int:
"""
Replace (alias) leaf submodules in `shell_model` with the corresponding submodules
from `turtle_model` when the shell submodule's tensors are on meta.

Logs each swap via log.info().
"""
turtle_map = dict(turtle_model.named_modules())
swapped = 0

for qname, shell_sub in list(shell_model.named_modules()):
if not qname: # skip root
continue
if not _is_leaf(shell_sub):
continue
if not _module_all_meta(shell_sub):
continue

turtle_sub = turtle_map.get(qname, None)
if turtle_sub is None:
# log.info(f"Module: Skipped {qname}: not found in turtle model")
continue

if require_class_match and (shell_sub.__class__ is not turtle_sub.__class__):
# log.info(
# f"Module: Skipped {qname}: class mismatch "
# f"(shell={shell_sub.__class__.__name__}, turtle={turtle_sub.__class__.__name__})"
# )
continue

if verify_shapes:
shell_ps = dict(shell_sub.named_parameters(recurse=False))
turtle_ps = dict(turtle_sub.named_parameters(recurse=False))
for n in set(shell_ps.keys()) & set(turtle_ps.keys()):
if shell_ps[n].shape != turtle_ps[n].shape:
# log.info(
# f"Module: Skipped {qname}: parameter shape mismatch at '{n}' "
# f"(shell={tuple(shell_ps[n].shape)}, turtle={tuple(turtle_ps[n].shape)})"
# )
break
else:
shell_bs = dict(shell_sub.named_buffers(recurse=False))
turtle_bs = dict(turtle_sub.named_buffers(recurse=False))
for n in set(shell_bs.keys()) & set(turtle_bs.keys()):
if shell_bs[n].shape != turtle_bs[n].shape:
# log.info(
# f"Module: Skipped {qname}: buffer shape mismatch at '{n}' "
# f"(shell={tuple(shell_bs[n].shape)}, turtle={tuple(turtle_bs[n].shape)})"
# )
break
else:
parent, leaf = _get_parent_and_leaf_by_path(shell_model, qname)
setattr(parent, leaf, turtle_sub)
swapped += 1
log.info(f"Module: Sync {qname} <- from turtle ({turtle_sub.__class__.__name__})")
continue
continue

parent, leaf = _get_parent_and_leaf_by_path(shell_model, qname)
setattr(parent, leaf, turtle_sub)
swapped += 1
log.info(f"Module:: Sync {qname} <- from turtle ({turtle_sub.__class__.__name__})")

if tie_after and hasattr(shell_model, "tie_weights") and getattr(shell_model.config, "tie_word_embeddings", False):
try:
shell_model.tie_weights()
log.info("Module: Re-tied embedding weights on shell model after full sync")
except Exception as e:
log.info(f"Module: tie_weights failed: {e}")

log.info(f"Module: Total synced modules: {swapped}")
return swapped