Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..utils.logger import setup_logger
from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module
from ..utils.offload import undo_offload_to_disk
from ..utils.torch import torch_streamCtx, torch_sync
from ..utils.torch import HAS_CUDA, torch_streamCtx, torch_sync

log = setup_logger()
lock = threading.Lock()
Expand Down
12 changes: 11 additions & 1 deletion gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..utils.model import find_modules, get_device, get_module, get_module_by_name_prefix, move_to, nested_move_to
from ..utils.offload import offload_to_disk
from ..utils.structure import print_module_tree
from ..utils.torch import (ALL_DEVICES, CPU, DEFAULT_BALANCE_STRATEGY, META, BalanceStrategy,
from ..utils.torch import (ALL_DEVICES, CPU, DEFAULT_BALANCE_STRATEGY, HAS_CUDA, META, BalanceStrategy,
device_next, device_next_reset, torch_empty_cache, torch_sync)
from .awq_processor import AWQProcessor

Expand Down Expand Up @@ -451,6 +451,11 @@ def loop(self, auto_gc=True, calibration_enable_gpu_cache=True, buffered_fwd=Fal
futures = []

def process_module(name, m):
# prevent cuda sync memory ctx bugs
m_device = get_device(m)
if HAS_CUDA and m_device is not None and m_device.type == "cuda":
torch.cuda.set_device(module.weight.device)

processor.process(module=m, auto_gc=auto_gc)
return name, m

Expand Down Expand Up @@ -544,6 +549,11 @@ def process_module(name, m):
for reverse_p in reversed(self.processors):
for name in processed_subset:
def finalize_module(module):
# prevent cuda sync memory ctx bugs
m_device = get_device(module)
if HAS_CUDA and m_device is not None and m_device.type == "cuda":
torch.cuda.set_device(module.weight.device)

reverse_p.submodule_finalize(module, self.gptq_model)

# checking for disk offloading
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def recurse_setattr(module, name, value):
recurse_setattr(getattr(module, name), rest, value)


def get_device(obj: torch.Tensor | nn.Module):
def get_device(obj: torch.Tensor | nn.Module) -> torch.device:
if isinstance(obj, torch.Tensor):
return obj.device

Expand Down
96 changes: 74 additions & 22 deletions gptqmodel/utils/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,45 +212,87 @@ def print_module_tree(
experts_regex: str = r"(^|\.)experts($|\.)",
experts_show: int = 1,
):
_ = re.compile(filter_regex) if filter_regex else None # reserved for future use
"""
Pretty-print a module tree with sizes, devices, dtypes, and optional param/buffer details.
Each depth uses a distinct color for better readability.
"""

# Color palette per depth (cycles if deeper)
DEPTH_COLORS = [
"\033[36m", # cyan
"\033[33m", # yellow
"\033[35m", # magenta
"\033[32m", # green
"\033[34m", # blue
]

def depth_color(depth: int) -> str:
return DEPTH_COLORS[depth % len(DEPTH_COLORS)]

_ = re.compile(filter_regex) if filter_regex else None
experts_name_re = re.compile(experts_regex) if collapse_experts else None
seen: Set[int] = set()

total_p = sum(p.numel() for p in model.parameters())
total_b = sum(b.numel() for b in model.buffers())

def should_collapse(qual_name: str, container: nn.Module) -> bool:
if not experts_name_re: return False
if not experts_name_re.search(qual_name): return False
if not isinstance(container, (nn.ModuleList, nn.Sequential)): return False
if not experts_name_re:
return False
if not experts_name_re.search(qual_name):
return False
if not isinstance(container, (nn.ModuleList, nn.Sequential)):
return False
names = [n for n, _ in container.named_children()]
if not names: return False
if not names:
return False
return all(n.isdigit() for n in names) and len(names) > max(0, experts_show)

def _format_line(prefix: str, trunk: str, qual_name: str, mod: nn.Module,
show_counts: bool, color: bool, depth: int) -> str:
cls = mod.__class__.__name__
left = _maybe(prefix + trunk, FG_GRAY, color=color)
# Apply depth-based color for the name
name = _maybe(qual_name, depth_color(depth), color=color)
klass = _maybe(cls, DIM, color=color)
if show_counts:
p, b = _counts_for_module(mod)
counts = _maybe(f"(P={human_count(p)} B={human_count(b)})", FG_YELLOW, color=color)
return f"{left}{name}: {klass} {counts}"
else:
return f"{left}{name}: {klass}"

def rec(mod: nn.Module, name: str, depth: int, prefix: str, is_last: bool):
if max_depth is not None and depth > max_depth: return
if max_depth is not None and depth > max_depth:
return
mod_id = id(mod)
shared = "" if mod_id not in seen else " ↩ shared ref"
seen.add(mod_id)

trunk = "└─ " if is_last else "├─ "
line = _format_line(prefix, trunk, name, mod, show_counts=True, color=color)
line = _format_line(prefix, trunk, name, mod, show_counts=True, color=color, depth=depth)
print(line + " " + _annotate(mod, color=color) + shared)
if shared: return
if shared:
return

indent = prefix + (" " if is_last else "│ ")
param_indent = indent + (" " if is_last else "│ ")

if show_all:
_print_params(indent, mod, include_buffers=True, color=color)
_print_params(param_indent, mod, include_buffers=True, color=color)
elif show_params or show_buffers:
_print_params(indent, mod, include_buffers=show_buffers, color=color)
_print_params(param_indent, mod, include_buffers=show_buffers, color=color)

children = list(mod.named_children())
n = len(children)
for i, (child_name, child) in enumerate(children):
last = (i == n - 1)
child_prefix = prefix + (" " if is_last else "│ ")
display_name = f"{name}.{child_name}" if name else child_name

if should_collapse(display_name, child):
line2 = _format_line(child_prefix, "└─ " if last else "├─ ", display_name, child, True, color)
line2 = _format_line(child_prefix, "└─ " if last else "├─ ",
display_name, child, True, color, depth+1)
print(line2 + " " + _annotate(child, color=color))
sub_children = list(child.named_children())
total_k = len(sub_children)
Expand All @@ -259,26 +301,36 @@ def rec(mod: nn.Module, name: str, depth: int, prefix: str, is_last: bool):
sub_last = (j == k_show - 1) and (k_show == total_k)
sub_prefix = child_prefix + (" " if last else "│ ")
sub_trunk = "└─ " if sub_last else "├─ "
line3 = _format_line(sub_prefix, sub_trunk, f"{display_name}.{sub_name}", sub_mod, True, color)
line3 = _format_line(sub_prefix, sub_trunk,
f"{display_name}.{sub_name}",
sub_mod, True, color, depth+2)
print(line3 + " " + _annotate(sub_mod, color=color))
rec(sub_mod, f"{display_name}.{sub_name}", depth + 2, child_prefix + (" " if last else "│ "), sub_last)
if k_show < total_k:
rec(sub_mod, f"{display_name}.{sub_name}",
depth + 2, child_prefix + (" " if last else "│ "), sub_last)
if k_show < total_k and total_k > 0:
p_one, b_one = _param_summary(sub_children[0][1], recurse=True)
collapsed = f"• … collapsed (repeats {k_show}..{total_k-1}, per-expert P={human_count(p_one)} B={human_count(b_one)})"
collapsed = (
f"• … collapsed (repeats {k_show}..{total_k-1}, "
f"per-expert P={human_count(p_one)} B={human_count(b_one)})"
)
print(_maybe(child_prefix + (" " if last else "│ ") + collapsed, DIM, color=color))
continue
rec(child, display_name, depth + 1, child_prefix, last)

print(_format_line("", "", root_name, model, show_counts=True, color=color) + " " + _annotate(model, color=color))
root_indent = " "
# Print root
print(_format_line("", "", root_name, model, show_counts=True, color=color, depth=0)
+ " " + _annotate(model, color=color))

root_trunk_indent = " "
root_param_indent = root_trunk_indent + " "

if show_all:
_print_params(root_indent, model, include_buffers=True, color=color)
_print_params(root_param_indent, model, include_buffers=True, color=color)
elif show_params or show_buffers:
_print_params(root_indent, model, include_buffers=show_buffers, color=color)
_print_params(root_param_indent, model, include_buffers=show_buffers, color=color)

children_root = list(model.named_children())
for i, (child_name, child) in enumerate(children_root):
last = (i == len(children_root) - 1)
for i, (child_name, child) in enumerate(model.named_children()):
last = (i == len(list(model.named_children())) - 1)
rec(child, f"{root_name}.{child_name}", 1, "", last)

print("\nTotal parameters:", human_count(total_p), " | Total buffers:", human_count(total_b))
Expand Down