In this notebook, I want to map the modules of cnxs' unet and diffusers' onto each other. When done, I should be able to load diffuser weights into cnxs.

___

In [1]:
#!pip install -Uqq transformers diffusers 

In [2]:
import torch

In [3]:
from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionXLPipeline

In [4]:
model = "stabilityai/stable-diffusion-xl-base-1.0"
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float16)

In [5]:
pipe = StableDiffusionXLPipeline.from_pretrained(model, vae=vae, torch_dtype=torch.float16)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

___

In [6]:
def nested_dict(d, ignore_bias=True):
    root = {}
    for k,v in d.items():
        if 'bias' in k: continue
        parts = k.replace('.weight','').split(".")
        d = root
        for part in parts[:-1]:
            d = d.setdefault(part, {})
        d[parts[-1]] = v 
    return root

def pretty_print_dict(d,lv=2,indent=0,depth=1):
    if lv is not None and depth > lv: return
    if not isinstance(d,dict):
        print(d)
        return
    for k,v in d.items():
        if isinstance(v, dict):
            print('  ' * indent + str(k))
            pretty_print_dict(v,lv,indent+2,depth+1)
        else: 
            print('  ' * indent + str(k) + ' -> ' + str(v))

___

In [7]:
def to_shape(o):
    if isinstance(o,dict): return {k:to_shape(v) for k,v in o.items()}
    elif isinstance(o,list): return o
    else: return list(o.shape)

def remove_bias(o):
    if isinstance(o,dict): return {k.replace('.weight',''):remove_bias(v) for k,v in o.items() if not 'bias' in k}
    else: return o

Load unet from CNXS

In [8]:
import json

In [9]:
with open('cnsx_base_state_dict_with_shapes.json', 'r') as infile:
    cn_sdict = json.load(infile)

In [10]:
cn_sdict = remove_bias(to_shape(cn_sdict))

Load unet from diffusers

In [11]:
df_sdict = remove_bias(to_shape(pipe.unet.state_dict()))

___


In [12]:
cn_bak = cn_sdict.copy()
df_bak = df_sdict.copy()

### First goal: Map one resnet

In [348]:
from dataclasses import dataclass

@dataclass
class MappedModule:
    cn: str
    df: str
    shape: list
    def __repr__(self): return f'{self.df} {self.shape}'


class UnmappedModel:
    # contains dict of style {'down_blocks.0.resnets.0': [4, 64, 64]}
    def __init__(self, modules): self.modules = modules
    def print(self, contains='', lv=None): pretty_print_dict(nested_dict(filter_dict(self.modules,by=contains)), lv=lv)
    def remove(self,k): del self.modules[k]
    def __getitem__(self,k): return self.modules[k]
    def __len__(self): return len(self.modules)
    def has(self,k): return any(k==k_ for k_ in self.modules.keys())

def filter_dict(d,by): return {k:v for k,v in d.items() if by in k}

class MappedModel:
    modules = []

    def __init__(self, unmapped_cn, unmapped_df):
        self.unmapped_cn=UnmappedModel(unmapped_cn)
        self.unmapped_df=UnmappedModel(unmapped_df)
    
    def add(self, cn_module, df_module):
        # check shapes
        cn_shape = self.unmapped_cn[cn_module]
        df_shape = self.unmapped_df[df_module]
        assert cn_shape==df_shape, f'Mapping don\'t fit: {cn_shape} != {df_shape}'
        # add to mapped
        self.modules = sorted(self.modules + [MappedModule(cn=cn_module,df=df_module,shape=cn_shape)], key=lambda o:o.df)
        # remove from unmapped
        self.unmapped_cn.remove(cn_module)
        self.unmapped_df.remove(df_module)
        #print(f'Added {df_module} ✅')
        
    def __repr__(self): return '\n'.join(str(m) for m in self.modules)
    def __len__(self): return len(self.modules)

    def __getitem__(self,k):
        for m in self.modules:
            if m.df==k: return m.cn
        raise ValueError(f"Didn't find  module with name {k}")

___

In [349]:
cn = cn_bak.copy()
df = df_bak.copy()

In [350]:
mapped = MappedModel(cn,df)

In [351]:
cn_unmapped = UnmappedModel(cn)
df_unmapped = UnmappedModel(df)

In [352]:
len(mapped), len(cn_unmapped), len(df_unmapped)

(0, 1050, 1050)

#### Let's map the downblocks

**Note:** The number of tranformer blocks varies between attentions. In downlock 1, there are 2, in downblock 2 each, there are 10 each.

