Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 178 additions & 28 deletions gptqmodel/utils/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,27 +214,154 @@ def print_module_tree(
):
"""
Pretty-print a module tree with sizes, devices, dtypes, and optional param/buffer details.
Each depth uses a distinct color for better readability.
Visual/UI features:
• Depth-based colors for module names (each level cycles its own color)
• Distinct colors for dtype tokens and device tokens
• Params/Buffers are indented one level deeper under each module for clarity
"""

# Color palette per depth (cycles if deeper)
# ------------------------------------------------------------------
# Depth color palette (portable 16-color ANSI for consistent display)
# ------------------------------------------------------------------
DEPTH_COLORS = [
"\033[36m", # cyan
"\033[33m", # yellow
"\033[35m", # magenta
"\033[32m", # green
"\033[34m", # blue
"\033[31m", # red
]

def depth_color(depth: int) -> str:
return DEPTH_COLORS[depth % len(DEPTH_COLORS)]

_ = re.compile(filter_regex) if filter_regex else None
# ------------------------------------------------------------------
# Token color maps (dtype/device) — 16-color ANSI with clear labels
# ------------------------------------------------------------------
DTYPE_COLOR = {
"float32": "\033[36m", # cyan
"float": "\033[36m", # cyan (alias)
"bfloat16":"\033[35m", # magenta
"float16": "\033[33m", # yellow
"half": "\033[33m", # yellow (alias)
"float8_e4m3fn": "\033[34m", # blue
"float8_e5m2": "\033[34m", # blue
"MXFP4": "\033[36m", # cyan (sentinel 4-bit)
"NVFP4": "\033[36m", # cyan (sentinel 4-bit)
"int8": "\033[31m", # red
"uint8": "\033[31m", # red
"int16": "\033[31m", # red
"short": "\033[31m", # red
"int32": "\033[31m", # red
"int": "\033[31m", # red
"bool": "\033[37m", # white/gray
"-": "\033[37m", # white/gray (unknown)
}
DEVICE_COLOR = {
"cpu": "\033[37m", # white/gray
"cuda": "\033[32m", # green
"xpu": "\033[34m", # blue
"npu": "\033[35m", # magenta
"mps": "\033[33m", # yellow
"hip": "\033[31m", # red
"privateuseone":"\033[36m", # cyan
"meta": "\033[90m", # dim gray
"-": "\033[37m", # white/gray (unknown)
}

def color_dtype(dtype_name: str) -> str:
code = DTYPE_COLOR.get(dtype_name, "")
return f"{code}{dtype_name}{RESET}" if (color and code) else dtype_name

def color_device(device_str: str) -> str:
# Accept full device strings like "cuda:0" -> key "cuda"
key = device_str.split(":")[0] if device_str else "-"
code = DEVICE_COLOR.get(key, "")
return f"{code}{device_str}{RESET}" if (color and code) else device_str

# ------------------------------------------------------------------
# Local helpers (annotation + param printing with colored tokens)
# ------------------------------------------------------------------
def colorize_annotation(annot: str) -> str:
"""
_annotate(mod) returns strings like:
"[cuda:0 | float16 | ~123MB]" or "[mixed[cuda:0, cpu] | mixed[float16, bfloat16] | ~...]"
We color the device token(s) and dtype token(s) in-place.
"""
if not color or "[" not in annot or "]" not in annot:
return annot
try:
# extract inside [...] and split by ' | '
left = annot[:annot.find("[")]
inner = annot[annot.find("[")+1 : annot.rfind("]")]
right = annot[annot.rfind("]")+1:]
parts = [p.strip() for p in inner.split("|")]
if len(parts) >= 2:
# devices
dev = parts[0]
if dev.startswith("mixed[") and dev.endswith("]"):
items = dev[6:-1]
colored = ", ".join(color_device(s.strip()) for s in items.split(","))
parts[0] = f"mixed[{colored}]"
else:
parts[0] = color_device(dev)

# dtypes
dt = parts[1]
if dt.startswith("mixed[") and dt.endswith("]"):
items = dt[6:-1]
colored = ", ".join(color_dtype(s.strip()) for s in items.split(","))
parts[1] = f"mixed[{colored}]"
else:
parts[1] = color_dtype(dt)

