In this notebook, I want to map the modules of cnxs' unet and diffusers' onto each other. When done, I should be able to load diffuser weights into cnxs.

**This is a (partial) copy of `map layer names.ipynb`, so I can compare things side by side**

___

In [1]:
#!pip install -Uqq transformers diffusers 

In [2]:
import torch

In [3]:
from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionXLPipeline

In [4]:
model = "stabilityai/stable-diffusion-xl-base-1.0"
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float16)

In [5]:
pipe = StableDiffusionXLPipeline.from_pretrained(model, vae=vae, torch_dtype=torch.float16)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

___

In [6]:
def nested_dict(d, ignore_bias=True):
    root = {}
    for k,v in d.items():
        if 'bias' in k: continue
        parts = k.replace('.weight','').split(".")
        d = root
        for part in parts[:-1]:
            d = d.setdefault(part, {})
        d[parts[-1]] = v 
    return root

def pretty_print_dict(d,lv=2,indent=0,depth=1):
    if lv is not None and depth > lv: return
    if not isinstance(d,dict):
        print(d)
        return
    for k,v in d.items():
        if isinstance(v, dict):
            print('  ' * indent + str(k))
            pretty_print_dict(v,lv,indent+2,depth+1)
        else: 
            print('  ' * indent + str(k) + ' -> ' + str(v))

___

In [7]:
def to_shape(o):
    if isinstance(o,dict): return {k:to_shape(v) for k,v in o.items()}
    elif isinstance(o,list): return o
    else: return list(o.shape)

def remove_bias(o):
    if isinstance(o,dict): return {k.replace('.weight',''):remove_bias(v) for k,v in o.items() if not 'bias' in k}
    else: return o

Load unet from CNXS

In [8]:
import json

In [9]:
with open('cnsx_base_state_dict_with_shapes.json', 'r') as infile:
    cn_sdict = json.load(infile)

In [10]:
cn_sdict = remove_bias(to_shape(cn_sdict))

Load unet from diffusers

In [11]:
df_sdict = remove_bias(to_shape(pipe.unet.state_dict()))

___


### First goal: Map one resnet

In [12]:
cn = cn_sdict
df = df_sdict

cn_bak = cn.copy()
df_bak = df.copy()

In [13]:
from dataclasses import dataclass

@dataclass
class MappedModule:
    cn: str
    df: str
    shape: list
    def __repr__(self): return f'{self.df} {self.shape}'


class UnmappedModel:
    def __init__(self, modules): self.modules = modules
    def print(self, contains='', lv=None): pretty_print_dict(nested_dict(filter_dict(self.modules,by=contains)), lv=lv)
    def remove(self,k): selfl.module.remove(k)

def filter_dict(d,by): return {k:v for k,v in d.items() if by in k}

class MappedModel:
    modules = []

    def __init__(self, unmapped_cn, unmapped_df): self.unmapped_cn,self.unmapped_df=unmapped_cn,unmapped_df
    
    def add(self, cn_module, df_module):
        # check shapes
        cn_shape = self.unmapped_cn[cn_module]
        df_shape = self.unmapped_cn[df_module]
        assert cn_shape==df_shape, f'Mapping don\'t fit: {cn_shape} != {df_shape}'
        # add to mapped
        self.modules = sorted(self.modules + [MappedModule(cn=cn_module,df=df_module,shape=cn_shape)], key=lambda o:o.df)
        # remove from unmapped
        self.unmapped_cn.remove(cn_module)
        self.unmapped_df.remove(df_module)
        
    def __repr__(self): return '\n'.join(str(m) for m in self.modules)

In [14]:
mapped = MappedModel(cn,df)

In [15]:
cn_unmapped = UnmappedModel(cn)
df_unmapped = UnmappedModel(df)

In [34]:
df_unmapped.print('conv_in', lv=5)

conv_in -> [320, 4, 3, 3]


In [35]:
cn_unmapped.print('input_blocks.0', lv=10)

input_blocks
    0
        0 -> [320, 4, 3, 3]


In [74]:
df_unmapped.print('up_blocks.1.resnets.2.norm1', lv=5)

up_blocks
    1
        resnets
            2
                norm1 -> [960]


In [69]:
cn_unmapped.print('output_blocks.0', lv=5)

output_blocks
    0
        0
            in_layers
                0 -> [2560]
                2 -> [1280, 2560, 3, 3]
            emb_layers
                1 -> [1280, 1280]
            out_layers
                0 -> [1280]
                3 -> [1280, 1280, 3, 3]
            skip_connection -> [1280, 2560, 1, 1]
        1
            norm -> [1280]
            proj_in -> [1280, 1280]
            transformer_blocks
                0
                1
                2
                3
                4
                5
                6
                7
                8
                9
            proj_out -> [1280, 1280]


In [68]:
cn_unmapped.print('input_blocks.1', lv=5)

input_blocks
    1
        0
            in_layers
                0 -> [320]
                2 -> [320, 320, 3, 3]
            emb_layers
                1 -> [320, 1280]
            out_layers
                0 -> [320]
                3 -> [320, 320, 3, 3]
