# Visualising UNET #

### Abstract ###

- Self explained. Using `torch_view` as main library.
- The docuement below is mainly copied from [mega_cmp.ipynb](./v2a/mmega_cmp.ipynb)

### Required libraries ###

- ~~Should be the common ML pack we're using. Also with [SD webui's dependency](https://github.com/AUTOMATIC1111/stable-diffusion-webui).~~

- [torchview](https://torchview.dev/)
- [scikit-learn](https://scikit-learn.org/stable/install.html)
- [NetworkX](https://networkx.org/documentation/stable/release/release_3.0.html)
- [safetensors](https://huggingface.co/docs/safetensors/index)
- [diffusers](https://huggingface.co/docs/diffusers/installation)
- [omegaconf](https://anaconda.org/conda-forge/omegaconf)
- [pytorch](https://pytorch.org/get-started/locally/#windows-python)
- [matplotlib](https://matplotlib.org/stable/api/matplotlib_configuration_api.html)
- [numpy](https://numpy.org/)
- [torchinfo](https://pypi.org/project/torchinfo/)

### Some layer name to interprept (for SD1.5) ###
- `first_stage_model`: VAE
- `cond_stage_model`: Text Encoder
- `model.diffusion_model`: Diffusion model
- `model_ema`: EMA model for training
- `cumprod`, `betas`, `alphas`: `CosineAnnealingLR`

### Some notation (Useful in the bin chart) ###
- `attn1`: `sattn` = *Self attention*
- `attn2`: `xattn` = *Cross attention*
- `ff`: *Feed forward*
- `norm`: [Normalisation layer](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html). `elementwise_affine=True` introduces trainable `bias` and `weight`. 
- `proj`: *Projection*
- `emb_layers`: *Embedding layers*
- `others`: `ff` + `norm` + `proj` + `emb_layers`

## Importing libraries ##

In [1]:
import os
import torch
from safetensors.torch import load_file #safe_open
from diffusers import UNet2DConditionModel

from torchview import draw_graph
from torchinfo import summary

import graphviz
graphviz.set_jupyter_format('png')

'svg'

In [2]:
torch.__version__

'2.0.1+cu118'

In [3]:
# Fix for OMP: Error #15
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [4]:
# TODO: Support 'cuda', but 'cpu' is arleady fast.
g_device = "cuda:0" #"cpu"
# Currently for generating graph only.
g_seed = 114514

In [5]:
# Model path
model_path = {
    "sd1": "runwayml/stable-diffusion-v1-5",
    "sd2": "stabilityai/stable-diffusion-2-1",
    "sdxl": "stabilityai/stable-diffusion-xl-base-1.0"
}

model_type = torch.float16 if "cuda" in g_device else torch.float # CPU doesn't support FP16 / FP8


Only online model is available.

`load_single_file` is failed (versioning hell, omitted). I load online model instead.

In [6]:
unet_instance = None # Clear
unet_instance = {} # Clear

In [7]:
# We run it later
# for k in model_path.keys():
#    unet_instance[k] = UNet2DConditionModel.from_pretrained(model_path[k], subfolder="unet",  torch_dtype=torch.float16).to(g_device)

Input size is trial and error.

Not actually, we can read `config.json` form the actual official model in HuggingFace. Originally it is scattered in different Git Repos, but HF does a great job here.

In [8]:
# Not Used.
input_data_mapping_sample = {
    "sd1": {
        'sample': torch.rand(1280, 4, 8, 8, dtype=model_type).to(g_device),
        'timestep': torch.rand(1, dtype=model_type).to(g_device),
        'encoder_hidden_states': torch.rand(1280, 77, 768, dtype=model_type).to(g_device),
    },
    "sd2": {
        'sample': torch.rand(1280, 4, 20, 20, dtype=model_type).to(g_device),
        'timestep': torch.rand(1, dtype=model_type).to(g_device),
        'encoder_hidden_states': torch.rand(1280, 77, 1024, dtype=model_type).to(g_device),
    },
    "sdxl": {
        'sample': torch.rand(1280, 4, 20, 20, dtype=model_type).to(g_device),
        'timestep': torch.rand(1, dtype=model_type).to(g_device),
        'encoder_hidden_states': torch.rand(1280, 77, 2048, dtype=model_type).to(g_device),
        'added_cond_kwargs': {
            'text_embeds': torch.rand(1280, 2560, dtype=model_type).to(g_device),
            'time_ids': torch.rand(1280, dtype=model_type).to(g_device),
        },
    },
}

Setting Graph output.

In [9]:
filename_paths = {
    "sd1": "./sd1_unet",
    "sd2": "./sd2_unet",
    "sdxl": "./sdxl_unet",
}

Main loop. 

In [10]:
def main_loop(cur_unet):
    unet_instance[cur_unet] = UNet2DConditionModel.from_pretrained(model_path[cur_unet], subfolder="unet",  torch_dtype=model_type).to(g_device) if cur_unet not in unet_instance else unet_instance[cur_unet]

    batch = unet_instance[cur_unet].config.block_out_channels[-1]
    channel = unet_instance[cur_unet].config.in_channels
    sequence_length = 77 # See CLIP
    feature_dim = unet_instance[cur_unet].config.cross_attention_dim
    attention_head_dim = unet_instance[cur_unet].config.attention_head_dim
    height = attention_head_dim if isinstance(attention_head_dim, int) else attention_head_dim[-1]
    width = height #square is fine
    step = 1 #arbitary single float

    inplace_input_data = {
        'sample': torch.rand(batch, channel, height, width, dtype=model_type).to(g_device),
        'timestep': torch.rand(step, dtype=model_type).to(g_device),
        'encoder_hidden_states': torch.rand(batch, sequence_length, feature_dim, dtype=model_type).to(g_device),
    }

    # SDXL special
    if cur_unet == "sdxl":
        addition_time_embed_dim = unet_instance[cur_unet].config.addition_time_embed_dim
        projection_class_embeddings_input_dim = unet_instance[cur_unet].config.projection_class_embeddings_input_dim
        time_sequence_length = projection_class_embeddings_input_dim - addition_time_embed_dim
        inplace_input_data['added_cond_kwargs'] = {
            'text_embeds': torch.rand(batch, time_sequence_length, dtype=model_type).to(g_device),
            'time_ids': torch.rand(batch, dtype=model_type).to(g_device),
        }

    model_summary = summary(unet_instance[cur_unet], 
        input_data=inplace_input_data, 
        col_names=("input_size", "output_size", "num_params")
    )

    with open(filename_paths[cur_unet] + '.txt', 'w') as the_file:
        the_file.write(str(model_summary))

    unet_png = draw_graph(unet_instance[cur_unet], 
        input_data=inplace_input_data, 
        graph_name=model_path[cur_unet], 
        device=g_device, mode="eval", 
        depth=1,       
        roll=True,        
        save_graph=True,
        filename=filename_paths[cur_unet]
    ) #, expand_nested=True, hide_inner_tensors=False, 

    #unet_png.visual_graph

- `sd1`: Should run with no problem
- `sd2`: May OOM, however I'm using RTX 3090 now.
- `sdxl`: This is tricky: No docuement. [Read code](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py) for workaround.

In [11]:
for k in filename_paths:
    main_loop(k)

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.



Now all diagrams and summaries are generated. Now we can try to map MBW layers (`IN00-11`, `MID`, `OUT00-11`) to the diagram.

In [12]:
#load_file(path, device=device)