In [353]:
print(len(pipe.unet.down_blocks[1].attentions[0].transformer_blocks))
print(len(pipe.unet.down_blocks[1].attentions[1].transformer_blocks))
print(len(pipe.unet.down_blocks[2].attentions[0].transformer_blocks))
print(len(pipe.unet.down_blocks[2].attentions[1].transformer_blocks))

2
2
10
10


In [354]:
def map_resnet(idx):
    d1,d2,c=idx
    cn_path = f'input_blocks.{c}.0.'
    df_path = f'down_blocks.{d1}.resnets.{d2}.'
    mapped.add(cn_module=cn_path+'in_layers.0',df_module=df_path+'norm1')
    mapped.add(cn_module=cn_path+'in_layers.2',df_module=df_path+'conv1')
    mapped.add(cn_module=cn_path+'emb_layers.1',df_module=df_path+'time_emb_proj')
    mapped.add(cn_module=cn_path+'out_layers.0',df_module=df_path+'norm2')
    mapped.add(cn_module=cn_path+'out_layers.3',df_module=df_path+'conv2')
    if mapped.unmapped_df.has(df_path+'conv_shortcut'): mapped.add(cn_module=cn_path+'skip_connection',df_module=df_path+'conv_shortcut')

def map_attn(idx):
    d1,d2,c=idx
    cn_path = f'input_blocks.{c}.1.'
    df_path = f'down_blocks.{d1}.attentions.{d2}.'
    mapped.add(cn_module=cn_path+'norm',df_module=df_path+'norm')
    mapped.add(cn_module=cn_path+'proj_in',df_module=df_path+'proj_in')
    num_tfs = 10 if d1==2 else 2 # attns in down block 2 have 10 tranformers each
    for i in range(num_tfs): map_tfmr(f'{cn_path}transformer_blocks.{i}.',f'{df_path}transformer_blocks.{i}.')        
    mapped.add(cn_module=cn_path+'proj_out',df_module=df_path+'proj_out')

def map_tfmr(cn_path,df_path):
    # nomenclature of tranformers is equal in cnxs and diffuers
    modules = [
        'norm1',
        'attn1.to_q','attn1.to_k','attn1.to_v','attn1.to_out.0',
        'norm2',
        'attn2.to_q','attn2.to_k','attn2.to_v','attn2.to_out.0',
        'norm3',
        'ff.net.0.proj','ff.net.2'
    ]
    for m in modules: mapped.add(cn_path+m,df_path+m)

def map_downsample(idx):
    d,c=idx
    cn_path = f'input_blocks.{c}.0.'
    df_path = f'down_blocks.{d}.downsamplers.0.'
    mapped.add(cn_module=cn_path+'op',df_module=df_path+'conv')
    
def map_down_block():
    # # resnets
    idx = (
        (0,0,1),(0,1,2), # down block 0
        (1,0,4),(1,1,5), # down block 1
        (2,0,7),(2,1,8), # down block 2
    )
    for i in idx: map_resnet(i)
    print(f'Mapped {len(idx)} Resnet blocks')
    # # attentions
    idx = (
        (1,0,4),(1,1,5), # down block 1
        (2,0,7),(2,1,8), # down block 2
    )
    for i in idx: map_attn(i)
    print(f'Mapped {len(idx)} Attention blocks')
    # # downsamplers
    idx = (
        (0,3), # down block 0
        (1,6), # down block 1
    )
    for i in idx: map_downsample(i)
    print(f'Mapped {len(idx)} Downsamplers blocks')

In [355]:
map_down_block()

Mapped 6 Resnet blocks
Mapped 4 Attention blocks
Mapped 2 Downsamplers blocks


In [356]:
len(mapped), len(cn_unmapped), len(df_unmapped)

(358, 692, 692)

#### Let's map the embeds, and in/outs

In [357]:
def map_conv_in():
    mapped.add(cn_module='input_blocks.0.0',df_module='conv_in')

def map_conv_out():
    mapped.add(cn_module='out.0',df_module='conv_norm_out')
    mapped.add(cn_module='out.2',df_module='conv_out')    

def map_embedds():
    mapped.add(cn_module='time_embed.0',df_module='time_embedding.linear_1')
    mapped.add(cn_module='time_embed.2',df_module='time_embedding.linear_2')
    mapped.add(cn_module='label_emb.0.0',df_module='add_embedding.linear_1')
    mapped.add(cn_module='label_emb.0.2',df_module='add_embedding.linear_2')

In [358]:
map_conv_in()
map_conv_out()
map_embedds()

In [359]:
len(mapped), len(cn_unmapped), len(df_unmapped)

(365, 685, 685)

#### Let's map the middle block

