In [None]:
from PIL import Image
import torch
from diffusers import (
    StableDiffusionPipeline,
    T2IAdapter,
    StableDiffusionAdapterPipeline,
    MultiAdapter,
)
import random
import numpy as np

from matplotlib import pyplot

NEGATIVE_PROMPT = 'extra digit, fewer digits, cropped, worst quality, low quality, blurry, pixelated, low resolution, overexposed, underexposed, too dark, too bright, too blurry, too pixelated, too low resolution, too low quality, too high quality, too high resolution, too sharp, too clear, too focused, too contrasty, too saturated'

In [None]:
class args:
    def __init__(self) -> None:
        self.stable_diffusion_device = "cuda"
        self.controlnet_device = "cuda"

        self.num_diffusion_steps = 50

        self.use_skeleton = False

# replace with your paths
class dirs:
    def __init__(self) -> None:
        self.image_src = "PATH_TO_SOURCE_IMAGE"
        self.text_src = "PATHS_TO_TEXT_FILE"
        self.structure_src = [
            "PATHS_TO_EDGE_IMAGES"
        ]
        self.color_src = [
            "PATHS_TO_COLOR_IMAGES"
        ]
        self.out_dir = "PATH_TO_OUTPUT_DIR"

In [None]:
class CMCEditing:
    def __init__(self, args):
        self.args = args
        self.init_models()
        self.ref_image = None
        self.is_openai_available = False
        self.setup_seed(9999)

    def init_models(self):
        if self.args.controlnet_device == 'cpu':
            self.data_type = torch.float32
        else:
            self.data_type = torch.float16

        print('\033[1;33m' + "Initializing models...".center(50, '-') + '\033[0m')

        self.ldm = StableDiffusionPipeline.from_pretrained(
            "sd-legacy/stable-diffusion-v1-5",
            torch_dtype=self.data_type,
            custom_pipeline="lpw_stable_diffusion",
        ).to(self.args.stable_diffusion_device)
        self.tokenizer = self.ldm.tokenizer
        self.text_encoder = self.ldm.text_encoder
        self.ldm.safety_checker = lambda images, clip_input: (images, False)

        l2_variants = [
            "TencentARC/t2iadapter_sketch_sd15v2",
            "TencentARC/t2iadapter_openpose_sd14v1",
                       ]
        
        self.l2_adapter = T2IAdapter.from_pretrained(
            l2_variants[self.args.use_skeleton],
            varient="fp16",
            torch_dtype=self.data_type
        ).to(self.args.controlnet_device)

        self.color_adapter = T2IAdapter.from_pretrained(
            "TencentARC/t2iadapter_color_sd14v1", 
            varient="fp16",
            torch_dtype=self.data_type
        ).to(self.args.controlnet_device)
           
        self.t2i_pipeline_layer3 = StableDiffusionAdapterPipeline(
            vae=self.ldm.vae,
            text_encoder=self.ldm.text_encoder,
            tokenizer=self.ldm.tokenizer,
            unet=self.ldm.unet,
            adapter=MultiAdapter([self.l2_adapter,self.color_adapter]),
            scheduler=self.ldm.scheduler,
            safety_checker=self.ldm.safety_checker,
            feature_extractor=self.ldm.feature_extractor
            )
        
        print('\033[1;32m' + "Model initialization finished!".center(50, '-') + '\033[0m')

    def setup_seed(self,seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.backends.cudnn.deterministic = True
    
    def text_to_image_with_multiple_guidance(self, text, image):
        assert len(image) > 1, 'Need more than one guidance image'
        generated_image = self.t2i_pipeline_layer3(
            text,
            image,
            num_inference_steps = self.args.num_diffusion_steps,
            guidance_scale=7.5,
            adapter_conditioning_scale=[1,1],
            negative_prompt=NEGATIVE_PROMPT
            ).images[0]
        return generated_image


In [None]:
args = args()
dirs = dirs()
processor = CMCEditing(args)

if not os.path.exists(dirs.out_dir):
    os.makedirs(dirs.out_dir)

ref_img = Image.open(dirs.image_src).resize((512,512))
ref_name = os.path.basename(dirs.image_src).replace(".jpg","")

try:
    with open(os.path.join(dirs.text_src,ref_name+'.txt')) as f:
                image_caption = f.read()
except:
    raise ValueError('Text not found for:',dirs.image_src, " ,exiting...")


In [None]:
pyplot.figure(figsize=(len(dirs.color_src)*5,len(dirs.structure_src)*5))

for i, edge in enumerate(dirs.structure_src):

    edge_img = Image.open(edge).convert("L").resize((512,512))
    edge_name =  os.path.basename(edge).replace(".png","")

    pyplot.subplot(len(dirs.structure_src),len(dirs.color_src)+1,(i+1)*len(dirs.color_src))

    for j,color in enumerate(dirs.color_src):

        color_img = Image.open(color).resize((512,512))
        color_name = os.path.basename(color).replace(".png","")
        save_name = f"layer3_edited_i{ref_name},e{edge_name},C{color_name}.png"

        layer3_recon = processor.text_to_image_with_multiple_guidance(image_caption, [edge_img, color_img])
        layer3_recon.save(os.path.join(dirs.out_dir,save_name))
        
        pyplot.subplot(len(dirs.structure_src),len(dirs.color_src)+1,(i+1)*(j+2))
        pyplot.imshow(layer3_recon)
        pyplot.title(color_name)
        pyplot.axis("off")

pyplot.show()