return left + "[" + " | ".join(parts) + "]" + right
except Exception:
return annot
return annot

def print_params_with_colors(indent: str, mod: nn.Module, *, include_buffers: bool):
"""
Local printer for params/buffers with colored dtype/device tokens and sizes.
Mirrors _print_params but adds coloring.
"""
def _line(kind: str, name: str, t: torch.Tensor) -> str:
# device
is_meta = bool(getattr(t, "is_meta", False) or (hasattr(t, "device") and t.device.type == "meta"))
dev_str = "meta" if is_meta else (str(t.device) if hasattr(t, "device") else "-")
dev_col = color_device(dev_str)

# dtype
dt_raw = getattr(t, "dtype", None)
dt_name = (str(dt_raw).replace("torch.", "")) if dt_raw is not None else "-"
dt_col = color_dtype(dt_name)

# size
if not is_meta and hasattr(t, "element_size"):
esize = t.element_size()
else:
esize = _elem_size(dt_raw) or 0.0
sizeb = t.numel() * (esize or 0.0)

kind_c = _maybe(kind, FG_CYAN, color=color) # "param"/"buffer" label (cyan)
name_c = _maybe(name, FG_GRAY, color=color) # parameter/buffer name (gray)
size_y = _maybe(_human_bytes(sizeb), FG_YELLOW, color=color) # size (yellow)
return f"{indent}{kind_c}: {name_c} shape={tuple(t.shape)} dtype={dt_col} device={dev_col} ~{size_y}"

for n, p in mod.named_parameters(recurse=False):
print(_line("param", n, p))
if include_buffers:
for n, b in mod.named_buffers(recurse=False):
print(_line("buffer", n, b))

# ------------------------------------------------------------------
# Setup + utilities
# ------------------------------------------------------------------
_ = re.compile(filter_regex) if filter_regex else None # reserved for future
experts_name_re = re.compile(experts_regex) if collapse_experts else None
seen: Set[int] = set()

total_p = sum(p.numel() for p in model.parameters())
total_b = sum(b.numel() for b in model.buffers())
total_b = sum(b.numel() for b in model.buffers()) # fixed loop variable

def should_collapse(qual_name: str, container: nn.Module) -> bool:
if not experts_name_re:
Expand All @@ -249,19 +376,21 @@ def should_collapse(qual_name: str, container: nn.Module) -> bool:
return all(n.isdigit() for n in names) and len(names) > max(0, experts_show)

def _format_line(prefix: str, trunk: str, qual_name: str, mod: nn.Module,
show_counts: bool, color: bool, depth: int) -> str:
show_counts: bool, depth: int) -> str:
cls = mod.__class__.__name__
left = _maybe(prefix + trunk, FG_GRAY, color=color)
# Apply depth-based color for the name
name = _maybe(qual_name, depth_color(depth), color=color)
klass = _maybe(cls, DIM, color=color)
left = _maybe(prefix + trunk, FG_GRAY, color=color) # tree trunk (gray)
name = _maybe(qual_name, depth_color(depth), color=color) # module name (depth-based color)
klass = _maybe(cls, DIM, color=color) # class name (dim)
if show_counts:
p, b = _counts_for_module(mod)
counts = _maybe(f"(P={human_count(p)} B={human_count(b)})", FG_YELLOW, color=color)
counts = _maybe(f"(P={human_count(p)} B={human_count(b)})", FG_YELLOW, color=color) # counts (yellow)
return f"{left}{name}: {klass} {counts}"
else:
return f"{left}{name}: {klass}"

# ------------------------------------------------------------------
# Recursive printer
# ------------------------------------------------------------------
def rec(mod: nn.Module, name: str, depth: int, prefix: str, is_last: bool):
if max_depth is not None and depth > max_depth:
return
Expand All @@ -270,18 +399,21 @@ def rec(mod: nn.Module, name: str, depth: int, prefix: str, is_last: bool):
seen.add(mod_id)

trunk = "└─ " if is_last else "├─ "
line = _format_line(prefix, trunk, name, mod, show_counts=True, color=color, depth=depth)
print(line + " " + _annotate(mod, color=color) + shared)
line = _format_line(prefix, trunk, name, mod, show_counts=True, depth=depth)
annot = colorize_annotation(_annotate(mod, color=color))
print(line + " " + annot + shared)
if shared:
return

