The goal of this notebook is to play around with the unet in controlnet.

---

In [1]:
import torch
import torch.nn as nn

In [2]:
device = "cpu"
device_dtype = torch.float32

In [3]:
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel

In [4]:
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=device_dtype)

In [5]:
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=device_dtype
).to(device)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.


In [6]:
for m in pipe.unet.down_blocks: print(type(m))
print('--')
print(type(pipe.unet.mid_block))
print('--')
for m in pipe.unet.up_blocks: print(type(m))

<class 'diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D'>
<class 'diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D'>
<class 'diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D'>
<class 'diffusers.models.unet_2d_blocks.DownBlock2D'>
--
<class 'diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn'>
--
<class 'diffusers.models.unet_2d_blocks.UpBlock2D'>
<class 'diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D'>
<class 'diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D'>
<class 'diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D'>


This is the architecture from the ControlNet-XS repo:

**Down**

    Conv2d      --> appending (   4, 320)
    ResBlock    --> appending ( 320, 320)
    ResBlock    --> appending ( 320, 320)
    Downsample  --> appending ( 320, 320)
    ResBlock    --> appending ( 320, 640)
    ResBlock    --> appending ( 640, 640)
    Downsample  --> appending ( 640, 640)
    ResBlock    --> appending ( 640,1280)
    ResBlock    --> appending (1280,1280)

**Mid**

    (1280, 1280)

**Up**

    ResBlock    --> appending (2560,1280)
    ResBlock    --> appending (2560,1280)
    ResBlock    --> appending (1920,1280)
    ResBlock    --> appending (1920, 640)
    ResBlock    --> appending (1280, 640)
    ResBlock    --> appending ( 960, 640)
    ResBlock    --> appending ( 960, 320)
    ResBlock    --> appending ( 640, 320)
    ResBlock    --> appending ( 640, 320)

---

In [7]:
for m in pipe.unet.down_blocks[0].attentions: print(type(m))
for m in pipe.unet.down_blocks[0].resnets: print(type(m))
for m in pipe.unet.down_blocks[0].downsamplers: print(type(m))

<class 'diffusers.models.transformer_2d.Transformer2DModel'>
<class 'diffusers.models.transformer_2d.Transformer2DModel'>
<class 'diffusers.models.resnet.ResnetBlock2D'>
<class 'diffusers.models.resnet.ResnetBlock2D'>
<class 'diffusers.models.resnet.Downsample2D'>


It seem we have the following correspondance from CN-XS to diffusuers:
- ResBlock -> ResnetBlock2D
- DownSample -> Transformer2DModel

Is that really correct? I'm not sure... Let's try to map via channel numbers

In [8]:
for m in pipe.unet.down_blocks:
    print(m.in_channels)

AttributeError: 'CrossAttnDownBlock2D' object has no attribute 'in_channels'

Okay, there are no attributes 'in_channels' (and I assume also 'out_channels') that are directly accessable. Where could they be? 🤔

In [9]:
def public_attrs(o): print([v for v in dir(o) if not v.startswith('_')])

In [10]:
blam = pipe.unet.down_blocks[0]

In [11]:
public_attrs(blam)

['T_destination', 'add_module', 'apply', 'attentions', 'bfloat16', 'buffers', 'call_super_init', 'children', 'cpu', 'cuda', 'double', 'downsamplers', 'dump_patches', 'eval', 'extra_repr', 'float', 'forward', 'get_buffer', 'get_extra_state', 'get_parameter', 'get_submodule', 'gradient_checkpointing', 'half', 'has_cross_attention', 'ipu', 'load_state_dict', 'modules', 'named_buffers', 'named_children', 'named_modules', 'named_parameters', 'num_attention_heads', 'parameters', 'register_backward_hook', 'register_buffer', 'register_forward_hook', 'register_forward_pre_hook', 'register_full_backward_hook', 'register_full_backward_pre_hook', 'register_load_state_dict_post_hook', 'register_module', 'register_parameter', 'register_state_dict_pre_hook', 'requires_grad_', 'resnets', 'set_extra_state', 'share_memory', 'state_dict', 'to', 'to_empty', 'train', 'training', 'type', 'xpu', 'zero_grad']


No other attribute name looks like they contain the info 'in_channels'/'out_channels'

Looking at the code of `CrossAttnDownBlock2D`, we can get `in_channels`/`out_channels` from the RestNets.

The 1ts resnet changes the channel number fomr `in_channels` to `out_channels`, so we can get both from the 1st resnet block.

In [12]:
blam.resnets[0].in_channels, blam.resnets[0].out_channels

(320, 320)

In [13]:
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D

