In [1]:
from diffusers import StableDiffusionXLPipeline
import torch
from collections import OrderedDict
from collections import defaultdict
import yaml 

In [7]:
def _convert_shape_tuple_to_str(shape):
    separator = ', '
    result_string = "(" + separator.join(map(str, shape)) + ")"
    return result_string

In [2]:
pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
# for i, item in enumerate([item for item in dir(pipeline) if not item.startswith('_')]):
#     print("{} : {}".format(i, item))

In [3]:
print(type(pipeline.unet))

<class 'diffusers.models.unet_2d_condition.UNet2DConditionModel'>


In [4]:
state_dict = pipeline.unet.state_dict()

In [5]:
weights_state_dict = OrderedDict()
for key in state_dict.keys():
    if key.endswith(".weight"):
        if len(state_dict[key].shape) > 1:
            weights_state_dict[key] = state_dict[key]

In [6]:
print(len(weights_state_dict))
tuple(list(weights_state_dict.values())[0].shape)

794


(320, 4, 3, 3)

In [8]:
weights_info_hierarchical = dict()
weights_info_flat = dict()

for i, key in enumerate(weights_state_dict.keys()):
    v = weights_state_dict[key]
    weight_size = torch.numel(v) / 1e6
    shape = tuple(v.shape)
    shape_string = _convert_shape_tuple_to_str(shape)
    #print("{}: {} : {:.2f}M : {}".format(i, key, weight_size, shape))
    weight_identifier = key.split(".")[0]

    if weight_identifier not in weights_info_hierarchical:
        weights_info_hierarchical[weight_identifier] = dict()

    if shape_string not in weights_info_hierarchical[weight_identifier]:
        weights_info_hierarchical[weight_identifier][shape_string] = {"count": 0, "numel": 0}

    weights_info_hierarchical[weight_identifier][shape_string]["count"] += 1
    weights_info_hierarchical[weight_identifier][shape_string]["numel"] = weight_size    
     


In [9]:
with open("weights_info_hierarchical.yaml", 'w') as yaml_file:
    yaml.dump(weights_info_hierarchical, yaml_file, default_flow_style=False)