# child base indent (same depth as module trunk)
indent = prefix + (" " if is_last else "│ ")
# params/buffers indent: one level deeper so they clearly nest under the module
param_indent = indent + (" " if is_last else "│ ")

if show_all:
_print_params(param_indent, mod, include_buffers=True, color=color)
print_params_with_colors(param_indent, mod, include_buffers=True)
elif show_params or show_buffers:
_print_params(param_indent, mod, include_buffers=show_buffers, color=color)
print_params_with_colors(param_indent, mod, include_buffers=show_buffers)

children = list(mod.named_children())
n = len(children)
Expand All @@ -292,21 +424,31 @@ def rec(mod: nn.Module, name: str, depth: int, prefix: str, is_last: bool):

if should_collapse(display_name, child):
line2 = _format_line(child_prefix, "└─ " if last else "├─ ",
display_name, child, True, color, depth+1)
print(line2 + " " + _annotate(child, color=color))
display_name, child, True, depth+1)
annot2 = colorize_annotation(_annotate(child, color=color))
print(line2 + " " + annot2)

sub_children = list(child.named_children())
total_k = len(sub_children)
k_show = max(0, min(experts_show, total_k))

for j, (sub_name, sub_mod) in enumerate(sub_children[:k_show]):
sub_last = (j == k_show - 1) and (k_show == total_k)
sub_prefix = child_prefix + (" " if last else "│ ")
sub_trunk = "└─ " if sub_last else "├─ "
line3 = _format_line(sub_prefix, sub_trunk,
f"{display_name}.{sub_name}",
sub_mod, True, color, depth+2)
print(line3 + " " + _annotate(sub_mod, color=color))
rec(sub_mod, f"{display_name}.{sub_name}",
depth + 2, child_prefix + (" " if last else "│ "), sub_last)
sub_mod, True, depth+2)
annot3 = colorize_annotation(_annotate(sub_mod, color=color))
print(line3 + " " + annot3)
rec(
sub_mod,
f"{display_name}.{sub_name}",
depth + 2,
child_prefix + (" " if last else "│ "),
sub_last,
)

if k_show < total_k and total_k > 0:
p_one, b_one = _param_summary(sub_children[0][1], recurse=True)
collapsed = (
Expand All @@ -315,29 +457,37 @@ def rec(mod: nn.Module, name: str, depth: int, prefix: str, is_last: bool):
)
print(_maybe(child_prefix + (" " if last else "│ ") + collapsed, DIM, color=color))
continue

rec(child, display_name, depth + 1, child_prefix, last)

# Print root
print(_format_line("", "", root_name, model, show_counts=True, color=color, depth=0)
+ " " + _annotate(model, color=color))
# ------------------------------------------------------------------
# Root print + recursion
# ------------------------------------------------------------------
root_line = _format_line("", "", root_name, model, show_counts=True, depth=0)
root_annot = colorize_annotation(_annotate(model, color=color))
print(root_line + " " + root_annot)

# Root params/buffers appear one level deeper than child trunks
root_trunk_indent = " "
root_param_indent = root_trunk_indent + " "

if show_all:
_print_params(root_param_indent, model, include_buffers=True, color=color)
print_params_with_colors(root_param_indent, model, include_buffers=True)
elif show_params or show_buffers:
_print_params(root_param_indent, model, include_buffers=show_buffers, color=color)
print_params_with_colors(root_param_indent, model, include_buffers=show_buffers)

for i, (child_name, child) in enumerate(model.named_children()):
last = (i == len(list(model.named_children())) - 1)
children_root = list(model.named_children())
for i, (child_name, child) in enumerate(children_root):
last = (i == len(children_root) - 1)
rec(child, f"{root_name}.{child_name}", 1, "", last)

# Footer totals
print("\nTotal parameters:", human_count(total_p), " | Total buffers:", human_count(total_b))
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen = total_p - trainable
print("Trainable:", human_count(trainable), " | Frozen:", human_count(frozen))


def _get_qualified_name(root: torch.nn.Module, obj: torch.nn.Module) -> str:
for name, mod in root.named_modules():
if mod is obj:
Expand Down