I have a bug when using SD21. The up-part can't be processed into subblocks. But it works for SDXL.

So let's compare both.

In [14]:
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from diffusers.models.controlnetxs import to_sub_blocks

from util import simple_describe, cls_name

In [15]:
sd21_pipe = StableDiffusionPipeline.from_single_file('weights/sd21/v2-1_512-ema-pruned.safetensors')
sdxl_pipe = StableDiffusionXLPipeline.from_single_file('weights/sdxl/sd_xl_base_1.0_0.9vae.safetensors')

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


In [16]:
sd21_unet = sd21_pipe.unet
sdxl_unet = sdxl_pipe.unet

In [41]:
from torch import nn
from itertools import zip_longest

class EmbedSequential(nn.ModuleList):
    """Sequential module passing embeddings (time and conditioning) to children if they support it."""

    def __init__(self, ms, *args, **kwargs):
        if not is_iterable(ms):
            ms = [ms]
        super().__init__(ms, *args, **kwargs)

    def forward(self, x, temb, cemb, attention_mask, cross_attention_kwargs):
        print('Actually, Im not implemented, brother')
        return x
        

def is_iterable(o):
    if isinstance(o, str):
        return False
    try:
        iter(o)
        return True
    except TypeError:
        return False

def to_sub_blocks(blocks):
    if not is_iterable(blocks):
        blocks = [blocks]
    sub_blocks = []
    for b in blocks:
        current_subblocks = []
        if hasattr(b, "resnets"):
            if hasattr(b, "attentions") and b.attentions is not None:
                for r,a in zip(b.resnets, b.attentions):
                    sub_blocks.append([r,a])

                num_resnets = len(b.resnets)
                num_attns = len(b.attentions)
                
                if num_resnets > num_attns:
                    # we can have more resnets than attentions, so add each resnet as separate subblock
                    for i in range(num_attns, num_resnets):
                        sub_blocks.append([b.attentions[i]])                
            else:
                for r in b.resnets:
                    sub_blocks.append([r])
        # upsamplers are part of the same subblock
        if hasattr(b, "upsamplers") and b.upsamplers is not None:
            for u in b.upsamplers:
                sub_blocks[-1].extend([u])
        # downsamplers are own subblock
        if hasattr(b, "downsamplers") and b.downsamplers is not None:
            for d in b.downsamplers:
                sub_blocks.append([d])

    return sub_blocks
    #return list(map(EmbedSequential, sub_blocks))

In [42]:
from types import SimpleNamespace

block1 = SimpleNamespace(resnets=['r1','r2'], attentions=['a1','a2'], upsamplers=['u'])
block2 = SimpleNamespace(resnets=['r1','r2'], attentions=['a1','a2'], upsamplers=['u'])
block3 = SimpleNamespace(resnets=['r1','r2'])

dummy_ups = [block1, block2, block3]

In [43]:
to_sub_blocks(dummy_ups)

[['r1', 'a1'],
 ['r2', 'a2', 'u'],
 ['r1', 'a1'],
 ['r2', 'a2', 'u'],
 ['r1'],
 ['r2']]

In [44]:
sdxl_up_subblocks = to_sub_blocks(sdxl_unet.up_blocks)

In [46]:
[list(map(cls_name, sb)) for sb in sdxl_up_subblocks]

[['ResnetBlock2D', 'Transformer2DModel'],
 ['ResnetBlock2D', 'Transformer2DModel'],
 ['ResnetBlock2D', 'Transformer2DModel', 'Upsample2D'],
 ['ResnetBlock2D', 'Transformer2DModel'],
 ['ResnetBlock2D', 'Transformer2DModel'],
 ['ResnetBlock2D', 'Transformer2DModel', 'Upsample2D'],
 ['ResnetBlock2D'],
 ['ResnetBlock2D'],
 ['ResnetBlock2D']]

In [47]:
sd21_up_subblocks = to_sub_blocks(sd21_unet.up_blocks)

In [48]:
[list(map(cls_name, sb)) for sb in sd21_up_subblocks]

[['ResnetBlock2D'],
 ['ResnetBlock2D'],
 ['ResnetBlock2D', 'Upsample2D'],
 ['ResnetBlock2D', 'Transformer2DModel'],
 ['ResnetBlock2D', 'Transformer2DModel'],
 ['ResnetBlock2D', 'Transformer2DModel', 'Upsample2D'],
 ['ResnetBlock2D', 'Transformer2DModel'],
 ['ResnetBlock2D', 'Transformer2DModel'],
 ['ResnetBlock2D', 'Transformer2DModel', 'Upsample2D'],
 ['ResnetBlock2D', 'Transformer2DModel'],
 ['ResnetBlock2D', 'Transformer2DModel'],
 ['ResnetBlock2D', 'Transformer2DModel']]

In [19]:
simple_describe(sdxl_unet.up_blocks)

 ModuleList 
	 CrossAttnUpBlock2D 
		 ResnetBlock2D (2560, 1280)
		 Transformer2DModel (1280, 1280)
		 ResnetBlock2D (2560, 1280)
		 Transformer2DModel (1280, 1280)
		 ResnetBlock2D (1920, 1280)
		 Transformer2DModel (1280, 1280)
		 Upsample2D (1280, 1280)
	 CrossAttnUpBlock2D 
		 ResnetBlock2D (1920, 640)
		 Transformer2DModel (640, 640)
		 ResnetBlock2D (1280, 640)
		 Transformer2DModel (640, 640)
		 ResnetBlock2D (960, 640)
		 Transformer2DModel (640, 640)
		 Upsample2D (640, 640)
	 UpBlock2D 
		 ResnetBlock2D (960, 320)
		 ResnetBlock2D (640, 320)
		 ResnetBlock2D (640, 320)


In [20]:
simple_describe(sd21_unet.up_blocks)

 ModuleList 
	 UpBlock2D 
		 ResnetBlock2D (2560, 1280)
		 ResnetBlock2D (2560, 1280)
		 ResnetBlock2D (2560, 1280)
		 Upsample2D (1280, 1280)
	 CrossAttnUpBlock2D 
		 ResnetBlock2D (2560, 1280)
		 Transformer2DModel (1280, 1280)
		 ResnetBlock2D (2560, 1280)
		 Transformer2DModel (1280, 1280)
		 ResnetBlock2D (1920, 1280)
		 Transformer2DModel (1280, 1280)
		 Upsample2D (1280, 1280)
	 CrossAttnUpBlock2D 
		 ResnetBlock2D (1920, 640)
		 Transformer2DModel (640, 640)
		 ResnetBlock2D (1280, 640)
		 Transformer2DModel (640, 640)
		 ResnetBlock2D (960, 640)
		 Transformer2DModel (640, 640)
		 Upsample2D (640, 640)
	 CrossAttnUpBlock2D 
		 ResnetBlock2D (960, 320)
		 Transformer2DModel (320, 320)
		 ResnetBlock2D (640, 320)
		 Transformer2DModel (320, 320)
		 ResnetBlock2D (640, 320)
		 Transformer2DModel (320, 320)
