In [1]:
from diffusers import PixArtTransformer2DModel
from peft import LoraConfig, LNTuningConfig
import json
import os
import peft

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pixart_dit_config = json.load(
            open(os.path.join('/mnt/ceph_rbd/zbc/pixart', "config.json"), "r"))

In [3]:
model = PixArtTransformer2DModel(**pixart_dit_config)

In [4]:
lora_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=['attn1.to_q', 'attn1.to_v']
)

In [5]:
peft_model = peft.PeftMixedModel(model, lora_config)

In [6]:
peft_model.print_trainable_parameters()

trainable params: 4,128,768 || all params: 614,984,864 || trainable%: 0.6714


In [7]:
xattn_config = LNTuningConfig(
    target_modules=['attn2'],
)

In [8]:
peft_model.add_adapter('xattn', xattn_config)

ValueError: The provided `peft_type` 'LN_TUNING' is not compatible with the `PeftMixedModel`. Compatible types are: (<PeftType.LORA: 'LORA'>, <PeftType.LOHA: 'LOHA'>, <PeftType.LOKR: 'LOKR'>, <PeftType.ADALORA: 'ADALORA'>, <PeftType.OFT: 'OFT'>)

In [9]:
import torch

In [10]:
import safetensors

In [11]:
ckpt = {}
with safetensors.safe_open('/mnt/ceph_rbd/zbc/slot_ar/dit_imagenet100_pixart_slot256/models/step60000/model.safetensors',
                           framework='pt') as f:
    for k in f.keys():
        ckpt[k] = f.get_tensor(k)

In [16]:
ckpt['vae.decoder.conv_in.bias'][:10]

tensor([ 0.1374, -0.1938, -0.0090, -0.1122,  0.1218,  0.0212,  0.1820,  0.0612,
        -0.1625,  0.0486])

In [18]:
old_vae_ckpt = {}
with safetensors.safe_open('/mnt/ceph_rbd/zbc/pixart/pixart_vae.safetensors',
                           framework='pt') as f:
    for k in f.keys():
        old_vae_ckpt[k] = f.get_tensor(k)

In [20]:
old_vae_ckpt['decoder.conv_in.bias'][:10]

tensor([ 0.1374, -0.1938, -0.0090, -0.1122,  0.1218,  0.0212,  0.1820,  0.0612,
        -0.1625,  0.0486])

In [21]:
old_pixart_ckpt = {}
with safetensors.safe_open('/mnt/ceph_rbd/zbc/pixart/pixart_alpha256XL_2.safetensors',
                           framework='pt') as f:
    for k in f.keys():
        old_pixart_ckpt[k] = f.get_tensor(k)

In [24]:
for k in old_pixart_ckpt.keys():
    if 'attn2' in k:
        print(k)

transformer_blocks.0.attn2.to_k.bias
transformer_blocks.0.attn2.to_k.weight
transformer_blocks.0.attn2.to_out.0.bias
transformer_blocks.0.attn2.to_out.0.weight
transformer_blocks.0.attn2.to_q.bias
transformer_blocks.0.attn2.to_q.weight
transformer_blocks.0.attn2.to_v.bias
transformer_blocks.0.attn2.to_v.weight
transformer_blocks.1.attn2.to_k.bias
transformer_blocks.1.attn2.to_k.weight
transformer_blocks.1.attn2.to_out.0.bias
transformer_blocks.1.attn2.to_out.0.weight
transformer_blocks.1.attn2.to_q.bias
transformer_blocks.1.attn2.to_q.weight
transformer_blocks.1.attn2.to_v.bias
transformer_blocks.1.attn2.to_v.weight
transformer_blocks.10.attn2.to_k.bias
transformer_blocks.10.attn2.to_k.weight
transformer_blocks.10.attn2.to_out.0.bias
transformer_blocks.10.attn2.to_out.0.weight
transformer_blocks.10.attn2.to_q.bias
transformer_blocks.10.attn2.to_q.weight
transformer_blocks.10.attn2.to_v.bias
transformer_blocks.10.attn2.to_v.weight
transformer_blocks.11.attn2.to_k.bias
transformer_blocks

In [25]:
for k in ckpt.keys():
    if 'attn2' in k:
        print(k)

