In this notebook, I want to map the modules of cnxs' unet and diffusers' onto each other. When done, I should be able to load diffuser weights into cnxs.

___

In [1]:
import torch

In [2]:
from diffusers import StableDiffusionPipeline, StableDiffusionControlNetXSPipeline, ControlNetXSModel

In [3]:
model = "stabilityai/stable-diffusion-2-1-base"
sd_pipe = StableDiffusionPipeline.from_pretrained(model)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [4]:
unet = sd_pipe.unet

In [5]:
controlnet = ControlNetXSModel.create_as_in_paper(unet, sdxl=False)

Set `norm_num_groups` to `min(block_out_channels)` (=4) so it divides all block_out_channels` ([4, 8, 16, 16]). Set it explicitly to remove this information.


In [6]:
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(model, controlnet=controlnet)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [7]:
control_unet = pipe.controlnet.control_model

___

In [8]:
def nested_dict(d, ignore_bias=True):
    root = {}
    for k,v in d.items():
        if 'bias' in k: continue
        parts = k.replace('.weight','').split(".")
        d = root
        for part in parts[:-1]:
            d = d.setdefault(part, {})
        d[parts[-1]] = v 
    return root

def pretty_print_dict(d,lv=2,indent=0,depth=1):
    if lv is not None and depth > lv: return
    if not isinstance(d,dict):
        print(d)
        return
    for k,v in d.items():
        if isinstance(v, dict):
            print('  ' * indent + str(k))
            pretty_print_dict(v,lv,indent+2,depth+1)
        else: 
            print('  ' * indent + str(k) + ' -> ' + str(v))

___

In [9]:
def to_shape(o):
    if isinstance(o,dict): return {k:to_shape(v) for k,v in o.items()}
    elif isinstance(o,list): return o
    else: return list(o.shape)

def remove_bias(o):
    if isinstance(o,dict): return {k.replace('.weight',''):remove_bias(v) for k,v in o.items() if not 'bias' in k}
    else: return o

Load unet from CNXS

In [10]:
import json

In [11]:
with open('cnsx_base_state_dict_with_shapes -- sd21.json', 'r') as infile:
    cn_sdict = json.load(infile)

In [12]:
from util import print_as_nested_dict
print_as_nested_dict(cn_sdict, lv=1)
print()
print_as_nested_dict(cn_sdict, 'control_model', lv=2,)

scale_list
control_model
enc_zero_convs_out
enc_zero_convs_in
dec_zero_convs_out
middle_block_out
input_hint_block

control_model
        time_embed
        input_blocks
        middle_block


‼️ This notebook should only map the unet part of the control model. So let's delete the rest (connections, time embedding, input hint)

In [13]:
cn_sdict = {
    k.replace('control_model.',''): v for k,v in cn_sdict.items() if k.startswith('control_model.')  
}

In [14]:
cn_sdict = remove_bias(to_shape(cn_sdict))

Load unet from diffusers

In [15]:
df_sdict = remove_bias(to_shape(control_unet.state_dict()))

___


In [16]:
cn_bak = cn_sdict.copy()
df_bak = df_sdict.copy()

### First goal: Map one resnet

In [17]:
from dataclasses import dataclass

@dataclass
class MappedModule:
    cn: str
    df: str
    shape: list
    def __repr__(self): return f'{self.df} {self.shape}'


class UnmappedModel:
    # contains dict of style {'down_blocks.0.resnets.0': [4, 64, 64]}
    def __init__(self, modules): self.modules = modules
    def print(self, contains='', lv=None): pretty_print_dict(nested_dict(filter_dict(self.modules,by=contains)), lv=lv)
    def remove(self,k): del self.modules[k]
    def __getitem__(self,k): return self.modules[k]
    def __len__(self): return len(self.modules)
    def has(self,k): return any(k==k_ for k_ in self.modules.keys())
    def __iter__(self): return iter(self.modules)

def filter_dict(d,by): return {k:v for k,v in d.items() if by in k}

class MappedModel:
    modules = []

    def __init__(self, unmapped_cn, unmapped_df):
        self.unmapped_cn=UnmappedModel(unmapped_cn)
        self.unmapped_df=UnmappedModel(unmapped_df)
    
    def add(self, cn_module, df_module):
        # check shapes
        cn_shape = self.unmapped_cn[cn_module]
        df_shape = self.unmapped_df[df_module]
        assert cn_shape==df_shape, f'Mapping don\'t fit: {cn_shape} != {df_shape}'
        # add to mapped
        self.modules = sorted(self.modules + [MappedModule(cn=cn_module,df=df_module,shape=cn_shape)], key=lambda o:o.df)
        # remove from unmapped
        self.unmapped_cn.remove(cn_module)
        self.unmapped_df.remove(df_module)
        #print(f'Added {df_module} ✅')
        
    def __repr__(self): return '\n'.join(str(m) for m in self.modules)
    def __len__(self): return len(self.modules)

    def __getitem__(self,k):
        for m in self.modules:
            if m.df==k: return m.cn
        raise ValueError(f"Didn't find  module with name {k}")

___

In [18]:
cn = cn_bak.copy()
df = df_bak.copy()

In [19]:
mapped = MappedModel(cn,df)

In [20]:
cn_unmapped = UnmappedModel(cn)
df_unmapped = UnmappedModel(df)

In [21]:
len(mapped), len(cn_unmapped), len(df_unmapped)

(0, 177, 177)

#### Let's map the downblocks

**Note:** The number of tranformer blocks varies between attentions. In downlock 1, there are 2, in downblock 2 each, there are 10 each.

**Edit:** No, that's only the case for SDXL, not for SD.

In [22]:
from util import cls_name

In [23]:
for i, down in enumerate(control_unet.down_blocks):
    print(f'{i}] {cls_name(down)}: ', end='')
    if hasattr(down, 'attentions'):
         print(' '.join(str(len(a.transformer_blocks)) for a in down.attentions), end='')
    print()

0] CrossAttnDownBlock2D: 1 1
1] CrossAttnDownBlock2D: 1 1
2] CrossAttnDownBlock2D: 1 1
3] DownBlock2D: 


In [24]:
from util import print_as_nested_dict

def listy(o): return isinstance(o, (list, tuple))
def listify(o): return o if listy(o) else [o]

def compare_as_nested_dict(cnxs_search, df_search, lvs, print_leaf=True):
    cnxs_search, df_search = listify(cnxs_search), listify(df_search)
    if not listy(lvs): lvs = (lvs,lvs)
    lv_cnxs, lv_df = lvs

    for query in cnxs_search:
        print_as_nested_dict(cn_sdict, query, lv=lv_cnxs, print_leaf=print_leaf)
        print()
    print('----')
    for query in df_search:
        print_as_nested_dict(df_sdict, query, lv=lv_df, print_leaf=print_leaf)
        print()        

In [25]:
#compare_as_nested_dict('input_blocks.1.0', 'down_blocks.0.resnets', lvs=6)

In [26]:
def map_resnet(idx):
    d1,d2,c=idx
    cn_path = f'input_blocks.{c}.0.'
    df_path = f'down_blocks.{d1}.resnets.{d2}.'
    mapped.add(cn_module=cn_path+'in_layers.0',df_module=df_path+'norm1')
    mapped.add(cn_module=cn_path+'in_layers.2',df_module=df_path+'conv1')
    mapped.add(cn_module=cn_path+'emb_layers.1',df_module=df_path+'time_emb_proj')
    mapped.add(cn_module=cn_path+'out_layers.0',df_module=df_path+'norm2')
    mapped.add(cn_module=cn_path+'out_layers.3',df_module=df_path+'conv2')
    if mapped.unmapped_df.has(df_path+'conv_shortcut'): mapped.add(cn_module=cn_path+'skip_connection',df_module=df_path+'conv_shortcut')
    print(f'-- mapped resnet {d1},{d2} / {c}')

def map_attn(idx):
    d1,d2,c=idx
    cn_path = f'input_blocks.{c}.1.'
    df_path = f'down_blocks.{d1}.attentions.{d2}.'
    mapped.add(cn_module=cn_path+'norm',df_module=df_path+'norm')
    mapped.add(cn_module=cn_path+'proj_in',df_module=df_path+'proj_in')
    num_tfs = 1 # sd21 always has 1 tf per layer for blocks 0 - 3; and none in block 4
    for i in range(num_tfs): map_tfmr(f'{cn_path}transformer_blocks.{i}.',f'{df_path}transformer_blocks.{i}.')        
    mapped.add(cn_module=cn_path+'proj_out',df_module=df_path+'proj_out')
    print(f'-- mapped attn {d1},{d2} / {c}')


def map_tfmr(cn_path,df_path):
    # nomenclature of tranformers is equal in cnxs and diffuers
    modules = [
        'norm1',
        'attn1.to_q','attn1.to_k','attn1.to_v','attn1.to_out.0',
        'norm2',
        'attn2.to_q','attn2.to_k','attn2.to_v','attn2.to_out.0',
        'norm3',
        'ff.net.0.proj','ff.net.2'
    ]
    for m in modules: mapped.add(cn_path+m,df_path+m)

def map_downsample(idx):
    d,c=idx
    cn_path = f'input_blocks.{c}.0.'
    df_path = f'down_blocks.{d}.downsamplers.0.'
    mapped.add(cn_module=cn_path+'op',df_module=df_path+'conv')
    print(f'-- mapped downsample {d} / {c}')

def map_down_block():
    blocks = 4
    layers_per_block = 2

    # # resnets
    def subblock_no(b,l): return 1+b*(layers_per_block+1)+l
    idx = [
        (b,l, subblock_no(b,l))
        for b in range(blocks)
        for l in range(layers_per_block)
    ]
    for i in idx: map_resnet(i)
    print(f'Mapped {len(idx)} Resnet blocks')
    # # attentions
    idx = [
        (b,l, subblock_no(b,l))
        for b in range(blocks-1) # in sd, last block has no attentions
        for l in range(layers_per_block)
    ]
    for i in idx: map_attn(i)
    print(f'Mapped {len(idx)} Attention blocks')
    # # downsamplers
    idx = [
        (b,(b+1)*(layers_per_block+1))
        for b in range(blocks-1) # last block has no downsampler
    ]
    for i in idx: map_downsample(i)
    print(f'Mapped {len(idx)} Downsamplers blocks')

In [27]:
map_down_block()

-- mapped resnet 0,0 / 1
-- mapped resnet 0,1 / 2
-- mapped resnet 1,0 / 4
-- mapped resnet 1,1 / 5
-- mapped resnet 2,0 / 7
-- mapped resnet 2,1 / 8
-- mapped resnet 3,0 / 10
-- mapped resnet 3,1 / 11
Mapped 8 Resnet blocks
-- mapped attn 0,0 / 1
-- mapped attn 0,1 / 2
-- mapped attn 1,0 / 4
-- mapped attn 1,1 / 5
-- mapped attn 2,0 / 7
-- mapped attn 2,1 / 8
Mapped 6 Attention blocks
-- mapped downsample 0 / 3
-- mapped downsample 1 / 6
-- mapped downsample 2 / 9
Mapped 3 Downsamplers blocks


In [28]:
len(mapped), len(cn_unmapped), len(df_unmapped)

(147, 30, 30)

In [29]:
cn_unmapped.print(lv=1)

time_embed
input_blocks
middle_block


#### Let's map the embeds, and conv in

In [33]:
def map_conv_in():
    mapped.add(cn_module='input_blocks.0.0',df_module='conv_in')

def map_embedds():
    mapped.add(cn_module='time_embed.0',df_module='time_embedding.linear_1')
    mapped.add(cn_module='time_embed.2',df_module='time_embedding.linear_2')
    # # SD (unlike SDXL) has not micro-conditioning, so no add_embedding
    # mapped.add(cn_module='label_emb.0.0',df_module='add_embedding.linear_1')
    # mapped.add(cn_module='label_emb.0.2',df_module='add_embedding.linear_2')

In [31]:
map_conv_in()
map_embedds()

In [32]:
len(mapped), len(cn_unmapped), len(df_unmapped)

(150, 27, 27)

#### Let's map the middle block

In [37]:
mid = control_unet.mid_block

print(f'Mid block has {len(mid.resnets)} resnets and {len(mid.attentions)} attentions')

print('Here are the # of transformers: ',end='')
print(' '.join(str(len(a.transformer_blocks)) for a in mid.attentions), end='')

Mid block has 2 resnets and 1 attentions
Here are the # of transformers: 1

In [38]:
def map_resnet(d,c):
    cn_path = f'middle_block.{c}.'
    df_path = f'mid_block.resnets.{d}.'
    mapped.add(cn_module=cn_path+'in_layers.0',df_module=df_path+'norm1')
    mapped.add(cn_module=cn_path+'in_layers.2',df_module=df_path+'conv1')
    mapped.add(cn_module=cn_path+'emb_layers.1',df_module=df_path+'time_emb_proj')
    mapped.add(cn_module=cn_path+'out_layers.0',df_module=df_path+'norm2')
    mapped.add(cn_module=cn_path+'out_layers.3',df_module=df_path+'conv2')
    if mapped.unmapped_df.has(df_path+'conv_shortcut'): mapped.add(cn_module=cn_path+'skip_connection',df_module=df_path+'conv_shortcut')

def map_attn(d,c):
    cn_path = f'middle_block.{c}.'
    df_path = f'mid_block.attentions.{d}.'
    mapped.add(cn_module=cn_path+'norm',df_module=df_path+'norm')
    mapped.add(cn_module=cn_path+'proj_in',df_module=df_path+'proj_in')
    num_tfs = 1
    for i in range(num_tfs): map_tfmr(f'{cn_path}transformer_blocks.{i}.',f'{df_path}transformer_blocks.{i}.')        
    mapped.add(cn_module=cn_path+'proj_out',df_module=df_path+'proj_out')

def map_tfmr(cn_path,df_path):
    # nomenclature of tranformers is equal in cnxs and diffuers
    modules = [
        'norm1',
        'attn1.to_q','attn1.to_k','attn1.to_v','attn1.to_out.0',
        'norm2',
        'attn2.to_q','attn2.to_k','attn2.to_v','attn2.to_out.0',
        'norm3',
        'ff.net.0.proj','ff.net.2'
    ]
    for m in modules: mapped.add(cn_path+m,df_path+m)

def map_mid_block():
    map_resnet(0,0)
    map_attn(0,1)
    map_resnet(1,2)
    print(f'Mapped 2 Restnets and 1 Attention blocks (R/A/R)')

In [39]:
map_mid_block()

Mapped 2 Restnets and 1 Attention blocks (R/A/R)


In [40]:
len(mapped), len(cn_unmapped), len(df_unmapped)

(177, 0, 0)

In [41]:
cn_to_skip = []
df_to_skip = ['up_blocks', 'conv_norm_out', 'conv_out']

for module in cn_unmapped:
    assert any(o in str(module) for o in cn_to_skip), f"Couldn't map {module}"

for module in df_unmapped:
    assert any(o in str(module) for o in df_to_skip), f"Couldn't map {module}"

**Done!**

___

Now, let's create the final mapping dictionary

Let's quickly check that all weights are actually mapped (if not, the below line will throw a ValueError).

In [42]:
for k in pipe.unet.state_dict().keys():
    if any(o in k for o in df_to_skip): continue
    mapped[k.replace('.weight','').replace('.bias','')]

Works, cool.

In [43]:
state_dict_mapping = {}
for k in pipe.unet.state_dict().keys():
    if any(o in k for o in df_to_skip): continue
    
    k_df = k.replace('.weight','').replace('.bias','')
    k_cn = mapped[k_df]
    state_dict_mapping[k_df+'.weight'] = k_cn+'.weight'
    state_dict_mapping[k_df+'.bias'] = k_cn+'.bias'

In [44]:
state_dict_mapping

{'conv_in.weight': 'input_blocks.0.0.weight',
 'conv_in.bias': 'input_blocks.0.0.bias',
 'time_embedding.linear_1.weight': 'time_embed.0.weight',
 'time_embedding.linear_1.bias': 'time_embed.0.bias',
 'time_embedding.linear_2.weight': 'time_embed.2.weight',
 'time_embedding.linear_2.bias': 'time_embed.2.bias',
 'down_blocks.0.attentions.0.norm.weight': 'input_blocks.1.1.norm.weight',
 'down_blocks.0.attentions.0.norm.bias': 'input_blocks.1.1.norm.bias',
 'down_blocks.0.attentions.0.proj_in.weight': 'input_blocks.1.1.proj_in.weight',
 'down_blocks.0.attentions.0.proj_in.bias': 'input_blocks.1.1.proj_in.bias',
 'down_blocks.0.attentions.0.transformer_blocks.0.norm1.weight': 'input_blocks.1.1.transformer_blocks.0.norm1.weight',
 'down_blocks.0.attentions.0.transformer_blocks.0.norm1.bias': 'input_blocks.1.1.transformer_blocks.0.norm1.bias',
 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight': 'input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight',
 'down_blocks.0.atten

In [45]:
import pickle

with open('sd21_state_dict_mapping.pkl', 'wb') as f:
    pickle.dump(state_dict_mapping, f)