From c1b26696b556f8d3afae7150f3a2ec7206b445c6 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 26 Sep 2025 01:43:21 +0000 Subject: [PATCH] diff colors per dtype/device Signed-off-by: Qubitium --- gptqmodel/utils/structure.py | 206 ++++++++++++++++++++++++++++++----- 1 file changed, 178 insertions(+), 28 deletions(-) diff --git a/gptqmodel/utils/structure.py b/gptqmodel/utils/structure.py index be7c5e4d9..1b203588d 100644 --- a/gptqmodel/utils/structure.py +++ b/gptqmodel/utils/structure.py @@ -214,27 +214,154 @@ def print_module_tree( ): """ Pretty-print a module tree with sizes, devices, dtypes, and optional param/buffer details. - Each depth uses a distinct color for better readability. + Visual/UI features: + • Depth-based colors for module names (each level cycles its own color) + • Distinct colors for dtype tokens and device tokens + • Params/Buffers are indented one level deeper under each module for clarity """ - # Color palette per depth (cycles if deeper) + # ------------------------------------------------------------------ + # Depth color palette (portable 16-color ANSI for consistent display) + # ------------------------------------------------------------------ DEPTH_COLORS = [ "\033[36m", # cyan "\033[33m", # yellow "\033[35m", # magenta "\033[32m", # green "\033[34m", # blue + "\033[31m", # red ] - def depth_color(depth: int) -> str: return DEPTH_COLORS[depth % len(DEPTH_COLORS)] - _ = re.compile(filter_regex) if filter_regex else None + # ------------------------------------------------------------------ + # Token color maps (dtype/device) — 16-color ANSI with clear labels + # ------------------------------------------------------------------ + DTYPE_COLOR = { + "float32": "\033[36m", # cyan + "float": "\033[36m", # cyan (alias) + "bfloat16":"\033[35m", # magenta + "float16": "\033[33m", # yellow + "half": "\033[33m", # yellow (alias) + "float8_e4m3fn": "\033[34m", # blue + "float8_e5m2": "\033[34m", # blue + "MXFP4": "\033[36m", # cyan (sentinel 4-bit) + "NVFP4": "\033[36m", # cyan (sentinel 4-bit) + "int8": "\033[31m", # red + "uint8": "\033[31m", # red + "int16": "\033[31m", # red + "short": "\033[31m", # red + "int32": "\033[31m", # red + "int": "\033[31m", # red + "bool": "\033[37m", # white/gray + "-": "\033[37m", # white/gray (unknown) + } + DEVICE_COLOR = { + "cpu": "\033[37m", # white/gray + "cuda": "\033[32m", # green + "xpu": "\033[34m", # blue + "npu": "\033[35m", # magenta + "mps": "\033[33m", # yellow + "hip": "\033[31m", # red + "privateuseone":"\033[36m", # cyan + "meta": "\033[90m", # dim gray + "-": "\033[37m", # white/gray (unknown) + } + + def color_dtype(dtype_name: str) -> str: + code = DTYPE_COLOR.get(dtype_name, "") + return f"{code}{dtype_name}{RESET}" if (color and code) else dtype_name + + def color_device(device_str: str) -> str: + # Accept full device strings like "cuda:0" -> key "cuda" + key = device_str.split(":")[0] if device_str else "-" + code = DEVICE_COLOR.get(key, "") + return f"{code}{device_str}{RESET}" if (color and code) else device_str + + # ------------------------------------------------------------------ + # Local helpers (annotation + param printing with colored tokens) + # ------------------------------------------------------------------ + def colorize_annotation(annot: str) -> str: + """ + _annotate(mod) returns strings like: + "[cuda:0 | float16 | ~123MB]" or "[mixed[cuda:0, cpu] | mixed[float16, bfloat16] | ~...]" + We color the device token(s) and dtype token(s) in-place. + """ + if not color or "[" not in annot or "]" not in annot: + return annot + try: + # extract inside [...] and split by ' | ' + left = annot[:annot.find("[")] + inner = annot[annot.find("[")+1 : annot.rfind("]")] + right = annot[annot.rfind("]")+1:] + parts = [p.strip() for p in inner.split("|")] + if len(parts) >= 2: + # devices + dev = parts[0] + if dev.startswith("mixed[") and dev.endswith("]"): + items = dev[6:-1] + colored = ", ".join(color_device(s.strip()) for s in items.split(",")) + parts[0] = f"mixed[{colored}]" + else: + parts[0] = color_device(dev) + + # dtypes + dt = parts[1] + if dt.startswith("mixed[") and dt.endswith("]"): + items = dt[6:-1] + colored = ", ".join(color_dtype(s.strip()) for s in items.split(",")) + parts[1] = f"mixed[{colored}]" + else: + parts[1] = color_dtype(dt) + + return left + "[" + " | ".join(parts) + "]" + right + except Exception: + return annot + return annot + + def print_params_with_colors(indent: str, mod: nn.Module, *, include_buffers: bool): + """ + Local printer for params/buffers with colored dtype/device tokens and sizes. + Mirrors _print_params but adds coloring. + """ + def _line(kind: str, name: str, t: torch.Tensor) -> str: + # device + is_meta = bool(getattr(t, "is_meta", False) or (hasattr(t, "device") and t.device.type == "meta")) + dev_str = "meta" if is_meta else (str(t.device) if hasattr(t, "device") else "-") + dev_col = color_device(dev_str) + + # dtype + dt_raw = getattr(t, "dtype", None) + dt_name = (str(dt_raw).replace("torch.", "")) if dt_raw is not None else "-" + dt_col = color_dtype(dt_name) + + # size + if not is_meta and hasattr(t, "element_size"): + esize = t.element_size() + else: + esize = _elem_size(dt_raw) or 0.0 + sizeb = t.numel() * (esize or 0.0) + + kind_c = _maybe(kind, FG_CYAN, color=color) # "param"/"buffer" label (cyan) + name_c = _maybe(name, FG_GRAY, color=color) # parameter/buffer name (gray) + size_y = _maybe(_human_bytes(sizeb), FG_YELLOW, color=color) # size (yellow) + return f"{indent}{kind_c}: {name_c} shape={tuple(t.shape)} dtype={dt_col} device={dev_col} ~{size_y}" + + for n, p in mod.named_parameters(recurse=False): + print(_line("param", n, p)) + if include_buffers: + for n, b in mod.named_buffers(recurse=False): + print(_line("buffer", n, b)) + + # ------------------------------------------------------------------ + # Setup + utilities + # ------------------------------------------------------------------ + _ = re.compile(filter_regex) if filter_regex else None # reserved for future experts_name_re = re.compile(experts_regex) if collapse_experts else None seen: Set[int] = set() total_p = sum(p.numel() for p in model.parameters()) - total_b = sum(b.numel() for b in model.buffers()) + total_b = sum(b.numel() for b in model.buffers()) # fixed loop variable def should_collapse(qual_name: str, container: nn.Module) -> bool: if not experts_name_re: @@ -249,19 +376,21 @@ def should_collapse(qual_name: str, container: nn.Module) -> bool: return all(n.isdigit() for n in names) and len(names) > max(0, experts_show) def _format_line(prefix: str, trunk: str, qual_name: str, mod: nn.Module, - show_counts: bool, color: bool, depth: int) -> str: + show_counts: bool, depth: int) -> str: cls = mod.__class__.__name__ - left = _maybe(prefix + trunk, FG_GRAY, color=color) - # Apply depth-based color for the name - name = _maybe(qual_name, depth_color(depth), color=color) - klass = _maybe(cls, DIM, color=color) + left = _maybe(prefix + trunk, FG_GRAY, color=color) # tree trunk (gray) + name = _maybe(qual_name, depth_color(depth), color=color) # module name (depth-based color) + klass = _maybe(cls, DIM, color=color) # class name (dim) if show_counts: p, b = _counts_for_module(mod) - counts = _maybe(f"(P={human_count(p)} B={human_count(b)})", FG_YELLOW, color=color) + counts = _maybe(f"(P={human_count(p)} B={human_count(b)})", FG_YELLOW, color=color) # counts (yellow) return f"{left}{name}: {klass} {counts}" else: return f"{left}{name}: {klass}" + # ------------------------------------------------------------------ + # Recursive printer + # ------------------------------------------------------------------ def rec(mod: nn.Module, name: str, depth: int, prefix: str, is_last: bool): if max_depth is not None and depth > max_depth: return @@ -270,18 +399,21 @@ def rec(mod: nn.Module, name: str, depth: int, prefix: str, is_last: bool): seen.add(mod_id) trunk = "└─ " if is_last else "├─ " - line = _format_line(prefix, trunk, name, mod, show_counts=True, color=color, depth=depth) - print(line + " " + _annotate(mod, color=color) + shared) + line = _format_line(prefix, trunk, name, mod, show_counts=True, depth=depth) + annot = colorize_annotation(_annotate(mod, color=color)) + print(line + " " + annot + shared) if shared: return + # child base indent (same depth as module trunk) indent = prefix + (" " if is_last else "│ ") + # params/buffers indent: one level deeper so they clearly nest under the module param_indent = indent + (" " if is_last else "│ ") if show_all: - _print_params(param_indent, mod, include_buffers=True, color=color) + print_params_with_colors(param_indent, mod, include_buffers=True) elif show_params or show_buffers: - _print_params(param_indent, mod, include_buffers=show_buffers, color=color) + print_params_with_colors(param_indent, mod, include_buffers=show_buffers) children = list(mod.named_children()) n = len(children) @@ -292,21 +424,31 @@ def rec(mod: nn.Module, name: str, depth: int, prefix: str, is_last: bool): if should_collapse(display_name, child): line2 = _format_line(child_prefix, "└─ " if last else "├─ ", - display_name, child, True, color, depth+1) - print(line2 + " " + _annotate(child, color=color)) + display_name, child, True, depth+1) + annot2 = colorize_annotation(_annotate(child, color=color)) + print(line2 + " " + annot2) + sub_children = list(child.named_children()) total_k = len(sub_children) k_show = max(0, min(experts_show, total_k)) + for j, (sub_name, sub_mod) in enumerate(sub_children[:k_show]): sub_last = (j == k_show - 1) and (k_show == total_k) sub_prefix = child_prefix + (" " if last else "│ ") sub_trunk = "└─ " if sub_last else "├─ " line3 = _format_line(sub_prefix, sub_trunk, f"{display_name}.{sub_name}", - sub_mod, True, color, depth+2) - print(line3 + " " + _annotate(sub_mod, color=color)) - rec(sub_mod, f"{display_name}.{sub_name}", - depth + 2, child_prefix + (" " if last else "│ "), sub_last) + sub_mod, True, depth+2) + annot3 = colorize_annotation(_annotate(sub_mod, color=color)) + print(line3 + " " + annot3) + rec( + sub_mod, + f"{display_name}.{sub_name}", + depth + 2, + child_prefix + (" " if last else "│ "), + sub_last, + ) + if k_show < total_k and total_k > 0: p_one, b_one = _param_summary(sub_children[0][1], recurse=True) collapsed = ( @@ -315,29 +457,37 @@ def rec(mod: nn.Module, name: str, depth: int, prefix: str, is_last: bool): ) print(_maybe(child_prefix + (" " if last else "│ ") + collapsed, DIM, color=color)) continue + rec(child, display_name, depth + 1, child_prefix, last) - # Print root - print(_format_line("", "", root_name, model, show_counts=True, color=color, depth=0) - + " " + _annotate(model, color=color)) + # ------------------------------------------------------------------ + # Root print + recursion + # ------------------------------------------------------------------ + root_line = _format_line("", "", root_name, model, show_counts=True, depth=0) + root_annot = colorize_annotation(_annotate(model, color=color)) + print(root_line + " " + root_annot) + # Root params/buffers appear one level deeper than child trunks root_trunk_indent = " " root_param_indent = root_trunk_indent + " " if show_all: - _print_params(root_param_indent, model, include_buffers=True, color=color) + print_params_with_colors(root_param_indent, model, include_buffers=True) elif show_params or show_buffers: - _print_params(root_param_indent, model, include_buffers=show_buffers, color=color) + print_params_with_colors(root_param_indent, model, include_buffers=show_buffers) - for i, (child_name, child) in enumerate(model.named_children()): - last = (i == len(list(model.named_children())) - 1) + children_root = list(model.named_children()) + for i, (child_name, child) in enumerate(children_root): + last = (i == len(children_root) - 1) rec(child, f"{root_name}.{child_name}", 1, "", last) + # Footer totals print("\nTotal parameters:", human_count(total_p), " | Total buffers:", human_count(total_b)) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) frozen = total_p - trainable print("Trainable:", human_count(trainable), " | Frozen:", human_count(frozen)) + def _get_qualified_name(root: torch.nn.Module, obj: torch.nn.Module) -> str: for name, mod in root.named_modules(): if mod is obj: