# Visualising UNET #

### Abstract ###

- Self explained. Using `torch_view` as main library.
- Some notations are borrowed from [mega_cmp.ipynb](./v2a/mmega_cmp.ipynb)
- **Warning:** [Make sure your C Drive is huge.](https://huggingface.co/docs/datasets/cache) Default directory is `C:\Users\User\.cache\huggingface\hub`
- Also make sure token is present in `C:\Users\User\.cache\huggingface\token`.
- From the inconsistint result of `["sdxl", "sd1", "sd2"]` **which was overestimated for 2.37x** (others are < 0.1%), I also implemented `diffusers` and `torch` native approach based from [this Stackoverflow post](https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model). Issues [#262](https://github.com/TylerYep/torchinfo/issues/262), [#303](https://github.com/TylerYep/torchinfo/issues/303), [#312](https://github.com/TylerYep/torchinfo/issues/312) were reported in `torchinfo`, which made me a bit panic. *Hopefully it can be justified from future inconsistent results.*
- Refer [diffusers.num_parameters](https://huggingface.co/docs/diffusers/api/models/overview#diffusers.ModelMixin.num_parameters) [and its code](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/modeling_utils.py#L1040), [nn.Parameter](https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html), [torch.numel](https://pytorch.org/docs/stable/generated/torch.numel.html) for how it is counted. It is very likely **MISMATCH** for other contents (e.g. `torchvision` and `torchinfo` here, refered as ["model summary"](https://stackoverflow.com/questions/42480111/how-do-i-print-the-model-summary-in-pytorch) )

### Required libraries ###

- ~~Should be the common ML pack we're using. Also with [SD webui's dependency](https://github.com/AUTOMATIC1111/stable-diffusion-webui).~~

- [torchview](https://torchview.dev/)
- [safetensors](https://huggingface.co/docs/safetensors/index)
- [diffusers](https://huggingface.co/docs/diffusers/installation)
- [omegaconf](https://anaconda.org/conda-forge/omegaconf)
- [pytorch](https://pytorch.org/get-started/locally/#windows-python)
- [Graphviz](https://graphviz.org/)
- [torchinfo](https://pypi.org/project/torchinfo/)
- [prettytable](https://pypi.org/project/prettytable/)
- ~~[ultralytics-thop](https://github.com/ultralytics/thop/tree/main)~~

### Some layer name to interprept ###

- Whole model combined as called `DiffusionPipeline` in [Diffusers](https://huggingface.co/docs/diffusers/index).

|Layer name|Description|Class name in Diffusers|
|---|---|---|
|`first_stage_model`|VAE|`AutoencoderKL`|
|`cond_stage_model`|Text Encoder (SD1, SD2)|`CLIPTextModel`|
|`conditioner.embedders.0`|Text Encoder 1 (SDXL)|`CLIPTextModel`|
|`conditioner.embedders.1`|Text Encoder 2 (SDXL)|`CLIPTextModelWithProjection`|
|`model.diffusion_model`|UNET|`UNet2DConditionModel`|
|`model_ema`|EMA model for training|n/a|
|`cumprod`, `betas`, `alphas`|`CosineAnnealingLR`|n/a|

### Some notation (Useful in the bin chart) ###
- `attn1`: `sattn` = *Self attention*
- `attn2`: `xattn` = *Cross attention*
- `ff`: *Feed forward*
- `norm`: [Normalisation layer](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html). `elementwise_affine=True` introduces trainable `bias` and `weight`. 
- `proj`: *Projection*
- `emb_layers`: *Embedding layers*
- `mlp`: *Multilayer perceptron*
- `others`: `ff` + `norm` + `proj` + `emb_layers`

## Importing libraries ##

In [1]:
import os
import torch
import diffusers
import accelerate
import transformers
import huggingface_hub
from safetensors.torch import load_file #safe_open
from diffusers import UNet2DConditionModel, SD3Transformer2DModel, FluxTransformer2DModel, AuraFlowTransformer2DModel, HunyuanDiT2DModel

from torchview import draw_graph
from torchinfo import summary

from prettytable import PrettyTable

#from thop import profile

import graphviz
graphviz.set_jupyter_format('png')

'svg'

In [2]:
print(torch.__version__)
print(diffusers.__version__)
print(transformers.__version__)
print(accelerate.__version__)
print(huggingface_hub.__version__)

2.4.0+cu124
0.31.0
4.44.0
0.33.0
0.24.5


In [3]:
# Fix for OMP: Error #15
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [4]:
# CPU for Flux.
g_device = "cuda:0" #"cpu"
# Currently for generating graph only.
g_seed = 114514

In [5]:
# Model path
model_path = {
    "sd1": "runwayml/stable-diffusion-v1-5",
    "sd2": "stabilityai/stable-diffusion-2-1",
    "sdxl": "stabilityai/stable-diffusion-xl-base-1.0",
    "sd3": "stabilityai/stable-diffusion-3-medium-diffusers",
    "sd35":"stabilityai/stable-diffusion-3.5-large",
    "flux": "black-forest-labs/FLUX.1-dev",
    "af": "fal/AuraFlow-v0.2",
    "hy": "Tencent-Hunyuan/HunyuanDiT-Diffusers",
}

model_type = torch.float16 if "cuda" in g_device else torch.float # CPU doesn't support FP16 / FP8
long_type = torch.int64 if "cuda" in g_device else torch.long # CPU doesn't support FP16 / FP8

Only online model is available.

`load_single_file` is failed (versioning hell, omitted). I load online model instead.

In [6]:
unet_instance = None # Clear
unet_instance = {} # Clear

In [7]:
# We run it later
# for k in model_path.keys():
#    unet_instance[k] = UNet2DConditionModel.from_pretrained(model_path[k], subfolder="unet",  torch_dtype=torch.float16).to(g_device)

Input size is trial and error.

Not actually, we can read `config.json` form the actual official model in HuggingFace, and `nn.Module` has already created with the config. 

Originally it is scattered in different Git Repos, but HF does a great job here.

- `sd1`: [stabilityai/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json)
- `sd2`: [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json)
- `sdxl`: [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json)
- `sd3`: [stabilityai/stable-diffusion-3-medium-diffusers](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers/blob/main/transformer/config.json)
- `flux`: [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/transformer/config.json)
- `af`: [fal/AuraFlow-v0.2](https://huggingface.co/fal/AuraFlow-v0.2/blob/main/transformer/config.json)
- `hy`: [Tencent-Hunyuan/HunyuanDiT-Diffusers](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-Diffusers/blob/main/transformer/config.json)

In [8]:
# Not Used.
input_data_mapping_sample = {
    "sd1": {
        'sample': torch.rand(1, 4, 64, 64).type(model_type).to(g_device),
        'timestep': torch.rand(1).type(model_type).to(g_device),
        'encoder_hidden_states': torch.rand(1, 77, 768).type(model_type).to(g_device),
    },
    "sd2": {
        'sample': torch.rand(1, 4, 96, 96).type(model_type).to(g_device),
        'timestep': torch.rand(1).type(model_type).to(g_device),
        'encoder_hidden_states': torch.rand(1280, 77, 1024).type(model_type).to(g_device),
    },
    "sdxl": {
        'sample': torch.rand(2, 4, 128, 128).type(model_type).to(g_device),
        'timestep': torch.rand(1).type(model_type).to(g_device),
        'encoder_hidden_states': torch.rand(2, 77, 2048).type(model_type).to(g_device),
        'added_cond_kwargs': {
            'text_embeds': torch.rand(2, 1280).type(model_type).to(g_device),
            'time_ids': torch.rand(2, 6).type(model_type).to(g_device),
        },
    },
    "sd3": {
        'hidden_states': torch.rand(1, 16, 128, 128).type(model_type).to(g_device),
        'timestep': torch.ones((1, )).type(long_type).to(g_device),
        'encoder_hidden_states': torch.rand(1, 77, 4096).type(model_type).to(g_device),
        'pooled_projections': torch.rand(1, 2048).type(model_type).to(g_device)
    },
    "sd35": {
        'hidden_states': torch.rand(1, 16, 128, 128).type(model_type).to(g_device),
        'timestep': torch.ones((1, )).type(long_type).to(g_device),
        'encoder_hidden_states': torch.rand(1, 77, 4096).type(model_type).to(g_device),
        'pooled_projections': torch.rand(1, 2048).type(model_type).to(g_device)
    },
    "flux": {
        'hidden_states': torch.rand(1, 4096, 64).type(model_type).to(g_device),
        'timestep': torch.ones((1, )).type(long_type).to(g_device),
        'guidance': torch.zeros((1, )).type(long_type).to(g_device),
        'encoder_hidden_states': torch.rand(1, 256, 4096).type(model_type).to(g_device),
        'pooled_projections': torch.rand(1, 768).type(model_type).to(g_device),
        'txt_ids': torch.rand(1, 256, 3).type(model_type).to(g_device),
        'img_ids': torch.rand(1, 4096, 3).type(model_type).to(g_device)
    },
    "af": {
        'hidden_states': torch.rand(2, 4, 128, 128).type(model_type).to(g_device),
        'encoder_hidden_states': torch.rand(2, 256, 2048).type(model_type).to(g_device),
        'timestep': torch.ones((2, )).type(long_type).to(g_device)
    },        
    "hy": {
        'hidden_states': torch.rand(2, 4, 128, 128).type(model_type).to(g_device),
        'timestep': torch.ones((2, )).type(long_type).to(g_device),
        'encoder_hidden_states': torch.rand(2, 77, 1024).type(model_type).to(g_device),
        'text_embedding_mask': torch.rand(2, 77).type(model_type).to(g_device),        
        'encoder_hidden_states_t5': torch.rand(2, 256, 2048).type(model_type).to(g_device),
        'text_embedding_mask_t5': torch.rand(2, 256).type(model_type).to(g_device),
        'image_meta_size': torch.rand(2, 6).type(model_type).to(g_device),
        'style': torch.rand((2, )).type(long_type).to(g_device),
        # Don't know how it becomes 2D.
        'image_rotary_emb': [torch.rand(4096, 88).type(model_type).to(g_device), torch.rand(4096, 88).type(model_type).to(g_device)]
    },
}

Setting Graph output.

In [9]:
filename_paths = {
    #"sd1": "./sd1_unet",
    #"sd2": "./sd2_unet",
    #"sdxl": "./sdxl_unet", 
    #"sd3": "./sd3_mmdit",
    "sd35": "./sd35_mmdit",
    #"flux": "./flux_mmdit", # Good for RTX 3090, 31 minutes on CPU
    #"af": "./af_mmdit",
    #"hy": "./hy_mmdit",
}

In [10]:
# For in place generation (requested)
png_results = {}

Make the native `torch` approach first.

In [11]:
def count_parameters_native(model):
    
    #Official approach straight from diffusers.  
    diffuser_total_params = model.num_parameters()

    #Native approach, along with layer name.
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params

    #Another native approach, but in single line
    pytorch_total_params = sum(p.numel() for p in model.parameters())

    result_str = "\r\n".join([
        table.get_string(), 
        f"Total Params (diffuser): {str(diffuser_total_params)}",
        f"Total Trainable Params (torch): {str(total_params)}", 
        f"Total Params (torch): {str(pytorch_total_params)}"
    ])
    return result_str

Main loop. Note that `input_data` is generated inplace. The expected dimension is already available in `model.config` mentioned above.

Notice that we need both `depth=1` and `depth=2` to **link MBW layers**.

Also it keep the diagram elegent (however the count won't match!)

In [12]:
# Intercepted from StableDiffusionXLPipeline
# C:\Users\User\.conda\envs\sklearn-env\Lib\site-packages\diffusers\pipelines\stable_diffusion_xl
if False:
    #torch.Size([2, 4, 128, 128])
    print(latent_model_input.shape)
    #torch.Size([])
    print(t.shape)
    #torch.Size([2, 77, 2048])
    print(prompt_embeds.shape)
    #torch.Size([2, 1280])
    print(added_cond_kwargs["text_embeds"].shape)
    #torch.Size([2, 6])
    print(added_cond_kwargs["time_ids"].shape)
    raise Exception("We are here.")

# Intercepted from FluxPipeline
# C:\ProgramData\Miniconda3\envs\novelai-env\Lib\site-packages\diffusers\pipelines\flux
if False:
    #torch.Size([1, 4096, 64])
    print(hidden_states.shape)
    #tensor([1000.0000,  904.5308,  759.5109,  512.8441], device='cuda:0') / 1000
    print(timestep.shape)
    #tensor([0.], device='cuda:0')
    print(guidance.shape)
    #torch.Size([1, 768])
    print(pooled_projections.shape)
    #torch.Size([1, 256, 4096])
    print(encoder_hidden_states.shape)
    #torch.Size([1, 256, 3])
    print(txt_ids.shape)
    #torch.Size([1, 4096, 3])
    print(img_ids.shape)
    #None
    print(joint_attention_kwargs)                
    raise Exception("We are here.")

# Intercepted from AuraFlowPipeline. Noticed that the variable name is not the one to the DiT.
# C:\ProgramData\Miniconda3\envs\novelai-env\Lib\site-packages\diffusers\pipelines\aura_flow
if False:
    #torch.Size([2, 4, 128, 128])
    print(latent_model_input.shape)
    #torch.Size([2, 256, 2048])
    print(prompt_embeds.shape)
    #torch.Size([2])
    print(timestep.shape)
    raise Exception("We are here.")

# Intercepted from HunyuanDiTPipeline. Noticed that the variable name is not the one to the DiT.
# C:\ProgramData\Miniconda3\envs\novelai-env\Lib\site-packages\diffusers\pipelines\hunyuandit
if False:
    #torch.Size([2, 4, 128, 128])  
    print(latent_model_input.shape)
    #torch.Size([2])
    print(t_expand.shape)
    #torch.Size([2, 77, 1024])
    print(prompt_embeds.shape)
    #torch.Size([2, 77])
    print(prompt_attention_mask.shape)
    #torch.Size([2, 256, 2048])
    print(prompt_embeds_2.shape)
    #torch.Size([2, 256])
    print(prompt_attention_mask_2.shape)
    #torch.Size([2, 6])
    print(add_time_ids.shape)
    #torch.Size([2])
    print(style.shape)
    #[torch.Size([4096, 88])]
    print(image_rotary_emb.shape)
    raise Exception("We are here.")

In [13]:
def get_mbw_component(cur_unet):
    if cur_unet == "sd1":
        return UNet2DConditionModel.from_pretrained(model_path[cur_unet], subfolder="unet", torch_dtype=model_type).to(g_device)
    elif cur_unet == "sd2":
        return UNet2DConditionModel.from_pretrained(model_path[cur_unet], subfolder="unet", torch_dtype=model_type).to(g_device)
    elif cur_unet == "sdxl":
        return UNet2DConditionModel.from_pretrained(model_path[cur_unet], subfolder="unet", torch_dtype=model_type).to(g_device)
    elif cur_unet == "sd3":
        return SD3Transformer2DModel.from_pretrained(model_path[cur_unet], subfolder="transformer", torch_dtype=model_type).to(g_device)
    elif cur_unet == "sd35":
        return SD3Transformer2DModel.from_pretrained(model_path[cur_unet], subfolder="transformer", torch_dtype=model_type).to(g_device)
    elif cur_unet == "flux":
        return FluxTransformer2DModel.from_pretrained(model_path[cur_unet], subfolder="transformer", guidance_embeds=True, torch_dtype=model_type).to(g_device)
    elif cur_unet == "af":
        return AuraFlowTransformer2DModel.from_pretrained(model_path[cur_unet], subfolder="transformer", torch_dtype=model_type).to(g_device)
    elif cur_unet == "hy":
        return HunyuanDiT2DModel.from_pretrained(model_path[cur_unet], subfolder="transformer", torch_dtype=model_type).to(g_device)

def get_feature_dim(cur_unet):
    # Should fit CLIPTextModel.hidden_size, 2048 = 768 + 1280 for SDXL
    if cur_unet == "sd1":
        return unet_instance[cur_unet].config.cross_attention_dim
    elif cur_unet == "sd2":
        return unet_instance[cur_unet].config.cross_attention_dim
    elif cur_unet == "sdxl":
        return unet_instance[cur_unet].config.cross_attention_dim
    elif cur_unet == "sd3":
        return unet_instance[cur_unet].config.joint_attention_dim
    elif cur_unet == "sd35":
        return unet_instance[cur_unet].config.joint_attention_dim
    elif cur_unet == "flux":
        return unet_instance[cur_unet].config.joint_attention_dim
    elif cur_unet == "af":
        return unet_instance[cur_unet].config.joint_attention_dim
    elif cur_unet == "hy":
        return unet_instance[cur_unet].config.cross_attention_dim

def get_sample_height(cur_unet):
    # IDK why Flux doesn't include the "128" as sample_size
    if cur_unet == "sd1":
        return unet_instance[cur_unet].config.sample_size
    elif cur_unet == "sd2":
        return unet_instance[cur_unet].config.sample_size
    elif cur_unet == "sdxl":
        return unet_instance[cur_unet].config.sample_size
    elif cur_unet == "sd3":
        return unet_instance[cur_unet].config.sample_size
    elif cur_unet == "sd35":
        return unet_instance[cur_unet].config.sample_size
    elif cur_unet == "flux":
        # Will drop eventually.
        return 64 #default_sample_size
    elif cur_unet == "af":
        return unet_instance[cur_unet].config.sample_size
    elif cur_unet == "hy":
        return unet_instance[cur_unet].config.sample_size
    
def get_sequence_length(cur_unet):
    # IDK why Flux doesn't include the "128" as sample_size
    if cur_unet == "sd1":
        return 77 # See CLIPTextModel.max_position_embeddings
    elif cur_unet == "sd2":
        return 77 # See CLIPTextModel.max_position_embeddings
    elif cur_unet == "sdxl":
        return 77 # See CLIPTextModel.max_position_embeddings
    elif cur_unet == "sd3":
        return 77 # See CLIPTextModel.max_position_embeddings
    elif cur_unet == "sd35":
        return 77 # See CLIPTextModel.max_position_embeddings
    elif cur_unet == "flux":
        return 256 #base_seq_len
    elif cur_unet == "af":
        return 256 #max_sequence_length
    elif cur_unet == "hy":
        return unet_instance[cur_unet].config.text_len #77

def main_loop(cur_unet):
    #240807: It is no longer "UNet" but we can still treat it as "that particular part of diffusion model"
    unet_instance[cur_unet] = get_mbw_component(cur_unet) if cur_unet not in unet_instance else unet_instance[cur_unet]

    sequence_length = get_sequence_length(cur_unet)
    feature_dim = get_feature_dim(cur_unet)
    height = get_sample_height(cur_unet)
    width = height #square is fine
    channel = unet_instance[cur_unet].config.in_channels
    step = 1 #arbitary single float
    batch = 1 #1bs

    inplace_input_data = {
        'sample': torch.rand(batch, channel, height, width).type(model_type).to(g_device),
        'timestep': torch.rand(step).type(model_type).to(g_device),
        'encoder_hidden_states': torch.rand(batch, sequence_length, feature_dim).type(model_type).to(g_device),
    }
    
    # SDXL special
    if cur_unet == "sdxl":
        addition_time_embed_dim = unet_instance[cur_unet].config.addition_time_embed_dim
        projection_class_embeddings_input_dim = unet_instance[cur_unet].config.projection_class_embeddings_input_dim
        conv_in_kernel = unet_instance[cur_unet].config.conv_in_kernel
        conv_out_kernel = unet_instance[cur_unet].config.conv_out_kernel
        time_sequence_length = int((projection_class_embeddings_input_dim - addition_time_embed_dim) / 2)
        time_id_length = conv_in_kernel + conv_out_kernel
        inplace_input_data['added_cond_kwargs'] = {
            'text_embeds': torch.rand(batch, time_sequence_length).type(model_type).to(g_device),
            'time_ids': torch.rand(batch, time_id_length).type(model_type).to(g_device),
        }

    # SD3 special
    if (cur_unet == "sd3") or (cur_unet == "sd35") or (cur_unet == "flux") or (cur_unet == "hy"):
        if not (cur_unet == "af"):
            projection_dim = unet_instance[cur_unet].config.pooled_projection_dim

        # Drop sample
        del inplace_input_data['sample']
        inplace_input_data['hidden_states'] = torch.rand(batch, channel, height, width).type(model_type).to(g_device)      
        inplace_input_data['pooled_projections'] = torch.rand(batch, projection_dim).type(model_type).to(g_device)
        inplace_input_data['timestep'] = torch.ones((batch, )).type(long_type).to(g_device) #torch.ones((inner_dim, ), dtype=torch.long).to(g_device)

        if (cur_unet == "flux"):
            #height * width (64*64) instead of feature_dim (4096)
            inplace_input_data['hidden_states'] = torch.rand(batch, height * width, channel).type(model_type).to(g_device)
            inplace_input_data['guidance'] = torch.zeros((batch, )).type(model_type).to(g_device)
            inplace_input_data['txt_ids'] = torch.zeros(batch, sequence_length, 3).type(model_type).to(g_device)
            inplace_input_data['img_ids'] = torch.zeros(batch, height * width, 3).type(model_type).to(g_device)

        if (cur_unet == "hy"):
            attention_head_dim = unet_instance[cur_unet].attention_head_dim
            sequence_length_t5 = unet_instance[cur_unet].config.text_len_t5
            feature_dim_t5 = unet_instance[cur_unet].config.cross_attention_dim_t5
            inplace_input_data['encoder_hidden_states_t5'] = torch.rand(batch, sequence_length_t5, feature_dim_t5).type(model_type).to(g_device)
            inplace_input_data['text_embedding_mask'] = torch.rand(batch, sequence_length).type(model_type).to(g_device)
            inplace_input_data['text_embedding_mask_t5'] = torch.rand(batch, sequence_length_t5).type(model_type).to(g_device)
            del inplace_input_data['pooled_projections']
            inplace_input_data['timestep'] = torch.ones((batch, )).type(model_type).to(g_device)
            inplace_input_data['image_meta_size'] = torch.rand(batch, 6).type(model_type).to(g_device)
            inplace_input_data['style'] = torch.rand((batch, )).type(long_type).to(g_device)
            # Don't know how it becomes 2D.
            inplace_input_data['image_rotary_emb'] = [
                torch.rand(feature_dim * 4, attention_head_dim).type(model_type).to(g_device), 
                torch.rand(feature_dim * 4, attention_head_dim).type(model_type).to(g_device)
            ]
        
    # AF special: just build from scratch.
    if (cur_unet == "af"):      
        do_cfg = 2 # do_classifier_free_guidance
        inplace_input_data = {
            'hidden_states': torch.rand(batch * do_cfg, channel, height * do_cfg, width * do_cfg).type(model_type).to(g_device),
            'encoder_hidden_states': torch.rand(batch * do_cfg, sequence_length, feature_dim).type(model_type).to(g_device),
            'timestep': torch.ones((batch * do_cfg, )).type(long_type).to(g_device)
        }
        
    model_summary = summary(unet_instance[cur_unet], 
        input_data=inplace_input_data, 
        col_names=("input_size", "output_size", "num_params")
    )

    with open(filename_paths[cur_unet] + '.txt', 'w') as the_file:
        the_file.write(str(model_summary))

    unet_png = draw_graph(unet_instance[cur_unet], 
        input_data=inplace_input_data, 
        graph_name=model_path[cur_unet], 
        device=g_device, mode="eval", 
        depth=1,     
        
        roll=True,        
        save_graph=True,
        filename=filename_paths[cur_unet]
    ) #expand_nested=True, hide_inner_tensors=False,   
    png_results[cur_unet] = unet_png

    unet_png_2 = draw_graph(unet_instance[cur_unet], 
        input_data=inplace_input_data, 
        graph_name=model_path[cur_unet], 
        device=g_device, mode="eval", 
        depth=2,       
        expand_nested=True,
        roll=True,        
        save_graph=True,
        filename='{}_2'.format(filename_paths[cur_unet])
    ) #hide_inner_tensors=False, 

    # Completely disable because of the ugly (*arg)
    if False:
        thop_input = [v for k, v in inplace_input_data.items()]
        macs, params = profile(unet_instance[cur_unet], inputs=(*thop_input, ))
        with open(filename_paths[cur_unet] + '.thop.txt', 'w') as the_file:
            the_file.write("{},{}".format(str(macs),str(params)))

    native_count = count_parameters_native(unet_instance[cur_unet])
    with open(filename_paths[cur_unet] + '.native.txt', 'w') as the_file:
        the_file.write(str(native_count))

- `sd1`: Should run with no problem
- `sd2`: May OOM, however I'm using RTX 3090 now.
- `sdxl`: This is tricky: No docuement. [Read code](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py) for workaround.
- `sd3`: Login required. Paste the token in `C:\Users\User\.cache\huggingface\token`.
- `flux`: Same as `sd3`. Also it may be huge AF because I'm using the dev version. CPU mode only! Also requires newest HF libraries! No Doc, only codes in [HF](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py) and [OG repo](https://github.com/black-forest-labs/flux/blob/main/src/flux/model.py)
- `af`: Even less doc then `flux`. Intercept `latent` directly.
- `hy`: A bit more doc. However it is also not complete. Intercept `latent` directly also.

In [14]:
for k in filename_paths:
    main_loop(k)

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

  hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)


dot: graph is too large for cairo-renderer bitmaps. Scaling by 0.715484 to fit


Now generate the image inplace.

In [15]:
#png_results["sd1"].visual_graph

In [16]:
#png_results["sd2"].visual_graph

In [17]:
#png_results["sdxl"].visual_graph

We may extend the work below later. Now we block the codes.

In [18]:
raise Exception("ERROR")

Exception: ERROR

Now all diagrams and summaries are generated. Now we can try to map MBW layers (`IN00-11`, `MID`, `OUT00-11`) to the diagram.

Note that local models are used instead. Also only `safetensors` are used.

This time models is read as files, and I'm almost OOM.

In [None]:
local_model_path = {
    "sd1": "../../stable-diffusion-webui/tmp/view_unet/21b-AstolfoMix-2020b.safetensors",
    "sd2": "../../stable-diffusion-webui/tmp/view_unet/wd-1-5-beta2-fp16.safetensors",
    "sdxl": "../../stable-diffusion-webui/tmp/view_unet/wdxl-aesthetic-0.9.safetensors",
}

local_models = {}

for k in local_model_path.keys():
    local_models[k] = load_file(local_model_path[k], device='cpu')

# I don't even have time to have sneek peek on SDXL models... Let's crack them here
max_layers = 100

`load_file` only return a HUGE `dict`. You can use JSON library to visualize it, but it is not useful. There is no linkage between the layers.

First, make a nice name finder.

In [None]:
def do_a_keyword_in_a_layer(layer, *keywords):
    for k in keywords:
         if k not in layer:
            return False
    return True

def do_some_keywords_in_some_layers(model, *keywords):
    for layer in list(model.keys()):
        if do_a_keyword_in_a_layer(layer, *keywords):
            return layer
    return False

Now we read the summary and the png, there are some noticeable layers: 

|Layer Name|IN / MID / OUT?|Presence of Identifible Layers?|Layer in `sd1` and `sd2`|Layer in SDXL|
|---|---|---|---|---|
|`Conv2d`|`input_blocks`|n/a|`IN00`|`IN00`|
|`CrossAttnDownBlock2D` > `Transformer2DModel`|`input_blocks`|`transformer_blocks`|`IN01`,`IN02`,`IN04`,`IN05`,`IN07`,`IN08`|`IN04`,`IN05`,`IN07`,`IN08`|
|`CrossAttnDownBlock2D` > `DownSample2D`|`input_blocks`|`op`|`IN03`,`IN06`,`IN09`|`IN03`,`IN06`|
|`DownBlock2D` > `ResnetBlock2D`|`input_blocks`|`in_layers`|`IN10`,`IN11`|`IN01`,`IN02`|
|`UNetMidBlock2DCrossAttn`|`middle_block`|`transformer_blocks`|`MID`|`MID`|
|`UpBlock2D` > `ResnetBlock2D`|`output_blocks`|`in_layers`|`OUT00`,`OUT01`,`OUT02`|`OUT06`,`OUT07`,`OUT08`|
|`CrossAttnUpBlock2D` > `Transformer2DModel`|`output_blocks`|`transformer_blocks`|`OUT03`,`OUT04`,`OUT05`,`OUT06`,`OUT07`,`OUT08`,`OUT09`,`OUT10`,`OUT11`|`OUT00`,`OUT01`,`OUT02`,`OUT03`,`OUT04`,`OUT05`|
|`UpBlock2D` > `UpSample2D`|`output_blocks`|`conv`|`OUT02`|n/a|
|`CrossAttnUpBlock2D` > `UpSample2D`|`output_blocks`|`conv`|`OUT05`,`OUT08`|`OUT02`,`OUT05`|
|`GroupNorm`, `Conv2d`|`output_blocks`|n/a|`OUT`|`OUT`|

### Streadgy ###

Note that `block_types` are tested insequence. Larger layers will be indentified first (usually spotted in layer1), then the smaller layers in layer 2, note that most of them has some indentifiers (not unique) to be indentified.

### Findings ###

IN00, IN03, IN06, IN09 is suprisingly low in layers count.

**The Tensor size may not match the summary / png.** It is becase my 64x64 "image size" may be compressed into just 3x3 in actual tensor block as in latent space. Also, my "batch size" is "1" for single image. However it may be "320 tokens" after transforming in CLIP. **However 2nd parameter remains constent as shown in most images in the internet.**

For most neuron layers, it is $Y=SiLU(WX+B)$ ([SiLU activation](https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html) for [common terms](https://medium.com/@okanyenigun/building-a-neural-network-from-scratch-in-python-a-step-by-step-guide-8f8cab064c8a)), which `weight` is a tensor, and `bias` is a scalar only. We look for tesnsor size.

|Model|Size in actual tensor|Size in summary|
|---|---|---|
|`sd1`|`[320, 4, 3, 3]`|`[1, 4, 64, 64]`|
|`sd1`|`[320, 320, 3, 3]`|`[1, 320, 32, 32]`|
|`sd1`|`[640, 640, 3, 3]`|`[1, 640, 16, 16]`|
|`sd1`|`[1280, 1280, 3, 3]`|`[1, 1280, 8, 8]`|

`IN00` doesn't have `op` marked. Also it has `4` channels, therefore it could be `Conv2D`.

For `CrossAttnDownBlock2D` in layer1, `DownSample2D` comes after 2x `Transformer2DModel`, resembles `IN03` comes after `IN01` and `IN02` and before `IN04`, which is plausible.

Same pattern applies to `DownBlock2D`, then `IN10` and `IN11` are identified.

However `UpBlock2D` is a bit different: `UpSample2D` doesn't follow `DownSample2D`, it has `conv` instead of `op`, and the id is different. 

`OUT11` doesn't have `UpSample2D` even it is `CrossAttnUpBlock2D` in layer 1. *What an asymmetry*.

`OUT` has no suffix like `IN00` also. However `OUT` has *2 distinct layers* with same size: `GroupNorm` and  `Conv2d`.

SD2 is almost identical with SD1, with only slight size in difference.

SDXL is **different** with SD1. The greatest difference is `UpBlock2D` is now **after** `CrossAttnUpBlock2D` instead of before them. Also `UpBlock2D` doesn't have `UpSample2D`.

In [None]:
def list_interested_layers(model, layers_count = max_layers):
    block_layers = ['input_blocks.0','input_blocks','middle_block','output_blocks','.out']
    block_types = ['transformer_blocks','in_layers','conv','op']

    for b_l in block_layers:
        for i in range(layers_count):
            # Search for direct layers (not much)
            direct_layer = do_some_keywords_in_some_layers(model, "{}.{}.{}".format(b_l, i, 'weight'))
            if direct_layer:
                print('{}: {}'.format(direct_layer, model[direct_layer].size()))
            else:
                for b_t in block_types:
                    serarch_result = do_some_keywords_in_some_layers(model, "{}.{}.".format(b_l, i), b_t)
                    if serarch_result:
                        print('{}: {}'.format(serarch_result, model[serarch_result].size()))
                        break
                # Special case: UpSample2D are embedded in some output_blocks
                if (b_l == 'output_blocks'):
                    is_upsample2d = do_some_keywords_in_some_layers(model, "{}.{}.".format(b_l, i), 'conv')
                    if is_upsample2d:
                        print('{}: {}'.format(is_upsample2d, model[is_upsample2d].size()))

In [None]:
list_interested_layers(local_models['sd1'])

model.diffusion_model.input_blocks.0.0.weight: torch.Size([320, 4, 3, 3])
model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([320, 320])
model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([320, 320])
model.diffusion_model.input_blocks.3.0.op.bias: torch.Size([320])
model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([640, 640])
model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([640, 640])
model.diffusion_model.input_blocks.6.0.op.bias: torch.Size([640])
model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([1280, 1280])
model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([1280, 1280])
model.diffusion_model.input_blocks.9.0.op.bias: torch.Size([1280])
model.diffusion_model.input_blocks.10.0.in_layers.0.bias: torch.Size([1280])
model.diffusion_model.input_blocks

In [None]:
list_interested_layers(local_models['sd2'])

model.diffusion_model.input_blocks.0.0.weight: torch.Size([320, 4, 3, 3])
model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([320, 320])
model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([320, 320])
model.diffusion_model.input_blocks.3.0.op.bias: torch.Size([320])
model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([640, 640])
model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([640, 640])
model.diffusion_model.input_blocks.6.0.op.bias: torch.Size([640])
model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([1280, 1280])
model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([1280, 1280])
model.diffusion_model.input_blocks.9.0.op.bias: torch.Size([1280])
model.diffusion_model.input_blocks.10.0.in_layers.0.bias: torch.Size([1280])
model.diffusion_model.input_blocks

In [None]:
list_interested_layers(local_models['sdxl'])

model.diffusion_model.input_blocks.0.0.weight: torch.Size([320, 4, 3, 3])
model.diffusion_model.input_blocks.1.0.in_layers.0.bias: torch.Size([320])
model.diffusion_model.input_blocks.2.0.in_layers.0.bias: torch.Size([320])
model.diffusion_model.input_blocks.3.0.op.bias: torch.Size([320])
model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([640, 640])
model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([640, 640])
model.diffusion_model.input_blocks.6.0.op.bias: torch.Size([640])
model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([1280, 1280])
model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([1280, 1280])
model.diffusion_model.middle_block.0.in_layers.0.bias: torch.Size([1280])
model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight: torch.Size([1280, 1280])
model.diffusion_model.middle_block.2.in_layers.0.bias: t