In this notebook, I'm mapping the parameters given by Heidelberg CVL to the parameters required by my diffusers model

In [1]:
import torch

Load Heidelberg weights

In [2]:
file = '../../../../.hf-cache/CVL-Heidelberg/sd21_encD_canny_14m.ckpt'

In [3]:
WEIGHT_SAVE_PATH = 'cnxs-sd-canny'

In [4]:
checkpoint = torch.load(file,map_location=torch.device('cpu'))

In [5]:
weights_tensors = checkpoint['state_dict']

In [6]:
from util import print_as_nested_dict

These are the params (on lv 1) the weights provide

In [7]:
print_as_nested_dict(sorted(weights_tensors), lv=1)
print()
print_as_nested_dict(sorted(weights_tensors), 'control_model', lv=2)

control_model
dec_zero_convs_out
enc_zero_convs_in
enc_zero_convs_out
input_hint_block
middle_block_out
scale_list

control_model
        input_blocks
        middle_block
        time_embed


Create diffuers model (with random weights)

In [8]:
from diffusers import StableDiffusionPipeline

In [9]:
model = "stabilityai/stable-diffusion-2-1-base"

In [10]:
pipe = StableDiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16)
sd_unet = pipe.unet

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [11]:
from diffusers.models.controlnetxs import ControlNetXSModel

In [12]:
cnxs = ControlNetXSModel.init_original(base_model=sd_unet, is_sdxl=False)

