This notebooks tests new code in diffusers: `ControlledUNet2DConditionModel.__init__`.

In [1]:
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.utils.checkpoint

In [2]:
from diffusers.configuration_utils import ConfigMixin
from diffusers.loaders import UNet2DConditionLoadersMixin
from diffusers.utils import BaseOutput, logging

from diffusers.models.embeddings import (
    GaussianFourierProjection,
    TimestepEmbedding,
    Timesteps,
    get_timestep_embedding
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
    CrossAttnDownBlock2D,
    DownBlock2D,
    CrossAttnUpBlock2D,
    UpBlock2D,
)
from diffusers.models.unet_2d_condition import UNet2DConditionModel

In [3]:
@dataclass
class UNet2DConditionOutput(BaseOutput):
    sample: torch.FloatTensor = None

In [4]:
def zero_module(module):
    for p in module.parameters():
        nn.init.zeros_(p)
    return module

In [37]:
from diffusers.models import UNet2DConditionModel, ControlNetModel

In [20]:
!ls -a "../../../../.hf-cache/CVL-Heidelberg"

[34m.[m[m                               sd21_encD_depth_14m.ckpt
[34m..[m[m                              sdxl_encD_canny_48m.safetensors
sd21_encD_canny_14m.ckpt        sdxl_encD_depth_48m.safetensors


In [22]:
checkpoint = torch.load('../../../../.hf-cache/CVL-Heidelberg/sd21_encD_canny_14m.ckpt', map_location=torch.device('cpu'))

In [45]:
ctrl_model_state_dict = {
    k.replace('control_model.',''):v 
    for k,v in checkpoint['state_dict'].items() 
    if 'control_model.' in k
}

In [59]:
def print_state_dict(sdict, lv=None):
    if lv is None: keys = list(sdict.keys())
    else: keys = list({'.'.join(k.split('.')[:lv+1]): '' for k in sdict.keys()}). # dict preserves order, set doesn't (therefore :'')
    for k in keys: print(k)

In [60]:
print_state_dict(ctrl_model_state_dict, lv=0)

time_embed
input_blocks
middle_block


In [68]:
print_state_dict(cnet.state_dict(), lv=0)

conv_in
time_embedding
controlnet_cond_embedding
down_blocks
controlnet_down_blocks
controlnet_mid_block
mid_block


A `ControlNetModel` in diffusers already contains a base and ctrl model. **My assumption that it's just the ctrl model is wrong.**. This means **I have to define a new class for a ControlModel** in the sense of CN-XS, ie a model that exists in addition to the base model.

In [69]:
print_state_dict(ctrl_model_state_dict, lv=1)

time_embed.0
time_embed.2
input_blocks.0
input_blocks.1
input_blocks.2
input_blocks.3
input_blocks.4
input_blocks.5
input_blocks.6
input_blocks.7
input_blocks.8
input_blocks.9
input_blocks.10
input_blocks.11
middle_block.0
middle_block.1
middle_block.2


In [67]:
print_state_dict(cnet.state_dict(), lv=1)

conv_in.weight
conv_in.bias
time_embedding.linear_1
time_embedding.linear_2
controlnet_cond_embedding.conv_in
controlnet_cond_embedding.blocks
controlnet_cond_embedding.conv_out
down_blocks.0
down_blocks.1
down_blocks.2
down_blocks.3
controlnet_down_blocks.0
controlnet_down_blocks.1
controlnet_down_blocks.2
controlnet_down_blocks.3
controlnet_down_blocks.4
controlnet_down_blocks.5
controlnet_down_blocks.6
controlnet_down_blocks.7
controlnet_down_blocks.8
controlnet_down_blocks.9
controlnet_down_blocks.10
controlnet_down_blocks.11
controlnet_mid_block.weight
controlnet_mid_block.bias
mid_block.attentions
mid_block.resnets


In [38]:
cnet = ControlNetModel()

In [62]:
#cnet.load_state_dict(ctrl_model_state_dict)

In [None]:
unet_model.load_state_dict(checkpoint['unet_state_dict'])  # Replace 'unet_state_dict' with the actual key

In [63]:
#ctrl_model  = UNet2DConditionModel.from_pretrained('CVL-Heidelberg/ControlNet-XS', variant='sd21_encD_canny_14m.ckpt')

In [7]:
class ControlledUNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
    def __init__(
            self,
            in_channels,
            model_channels,
            out_channels,
            hint_channels,
            num_res_blocks,
            attention_resolutions,
        ):
        super().__init__()

        # 1 - Save parameters
        # TODO make variables
        self.control_mode = "canny"
        self.learn_embedding = False
        self.infusion2control = "cat"
        self.infusion2base = "add"
        self.in_ch_factor = 1 if "cat" == 'add' else 2
        self.guiding = "encoder"
        self.two_stream_mode = "cross"
        self.control_model_ratio = 1.0
        self.out_channels = out_channels
        self.dims = 2
        self.model_channels = model_channels
        self.no_control = False
        self.control_scale = 1.0

        self.hint_model = None

        # 2 - Create base and control model
        # TODO 1. create base model, or 2. pass it
        self.base_model = base_model = UNet2DConditionModel()
        # TODO create control model
        self.control_model = ctrl_model = UNet2DConditionModel()


        # 3 - Gather Channel Sizes
        ch_inout_ctrl = {'enc': [], 'mid': [], 'dec': []}
        ch_inout_base = {'enc': [], 'mid': [], 'dec': []}

        # 3.1 - input convolution
        ch_inout_ctrl['enc'].append((ctrl_model.conv_in.in_channels, ctrl_model.conv_in.out_channels))
        ch_inout_base['enc'].append((base_model.conv_in.in_channels, base_model.conv_in.out_channels))

        # 3.2 - encoder blocks
        for module in ctrl_model.down_blocks:
            if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
                    for r in module.resnets:
                        ch_inout_ctrl['enc'].append((r.in_channels, r.out_channels))
                    if module.downsamplers:
                        ch_inout_ctrl['enc'].append((module.downsamplers[0].channels, module.downsamplers[0].out_channels))
            else:
                raise ValueError(f'Encountered unknown module of type {type(module)} while creating ControlNet-XS.')
    
        for module in base_model.down_blocks:
            if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
                    for r in module.resnets:
                        ch_inout_base['enc'].append((r.in_channels, r.out_channels))
                    if module.downsamplers:
                        ch_inout_base['enc'].append((module.downsamplers[0].channels, module.downsamplers[0].out_channels))
            else:
                raise ValueError(f'Encountered unknown module of type {type(module)} while creating ControlNet-XS.')

        # 3.3 - middle block
        ch_inout_ctrl['mid'].append((ctrl_model.mid_block.resnets[0].in_channels, ctrl_model.mid_block.resnets[0].in_channels))
        ch_inout_base['mid'].append((base_model.mid_block.resnets[0].in_channels, base_model.mid_block.resnets[0].in_channels))
    
        # 3.4 - decoder blocks
        for module in base_model.up_blocks:
            if isinstance(module, (CrossAttnUpBlock2D, UpBlock2D)):
                for r in module.resnets:
                    ch_inout_base['dec'].append((r.in_channels, r.out_channels))
            else:
                raise ValueError(f'Encountered unknown module of type {type(module)} while creating ControlNet-XS.')
            
        self.ch_inout_ctrl = ch_inout_ctrl
        self.ch_inout_base = ch_inout_base

        # 4 - Build connections between base and control model
        self.enc_zero_convs_out = nn.ModuleList([])
        self.enc_zero_convs_in = nn.ModuleList([])

        self.middle_block_out = nn.ModuleList([])
        self.middle_block_in = nn.ModuleList([])

        self.dec_zero_convs_out = nn.ModuleList([])
        self.dec_zero_convs_in = nn.ModuleList([])

        for ch_io_base in ch_inout_base['enc']:
            self.enc_zero_convs_in.append(self.make_zero_conv(
                in_channels=ch_io_base[1], out_channels=ch_io_base[1])
            )
        
        self.middle_block_out = self.make_zero_conv(ch_inout_ctrl['mid'][-1][1], ch_inout_base['mid'][-1][1])
        
        self.dec_zero_convs_out.append(
            self.make_zero_conv(ch_inout_ctrl['enc'][-1][1], ch_inout_base['mid'][-1][1])
        )
        for i in range(1, len(ch_inout_ctrl['enc'])):
            self.dec_zero_convs_out.append(
                self.make_zero_conv(ch_inout_ctrl['enc'][-(i + 1)][1], ch_inout_base['dec'][i - 1][1])
            )
    
        # 5 - Input hint block TODO: Understand
        self.input_hint_block = nn.Sequential(
            nn.Conv2d(hint_channels, 16, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(16, 16, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(16, 32, 3, padding=1, stride=2),
            nn.SiLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(32, 96, 3, padding=1, stride=2),
            nn.SiLU(),
            nn.Conv2d(96, 96, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(96, 256, 3, padding=1, stride=2),
            nn.SiLU(),
            zero_module(nn.Conv2d(256, int(model_channels * self.control_model_ratio), 3, padding=1))
        )
    
        scale_list = [1.] * len(self.enc_zero_convs_out) + [1.] + [1.] * len(self.dec_zero_convs_out)
        self.register_buffer('scale_list', torch.tensor(scale_list))

    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        class_labels: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        mid_block_additional_residual: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:

        # pass control_input into controlnet
        
        # # Encoder
        # for zip(control_encoding, base_encoding):
        # ...

        # # Bottleneck
        # control_encoding, base_encoding
        # ...

        # # Decoder
        # for zip(base_decoding):
        # ...

        return UNet2DConditionOutput(sample=sample)

    def make_zero_conv(self, in_channels, out_channels=None):
        # keep running track # todo: better comment
        self.in_channels = in_channels
        self.out_channels = out_channels or in_channels
        return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))

In [8]:
params = {
    'in_channels': 4,
    'out_channels': 4,
    'hint_channels': 3,
    'model_channels': 320,
    'attention_resolutions': [4, 2],
    'num_res_blocks': 2,
}

In [9]:
ControlledUNet2DConditionModel(**params)

ControlledUNet2DConditionModel(
  (base_model): UNet2DConditionModel(
    (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (time_proj): Timesteps()
    (time_embedding): TimestepEmbedding(
      (linear_1): Linear(in_features=320, out_features=1280, bias=True)
      (act): SiLU()
      (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
    )
    (down_blocks): ModuleList(
      (0): CrossAttnDownBlock2D(
        (attentions): ModuleList(
          (0-1): 2 x Transformer2DModel(
            (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
            (proj_in): LoRACompatibleConv(320, 320, kernel_size=(1, 1), stride=(1, 1))
            (transformer_blocks): ModuleList(
              (0): BasicTransformerBlock(
                (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
                (attn1): Attention(
                  (to_q): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
                  (to_

___

Okay, let's run the code manually

In [56]:
for k,v in params.items(): globals()[k] = v

In [57]:
# 1 - Save parameters
self__control_mode = "canny"
self__learn_embedding = False
self__infusion2control = "cat"
self__infusion2base = "add"
self__in_ch_factor = 1 if "cat" == 'add' else 2
self__guiding = "encoder"
self__two_stream_mode = "cross"
self__control_model_ratio = 1.0
self__out_channels = out_channels
self__dims = 2
self__model_channels = model_channels
self__no_control = False
self__control_scale = 1.0

self__hint_model = None

In [58]:
# 2 - Create base and control model
self__base_model = base_model = UNet2DConditionModel()
self__control_model = ctrl_model = UNet2DConditionModel()

In [59]:
# 3 - Gather Channel Sizes
ch_inout_ctrl = {'enc': [], 'mid': [], 'dec': []}
ch_inout_base = {'enc': [], 'mid': [], 'dec': []}

In [60]:
# 3.1 - input convolution
ch_inout_ctrl['enc'].append((ctrl_model.conv_in.in_channels, ctrl_model.conv_in.out_channels))
ch_inout_base['enc'].append((base_model.conv_in.in_channels, base_model.conv_in.out_channels))

In [61]:
# 3.2 - encoder blocks
for module in ctrl_model.down_blocks:
    if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
            for r in module.resnets:
                ch_inout_ctrl['enc'].append((r.in_channels, r.out_channels))
            if module.downsamplers:
                ch_inout_ctrl['enc'].append((module.downsamplers[0].channels, module.downsamplers[0].out_channels))
    else:
        raise ValueError(f'Encountered unknown module of type {type(module)} while creating ControlNet-XS.')

In [62]:
for module in base_model.down_blocks:
    if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
            for r in module.resnets:
                ch_inout_base['enc'].append((r.in_channels, r.out_channels))
            if module.downsamplers:
                ch_inout_base['enc'].append((module.downsamplers[0].channels, module.downsamplers[0].out_channels))
    else:
        raise ValueError(f'Encountered unknown module of type {type(module)} while creating ControlNet-XS.')

In [63]:
# 3.3 - middle block
ch_inout_ctrl['mid'].append((ctrl_model.mid_block.resnets[0].in_channels, ctrl_model.mid_block.resnets[0].in_channels))
ch_inout_base['mid'].append((base_model.mid_block.resnets[0].in_channels, base_model.mid_block.resnets[0].in_channels))

In [64]:
# 3.4 - decoder blocks
for module in base_model.up_blocks:
    if isinstance(module, (CrossAttnUpBlock2D, UpBlock2D)):
        for r in module.resnets:
            ch_inout_base['dec'].append((r.in_channels, r.out_channels))
    else:
        raise ValueError(f'Encountered unknown module of type {type(module)} while creating ControlNet-XS.')

In [65]:
self__ch_inout_ctrl = ch_inout_ctrl
self__ch_inout_base = ch_inout_base

In [66]:
self__ch_inout_ctrl

{'enc': [(4, 320),
  (320, 320),
  (320, 320),
  (320, 320),
  (320, 640),
  (640, 640),
  (640, 640),
  (640, 1280),
  (1280, 1280),
  (1280, 1280),
  (1280, 1280),
  (1280, 1280)],
 'mid': [(1280, 1280)],
 'dec': []}

In [67]:
self__ch_inout_base

{'enc': [(4, 320),
  (320, 320),
  (320, 320),
  (320, 320),
  (320, 640),
  (640, 640),
  (640, 640),
  (640, 1280),
  (1280, 1280),
  (1280, 1280),
  (1280, 1280),
  (1280, 1280)],
 'mid': [(1280, 1280)],
 'dec': [(2560, 1280),
  (2560, 1280),
  (2560, 1280),
  (2560, 1280),
  (2560, 1280),
  (1920, 1280),
  (1920, 640),
  (1280, 640),
  (960, 640),
  (960, 320),
  (640, 320),
  (640, 320)]}

In [68]:
# 4 - Build connections between base and control model
self__enc_zero_convs_out = nn.ModuleList([])
self__enc_zero_convs_in = nn.ModuleList([])

self__middle_block_out = nn.ModuleList([])
self__middle_block_in = nn.ModuleList([])

self__dec_zero_convs_out = nn.ModuleList([])
self__dec_zero_convs_in = nn.ModuleList([])

In [69]:
class DummyObject():
    def make_zero_conv(self, in_channels, out_channels=None):
        # keep running track # todo: better comment
        self.in_channels = in_channels
        self.out_channels = out_channels or in_channels
        return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))

slimshady = DummyObject()

In [70]:
for ch_io_base in ch_inout_base['enc']:
    self__enc_zero_convs_in.append(slimshady.make_zero_conv(
        in_channels=ch_io_base[1], out_channels=ch_io_base[1])
    )

In [71]:
self__middle_block_out = slimshady.make_zero_conv(ch_inout_ctrl['mid'][-1][1], ch_inout_base['mid'][-1][1])

In [73]:
self__dec_zero_convs_out.append(
    slimshady.make_zero_conv(ch_inout_ctrl['enc'][-1][1], ch_inout_base['mid'][-1][1])
)

ModuleList(
  (0): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
)

In [74]:
for i in range(1, len(ch_inout_ctrl['enc'])):
    self__dec_zero_convs_out.append(
        slimshady.make_zero_conv(ch_inout_ctrl['enc'][-(i + 1)][1], ch_inout_base['dec'][i - 1][1])
    )

In [75]:
# 5 - Input hint block TODO: Understand
self__input_hint_block = nn.Sequential(
    nn.Conv2d(hint_channels, 16, 3, padding=1),
    nn.SiLU(),
    nn.Conv2d(16, 16, 3, padding=1),
    nn.SiLU(),
    nn.Conv2d(16, 32, 3, padding=1, stride=2),
    nn.SiLU(),
    nn.Conv2d(32, 32, 3, padding=1),
    nn.SiLU(),
    nn.Conv2d(32, 96, 3, padding=1, stride=2),
    nn.SiLU(),
    nn.Conv2d(96, 96, 3, padding=1),
    nn.SiLU(),
    nn.Conv2d(96, 256, 3, padding=1, stride=2),
    nn.SiLU(),
    zero_module(nn.Conv2d(256, int(model_channels * self__control_model_ratio), 3, padding=1))
)

In [77]:
scale_list = [1.] * len(self__enc_zero_convs_out) + [1.] + [1.] * len(self__dec_zero_convs_out)
#self__register_buffer('scale_list', torch.tensor(scale_list))

Works! 🎉

---

In [82]:
two_stream_model = ControlledUNet2DConditionModel(**params)

Works! 🎉

---

In [89]:
from diffusers.models.unet_2d_condition_control import ControlledUNet2DConditionModel as ControlledUNet2DConditionModelFromCodeBase

In [90]:
ControlledUNet2DConditionModelFromCodeBase

diffusers.models.unet_2d_condition_control.ControlledUNet2DConditionModel

In [91]:
two_stream_model = ControlledUNet2DConditionModelFromCodeBase(**params)

Works! 🎉