In [1]:
import diffusers
from diffusers import UNet2DConditionModel
from PIL import Image
import sys
from transformers import CLIPVisionModel, CLIPImageProcessor
import torch

sys.path.append("/home/aihao/workspace/StableDiffusionReferenceOnly/src")
from stable_diffusion_reference_only.pipelines.stable_diffusion_reference_only_pipeline import (
    StableDiffusionReferenceOnlyPipeline,
)
from stable_diffusion_reference_only.models.unet_2d_dobule_condition import (
    UNet2DDobuleConditionModel,
)
from diffusers.schedulers import DDPMScheduler
import json

In [2]:
unet_config_path = "/home/aihao/workspace/StableDiffusionReferenceOnly/src/stable_diffusion_reference_only/models/unet-2-1.json"
pretrained_unet_path = "stabilityai/stable-diffusion-2-1"
pretrained_image_encoder_path = "openai/clip-vit-large-patch14"

# unet_config_path = "/home/aihao/workspace/StableDiffusionReferenceOnly/src/stable_diffusion_reference_only/models/unet_xl-base-1.0.json"
# pretrained_unet_path = "stabilityai/stable-diffusion-xl-base-1.0"
# pretrained_image_encoder_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"

In [3]:
with open(unet_config_path) as f:
    unet_config = json.load(f)

In [4]:
unet_config

{'_class_name': 'UNet2DDobuleConditionModel',
 '_diffusers_version': '0.10.0.dev0',
 'act_fn': 'silu',
 'attention_head_dim': [5, 10, 20, 20],
 'block_out_channels': [320, 640, 1280, 1280],
 'center_input_sample': False,
 'cross_attention_dim': 1024,
 'down_block_types': ['CrossAttnDownBlock2D',
  'CrossAttnDownBlock2D',
  'CrossAttnDownBlock2D',
  'DownBlock2D'],
 'downsample_padding': 1,
 'dual_cross_attention': False,
 'flip_sin_to_cos': True,
 'freq_shift': 0,
 'in_channels': 4,
 'layers_per_block': 2,
 'mid_block_scale_factor': 1,
 'norm_eps': 1e-05,
 'norm_num_groups': 32,
 'num_class_embeds': None,
 'only_cross_attention': False,
 'out_channels': 4,
 'sample_size': 96,
 'up_block_types': ['UpBlock2D',
  'CrossAttnUpBlock2D',
  'CrossAttnUpBlock2D',
  'CrossAttnUpBlock2D'],
 'use_linear_projection': True,
 'upcast_attention': True}

In [3]:
pretrained_unet = UNet2DConditionModel.from_pretrained(
    pretrained_unet_path, subfolder="unet"
)

In [4]:
unet = UNet2DDobuleConditionModel.from_config(pretrained_unet.config)

In [5]:
unet

UNet2DDobuleConditionModel(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (controlnet_cond_embedding): ControlNetConditioningEmbedding(
    (conv_in): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (blocks): ModuleList(
      (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): Conv2d(32, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (4): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): Conv2d(96, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (conv_out): Conv2d(25

In [6]:
unet_parameters = unet.state_dict()
pretrained_unet_parameters = pretrained_unet.state_dict()
for key in unet_parameters:
    if key in pretrained_unet_parameters:
        if unet_parameters[key].shape == pretrained_unet_parameters[key].shape:
            unet_parameters[key] = pretrained_unet_parameters[key]
        elif unet_parameters[key].shape < pretrained_unet_parameters[key].shape:
            print(key)
            unet_parameters[key] = pretrained_unet_parameters[key][
                0 : unet_parameters[key].shape[0],
                0 : unet_parameters[key].shape[1],
            ]
        else:
            print(key)
            unet_parameters[key] = torch.nn.functional.pad(
                pretrained_unet_parameters[key],
                (
                    0,
                    unet_parameters[key].shape[1]
                    - pretrained_unet_parameters[key].shape[1],
                    0,
                    0,
                ),
            )

In [7]:
unet.load_state_dict(unet_parameters)

<All keys matched successfully>

In [9]:
for key in unet.state_dict():
    if key in pretrained_unet.state_dict():
        print(unet.state_dict()[key] == pretrained_unet.state_dict()[key])
    else:
        print(key)

tensor([[[[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]]],


        [[[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]]],


        [[[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True,

In [10]:
# vae = diffusers.AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
vae = diffusers.AutoencoderKL.from_pretrained(
    "stabilityai/stable-diffusion-2-1", subfolder="vae"
)

In [11]:
image_encoder = CLIPVisionModel.from_pretrained(pretrained_image_encoder_path)

In [12]:
clip_image_processor = CLIPImageProcessor.from_pretrained(pretrained_image_encoder_path)

In [13]:
scheduler = DDPMScheduler.from_pretrained(pretrained_unet_path, subfolder="scheduler")

In [14]:
pipe = StableDiffusionReferenceOnlyPipeline(
    vae, image_encoder, clip_image_processor, unet, scheduler
)

In [16]:
pipe.save_pretrained(
    "/home/aihao/workspace/DeepLearningContent/models/sd_reference_only/init_0.1"
)