`norm_num_groups` was set to `min(block_out_channels)` (=4) so it divides all block_out_channels` ([4, 8, 16, 16]). Set it explicitly to remove this information.


In [13]:
model_tensors = cnxs.state_dict()

These are the params (on lv 1) the model needs

In [14]:
print_as_nested_dict(sorted(model_tensors), lv=1)
print()
print_as_nested_dict(sorted(model_tensors), 'control_model', lv=2)

control_model
controlnet_cond_embedding
down_zero_convs_in
down_zero_convs_out
middle_block_out
up_zero_convs_out

control_model
        conv_in
        down_blocks
        mid_block
        time_embedding


In [15]:
def lv0(k): return k.split('.')[0]

In [16]:
model_lv0 = set(map(lv0,model_tensors.keys()))
weights_lv0 = set(map(lv0,weights_tensors.keys()))

missing   = sorted(list(weights_lv0 - model_lv0))
unexpected= sorted(list(model_lv0 - weights_lv0))
expected  = sorted(list(model_lv0.intersection(weights_lv0)))

print('Provided in weights and expected in model:')
print(expected)
print('\nProvided by weights, but missing in model:')
print(missing)
print('\nNot provided by weights, but in model:')
print(unexpected)

Provided in weights and expected in model:
['control_model', 'middle_block_out']

Provided by weights, but missing in model:
['dec_zero_convs_out', 'enc_zero_convs_in', 'enc_zero_convs_out', 'input_hint_block', 'scale_list']

Not provided by weights, but in model:
['controlnet_cond_embedding', 'down_zero_convs_in', 'down_zero_convs_out', 'up_zero_convs_out']


This is as expected, as
- I've renamed `dec_zero_convs_out`, `enc_zero_convs_in` and `enc_zero_convs_out` into `down_zero_convs_in`, `down_zero_convs_out`, `up_zero_convs_out` to be consistent with diffusers terminology
- I've deleted `scale_list`; it's now passed as an argument in the `forward`
- I've changed the `input_hint_block` to `controlnet_cond_embedding` to be more in line with the implementation of the original ControlNet

## Let's load everything except the unet

In [17]:
print_as_nested_dict(model_tensors, 'middle_block_out', lv=3, print_leaf=True)

middle_block_out	[1280, 16, 1, 1]


In [18]:
num_connections = len([k for k in model_tensors.keys() if ('down_zero_convs_in' in k) and ('weight' in k)])
num_connections

12

In [19]:
print(f'The number of connections is {num_connections} and should be blocks*(layers_per_block+1)=4*(2+1)')

The number of connections is 12 and should be blocks*(layers_per_block+1)=4*(2+1)


In [20]:
available_key_mapping = {
    # NOTE: I'm renaming enc/dec to down/up to be consistent with diffusers terminology
    **{f'dec_zero_convs_out.{i}.0': f'up_zero_convs_out.{i}' for i in range(num_connections)},
    **{f'enc_zero_convs_in.{i}.0': f'down_zero_convs_in.{i}' for i in range(num_connections)},
    **{f'enc_zero_convs_out.{i}.0': f'down_zero_convs_out.{i}' for i in range(num_connections)},
    'input_hint_block.0': 'controlnet_cond_embedding.conv_in',
    **{f'input_hint_block.{2*(i+1)}': f'controlnet_cond_embedding.blocks.{i}' for i in range(6)},
    'input_hint_block.14': 'controlnet_cond_embedding.conv_out',
    'middle_block_out.0': 'middle_block_out',
}

cnxs_mapping_without_unet = {}
for key_weights in weights_tensors.keys():
    # only consider params starting with one of the above keys 
    if not any(key_weights.startswith(k) for k in available_key_mapping.keys()): continue

    # replace their beginning according to the mapping above
    key_model = key_weights
    for o, replacement in available_key_mapping.items():
        if key_weights.startswith(o):
            key_model = key_weights.replace(o, replacement)
            break

    if key_model in model_tensors:
        cnxs_mapping_without_unet[key_weights] = key_model
    else:
        print(f"Can't find key {key_model} in model")

In [21]:
len(cnxs_mapping_without_unet)

90

In [22]:
cnxs_mapping_without_unet

{'enc_zero_convs_out.0.0.weight': 'down_zero_convs_out.0.weight',
 'enc_zero_convs_out.0.0.bias': 'down_zero_convs_out.0.bias',
 'enc_zero_convs_out.1.0.weight': 'down_zero_convs_out.1.weight',
 'enc_zero_convs_out.1.0.bias': 'down_zero_convs_out.1.bias',
 'enc_zero_convs_out.2.0.weight': 'down_zero_convs_out.2.weight',
 'enc_zero_convs_out.2.0.bias': 'down_zero_convs_out.2.bias',
 'enc_zero_convs_out.3.0.weight': 'down_zero_convs_out.3.weight',
 'enc_zero_convs_out.3.0.bias': 'down_zero_convs_out.3.bias',
 'enc_zero_convs_out.4.0.weight': 'down_zero_convs_out.4.weight',
 'enc_zero_convs_out.4.0.bias': 'down_zero_convs_out.4.bias',
 'enc_zero_convs_out.5.0.weight': 'down_zero_convs_out.5.weight',
 'enc_zero_convs_out.5.0.bias': 'down_zero_convs_out.5.bias',
 'enc_zero_convs_out.6.0.weight': 'down_zero_convs_out.6.weight',
 'enc_zero_convs_out.6.0.bias': 'down_zero_convs_out.6.bias',
 'enc_zero_convs_out.7.0.weight': 'down_zero_convs_out.7.weight',
 'enc_zero_convs_out.7.0.bias': 'down_

So far, we have loaded everything expect the unet (ie `ctrl_model`).

## Let's load the unet

In [23]:
import pickle

In [24]:
with open('mappings/sd21_state_dict_mapping.pkl', 'rb') as f:
    unet_key_mapping = pickle.load(f)

The unet-mapping-dict maps from diffusers notation to cnxs notation, but I need the map the other way round

In [25]:
unet_key_mapping = {v:k for k,v in unet_key_mapping.items()}

Let's check that every tensor can be mapped from weights into model

In [26]:
weights_unet_params = [k for k in weights_tensors.keys() if k.startswith('control_model')]
model_unet_params   = [k for k in model_tensors.keys()   if k.startswith('control_model')]

In [27]:
print(f'The weights provide {len(weights_unet_params)} parameters for the unet, while the model expects {len(model_unet_params)}')

The weights provide 312 parameters for the unet, while the model expects 312


In [28]:
print(f'The param-mapping-dict for the SD21 unet has {len(unet_key_mapping)} entries')

The param-mapping-dict for the SD21 unet has 340 entries


Let's first check that all params in weights are present in the unet-mapping-dict

In [29]:
present = [p for p in weights_unet_params if p.replace('control_model.','') in unet_key_mapping.keys()]
not_present = [p for p in weights_unet_params if p.replace('control_model.','') not in unet_key_mapping.keys()]

In [30]:
len(present), len(not_present)

(298, 14)

Cool, almost all params in the weights can be mapped, expect these 10 (Edit: 88 for SD) below.

In [31]:
not_present

['control_model.input_blocks.1.0.skip_connection.weight',
 'control_model.input_blocks.1.0.skip_connection.bias',
 'control_model.input_blocks.2.0.skip_connection.weight',
 'control_model.input_blocks.2.0.skip_connection.bias',
 'control_model.input_blocks.5.0.skip_connection.weight',
 'control_model.input_blocks.5.0.skip_connection.bias',
 'control_model.input_blocks.8.0.skip_connection.weight',
 'control_model.input_blocks.8.0.skip_connection.bias',
 'control_model.input_blocks.10.0.skip_connection.weight',
 'control_model.input_blocks.10.0.skip_connection.bias',
 'control_model.input_blocks.11.0.skip_connection.weight',
 'control_model.input_blocks.11.0.skip_connection.bias',
 'control_model.middle_block.0.skip_connection.weight',
 'control_model.middle_block.0.skip_connection.bias']

These are all restnet skip connections. It makes sense that these are not in the unet-param-mapping, because in a normal unet, the resnets have equal input and output sizes. Therefore the skip-connections are `nn.Identity` and don't require parameters.

In the controller part of controlnet-xs, we have resnets with different input and output sizes, because we're infusing information from the base model into the control model. Therefore, we use convolutions as skip-connections.

In [32]:
def match_by_parent(o):
    assert 'skip_connection' in o, 'Only skip-connections should be matches via the `match_by_parent` function'
    w,b = 'weight' in o, 'bias' in o
    o = o.replace('control_model.','').replace('.skip_connection','').replace('.weight','').replace('.bias','')
    for k,v in unet_key_mapping.items():
        if o in k:
            o = 'control_model.' + '.'.join(v.split('.')[:-2]) + '.conv_shortcut'
            if w: o+= '.weight'
            if b: o+= '.bias'
            return o
    return None

assert match_by_parent('control_model.input_blocks.1.0.skip_connection.bias')=='control_model.down_blocks.0.resnets.0.conv_shortcut.bias'

Shapes don't need to match fully, they need only be identical after broadcasting. E.g., `(4,4,1,1)` and `(4,4)` should be treated equally

In [33]:
def equal_for_broadcasting(s1, s2):
    l1, l2 = len(s1), len(s2)
    if l1==0 or l2==0: return False
    if l1<l2: s1, s2 = s2, s1 # Make s1 the longer list
    s1 = list(s1)
    s2 = list(s2) + [1] * (len(s1) - len(s2))
    return all(d1 == d2 or d2 == 1 for d1, d2 in zip(s1, s2))

assert equal_for_broadcasting((5,5), (5,5,1,1))
assert equal_for_broadcasting((4,1), (4,))
assert equal_for_broadcasting((3,3,3), (3,3,3))

Let's, for each parameter as defined in the unet-mapping-dict, check if the either it is
- provided by the weights, expected by the model and the shapes fit ✅, or
- provided by the weights, expected by the model, but the shapes mismatch ☑️, or
- provided by the weights, but not missing in the model 🤔

In [34]:
okay, shape_mismatch, missing, not_in_mapping = [],[],[],[]

for k in weights_unet_params:
    key_weights,key_model = None,None
    
    key_weights = k

    if not k.replace('control_model.','') in unet_key_mapping.keys():
        if 'skip_connection' in k:
            key_model = match_by_parent(k)
        else:            
            not_in_mapping.append(k)
            continue
    else:
        key_model = 'control_model.'+unet_key_mapping[k.replace('control_model.','')]
    

    if not key_model in model_tensors:
        missing.append((key_weights, key_model))
        continue
    
    shape_model   = list(model_tensors[key_model].shape)
    shape_weights = list(weights_tensors[key_weights].shape)
    
    if not equal_for_broadcasting(shape_model,shape_weights):
        shape_mismatch.append((key_weights,shape_weights,key_model,shape_model))
        continue

    okay.append((key_weights, key_model))

In [35]:
len(okay),len(shape_mismatch),len(missing),len(not_in_mapping)

(312, 0, 0, 0)

In [36]:
print(f'Reminder: There are {len(weights_unet_params)} params provided by the weights')

Reminder: There are 312 params provided by the weights


In [37]:
print(f'Of those, {len(okay)} params can be matched correctly ✅')

Of those, 312 params can be matched correctly ✅


In [38]:
print(f'{len(shape_mismatch)} params can be matched but have mismatching shapes ☑️. These are:')
for kw,sw,km,sm in shape_mismatch: print(f'- "{kw}" has shape {sw} in weights and {sm} in model.\n\t It\'s name in model is "{km}"')

0 params can be matched but have mismatching shapes ☑️. These are:


In [39]:
print(f'{len(missing)} params are provided in the weights, but missing in the model 🤔. These are:')
for kw,km in missing: print(f'- "{km}" (called "{kw}" in weights)')

0 params are provided in the weights, but missing in the model 🤔. These are:


In [40]:
print(f'{len(not_in_mapping)} params are not present in the unet-mapping-dict.')

0 params are not present in the unet-mapping-dict.


### Unexpected params

These params are not provided in the weights, but currently (and wrongly) expected by the model.

In [41]:
matched__model_nomenclature = [km for kw,km in okay] + [km for kw,sw,km,sm in shape_mismatch]

In [42]:
matched__model_nomenclature[:3]

['control_model.time_embedding.linear_1.weight',
 'control_model.time_embedding.linear_1.bias',
 'control_model.time_embedding.linear_2.weight']

In [43]:
any(model_unet_params[0]==o for o in matched__model_nomenclature)

True

In [44]:
unexpted_params = [
    p
    for p in model_unet_params
    if not any(p==o for o in matched__model_nomenclature)
]
len(unexpted_params)

0

In [45]:
def containing(l, strs, invert=False):
    if not isinstance(strs,list): strs=[strs]
    if invert:
        for s in strs: l = list(filter(lambda o:s not in o, l))
    else:
        for s in strs: l = list(filter(lambda o:s in o, l))
    return l

assert containing(['aa','ab'], 'a') == ['aa', 'ab']
assert containing(['aa','ab'], ['a']) == ['aa', 'ab']
assert containing(['aa','ab'], ['aa']) == ['aa']
assert containing(['aa','ab'], ['a','c']) == []
assert containing(['aa','ab'], 'b', invert=True) == ['aa']

___

## Map params

We have mapped everything! Let's do the actual mapping and create the mapped cnxs object

In [46]:
cnxs_full_mapping = {
    **cnxs_mapping_without_unet,
    **{kw:km for kw,km in okay}
}
len(cnxs_full_mapping)

402

In [47]:
assert len(cnxs_mapping_without_unet) + len(okay) == len(cnxs_full_mapping)

We need the mapping from diffusers nomenclature to cnxs nomenclature

In [48]:
cnxs_full_mapping = {v:k for k,v in cnxs_full_mapping.items()}

In [49]:
with open('mappings/sd21_cnxs_state_dict_mapping.pkl', 'wb') as f:
    pickle.dump(cnxs_full_mapping, f)

In [50]:
cnxs_state_dict = cnxs.state_dict()

We need to make sure the weights fit 100%, even if they're equal for broadcasting

In [51]:
for k,v in cnxs_full_mapping.items():
    mt,wt = model_tensors[k],weights_tensors[v]
    if mt.shape==wt.shape:
        # Load tensor
        cnxs_state_dict[k] = weights_tensors[v]
    else:
        # Load tensor with 2 trailing unit dims added
        assert list(mt.shape)==list(wt.shape)+[1,1], 'Unexpected shape mismatch found'
        cnxs_state_dict[k] = wt.unsqueeze(-1).unsqueeze(-1) 

In [52]:
assert len(cnxs.down_zero_convs_out)==num_connections
assert len(cnxs.up_zero_convs_out)==num_connections

In [53]:
cnxs.load_state_dict(cnxs_state_dict)

<All keys matched successfully>

**All keys matched successfully** 😍🎉✨

In [54]:
cnxs.save_pretrained(f'weights/{WEIGHT_SAVE_PATH}')

In [55]:
assert cnxs.control_model.down_blocks[1].attentions[0].transformer_blocks[0].attn1.heads==1

In [56]:
assert cnxs.control_model.down_blocks[2].attentions[0].transformer_blocks[0].attn1.heads==2

Test `from_unet`

In [57]:
second_cnxs = ControlNetXSModel.from_unet(
    sd_unet,
    size_ratio=0.0125,
    dim_attention_heads=8,
)

`norm_num_groups` was set to `min(block_out_channels)` (=4) so it divides all block_out_channels` ([4, 8, 16, 16]). Set it explicitly to remove this information.


In [58]:
from datetime import datetime

now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

print(f"Finished running at {now}")

Finished running at 2023-11-30 19:12:35