In [360]:
def map_resnet(d,c):
    cn_path = f'middle_block.{c}.'
    df_path = f'mid_block.resnets.{d}.'
    mapped.add(cn_module=cn_path+'in_layers.0',df_module=df_path+'norm1')
    mapped.add(cn_module=cn_path+'in_layers.2',df_module=df_path+'conv1')
    mapped.add(cn_module=cn_path+'emb_layers.1',df_module=df_path+'time_emb_proj')
    mapped.add(cn_module=cn_path+'out_layers.0',df_module=df_path+'norm2')
    mapped.add(cn_module=cn_path+'out_layers.3',df_module=df_path+'conv2')
    if mapped.unmapped_df.has(df_path+'conv_shortcut'): mapped.add(cn_module=cn_path+'skip_connection',df_module=df_path+'conv_shortcut')

def map_attn(d,c):
    cn_path = f'middle_block.{c}.'
    df_path = f'mid_block.attentions.{d}.'
    mapped.add(cn_module=cn_path+'norm',df_module=df_path+'norm')
    mapped.add(cn_module=cn_path+'proj_in',df_module=df_path+'proj_in')
    num_tfs = 10 # the attn has 10 tranformers each
    for i in range(num_tfs): map_tfmr(f'{cn_path}transformer_blocks.{i}.',f'{df_path}transformer_blocks.{i}.')        
    mapped.add(cn_module=cn_path+'proj_out',df_module=df_path+'proj_out')

def map_tfmr(cn_path,df_path):
    # nomenclature of tranformers is equal in cnxs and diffuers
    modules = [
        'norm1',
        'attn1.to_q','attn1.to_k','attn1.to_v','attn1.to_out.0',
        'norm2',
        'attn2.to_q','attn2.to_k','attn2.to_v','attn2.to_out.0',
        'norm3',
        'ff.net.0.proj','ff.net.2'
    ]
    for m in modules: mapped.add(cn_path+m,df_path+m)

def map_mid_block():
    map_resnet(0,0)
    map_attn(0,1)
    map_resnet(1,2)
    print(f'Mapped 2 Restnets and 1 Attention blocks (R/A/R)')

In [361]:
map_mid_block()

Mapped 2 Restnets and 1 Attention blocks (R/A/R)


In [362]:
len(mapped), len(cn_unmapped), len(df_unmapped)

(508, 542, 542)

#### Let's map the up blocks

In [363]:
n_res_in_down = [len(pipe.unet.down_blocks[i].resnets) for i in (0,1,2)]
n_res_in_up   = [len(pipe.unet.up_blocks  [i].resnets) for i in (0,1,2)]
print(n_res_in_down, '|', n_res_in_up)

[2, 2, 2] | [3, 3, 3]


In the up part, we have **3** resnets per block instead of 2 as in the down part.

In [364]:
n_attn_in_down = [len(pipe.unet.down_blocks[i].attentions) for i in (1,2)]
n_attn_in_up   = [len(pipe.unet.up_blocks  [i].attentions) for i in (0,1)]
print(n_attn_in_down, '|', n_attn_in_up)

[2, 2] | [3, 3]


Similarly, the number of attns is 3 instead of 2.

In [365]:
n_tf_in_up0 = [len(pipe.unet.up_blocks[0].attentions[a].transformer_blocks) for i in (0,1,2)]
n_tf_in_up1 = [len(pipe.unet.up_blocks[1].attentions[a].transformer_blocks) for i in (0,1,2)]
print(n_tf_in_up0, '|', n_tf_in_up1)

[10, 10, 10] | [2, 2, 2]


As in the down part, the lowest block (here: bock 0) has 10 tranformers per attention, while the other block has 2 transformers per attention.

In [366]:
def map_resnet(idx):
    d1,d2,c=idx
    cn_path = f'output_blocks.{c}.0.'
    df_path = f'up_blocks.{d1}.resnets.{d2}.'
    mapped.add(cn_module=cn_path+'in_layers.0',df_module=df_path+'norm1')
    mapped.add(cn_module=cn_path+'in_layers.2',df_module=df_path+'conv1')
    mapped.add(cn_module=cn_path+'emb_layers.1',df_module=df_path+'time_emb_proj')
    mapped.add(cn_module=cn_path+'out_layers.0',df_module=df_path+'norm2')
    mapped.add(cn_module=cn_path+'out_layers.3',df_module=df_path+'conv2')
    if mapped.unmapped_df.has(df_path+'conv_shortcut'): mapped.add(cn_module=cn_path+'skip_connection',df_module=df_path+'conv_shortcut')

