In this notebook, I want to map the modules of cnxs' unet and diffusers' onto each other. When done, I should be able to load diffuser weights into cnxs.

___

In [1]:
#!pip install -Uqq transformers diffusers 

In [2]:
import torch

In [3]:
from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionXLPipeline

In [4]:
model = "stabilityai/stable-diffusion-xl-base-1.0"
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float16)

In [5]:
pipe = StableDiffusionXLPipeline.from_pretrained(model, vae=vae, torch_dtype=torch.float16)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

___

In [6]:
def nested_dict(d, ignore_bias=True):
    root = {}
    for k,v in d.items():
        if 'bias' in k: continue
        parts = k.replace('.weight','').split(".")
        d = root
        for part in parts[:-1]:
            d = d.setdefault(part, {})
        d[parts[-1]] = v 
    return root

def pretty_print_dict(d,lv=2,indent=0,depth=1):
    if lv is not None and depth > lv: return
    if not isinstance(d,dict):
        print(d)
        return
    for k,v in d.items():
        if isinstance(v, dict):
            print('  ' * indent + str(k))
            pretty_print_dict(v,lv,indent+2,depth+1)
        else: 
            print('  ' * indent + str(k) + ' -> ' + str(v))

___

In [7]:
def to_shape(o):
    if isinstance(o,dict): return {k:to_shape(v) for k,v in o.items()}
    elif isinstance(o,list): return o
    else: return list(o.shape)

def remove_bias(o):
    if isinstance(o,dict): return {k.replace('.weight',''):remove_bias(v) for k,v in o.items() if not 'bias' in k}
    else: return o

Load unet from CNXS

In [8]:
import json

In [9]:
with open('cnsx_base_state_dict_with_shapes.json', 'r') as infile:
    cn_sdict = json.load(infile)

In [10]:
cn_sdict = remove_bias(to_shape(cn_sdict))

Load unet from diffusers

In [11]:
df_sdict = remove_bias(to_shape(pipe.unet.state_dict()))

___


In [104]:
cn_bak = cn_sdict.copy()
df_bak = df_sdict.copy()

### First goal: Map one resnet

In [120]:
from dataclasses import dataclass

@dataclass
class MappedModule:
    cn: str
    df: str
    shape: list
    def __repr__(self): return f'{self.df} {self.shape}'


class UnmappedModel:
    # contains dict of style {'down_blocks.0.resnets.0': [4, 64, 64]}
    def __init__(self, modules): self.modules = modules
    def print(self, contains='', lv=None): pretty_print_dict(nested_dict(filter_dict(self.modules,by=contains)), lv=lv)
    def remove(self,k): del self.modules[k]
    def __getitem__(self,k): return self.modules[k]
    def __len__(self): return len(self.modules)
    def has(self,k): return any(k==k_ for k_ in self.modules.keys())

def filter_dict(d,by): return {k:v for k,v in d.items() if by in k}

class MappedModel:
    modules = []

    def __init__(self, unmapped_cn, unmapped_df):
        self.unmapped_cn=UnmappedModel(unmapped_cn)
        self.unmapped_df=UnmappedModel(unmapped_df)
    
    def add(self, cn_module, df_module):
        # check shapes
        cn_shape = self.unmapped_cn[cn_module]
        df_shape = self.unmapped_df[df_module]
        assert cn_shape==df_shape, f'Mapping don\'t fit: {cn_shape} != {df_shape}'
        # add to mapped
        self.modules = sorted(self.modules + [MappedModule(cn=cn_module,df=df_module,shape=cn_shape)], key=lambda o:o.df)
        # remove from unmapped
        self.unmapped_cn.remove(cn_module)
        self.unmapped_df.remove(df_module)
        print(f'Added {df_module} ✅')
        
    def __repr__(self): return '\n'.join(str(m) for m in self.modules)
    def __len__(self): return len(self.modules)

___

In [121]:
cn = cn_bak.copy()
df = df_bak.copy()

In [122]:
mapped = MappedModel(cn,df)

