In this notebook, I'm mapping the parameters given by Heidelberg CVL to the parameters required by my model

In [1]:
import torch
from safetensors import safe_open

In [2]:
file = '../../../../.hf-cache/CVL-Heidelberg/sdxl_encD_canny_48m.safetensors'

In [3]:
weights_tensors = {}
with safe_open(file, framework='pt', device='cpu') as f:
   for key in f.keys():
       weights_tensors[key] = f.get_tensor(key)

In [4]:
from util import print_as_nested_dict

These are the params (on lv 1) the weights provide

In [5]:
print_as_nested_dict(sorted(weights_tensors))

control_model
dec_zero_convs_out
enc_zero_convs_in
enc_zero_convs_out
input_hint_block
middle_block_out
scale_list


In [6]:
from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionXLPipeline

In [7]:
model = "stabilityai/stable-diffusion-xl-base-1.0"
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float16)

In [8]:
pipe = StableDiffusionXLPipeline.from_pretrained(model, vae=vae, torch_dtype=torch.float16)
sdxl_unet = pipe.unet

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

At the end of __init__, the sigmas are tensor([14.6146, 14.5263, 14.4386, 14.3515, 14.2651]) ...


In [9]:
from diffusers.models.controlnetxs import ControlNetXSModel

In [10]:
cnxs = ControlNetXSModel.create_as_in_paper(base_model=sdxl_unet)

In [11]:
model_tensors = cnxs.state_dict()

These are the params (on lv 1) the model needs

In [12]:
print_as_nested_dict(sorted(model_tensors))

base_model
control_model
dec_zero_convs_out
enc_zero_convs_in
enc_zero_convs_out
input_hint_block
middle_block_out
scale_list


In [13]:
def lv0(k): return k.split('.')[0]

In [14]:
model_lv0 = set(map(lv0,model_tensors.keys()))
weights_lv0 = set(map(lv0,weights_tensors.keys()))

missing   = sorted(list(weights_lv0 - model_lv0))
unexpected= sorted(list(model_lv0 - weights_lv0))
expected  = sorted(list(model_lv0.intersection(weights_lv0)))


print('Provided in weights and expected in model:')
print(expected)
print('\nProvided by weights, but missing in model:')
print(missing)
print('\nNot provided by weights, but in model:')
print(unexpected)

Provided in weights and expected in model:
['control_model', 'dec_zero_convs_out', 'enc_zero_convs_in', 'enc_zero_convs_out', 'input_hint_block', 'middle_block_out', 'scale_list']

Provided by weights, but missing in model:
[]

Not provided by weights, but in model:
['base_model']


Everything except `base_model` groups can be mapped directly; and `base_model` is the base diffusion model which will be loaded externally

## Let's load everything except the unet

In [15]:
available_key_mapping = {
    'dec_zero_convs_out.0': 'dec_zero_convs_out',
    'enc_zero_convs_in.0': 'enc_zero_convs_in',
    'enc_zero_convs_out.0': 'enc_zero_convs_out',
    'input_hint_block': 'input_hint_block',
    'middle_block_out.0': 'middle_block_out',
    'scale_list': 'scale_list'
}

cnxs_mapping_without_unet = {}
for key_weights in weights_tensors.keys():
    # only consider params starting with one of the above keys 
    if not any(key_weights.startswith(k) for k in available_key_mapping.keys()): continue

    # replace their beginning according to the mapping above
    key_model = key_weights
    for o, replacement in available_key_mapping.items():
        if key_weights.startswith(o):
            key_model = key_weights.replace(o, replacement)
            break

    if key_model in model_tensors:
        cnxs_mapping_without_unet[key_weights] = key_model
    else:
        print(f"Can't find key {modified_k} in model")

In [16]:
len(cnxs_mapping_without_unet)

25

In [17]:
cnxs_mapping_without_unet

