In [1]:
import torch
from safetensors import safe_open

In [2]:
file = '../../../../.hf-cache/CVL-Heidelberg/sdxl_encD_canny_48m.safetensors'

In [3]:
weights_tensors = {}
with safe_open(file, framework='pt', device='cpu') as f:
   for key in f.keys():
       weights_tensors[key] = f.get_tensor(key)

In [4]:
from util import print_as_nested_dict

In [5]:
print_as_nested_dict(weights_tensors)

control_model
dec_zero_convs_out
enc_zero_convs_in
enc_zero_convs_out
input_hint_block
middle_block_out
scale_list


In [6]:
from diffusers.models.controlnetxs import ControlNetXSModel

In [7]:
cnxs = ControlNetXSModel.create_as_in_paper()

In [8]:
model_tensors = cnxs.state_dict()

In [9]:
print_as_nested_dict(model_tensors)

scale_list
base_model
control_model
enc_zero_convs_in
middle_block_out
dec_zero_convs_out
input_hint_block
encoder_hid_proj


In [10]:
# control_model -- control_model <-- needs remapping
# enc_zero_convs_out <-- missing in model
# encoder_hid_proj <-- missing in weights

In [11]:
available_key_mapping = {
    'dec_zero_convs_out.0': 'dec_zero_convs_out',
    'enc_zero_convs_in.0': 'enc_zero_convs_in',
    'input_hint_block': 'input_hint_block',
    'middle_block_out.0': 'middle_block_out',
    'scale_list': 'scale_list'
}

for k in weights_tensors.keys():
    for o in available_key_mapping.keys():
        if o in k:
            modifed_k = k.replace(o, available_key_mapping[o])
            if modifed_k in model_tensors:
                model_tensors[modifed_k] = weights_tensors[k]
            else:
                print(f"Can't find key {modifed_k} in model")

In [12]:
import pickle

In [13]:
with open('state_dict_mapping.pkl', 'rb') as f:
    ctrl_model_key_mapping = pickle.load(f)

Let's check that every tensor can be mapped from weights into model

In [14]:
remaining_from_weights = [k for k in weights_tensors.keys() if k.startswith('control_model')]
remaining_from_model   = [k for k in model_tensors.keys()   if k.startswith('control_model')]

In [15]:
len(remaining_from_model), len(remaining_from_weights)

(2106, 818)

In [16]:
def broadcastable(s1, s2):
    len1, len2 = len(s1), len(s2)
    if len1 == 0 or len2 == 0: return False
    s1 = list(s1) + [1] * (len2 - len1)
    s2 = list(s2) + [1] * (len1 - len2)
    return all(d1==d2 or d1==1 or d2==1 for d1,d2 in zip(s1, s2))

assert broadcastable((5,1,10), (5,5))
assert broadcastable((5,1,10), (5,5,1))
assert broadcastable((5,5), (5,5,1,1))
assert not broadcastable((5,1,10), (5,5,2))

In [17]:
not_in_model,not_in_weights,wrong_shapes = [],[],[]

for k,v in ctrl_model_key_mapping.items():
    key_model   = 'control_model.'+k
    key_weights = 'control_model.'+v

    if key_model in model_tensors:
        shape_model   = model_tensors[key_model].shape
    else:
        not_in_model.append(key_model)
        continue

    if key_weights in weights_tensors:
        shape_weights   = weights_tensors[key_weights].shape
    else:
        not_in_weights.append(key_weights)
        continue
    
    def no_weight(s): return s.replace('.weight','')
    
    if not broadcastable(shape_model,shape_weights): wrong_shapes.append((k,list(shape_model),list(shape_weights)))

    remaining_from_weights.remove(key_weights)
    remaining_from_model.remove(key_model)

In [18]:
len(not_in_model),len(not_in_weights),len(wrong_shapes)

(4, 1292, 1)

#### Missing in model

In [19]:
missing_in_model_tf  = [o for o in not_in_model if 'transformer_blocks' in o]
missing_in_model_rest= [o for o in not_in_model if not 'transformer_blocks' in o]

len(missing_in_model_tf), len(missing_in_model_rest)

(0, 4)

In [20]:
missing_in_model_tf

[]

The missing modules are almost only in attentions

Edit: Not anymore 😎

In [21]:
missing_in_model_rest

['control_model.add_embedding.linear_1.weight',
 'control_model.add_embedding.linear_1.bias',
 'control_model.add_embedding.linear_2.weight',
 'control_model.add_embedding.linear_2.bias']

The rest seems to be the conditional embedding

#### Wrong shapes

In [22]:
wrong_shapes_attn = [o for (o,sm,sw) in wrong_shapes if 'attentions' in o]
wrong_shapes_rest = [o for (o,sm,sw) in wrong_shapes if not 'attentions' in o]
len(wrong_shapes_attn), len(wrong_shapes_rest)

(0, 1)

In [23]:
wrong_shapes_attn

[]

The wrong shapes are almost only in attentions

Edit: Not anymore 😁

In [24]:
def print_wrong_shapes_in_attn(b,a,of='down_blocks'):
    name = f'down_blocks.{b}.attentions.{a}.'
    if of=='mid_block': name=name.replace(f'down_blocks.{b}','mid_block')
    attn0 = [(o,sm,sw) for o,sm,sw in wrong_shapes if name in o]
    for o,sm,sw in attn0:
        o=o.replace(name,'')
        print(o,' -> ',sw,' not (',sm,')')

In [25]:
for d in (0,1):
    for a in (0,1):
        print(f'>>> down_blocks.{d}.attentions.{a}.')
        print_wrong_shapes_in_attn(1,0)

>>> down_blocks.0.attentions.0.
>>> down_blocks.0.attentions.1.
>>> down_blocks.1.attentions.0.
>>> down_blocks.1.attentions.1.


In [26]:
print_wrong_shapes_in_attn(0,0,of='mid_block')

In [27]:
wrong_shapes_rest

['time_embedding.linear_1.weight']

And in `time_embedding.linear_1.weight`

In [28]:
for o,sm,sw in wrong_shapes: print(o,' -> ',sw,' not (',sm,')')

time_embedding.linear_1.weight  ->  [1280, 320]  not ( [1280, 32] )


#### Missing in weights

In [29]:
not_in_weights

['control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.bias',
 'control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.bias',
 'control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.bias',
 'control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.bias',
 'control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.bias',
 'control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.bias',
 'control_model.input_blocks.4.1.transformer_blocks.1.attn1.to_q.bias',
 'control_model.input_blocks.4.1.transformer_blocks.1.attn1.to_k.bias',
 'control_model.input_blocks.4.1.transformer_blocks.1.attn1.to_v.bias',
 'control_model.input_blocks.4.1.transformer_blocks.1.attn2.to_q.bias',
 'control_model.input_blocks.4.1.transformer_blocks.1.attn2.to_k.bias',
 'control_model.input_blocks.4.1.transformer_blocks.1.attn2.to_v.bias',
 'control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.bias',
 'control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k