In [None]:
WORKING_DIR = "/home/xavier/Documents/DAE_project"

# Use own method to print

In [15]:
import torch
import dnnlib
import legacy
import os

# The StyleGAN2 repo's own utility for printing model summaries
from torch_utils import misc

# Set compiler environment variables if needed for custom CUDA extensions
os.environ['CC'] = "/usr/bin/gcc-9"
os.environ['CXX'] = "/usr/bin/g++-9"


def analyze_full_model_architecture(network_pkl_path: str):
    """
    Loads a StyleGAN2-based model and prints the architecture summary for the
    Generator (G), Discriminator (D), and Encoder (E) using the built-in
    misc.print_module_summary() utility.
    """
    print(f"Loading models from '{network_pkl_path}'...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    try:
        with dnnlib.util.open_url(network_pkl_path) as fp:
            models = legacy.load_network_pkl(fp)
            G = models['G_ema'].to(device).eval()
            D = models['D'].to(device).eval()
            E = models['E_ema'].to(device).eval()
    except Exception as e:
        print(f"Error loading models: {e}")
        return

    print("Models G, D, and E loaded successfully.")

    # --- Use misc.print_module_summary() as requested ---

    # 1. Define a batch size for creating dummy inputs
    batch_gpu = 1
    print(f"Using batch size: {batch_gpu} for summary generation.\n")

    # 2. Create dummy input tensors based on the Generator's parameters
    # This ensures the inputs have the correct dimensions for all models.
    z = torch.empty([batch_gpu, G.z_dim], device=device)
    c = torch.empty([batch_gpu, G.c_dim], device=device)
    img = torch.empty([batch_gpu, G.img_channels, G.img_resolution, G.img_resolution], device=device)

    print("--- Encoder (E) Summary ---")
    misc.print_module_summary(E, [img, c])

    print("\n" + "=" * 80 + "\n")
    # 3. Print the summary for each model
    print("--- Generator (G) Summary ---")
    misc.print_module_summary(G, [z, c])
    print("\n" + "=" * 80 + "\n")

    print("--- Discriminator (D) Summary ---")
    # Using the exact input structure you provided
    misc.print_module_summary(D, [[img, img], [c, c]])

    print("\nSummary generation complete.")


if __name__ == '__main__':
    # --- Please modify the path to your model here ---
    network_pkl = f"{WORKING_DIR}/models/network-snapshot-001512.pkl"

    # Run the analysis
    analyze_full_model_architecture(network_pkl)



Loading models from '/home/xavier/PycharmProjects/training-runs/e13-post/from302kimgs/00001-stylegan2-trainingset2-gpus4-batch112-gamma10/network-snapshot-001512.pkl'...
Using device: cuda
Models G, D, and E loaded successfully.
Using batch size: 1 for summary generation.

--- Encoder (E) Summary ---

Discriminator2       Parameters  Buffers  Output shape        Datatype
---                  ---         ---      ---                 ---     
b512.fromrgb         64          16       [1, 64, 512, 512]   float16 
b512.skip            8192        16       [1, 128, 256, 256]  float16 
b512.conv0           36864       16       [1, 64, 512, 512]   float16 
b512.conv1           73728       16       [1, 128, 256, 256]  float16 
b512                 -           16       [1, 128, 256, 256]  float16 
b256.skip            32768       16       [1, 256, 128, 128]  float16 
b256.conv0           147456      16       [1, 128, 256, 256]  float16 
b256.conv1           294912      16       [1, 256, 128, 12

# Detailed print

In [25]:
import torch
import dnnlib
import legacy
import os
from collections import OrderedDict

# Set compiler environment variables if needed for custom CUDA extensions
os.environ['CC'] = "/usr/bin/gcc-9"
os.environ['CXX'] = "/usr/bin/g++-9"


def generate_publication_summary(model_name: str, model: torch.nn.Module, dummy_inputs: list,
                                 output_format: str = 'text'):
    """
    Generates a detailed, publication-ready architecture summary for a given model.

    This function can output the summary as plain text or as a LaTeX table.
    """
    if output_format == 'latex':
        print(f"% --- LaTeX Summary for {model_name} (Copy and paste into your .tex file) ---")
    else:
        print(f"--- {model_name} Summary ---")

    # 1. Use hooks to capture I/O shapes and datatypes
    summary_data = OrderedDict()
    hooks = []

    def hook_fn(module, input, output):
        def get_tensor_info(data):
            if isinstance(data, torch.Tensor):
                return list(data.shape), str(data.dtype).replace('torch.', '')
            if isinstance(data, (list, tuple)) and data:
                for item in data:
                    shape, dtype = get_tensor_info(item)
                    if shape is not None:
                        return shape, dtype
            return None, None

        output_shape, output_dtype = get_tensor_info(output)
        summary_data[module] = {"output_shape": output_shape, "output_dtype": output_dtype}

    for module in model.modules():
        hooks.append(module.register_forward_hook(hook_fn))

    with torch.no_grad():
        model(*dummy_inputs)

    for hook in hooks:
        hook.remove()

    # 2. Generate and print the summary in the chosen format
    if output_format == 'latex':
        generate_latex_table(model_name, model, summary_data)
    else:
        generate_text_table(model, summary_data)


def generate_text_table(model, summary_data):
    """Prints the architecture summary as a plain text table."""
    header = (
        f"{'Layer (type)':<50} {'Output Shape':<25} {'Kernel/Stride':<15} "
        f"{'Activation':<12} {'Datatype':<10} {'Parameters':<15}"
    )
    print(header)
    print("=" * len(header))
    total_params = 0

    for name, module in model.named_modules():
        if name == "": continue
        params = sum(p.numel() for p in module.parameters(recurse=False))
        total_params += params
        io_info = summary_data.get(module)
        output_shape_str = str(io_info['output_shape']) if io_info and io_info['output_shape'] else 'N/A'
        output_dtype_str = str(io_info['output_dtype']) if io_info and io_info['output_dtype'] else 'N/A'
        layer_type, details_str, activation_str = get_module_details(module)
        indent = "  " * (name.count('.'))
        name_str = indent + f"{name.split('.')[-1]} ({layer_type})"
        row = (
            f"{name_str:<50} {output_shape_str:<25} {details_str:<15} "
            f"{activation_str:<12} {output_dtype_str:<10} {params:<15,}"
        )
        if params > 0 or "Network" in layer_type:  # Print parent networks even if they have no direct params
            print(row)

    print("=" * len(header))
    print(f"Total Trainable Parameters: {total_params:,}")


def generate_latex_table(model_name, model, summary_data):
    """Prints the architecture summary as a LaTeX table."""
    latex_output = [
        r"\begin{table*}[ht]",
        r"  \centering",
        fr"  \caption{{Detailed architecture of the {model_name} network.}}",
        fr"  \label{{tab:arch_{model_name.lower().replace(' ', '_')}}}",
        r"  \begin{tabular}{l l l l l r}",
        r"    \hline",
        r"    \textbf{Layer (type)} & \textbf{Output Shape} & \textbf{Kernel/Stride} & \textbf{Activation} & \textbf{Datatype} & \textbf{Parameters} \\",
        r"    \hline"
    ]
    total_params = 0

    for name, module in model.named_modules():
        if name == "": continue
        params = sum(p.numel() for p in module.parameters(recurse=False))
        total_params += params
        layer_type, details_str, activation_str = get_module_details(module)

        # Don't print rows for modules with no parameters, unless it's a major structural block
        if params == 0 and "Network" not in layer_type:
            continue

        io_info = summary_data.get(module)
        output_shape_str = str(io_info['output_shape']) if io_info and io_info['output_shape'] else 'N/A'
        output_dtype_str = str(io_info['output_dtype']) if io_info and io_info['output_dtype'] else 'N/A'

        def escape_latex(s):
            return s.replace('_', r'\_').replace('%', r'\%').replace('$', r'\$')

        indent_level = name.count('.')
        indent = r"\quad " * indent_level

        # Bold major network components
        is_major_block = "Network" in layer_type
        name_prefix = r"\textbf{" if is_major_block else ""
        name_suffix = "}" if is_major_block else ""

        # For major blocks, don't show the layer type in parentheses
        if is_major_block:
            name_str = indent + name_prefix + escape_latex(f"{name.split('.')[-1]}") + name_suffix
        else:
            name_str = indent + name_prefix + escape_latex(f"{name.split('.')[-1]} ({layer_type})") + name_suffix

        row_items = [
            name_str,
            escape_latex(output_shape_str),
            escape_latex(details_str),
            escape_latex(activation_str),
            escape_latex(output_dtype_str),
            f"{params:,}" if params > 0 else "-"
        ]
        latex_output.append("    " + " & ".join(row_items) + r" \\")

    latex_output.extend([
        r"    \hline",
        fr"    \textbf{{Total}} & & & & & \textbf{{{total_params:,}}} \\",
        r"    \hline",
        r"  \end{tabular}",
        r"\end{table*}"
    ])
    print("\n".join(latex_output))


def get_module_details(module):
    """Helper function to extract details from a specific module."""
    layer_type = module.__class__.__name__
    details_str = "N/A"
    activation_str = "N/A"

    if layer_type == 'SynthesisLayer':
        k_size = tuple(getattr(module, 'weight', torch.zeros(0)).shape[2:])
        stride_info = f"{module.up} (up)"
        details_str = f"{k_size[0]}x{k_size[1]} / {stride_info}"
        activation_str = getattr(module, 'activation', 'N/A')
    elif layer_type == 'ToRGBLayer':
        k_size = tuple(getattr(module, 'weight', torch.zeros(0)).shape[2:])
        stride_info = "1"  # ToRGBLayer does not upsample.
        details_str = f"{k_size[0]}x{k_size[1]} / {stride_info}"
        activation_str = getattr(module, 'activation', 'N/A')
    elif layer_type == 'Conv2dLayer':
        k_size = tuple(module.weight.shape[2:])
        stride_info = f"{module.down} (down)" if module.down > 1 else f"{module.up} (up)"
        details_str = f"{k_size[0]}x{k_size[1]} / {stride_info}"
        activation_str = getattr(module, 'activation', 'N/A')
    elif layer_type == 'FullyConnectedLayer':
        out_f, in_f = module.weight.shape
        details_str = f"({in_f}, {out_f})"
        activation_str = getattr(module, 'activation', 'N/A')
    elif layer_type in ['MappingNetwork', 'SynthesisNetwork']:
        # For major blocks, we don't show these details in the parent row
        details_str, activation_str = "", ""

    return layer_type, details_str, activation_str


def analyze_full_model_architecture(network_pkl_path: str, output_format: str = 'latex'):
    """Loads G, D, and E and generates a detailed summary for each."""
    print(f"Loading models from '{network_pkl_path}'...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    try:
        with dnnlib.util.open_url(network_pkl_path) as fp:
            models = legacy.load_network_pkl(fp)
            G = models['G_ema'].to(device).eval()
            D = models['D'].to(device).eval()
            E = models['E_ema'].to(device).eval()
    except Exception as e:
        print(f"Error loading models: {e}")
        return

    print(f"Models G, D, and E loaded successfully. Generating {output_format} summaries...\n")

    batch_gpu = 1
    z = torch.empty([batch_gpu, G.z_dim], device=device)
    c = torch.empty([batch_gpu, G.c_dim], device=device)
    img = torch.empty([batch_gpu, G.img_channels, G.img_resolution, G.img_resolution], device=device)

    generate_publication_summary("Encoder (E)", E, [img, c], output_format)
    print("\n")
    generate_publication_summary("Generator (G)", G, [z, c], output_format)
    print("\n")
    generate_publication_summary("Discriminator (D)", D, [[img, img], [c, c]], output_format)

    print(f"\nSummary generation complete. You can now copy the LaTeX code into your Overleaf project.")


if __name__ == '__main__':
    network_pkl = f"{WORKING_DIR}/models/network-snapshot-001512.pkl"
    # To get plain text output instead, change to output_format='text'
    analyze_full_model_architecture(network_pkl, output_format='latex')



Loading models from '/home/xavier/PycharmProjects/training-runs/e13-post/from302kimgs/00001-stylegan2-trainingset2-gpus4-batch112-gamma10/network-snapshot-001512.pkl'...
Using device: cuda
Models G, D, and E loaded successfully. Generating latex summaries...

% --- LaTeX Summary for Encoder (E) (Copy and paste into your .tex file) ---
\begin{table*}[ht]
  \centering
  \caption{Detailed architecture of the Encoder (E) network.}
  \label{tab:arch_encoder_(e)}
  \begin{tabular}{l l l l l r}
    \hline
    \textbf{Layer (type)} & \textbf{Output Shape} & \textbf{Kernel/Stride} & \textbf{Activation} & \textbf{Datatype} & \textbf{Parameters} \\
    \hline
    \quad fromrgb (Conv2dLayer) & [1, 64, 512, 512] & 1x1 / 1 (up) & lrelu & float16 & 64 \\
    \quad conv0 (Conv2dLayer) & [1, 64, 512, 512] & 3x3 / 1 (up) & lrelu & float16 & 36,864 \\
    \quad conv1 (Conv2dLayer) & [1, 128, 256, 256] & 3x3 / 2 (down) & lrelu & float16 & 73,728 \\
    \quad skip (Conv2dLayer) & [1, 128, 256, 256] & 1x1 /