{'dec_zero_convs_out.0.0.bias': 'dec_zero_convs_out.0.bias',
 'dec_zero_convs_out.0.0.weight': 'dec_zero_convs_out.0.weight',
 'enc_zero_convs_in.0.0.bias': 'enc_zero_convs_in.0.bias',
 'enc_zero_convs_in.0.0.weight': 'enc_zero_convs_in.0.weight',
 'enc_zero_convs_out.0.0.bias': 'enc_zero_convs_out.0.bias',
 'enc_zero_convs_out.0.0.weight': 'enc_zero_convs_out.0.weight',
 'input_hint_block.0.bias': 'input_hint_block.0.bias',
 'input_hint_block.0.weight': 'input_hint_block.0.weight',
 'input_hint_block.10.bias': 'input_hint_block.10.bias',
 'input_hint_block.10.weight': 'input_hint_block.10.weight',
 'input_hint_block.12.bias': 'input_hint_block.12.bias',
 'input_hint_block.12.weight': 'input_hint_block.12.weight',
 'input_hint_block.14.bias': 'input_hint_block.14.bias',
 'input_hint_block.14.weight': 'input_hint_block.14.weight',
 'input_hint_block.2.bias': 'input_hint_block.2.bias',
 'input_hint_block.2.weight': 'input_hint_block.2.weight',
 'input_hint_block.4.bias': 'input_hint_bloc

So far, we have loaded everything expect the unet (ie `ctrl_model`).

## Let's load the unet

In [18]:
import pickle

In [19]:
with open('mappings/sdxl_state_dict_mapping.pkl', 'rb') as f:
    unet_key_mapping = pickle.load(f)

The unet-mapping-dict maps from diffusers notation to cnxs notation, but I need the map the other way round

In [20]:
unet_key_mapping = {v:k for k,v in unet_key_mapping.items()}

Let's check that every tensor can be mapped from weights into model

In [21]:
weights_unet_params = [k for k in weights_tensors.keys() if k.startswith('control_model')]
model_unet_params   = [k for k in model_tensors.keys()   if k.startswith('control_model')]

In [22]:
print(f'The weights provide {len(weights_unet_params)} parameters for the unet, while the model expects {len(model_unet_params)}')

The weights provide 818 parameters for the unet, while the model expects 814


In [23]:
print(f'The param-mapping-dict for the SDXL unet has {len(unet_key_mapping)} entries')

The param-mapping-dict for the SDXL unet has 2100 entries


Let's first check that all params in weights are present in the unet-mapping-dict

In [24]:
present = [p for p in weights_unet_params if p.replace('control_model.','') in unet_key_mapping.keys()]
not_present = [p for p in weights_unet_params if p.replace('control_model.','') not in unet_key_mapping.keys()]

In [25]:
len(present), len(not_present)

(808, 10)

Cool, almost all params in the weights can be mapped, expect these 10 below.

In [26]:
not_present

['control_model.input_blocks.1.0.skip_connection.bias',
 'control_model.input_blocks.1.0.skip_connection.weight',
 'control_model.input_blocks.2.0.skip_connection.bias',
 'control_model.input_blocks.2.0.skip_connection.weight',
 'control_model.input_blocks.5.0.skip_connection.bias',
 'control_model.input_blocks.5.0.skip_connection.weight',
 'control_model.input_blocks.8.0.skip_connection.bias',
 'control_model.input_blocks.8.0.skip_connection.weight',
 'control_model.middle_block.0.skip_connection.bias',
 'control_model.middle_block.0.skip_connection.weight']

These are all restnet skip connections. It makes sense that these are not in the unet-param-mapping, because in a normal unet, the restnets have equal input and output sizes. Therefore the skip-connections are `nn.Identity` and don't require parameters.

In the controller part of controlnet-xs, we have resnets with different input and output sizes, because we're infusing information from the base model into the control model. Therefore, we use convolutions as skip-connections.

In [27]:
def match_by_parent(o):
    assert 'skip_connection' in o, 'Only skip-connections should be matches via the `match_by_parent` function'
    w,b = 'weight' in o, 'bias' in o
    o = o.replace('control_model.','').replace('.skip_connection','').replace('.weight','').replace('.bias','')
    for k,v in unet_key_mapping.items():
        if o in k:
            o = 'control_model.' + '.'.join(v.split('.')[:-2]) + '.conv_shortcut'
            if w: o+= '.weight'
            if b: o+= '.bias'
            return o
    return None

assert match_by_parent('control_model.input_blocks.1.0.skip_connection.bias')=='control_model.down_blocks.0.resnets.0.conv_shortcut.bias'

Shapes don't need to match fully, they need only be identical after broadcasting. E.g., `(4,4,1,1)` and `(4,4)` should be treated equally

In [28]:
def equal_for_broadcasting(s1, s2):
    l1, l2 = len(s1), len(s2)
    if l1==0 or l2==0: return False
    if l1<l2: s1, s2 = s2, s1 # Make s1 the longer list
    s1 = list(s1)
    s2 = list(s2) + [1] * (len(s1) - len(s2))
    return all(d1 == d2 or d2 == 1 for d1, d2 in zip(s1, s2))

assert equal_for_broadcasting((5,5), (5,5,1,1))
assert equal_for_broadcasting((4,1), (4,))
assert equal_for_broadcasting((3,3,3), (3,3,3))

Let's, for each parameter as defined in the unet-mapping-dict, check if the either it is
- provided by the weights, expected by the model and the shapes fit ✅, or
- provided by the weights, expected by the model, but the shapes mismatch ☑️, or
- provided by the weights, but not missing in the model 🤔

In [29]:
okay, shape_mismatch, missing, not_in_mapping = [],[],[],[]

for k in weights_unet_params:
    key_weights,key_model = None,None
    
    key_weights = k

    if not k.replace('control_model.','') in unet_key_mapping.keys():
        if 'skip_connection' in k:
            key_model = match_by_parent(k)
        else:            
            not_in_mapping.append(k)
            continue
    else:
        key_model = 'control_model.'+unet_key_mapping[k.replace('control_model.','')]
    

    if not key_model in model_tensors:
        missing.append((key_weights, key_model))
        continue
    
    shape_model   = list(model_tensors[key_model].shape)
    shape_weights = list(weights_tensors[key_weights].shape)
    
    if not equal_for_broadcasting(shape_model,shape_weights):
        shape_mismatch.append((key_weights,shape_weights,key_model,shape_model))
        continue

    okay.append((key_weights, key_model))

In [30]:
len(okay),len(shape_mismatch),len(missing),len(not_in_mapping)

(814, 0, 4, 0)

In [31]:
print(f'Reminder: There are {len(weights_unet_params)} params provided by the weights')

Reminder: There are 818 params provided by the weights


In [32]:
print(f'Of those, {len(okay)} params can be matched correctly ✅')

Of those, 814 params can be matched correctly ✅


In [33]:
print(f'{len(shape_mismatch)} params can be matched but have mismatching shapes ☑️. These are:')
for kw,sw,km,sm in shape_mismatch: print(f'- "{kw}" has shape {sw} in weights and {sm} in model.\n\t It\'s name in model is "{km}"')

0 params can be matched but have mismatching shapes ☑️. These are:


In [34]:
print(f'{len(missing)} params are provided in the weights, but missing in the model 🤔. These are:')
for kw,km in missing: print(f'- "{km}" (called "{kw}" in weights)')

4 params are provided in the weights, but missing in the model 🤔. These are:
- "control_model.add_embedding.linear_1.bias" (called "control_model.label_emb.0.0.bias" in weights)
- "control_model.add_embedding.linear_1.weight" (called "control_model.label_emb.0.0.weight" in weights)
- "control_model.add_embedding.linear_2.bias" (called "control_model.label_emb.0.2.bias" in weights)
- "control_model.add_embedding.linear_2.weight" (called "control_model.label_emb.0.2.weight" in weights)


These all belong to the label embedding of the control model, which is not used at all (only the label embedding of the base model is used). So we can safely ignore these 4 params. 

In [35]:
print(f'{len(not_in_mapping)} params are not present in the unet-mapping-dict.')

0 params are not present in the unet-mapping-dict.


### Unexpected params

These params are not provided in the weights, but currently (and wrongly) expected by the model.

In [36]:
matched__model_nomenclature = [km for kw,km in okay] + [km for kw,sw,km,sm in shape_mismatch]

In [37]:
matched__model_nomenclature[:3]

['control_model.conv_in.bias',
 'control_model.conv_in.weight',
 'control_model.down_blocks.0.resnets.0.time_emb_proj.bias']

In [38]:
any(model_unet_params[0]==o for o in matched__model_nomenclature)

True

In [39]:
unexpted_params = [
    p
    for p in model_unet_params
    if not any(p==o for o in matched__model_nomenclature)
]
len(unexpted_params)

0

There are ~900 parameters not provided in weight that the model wongly expects

**Edit:** These were all from the afte the mid block (ie up block, or conv out). I have removed them.

In [40]:
unexpted_params[:3]

[]

In [41]:
def containing(l, strs, invert=False):
    if not isinstance(strs,list): strs=[strs]
    if invert:
        for s in strs: l = list(filter(lambda o:s not in o, l))
    else:
        for s in strs: l = list(filter(lambda o:s in o, l))
    return l

assert containing(['aa','ab'], 'a') == ['aa', 'ab']
assert containing(['aa','ab'], ['a']) == ['aa', 'ab']
assert containing(['aa','ab'], ['aa']) == ['aa']
assert containing(['aa','ab'], ['a','c']) == []
assert containing(['aa','ab'], 'b', invert=True) == ['aa']

**Hypothesis (old):** These are mostly transformer_blocks biases

In [42]:
tf_bias = containing(unexpted_params, ['transformer_blocks', 'bias'])
len(tf_bias)

0

Let's verify that the correctly mapped params don't have transformer_blocks biases

In [43]:
len(containing(matched__model_nomenclature, ['transformer_blocks', 'bias']))

238

**Result: No**, there are biases for tranformer_blocks in the correctly mapped params.

**Hypothesis:** These are mostly upblocks, as our control model doesnt need the up part.

In [44]:
len(containing(unexpted_params, 'up_block')), len(containing(unexpted_params, 'up_block', invert=True))

(0, 0)

**Result: Yes!** 868 of 873 unexpected params are from the upblocks.

We can just ignore these params in the forward, so it doesn't matter that they exist.

___

What are the remaining 4 params?

In [45]:
containing(unexpted_params, 'up_block', invert=True)

[]

`conv_out` and `conv_norm_out` are from the output part (after the up part), so can be ignored as well

**==> All unexpected extra params can be safely ignored**

___

## Map params

We have mapped everything! Let's do the actual mapping and create the mapped cnxs object

In [46]:
cnxs_full_mapping = {
    **cnxs_mapping_without_unet,
    **{kw:km for kw,km in okay}
}
len(cnxs_full_mapping)

839

In [47]:
assert len(cnxs_mapping_without_unet) + len(okay) == len(cnxs_full_mapping)

We need the mapping from diffusers nomenclature to cnxs nomenclature

In [48]:
cnxs_full_mapping = {v:k for k,v in cnxs_full_mapping.items()}

In [49]:
with open('mappings/cnxs_state_dict_mapping.pkl', 'wb') as f:
    pickle.dump(cnxs_full_mapping, f)

In [50]:
cnxs_state_dict = cnxs.state_dict()

We need to make sure the weights fit 100%, even if they're equal for broadcasting

In [51]:
for k,v in cnxs_full_mapping.items():
    mt,wt = model_tensors[k],weights_tensors[v]
    if mt.shape==wt.shape:
        # Load tensor
        cnxs_state_dict[k] = weights_tensors[v]
    else:
        # Load tensor with 2 trailing unit dims added
        assert list(mt.shape)==list(wt.shape)+[1,1], 'Unexpected shape mismatch found'
        cnxs_state_dict[k] = wt.unsqueeze(-1).unsqueeze(-1) 

In [52]:
len(cnxs.enc_zero_convs_out), len(cnxs.dec_zero_convs_out)

(9, 9)

In [53]:
cnxs.load_state_dict(cnxs_state_dict)

<All keys matched successfully>

**All keys matched successfully** 😍🎉✨

I don't want to save the base unet with the control net xs, so let's delete it first

In [54]:
del cnxs.base_model

In [55]:
cnxs.save_pretrained('weights/cnxs')