In [1]:
import torch
from typing import *
from diffusers.utils import load_image
from diffusers import StableDiffusionAdapterPipeline, T2IAdapter

model_name = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionAdapterPipeline.from_pretrained(model_name, torch_dtype=torch.float32).to('cuda')

adapter_ckpt = "./models/t2iadapter_seg_sd14v1.pth"
pipe.adapter = T2IAdapter(channels_in=int(3), 
                       block_out_channels=[320, 640, 1280, 1280][:4], 
                       num_res_blocks=2, 
                       kernel_size=1, 
                       res_block_skip=True, 
                       use_conv=False)
pipe.adapter.load_state_dict(torch.load(adapter_ckpt))
pipe.adapter = pipe.adapter.to('cuda')

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


In [2]:
@torch.no_grad()
def get_color_masks(image: torch.Tensor) -> Dict[Tuple[int], torch.Tensor]:
    h, w, c = image.shape
    assert c == 3
    
    img_2d = image.view((-1, 3))
    colors, freqs = torch.unique(img_2d, return_counts=True, dim=0)
    colors = colors[freqs >= h]
    color2mask = {}
    for color in colors:
        mask = (image == color).float().max(dim=-1).values
        color = color.cpu().numpy().tolist()
        color2mask[tuple(color)] = mask
    return color2mask
    
mask = load_image("./diffusers-t2i-adapter/motor.png")

prompt = ["A black Honda motorcycle parked in front of a garage"]

image = pipe(prompt, [mask, mask]).images[0]
image.save('test.jpg')

In [26]:
pipe.adapter.state_dict().keys()

odict_keys(['body.0.block1.weight', 'body.0.block1.bias', 'body.0.block2.weight', 'body.0.block2.bias', 'body.1.block1.weight', 'body.1.block1.bias', 'body.1.block2.weight', 'body.1.block2.bias', 'body.2.conv1.weight', 'body.2.conv1.bias', 'body.2.block1.weight', 'body.2.block1.bias', 'body.2.block2.weight', 'body.2.block2.bias', 'body.3.block1.weight', 'body.3.block1.bias', 'body.3.block2.weight', 'body.3.block2.bias', 'body.4.conv1.weight', 'body.4.conv1.bias', 'body.4.block1.weight', 'body.4.block1.bias', 'body.4.block2.weight', 'body.4.block2.bias', 'body.5.block1.weight', 'body.5.block1.bias', 'body.5.block2.weight', 'body.5.block2.bias', 'body.6.block1.weight', 'body.6.block1.bias', 'body.6.block2.weight', 'body.6.block2.bias', 'body.7.block1.weight', 'body.7.block1.bias', 'body.7.block2.weight', 'body.7.block2.bias', 'conv_in.weight', 'conv_in.bias'])

In [18]:
torch.load(adapter_ckpt).keys()

dict_keys(['body.0.block1.weight', 'body.0.block1.bias', 'body.0.block2.weight', 'body.0.block2.bias', 'body.1.block1.weight', 'body.1.block1.bias', 'body.1.block2.weight', 'body.1.block2.bias', 'body.2.in_conv.weight', 'body.2.in_conv.bias', 'body.2.block1.weight', 'body.2.block1.bias', 'body.2.block2.weight', 'body.2.block2.bias', 'body.3.block1.weight', 'body.3.block1.bias', 'body.3.block2.weight', 'body.3.block2.bias', 'body.4.in_conv.weight', 'body.4.in_conv.bias', 'body.4.block1.weight', 'body.4.block1.bias', 'body.4.block2.weight', 'body.4.block2.bias', 'body.5.block1.weight', 'body.5.block1.bias', 'body.5.block2.weight', 'body.5.block2.bias', 'body.6.block1.weight', 'body.6.block1.bias', 'body.6.block2.weight', 'body.6.block2.bias', 'body.7.block1.weight', 'body.7.block1.bias', 'body.7.block2.weight', 'body.7.block2.bias', 'conv_in.weight', 'conv_in.bias'])

In [27]:
!pip install .

Defaulting to user installation because normal site-packages is not writeable
Processing /n/home07/adamaraju/fasrc/diffusers-t2i-adapter
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h    Preparing wheel metadata ... [?25ldone
Building wheels for collected packages: diffusers
  Building wheel for diffusers (PEP 517) ... [?25ldone
[?25h  Created wheel for diffusers: filename=diffusers-0.16.0.dev0-py3-none-any.whl size=877380 sha256=d788a3470d51ba6b3ed65cda3f1195ec3a9cdb4ab531597c5722e55083c06a44
  Stored in directory: /n/home07/adamaraju/.cache/pip/wheels/3c/8b/ec/3a7ad4250255f19881dec9cb1fb244751e1c61a1c20f76a2d9
Successfully built diffusers
Installing collected packages: diffusers
  Attempting uninstall: diffusers
    Found existing installation: diffusers 0.16.0.dev0
    Uninstalling diffusers-0.16.0.dev0:
      Successfully uninstalled diffusers-0.16.0.dev0
Successfully installed diffusers-0.16.0.dev0