pixart_dit.transformer_blocks.0.attn2.to_k.bias
pixart_dit.transformer_blocks.0.attn2.to_k.weight
pixart_dit.transformer_blocks.0.attn2.to_out.0.bias
pixart_dit.transformer_blocks.0.attn2.to_out.0.weight
pixart_dit.transformer_blocks.0.attn2.to_q.bias
pixart_dit.transformer_blocks.0.attn2.to_q.weight
pixart_dit.transformer_blocks.0.attn2.to_v.bias
pixart_dit.transformer_blocks.0.attn2.to_v.weight
pixart_dit.transformer_blocks.1.attn2.to_k.bias
pixart_dit.transformer_blocks.1.attn2.to_k.weight
pixart_dit.transformer_blocks.1.attn2.to_out.0.bias
pixart_dit.transformer_blocks.1.attn2.to_out.0.weight
pixart_dit.transformer_blocks.1.attn2.to_q.bias
pixart_dit.transformer_blocks.1.attn2.to_q.weight
pixart_dit.transformer_blocks.1.attn2.to_v.bias
pixart_dit.transformer_blocks.1.attn2.to_v.weight
pixart_dit.transformer_blocks.10.attn2.to_k.bias
pixart_dit.transformer_blocks.10.attn2.to_k.weight
pixart_dit.transformer_blocks.10.attn2.to_out.0.bias
pixart_dit.transformer_blocks.10.attn2.to_out.0

In [37]:
hybrid_ckpt = {}
for k, v in ckpt.items():
    if 'attn2' not in k:
        hybrid_ckpt[k] = v
    else:
        hybrid_ckpt[k] = old_pixart_ckpt[k[len('pixart_dit.'):]]

In [38]:
hybrid_ckpt.keys()

dict_keys(['encoder.blocks.0.attn.proj.bias', 'encoder.blocks.0.attn.proj.weight', 'encoder.blocks.0.attn.qkv.bias', 'encoder.blocks.0.attn.qkv.weight', 'encoder.blocks.0.mlp.fc1.bias', 'encoder.blocks.0.mlp.fc1.weight', 'encoder.blocks.0.mlp.fc2.bias', 'encoder.blocks.0.mlp.fc2.weight', 'encoder.blocks.0.norm1.bias', 'encoder.blocks.0.norm1.weight', 'encoder.blocks.0.norm2.bias', 'encoder.blocks.0.norm2.weight', 'encoder.blocks.1.attn.proj.bias', 'encoder.blocks.1.attn.proj.weight', 'encoder.blocks.1.attn.qkv.bias', 'encoder.blocks.1.attn.qkv.weight', 'encoder.blocks.1.mlp.fc1.bias', 'encoder.blocks.1.mlp.fc1.weight', 'encoder.blocks.1.mlp.fc2.bias', 'encoder.blocks.1.mlp.fc2.weight', 'encoder.blocks.1.norm1.bias', 'encoder.blocks.1.norm1.weight', 'encoder.blocks.1.norm2.bias', 'encoder.blocks.1.norm2.weight', 'encoder.blocks.10.attn.proj.bias', 'encoder.blocks.10.attn.proj.weight', 'encoder.blocks.10.attn.qkv.bias', 'encoder.blocks.10.attn.qkv.weight', 'encoder.blocks.10.mlp.fc1.bias

In [39]:
safetensors.torch.save_file(hybrid_ckpt, 
                      '/mnt/ceph_rbd/zbc/pixart/pixart_alpha256XL_2_resetx_hybrid.safetensors')

In [40]:
for k in hybrid_ckpt.keys():
    print(k)

encoder.blocks.0.attn.proj.bias
encoder.blocks.0.attn.proj.weight
encoder.blocks.0.attn.qkv.bias
encoder.blocks.0.attn.qkv.weight
encoder.blocks.0.mlp.fc1.bias
encoder.blocks.0.mlp.fc1.weight
encoder.blocks.0.mlp.fc2.bias
encoder.blocks.0.mlp.fc2.weight
encoder.blocks.0.norm1.bias
encoder.blocks.0.norm1.weight
encoder.blocks.0.norm2.bias
encoder.blocks.0.norm2.weight
encoder.blocks.1.attn.proj.bias
encoder.blocks.1.attn.proj.weight
encoder.blocks.1.attn.qkv.bias
encoder.blocks.1.attn.qkv.weight
encoder.blocks.1.mlp.fc1.bias
encoder.blocks.1.mlp.fc1.weight
encoder.blocks.1.mlp.fc2.bias
encoder.blocks.1.mlp.fc2.weight
encoder.blocks.1.norm1.bias
encoder.blocks.1.norm1.weight
encoder.blocks.1.norm2.bias
encoder.blocks.1.norm2.weight
encoder.blocks.10.attn.proj.bias
encoder.blocks.10.attn.proj.weight
encoder.blocks.10.attn.qkv.bias
encoder.blocks.10.attn.qkv.weight
encoder.blocks.10.mlp.fc1.bias
encoder.blocks.10.mlp.fc1.weight
encoder.blocks.10.mlp.fc2.bias
encoder.blocks.10.mlp.fc2.weigh