In [1]:
import diffusers
from diffusers import UNet2DConditionModel
from PIL import Image
import sys
from transformers import CLIPVisionModel, CLIPImageProcessor

sys.path.append("/home/aihao/workspace/StableDiffusionReferenceOnly/src")
from stable_diffusion_joint_control.pipelines.stable_diffusion_reference_only_pipeline import (
    StableDiffusionReferenceOnlyPipeline,
)
from stable_diffusion_joint_control.models.dobule_condition_unet import (
    UNet2DDobuleConditionModel,
)

In [2]:
sd_pipe=diffusers.StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1"
)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [3]:
new_config = dict(sd_pipe.unet.config)

In [4]:
new_config

{'sample_size': 96,
 'in_channels': 4,
 'out_channels': 4,
 'center_input_sample': False,
 'flip_sin_to_cos': True,
 'freq_shift': 0,
 'down_block_types': ['CrossAttnDownBlock2D',
  'CrossAttnDownBlock2D',
  'CrossAttnDownBlock2D',
  'DownBlock2D'],
 'mid_block_type': 'UNetMidBlock2DCrossAttn',
 'up_block_types': ['UpBlock2D',
  'CrossAttnUpBlock2D',
  'CrossAttnUpBlock2D',
  'CrossAttnUpBlock2D'],
 'only_cross_attention': False,
 'block_out_channels': [320, 640, 1280, 1280],
 'layers_per_block': 2,
 'downsample_padding': 1,
 'mid_block_scale_factor': 1,
 'act_fn': 'silu',
 'norm_num_groups': 32,
 'norm_eps': 1e-05,
 'cross_attention_dim': 1024,
 'transformer_layers_per_block': 1,
 'encoder_hid_dim': None,
 'encoder_hid_dim_type': None,
 'attention_head_dim': [5, 10, 20, 20],
 'num_attention_heads': None,
 'dual_cross_attention': False,
 'use_linear_projection': True,
 'class_embed_type': None,
 'addition_embed_type': None,
 'addition_time_embed_dim': None,
 'num_class_embeds': None,
 

In [5]:
new_config['cross_attention_dim']=1664

In [6]:
new_config["_name_or_path"] = "unet"

In [7]:
new_config

{'sample_size': 96,
 'in_channels': 4,
 'out_channels': 4,
 'center_input_sample': False,
 'flip_sin_to_cos': True,
 'freq_shift': 0,
 'down_block_types': ['CrossAttnDownBlock2D',
  'CrossAttnDownBlock2D',
  'CrossAttnDownBlock2D',
  'DownBlock2D'],
 'mid_block_type': 'UNetMidBlock2DCrossAttn',
 'up_block_types': ['UpBlock2D',
  'CrossAttnUpBlock2D',
  'CrossAttnUpBlock2D',
  'CrossAttnUpBlock2D'],
 'only_cross_attention': False,
 'block_out_channels': [320, 640, 1280, 1280],
 'layers_per_block': 2,
 'downsample_padding': 1,
 'mid_block_scale_factor': 1,
 'act_fn': 'silu',
 'norm_num_groups': 32,
 'norm_eps': 1e-05,
 'cross_attention_dim': 1664,
 'transformer_layers_per_block': 1,
 'encoder_hid_dim': None,
 'encoder_hid_dim_type': None,
 'attention_head_dim': [5, 10, 20, 20],
 'num_attention_heads': None,
 'dual_cross_attention': False,
 'use_linear_projection': True,
 'class_embed_type': None,
 'addition_embed_type': None,
 'addition_time_embed_dim': None,
 'num_class_embeds': None,
 

In [8]:
unet = UNet2DDobuleConditionModel.from_config(new_config)

In [9]:
unet

UNet2DDobuleConditionModel(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (controlnet_cond_embedding): ControlNetConditioningEmbedding(
    (conv_in): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (blocks): ModuleList(
      (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): Conv2d(32, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (4): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): Conv2d(96, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (conv_out): Conv2d(25

In [10]:
image_encoder = CLIPVisionModel.from_pretrained(
    "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [11]:
clip_image_processor = CLIPImageProcessor.from_pretrained(
    "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
)

In [12]:
my_pipeline = StableDiffusionReferenceOnlyPipeline(
    sd_pipe.vae, image_encoder, clip_image_processor, unet, sd_pipe.scheduler
)

In [13]:
my_pipeline.save_pretrained(
    "/home/aihao/workspace/DeepLearningContent/models/sd_ro/init"
)

In [19]:
print(my_pipeline.unet.controlnet_cond_embedding.state_dict())

OrderedDict([('conv_in.weight', tensor([[[[-0.1058,  0.1352,  0.1860],
          [-0.1434, -0.0410, -0.0316],
          [-0.1231, -0.0663, -0.1309]],

         [[-0.0760, -0.0124, -0.1657],
          [ 0.1889, -0.0391, -0.0769],
          [ 0.0940,  0.1285,  0.0888]],

         [[-0.1047, -0.0279,  0.0548],
          [ 0.1347, -0.0195, -0.1463],
          [-0.1547,  0.1799,  0.0938]]],


        [[[ 0.0114,  0.1891,  0.1718],
          [-0.1248,  0.1773, -0.1163],
          [-0.1482, -0.0503, -0.0398]],

         [[ 0.1719, -0.0027,  0.0023],
          [ 0.1358, -0.0985, -0.1177],
          [ 0.0454, -0.1468,  0.1322]],

         [[-0.1124,  0.1706,  0.0172],
          [-0.0410,  0.1312, -0.0744],
          [-0.1764, -0.0704,  0.1370]]],


        [[[ 0.0777, -0.0686,  0.1272],
          [-0.0374,  0.0384, -0.0537],
          [ 0.1043, -0.1520, -0.1174]],

         [[-0.0819, -0.0611,  0.0649],
          [-0.1411, -0.0771,  0.0908],
          [-0.1658,  0.1142, -0.0889]],

         [[-

In [None]:
my_pipeline=my_pipeline.to("cuda")

In [None]:
my_pipeline(
    Image.open(
        "/home/aihao/workspace/DeepLearningContent/datasets/images/data/pixiv/狗脸脸dogface/Illustration/_狗脸脸dogface - pixiv__Illustration_五等分的花嫁,Miku Nakano,同人游戏,wedding dress,miko clothing,五等分的抢婚_88734377_p000.jpg"
    ),
    Image.open(
        "/home/aihao/workspace/DeepLearningContent/datasets/images/data/pixiv/狗脸脸dogface/Illustration/_狗脸脸dogface - pixiv__Illustration_Genshin Impact,barefoot,Kokomi,Sangonomiya Kokomi,girl,睡衣,underwater,sleep-wear,bellybutton,Genshin Impact 10000+ bookmarks_96862960_p000.jpg"
    ),
).images[0]