def map_attn(idx):
    d1,d2,c=idx
    cn_path = f'output_blocks.{c}.1.'
    df_path = f'up_blocks.{d1}.attentions.{d2}.'
    mapped.add(cn_module=cn_path+'norm',df_module=df_path+'norm')
    mapped.add(cn_module=cn_path+'proj_in',df_module=df_path+'proj_in')
    num_tfs = 10 if d1==0 else 2 # attns in up block 0 have 10 tranformers each
    for i in range(num_tfs): map_tfmr(f'{cn_path}transformer_blocks.{i}.',f'{df_path}transformer_blocks.{i}.')        
    mapped.add(cn_module=cn_path+'proj_out',df_module=df_path+'proj_out')

def map_tfmr(cn_path,df_path):
    # nomenclature of tranformers is equal in cnxs and diffuers
    modules = [
        'norm1',
        'attn1.to_q','attn1.to_k','attn1.to_v','attn1.to_out.0',
        'norm2',
        'attn2.to_q','attn2.to_k','attn2.to_v','attn2.to_out.0',
        'norm3',
        'ff.net.0.proj','ff.net.2'
    ]
    for m in modules: mapped.add(cn_path+m,df_path+m)

def map_upsample(idx):
    d,c=idx
    cn_path = f'output_blocks.{c}.2.'
    df_path = f'up_blocks.{d}.upsamplers.0.'
    mapped.add(cn_module=cn_path+'conv',df_module=df_path+'conv')

def map_up_block():
    # # resnets
    idx = (
        (0,0,0),(0,1,1),(0,2,2), # up block 0
        (1,0,3),(1,1,4),(1,2,5), # up block 1
        (2,0,6),(2,1,7),(2,2,8), # up block 2
    )
    for i in idx: map_resnet(i)
    print(f'Mapped {len(idx)} Resnet blocks')
    # # attentions
    idx = (
        (0,0,0),(0,1,1),(0,2,2), # up block 0
        (1,0,3),(1,1,4),(1,2,5), # up block 1
    )
    for i in idx: map_attn(i)
    print(f'Mapped {len(idx)} Attention blocks')
    # # downsamplers
    idx = (
        (0,2), # up block 0
        (1,5), # up block 1
    )
    for i in idx: map_upsample(i)
    print(f'Mapped {len(idx)} Upsamplers blocks')

In [367]:
map_up_block()

Mapped 9 Resnet blocks
Mapped 6 Attention blocks
Mapped 2 Upsamplers blocks


In [368]:
len(mapped), len(cn_unmapped), len(df_unmapped)

(1050, 0, 0)

**We are done!!**

In [369]:
assert len(df_unmapped)==0
assert len(cn_unmapped)==0

___

Now, let's create the final mapping dictionary

Let's quickly check that all weights are actually mapped (if not, the below line will throw a ValueError).

In [370]:
for k in pipe.unet.state_dict().keys(): mapped[k.replace('.weight','').replace('.bias','')]

Works, cool.

In [371]:
state_dict_mapping = {}
for k in pipe.unet.state_dict().keys():
    k_df = k.replace('.weight','').replace('.bias','')
    k_cn = mapped[k_df]
    state_dict_mapping[k_df+'.weight'] = k_cn+'.weight'
    state_dict_mapping[k_df+'.bias'] = k_cn+'.bias'

In [372]:
state_dict_mapping

{'conv_in.weight': 'input_blocks.0.0.weight',
 'conv_in.bias': 'input_blocks.0.0.bias',
 'time_embedding.linear_1.weight': 'time_embed.0.weight',
 'time_embedding.linear_1.bias': 'time_embed.0.bias',
 'time_embedding.linear_2.weight': 'time_embed.2.weight',
 'time_embedding.linear_2.bias': 'time_embed.2.bias',
 'add_embedding.linear_1.weight': 'label_emb.0.0.weight',
 'add_embedding.linear_1.bias': 'label_emb.0.0.bias',
 'add_embedding.linear_2.weight': 'label_emb.0.2.weight',
 'add_embedding.linear_2.bias': 'label_emb.0.2.bias',
 'down_blocks.0.resnets.0.norm1.weight': 'input_blocks.1.0.in_layers.0.weight',
 'down_blocks.0.resnets.0.norm1.bias': 'input_blocks.1.0.in_layers.0.bias',
 'down_blocks.0.resnets.0.conv1.weight': 'input_blocks.1.0.in_layers.2.weight',
 'down_blocks.0.resnets.0.conv1.bias': 'input_blocks.1.0.in_layers.2.bias',
 'down_blocks.0.resnets.0.time_emb_proj.weight': 'input_blocks.1.0.emb_layers.1.weight',
 'down_blocks.0.resnets.0.time_emb_proj.bias': 'input_blocks.1.

In [373]:
import pickle

with open('state_dict_mapping.pkl', 'wb') as f:
    pickle.dump(state_dict_mapping, f)

Now I should be able to load SDXL weights into the CNXS version of SDXL!