def print_channels(m):
    if isinstance(m, CrossAttnDownBlock2D): ni,no=m.resnets[0].in_channels, m.resnets[0].out_channels
    elif isinstance(m, DownBlock2D): ni,no=m.resnets[0].in_channels, m.resnets[0].out_channels

    elif isinstance(m, CrossAttnUpBlock2D): ni,no=m.resnets[0].in_channels, m.resnets[0].out_channels
    elif isinstance(m, UpBlock2D): ni,no=m.resnets[0].in_channels, m.resnets[0].out_channels

    else:
        print(f'Channel inspection not implemented for type {type(m)}, brah')
        return
        
    print(f'({ni}, {no})') 

In [14]:
print_channels(blam)

(320, 320)


In [15]:
for m in pipe.unet.down_blocks[0].attentions: print(type(m))
for m in pipe.unet.down_blocks[0].resnets: print(type(m))
for m in pipe.unet.down_blocks[0].downsamplers: print(type(m))

<class 'diffusers.models.transformer_2d.Transformer2DModel'>
<class 'diffusers.models.transformer_2d.Transformer2DModel'>
<class 'diffusers.models.resnet.ResnetBlock2D'>
<class 'diffusers.models.resnet.ResnetBlock2D'>
<class 'diffusers.models.resnet.Downsample2D'>


In [16]:
all_blocks = list(pipe.unet.down_blocks) + [pipe.unet.mid_block] + list(pipe.unet.up_blocks)

In [17]:
for m in all_blocks: print(type(m))

<class 'diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D'>
<class 'diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D'>
<class 'diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D'>
<class 'diffusers.models.unet_2d_blocks.DownBlock2D'>
<class 'diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn'>
<class 'diffusers.models.unet_2d_blocks.UpBlock2D'>
<class 'diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D'>
<class 'diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D'>
<class 'diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D'>


In [18]:
for m in all_blocks: print_channels(m)

(320, 320)
(320, 640)
(640, 1280)
(1280, 1280)
Channel inspection not implemented for type <class 'diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn'>, brah
(2560, 1280)
(2560, 1280)
(1920, 640)
(960, 320)


**Q-1:** We have channels sizes for all blocks in the down part of the unet. Is that enough, or for which block do we actually need to record the channel nums?

**Down**

    Conv2d      --> appending (   4, 320)
    ResBlock    --> appending ( 320, 320)
    ResBlock    --> appending ( 320, 320)
    Downsample  --> appending ( 320, 320)
    ResBlock    --> appending ( 320, 640)
    ResBlock    --> appending ( 640, 640)
    Downsample  --> appending ( 640, 640)
    ResBlock    --> appending ( 640,1280)
    ResBlock    --> appending (1280,1280)

**Mid**

    (1280, 1280)

**Up**

    ResBlock    --> appending (2560,1280)
    ResBlock    --> appending (2560,1280)
    ResBlock    --> appending (1920,1280)
    ResBlock    --> appending (1920, 640)
    ResBlock    --> appending (1280, 640)
    ResBlock    --> appending ( 960, 640)
    ResBlock    --> appending ( 960, 320)
    ResBlock    --> appending ( 640, 320)
    ResBlock    --> appending ( 640, 320)

**No**, we don't have all channels sizes for the down part. We only have 4 pairs of numbers, but CN-XS has 9.

In [19]:
for m in pipe.unet.down_blocks:
    for r in m.resnets: print('ResBlock', r.in_channels, r.out_channels)
    if m.downsamplers: print('Downsample', m.downsamplers[0].channels, m.downsamplers[0].out_channels)
    print('--')

ResBlock 320 320
ResBlock 320 320
Downsample 320 320
--
ResBlock 320 640
ResBlock 640 640
Downsample 640 640
--
ResBlock 640 1280
ResBlock 1280 1280
Downsample 1280 1280
--
ResBlock 1280 1280
ResBlock 1280 1280
--


Okay, now the numbers match! Still left to do:
- Where do I get the conv2d channels from?
- Why does the unet have extra blocks at the end: 1 downsample + 2 resnets?

___

**Hypothesis:** diffuers / cnxs use different versions of stable diffusion.<br/>
**Test:** run examination also for sd2.1 & sdxl<br/>
**Result:** yes, I used sd1.5 for diffusers, but used the default params for sdxl in cnxs

In [20]:
pipe21 = StableDiffusionControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=device_dtype
).to(device)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [21]:
for m in pipe21.unet.down_blocks:
    for r in m.resnets: print('ResBlock', r.in_channels, r.out_channels)
    if m.downsamplers: print('Downsample', m.downsamplers[0].channels, m.downsamplers[0].out_channels)
    print('--')

ResBlock 320 320
ResBlock 320 320
Downsample 320 320
--
ResBlock 320 640
ResBlock 640 640
Downsample 640 640
--
ResBlock 640 1280
ResBlock 1280 1280
Downsample 1280 1280
--
ResBlock 1280 1280
ResBlock 1280 1280
--


