In [1]:
from diffusers import StableDiffusionXLPipeline
import torch
from collections import OrderedDict
from collections import defaultdict
import yaml 

In [2]:
def _convert_shape_tuple_to_str(shape):
    separator = ', '
    result_string = "(" + separator.join(map(str, shape)) + ")"
    return result_string

In [3]:
pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [4]:
# for i, item in enumerate([item for item in dir(pipeline) if not item.startswith('_')]):
#     print("{} : {}".format(i, item))

In [5]:
print(type(pipeline.unet))

<class 'diffusers.models.unet_2d_condition.UNet2DConditionModel'>


In [6]:
state_dict = pipeline.unet.state_dict()

In [7]:
weights_state_dict = OrderedDict()
for key in state_dict.keys():
    if key.endswith(".weight"):
        if len(state_dict[key].shape) > 1:
            weights_state_dict[key] = state_dict[key]

In [8]:
print(len(weights_state_dict))
tuple(list(weights_state_dict.values())[0].shape)

794


(320, 4, 3, 3)

In [9]:
weights_info_hierarchical = dict()
weights_info_flat = dict()

for i, key in enumerate(weights_state_dict.keys()):
    v = weights_state_dict[key]
    weight_size = torch.numel(v) / 1e6
    shape = tuple(v.shape)
    shape_string = _convert_shape_tuple_to_str(shape)
    #print("{}: {} : {:.2f}M : {}".format(i, key, weight_size, shape))
    weight_identifier = key.split(".")[0]

    if weight_identifier not in weights_info_hierarchical:
        weights_info_hierarchical[weight_identifier] = dict()

    if shape_string not in weights_info_hierarchical[weight_identifier]:
        weights_info_hierarchical[weight_identifier][shape_string] = {"count": 0, "numel": 0}

    if shape_string not in weights_info_flat:
        weights_info_flat[shape_string] = {"count": 0, "numel": 0}

    weights_info_hierarchical[weight_identifier][shape_string]["count"] += 1
    weights_info_hierarchical[weight_identifier][shape_string]["numel"] = weight_size    
     
    weights_info_flat[shape_string]["count"] += 1
    weights_info_flat[shape_string]["numel"] = weight_size


In [10]:
with open("weights_info_hierarchical.yaml", 'w') as yaml_file:
    yaml.dump(weights_info_hierarchical, yaml_file, default_flow_style=False)

with open("weights_info_flat.yaml", 'w') as yaml_file:
    yaml.dump(weights_info_flat, yaml_file, default_flow_style=False)

In [20]:
list_of_shapes = []
list_of_counts = []
list_of_sizes = []
for shape_str in weights_info_flat.keys():
    list_of_shapes.append(shape_str)
    list_of_counts.append(weights_info_flat[shape_str]["count"])
    list_of_sizes.append(weights_info_flat[shape_str]["numel"])

#sorted_indices = sorted(range(len(list_of_counts)), key=lambda i: list_of_counts[i], reverse=True)
sorted_indices = sorted(range(len(list_of_counts)), key=lambda i: list_of_sizes[i], reverse=True)

list_of_shapes = [list_of_shapes[i] for i in sorted_indices]
list_of_counts = [list_of_counts[i] for i in sorted_indices]
list_of_sizes = [list_of_sizes[i] for i in sorted_indices]

for i in range(len(list_of_shapes)):
    print("shape : {}, count = {}, size: {:.2f}M".format(
        list_of_shapes[i],
        list_of_counts[i],
        list_of_sizes[i],
    ))

shape : (1280, 2560, 3, 3), count = 2, size: 29.49M
shape : (1280, 1920, 3, 3), count = 1, size: 22.12M
shape : (1280, 1280, 3, 3), count = 11, size: 14.75M
shape : (10240, 1280), count = 60, size: 13.11M
shape : (640, 1920, 3, 3), count = 1, size: 11.06M
shape : (1280, 640, 3, 3), count = 1, size: 7.37M
shape : (640, 1280, 3, 3), count = 1, size: 7.37M
shape : (1280, 5120), count = 60, size: 6.55M
shape : (640, 960, 3, 3), count = 1, size: 5.53M
shape : (640, 640, 3, 3), count = 8, size: 3.69M
shape : (1280, 2816), count = 1, size: 3.60M
shape : (5120, 640), count = 10, size: 3.28M
shape : (1280, 2560, 1, 1), count = 2, size: 3.28M
shape : (320, 960, 3, 3), count = 1, size: 2.76M
shape : (1280, 2048), count = 120, size: 2.62M
shape : (1280, 1920, 1, 1), count = 1, size: 2.46M
shape : (640, 320, 3, 3), count = 1, size: 1.84M
shape : (320, 640, 3, 3), count = 2, size: 1.84M
shape : (1280, 1280), count = 381, size: 1.64M
shape : (640, 2560), count = 10, size: 1.64M
shape : (640, 2048), c