In [123]:
cn_unmapped = UnmappedModel(cn)
df_unmapped = UnmappedModel(df)

In [124]:
cn_unmapped.print('input_blocks.1.0', lv=5)

input_blocks
    1
        0
            in_layers
                0 -> [320]
                2 -> [320, 320, 3, 3]
            emb_layers
                1 -> [320, 1280]
            out_layers
                0 -> [320]
                3 -> [320, 320, 3, 3]


In [125]:
len(mapped), len(cn_unmapped), len(df_unmapped)

(0, 1050, 1050)

In [126]:
def map_resnet(idx):
    d1,d2,c=idx
    cn_path = f'input_blocks.{c}.0.'
    df_path = f'down_blocks.{d1}.resnets.{d2}.'
    mapped.add(cn_module=cn_path+'in_layers.0',df_module=df_path+'norm1')
    mapped.add(cn_module=cn_path+'in_layers.2',df_module=df_path+'conv1')
    mapped.add(cn_module=cn_path+'emb_layers.1',df_module=df_path+'time_emb_proj')
    mapped.add(cn_module=cn_path+'out_layers.0',df_module=df_path+'norm2')
    mapped.add(cn_module=cn_path+'out_layers.3',df_module=df_path+'conv2')
    if mapped.unmapped_df.has(df_path+'conv_shortcut'): mapped.add(cn_module=cn_path+'skip_connection',df_module=df_path+'conv_shortcut')

def map_downsample(idx):
    d1,d2,c=idx
    cn_path = f'input_blocks.{c}.0.'
    df_path = f'down_blocks.{d1}.downsamplers.{d2}.'
    mapped.add(cn_module=cn_path+'op',df_module=df_path+'conv')
    
def map_attn(idx):
    d1,d2,c=idx
    cn_path = f'input_blocks.{c}.1.'
    df_path = f'down_blocks.{d1}.attentions.{d2}.'
    mapped.add(cn_module=cn_path+'norm',df_module=df_path+'norm')
    mapped.add(cn_module=cn_path+'proj_in',df_module=df_path+'proj_in')
    # todo tf blocks
    mapped.add(cn_module=cn_path+'proj_out',df_module=df_path+'proj_out')
    

def map_down_block():
    # # resnets
    idx = (
        (0,0,1),
        (0,1,2),
        (1,0,4),
        (1,1,5),
    )
    for i in idx: map_resnet(i)
    # # attentions
    idx = (
        (1,0,4),
        (1,1,5),
    )
    for i in idx: map_resnet(i)
    # # downsamplers

In [127]:
map_down_block()

Added down_blocks.0.resnets.0.norm1 ✅
Added down_blocks.0.resnets.0.conv1 ✅
Added down_blocks.0.resnets.0.time_emb_proj ✅
Added down_blocks.0.resnets.0.norm2 ✅
Added down_blocks.0.resnets.0.conv2 ✅
Added down_blocks.0.resnets.1.norm1 ✅
Added down_blocks.0.resnets.1.conv1 ✅
Added down_blocks.0.resnets.1.time_emb_proj ✅
Added down_blocks.0.resnets.1.norm2 ✅
Added down_blocks.0.resnets.1.conv2 ✅
Added down_blocks.1.resnets.0.norm1 ✅
Added down_blocks.1.resnets.0.conv1 ✅
Added down_blocks.1.resnets.0.time_emb_proj ✅
Added down_blocks.1.resnets.0.norm2 ✅
Added down_blocks.1.resnets.0.conv2 ✅
Added down_blocks.1.resnets.0.conv_shortcut ✅
Added down_blocks.1.resnets.1.norm1 ✅
Added down_blocks.1.resnets.1.conv1 ✅
Added down_blocks.1.resnets.1.time_emb_proj ✅
Added down_blocks.1.resnets.1.norm2 ✅
Added down_blocks.1.resnets.1.conv2 ✅


KeyError: 'input_blocks.4.0.in_layers.0'

In [77]:
len(mapped), len(cn_unmapped), len(df_unmapped)

(15, 1035, 1035)

In [None]:
mapped.d