In [22]:
pipexl = StableDiffusionControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=device_dtype
).to(device)

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


In [23]:
for m in pipexl.unet.down_blocks:
    for r in m.resnets: print('ResBlock', r.in_channels, r.out_channels)
    if m.downsamplers: print('Downsample', m.downsamplers[0].channels, m.downsamplers[0].out_channels)
    print('--')

ResBlock 320 320
ResBlock 320 320
Downsample 320 320
--
ResBlock 320 640
ResBlock 640 640
Downsample 640 640
--
ResBlock 640 1280
ResBlock 1280 1280
--


___

Let's use the SDXL pipeline, as that was used in the CNXS code

In [24]:
all_blocks = list(pipexl.unet.down_blocks) + [pipexl.unet.mid_block] + list(pipexl.unet.up_blocks)

In [25]:
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D

def get_channels_for_downblocks(ms):
    channels = []

    for m in ms:
        if isinstance(m, (CrossAttnDownBlock2D, DownBlock2D)):
            for r in m.resnets: channels.append((r.in_channels, r.out_channels))
            if m.downsamplers:  channels.append((m.downsamplers[0].channels, m.downsamplers[0].out_channels))

        else: print(f'Encountered unknown block of type {type(m)}, brah')
    
    return channels

In [26]:
get_channels_for_downblocks(pipexl.unet.down_blocks)

[(320, 320),
 (320, 320),
 (320, 320),
 (320, 640),
 (640, 640),
 (640, 640),
 (640, 1280),
 (1280, 1280)]

In [27]:
from diffusers.models.unet_2d_blocks import UNetMidBlock2DCrossAttn

def get_channels_for_midblock(m):
    assert isinstance(m, UNetMidBlock2DCrossAttn)

    return (m.resnets[0].in_channels, m.resnets[0].in_channels)

In [28]:
get_channels_for_midblock(pipexl.unet.mid_block)

(1280, 1280)

In [29]:
from diffusers.models.unet_2d_blocks import CrossAttnUpBlock2D, UpBlock2D

def get_channels_for_upblocks(ms):
    channels = []

    for m in ms:
        if isinstance(m, (CrossAttnUpBlock2D, UpBlock2D)):
            for r in m.resnets: channels.append((r.in_channels, r.out_channels))

        else: print(f'Encountered unknown block of type {type(m)}, brah')
    
    return channels

In [30]:
get_channels_for_upblocks(pipexl.unet.up_blocks)

[(2560, 1280),
 (2560, 1280),
 (1920, 1280),
 (1920, 640),
 (1280, 640),
 (960, 640),
 (960, 320),
 (640, 320),
 (640, 320)]

In [31]:
channel_nums = {
    'enc': get_channels_for_downblocks(pipexl.unet.down_blocks),
    'mid': get_channels_for_midblock  (pipexl.unet.mid_block),
    'dec': get_channels_for_upblocks  (pipexl.unet.up_blocks)
}

for k,v in channel_nums.items():
    print(f'{k}:\t{str(v)}')

enc:	[(320, 320), (320, 320), (320, 320), (320, 640), (640, 640), (640, 640), (640, 1280), (1280, 1280)]
mid:	(1280, 1280)
dec:	[(2560, 1280), (2560, 1280), (1920, 1280), (1920, 640), (1280, 640), (960, 640), (960, 320), (640, 320), (640, 320)]


For comparison, here are the channel sizes from cnxs:

In [32]:
channel_nums_reference = {
    'enc': [(4, 320), (320, 320), (320, 320), (320, 320), (320, 640), (640, 640), (640, 640), (640, 1280), (1280, 1280)],
    'mid': [(1280, 1280)],
    'dec': [(2560, 1280), (2560, 1280), (1920, 1280), (1920, 640), (1280, 640), (960, 640), (960, 320), (640, 320), (640, 320)]
}
for k,v in channel_nums_reference.items():
    print(f'{k}:\t{str(v)}')

enc:	[(4, 320), (320, 320), (320, 320), (320, 320), (320, 640), (640, 640), (640, 640), (640, 1280), (1280, 1280)]
mid:	[(1280, 1280)]
dec:	[(2560, 1280), (2560, 1280), (1920, 1280), (1920, 640), (1280, 640), (960, 640), (960, 320), (640, 320), (640, 320)]


Almost identical! Only the `(4, 320)` at the beginning is missing. This is the initial `Conv2D`. Where is that hiding in diffusers? 🕵🏽

In [33]:
type(pipexl.unet)

diffusers.models.unet_2d_condition.UNet2DConditionModel

In [34]:
pipexl.unet.conv_in.in_channels, pipexl.unet.conv_in.out_channels

(4, 320)

There we have it!

In [None]:
pipexl.unet.i