### **MasaCtrl + SDXL**

In [1]:
import torch
import torch.nn.functional as nnf
import numpy as np
from PIL import Image
from tqdm import tqdm
import ast
import copy

from diffusers import StableDiffusionXLPipeline, DDIMScheduler
from diffusers.models.attention_processor import AttnProcessor2_0
from datasets import load_dataset

device = "cuda:0" if torch.cuda.is_available() else "cpu"
dtype = torch.float16

MasaCtrl Attnetion Processor Ï†ïÏùò

In [2]:
class MasaCtrlAttnProcessor(torch.nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.is_image_generation = False  # TrueÎ©¥ Target ÏÉùÏÑ± Îã®Í≥Ñ
        self.ref_k = None
        self.ref_v = None

    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs):
        batch_size, sequence_length, _ = hidden_states.shape
        attention_mask = None
        
        # Query ÏÉùÏÑ±
        query = attn.to_q(hidden_states)

        # Cross Attention Ï≤¥ÌÅ¨
        is_cross = encoder_hidden_states is not None
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        # Head Ï∞®Ïõê Î≥ÄÌôò
        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        #  Self-AttentionÏùº ÎïåÎßå ÏûëÎèô
        if not is_cross:
            # 1. Source Pass (KV Ï†ÄÏû•)
            if not self.is_image_generation:
                self.ref_k = key.detach().clone()
                self.ref_v = value.detach().clone()
            # 2. Target Pass (Ref KVÏôÄ Í≤∞Ìï©)
            else:
                if self.ref_k is not None and self.ref_v is not None:
                    key = torch.cat([self.ref_k, key], dim=2)
                    value = torch.cat([self.ref_v, value], dim=2)

        # Attention Í≥ÑÏÇ∞
        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)
        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)

        return hidden_states

MasaCtrl Controller

In [3]:
class MasaCtrlController:
    def __init__(self, pipe, layer_names_to_replace=None):
        self.pipe = pipe
        self.processors = {}
        
        # Up Block + Mid Block Self-AttentionÎßå Ï†ÅÏö©
        if layer_names_to_replace is None:
            layer_names_to_replace = []
            for name in pipe.unet.attn_processors.keys():
                if "attn1" in name and ("up_blocks" in name or "mid_block" in name):
                    layer_names_to_replace.append(name)
        
        # Processor ÍµêÏ≤¥
        for name in pipe.unet.attn_processors.keys():
            if name in layer_names_to_replace:
                self.processors[name] = MasaCtrlAttnProcessor(name)
            else:
                self.processors[name] = AttnProcessor2_0()
        
        pipe.unet.set_attn_processor(self.processors)

    def set_mode(self, mode):
        is_gen = (mode == 'target')
        for processor in self.processors.values():
            if isinstance(processor, MasaCtrlAttnProcessor):
                processor.is_image_generation = is_gen

    def clear_store(self):
        for processor in self.processors.values():
            if isinstance(processor, MasaCtrlAttnProcessor):
                processor.ref_k = None
                processor.ref_v = None


SDXL Î™®Îç∏ + MasaCtrl

In [4]:
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")

print("SDXL + MasaCtrl Î°úÎìú")
pipe = StableDiffusionXLPipeline.from_pretrained(
    model_id, 
    scheduler=scheduler, 
    torch_dtype=dtype, 
    variant="fp16", 
    use_safetensors=True
).to(device)

controller = MasaCtrlController(pipe)
pipe.enable_vae_tiling()
pipe.enable_model_cpu_offload()

SDXL + MasaCtrl Î°úÎìú


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


In [5]:
def update_position(dataset, offset, pos):
    for key in dataset['editing'].keys():
        if dataset['editing'][key]["position"] >= pos: 
            dataset['editing'][key]["position"] += offset
    return dataset

def generate_editing_prompt(dataset):
    original_prompt = dataset['source_prompt'].replace("[", "").replace("]", "")
    editing_prompts = original_prompt
    if 'editing' not in dataset: 
        return original_prompt
    
    keys = list(dataset['editing'])
    for word in keys:
        word = str(word)
        data = dataset['editing'][word]
        position, edit_type, action_word = data["position"], data["edit_type"], data["action"]
        word_count = word.count(" ") + 1
        
        if action_word == "+":
            current_sentence = editing_prompts.split()
            current_sentence.insert(position, f"{word}")
            editing_prompts = " ".join(current_sentence)
            dataset = update_position(dataset, word_count, position)
        elif action_word == "-":
            editing_prompts = editing_prompts.replace(word+" ", "", 1)
            dataset = update_position(dataset, -word_count, position)
        else:
            editing_prompts = editing_prompts.replace(action_word, word, 1)
            dataset = update_position(dataset, word_count - action_word.count(" ") - 1, position)
        del dataset['editing'][word]
    return editing_prompts

In [None]:
print("PIE-Bench++ Îç∞Ïù¥ÌÑ∞ÏÖã Î°úÎìú...")
hf_dataset = load_dataset("UB-CVML-Group/PIE_Bench_pp", "0_random_140", split="V1")

success_count = 0
TARGET_INDICES = [0, 2, 3, 13, 25, 40, 48, 58, 68, 102, 135]

for idx in TARGET_INDICES:
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        
    try:
        data = hf_dataset[idx]
        source_image = data['image'].convert("RGB").resize((1024, 1024))
        source_prompt = data['source_prompt']
        
        # Ìé∏Ïßë Ï†ïÎ≥¥ ÌååÏã±
        edit_action = ast.literal_eval(data['edit_action']) if isinstance(data['edit_action'], str) else data['edit_action']
        dataset_info = {'source_prompt': source_prompt, 'editing': copy.deepcopy(edit_action)}
        target_prompt = generate_editing_prompt(dataset_info)
        
        print(f"ÏõêÎ≥∏: {source_prompt[:60]}...")
        print(f"ÌÉÄÍ≤ü: {target_prompt[:60]}...")
        
        generator = torch.Generator(device).manual_seed(42 + idx)
        
        # 1. Source Pass (text2img - Íµ¨Ï°∞ Ï†ÄÏû•)
        controller.set_mode('source')
        controller.clear_store()
        
        print("Source Pass (Íµ¨Ï°∞ Ï†ÄÏû•)...")
        source_result = pipe(
            prompt=source_prompt,
            negative_prompt="blurry, low quality",
            num_inference_steps=25,
            guidance_scale=5.0,  
            width=1024,
            height=1024,
            generator=generator
        ).images[0]
        
        source_result.save(f"masactrl_source_{idx}.jpg")
        print(f"Source Ï†ÄÏû•: masactrl_source_{idx}.jpg")
        
        # 2. Target Pass (text2img - Ìé∏Ïßë Ï†ÅÏö©)
        controller.set_mode('target')
        
        print("Target Pass (Ìé∏Ïßë ÏÉùÏÑ±)...")
        target_result = pipe(
            prompt=target_prompt,
            negative_prompt="blurry, low quality",
            num_inference_steps=35,
            guidance_scale=7.5,
            width=1024,
            height=1024,
            generator=generator  # ÎèôÏùº seedÎ°ú Íµ¨Ï°∞ Ïú†ÏßÄ
        ).images[0]
        
        target_result.save(f"masactrl_sdxl_{idx}.jpg")
        print(f"masactrl_sdxl_{idx}.jpg ÏÉùÏÑ±")
        success_count += 1
        
    except Exception as e:
        print(f"Ïò§Î•ò: {str(e)}")
        import traceback
        traceback.print_exc()
        continue

PIE-Bench++ Îç∞Ïù¥ÌÑ∞ÏÖã Î°úÎìú...

üîÑ [1/11] Ïù∏Îç±Ïä§ 0 Ï≤òÎ¶¨...
ÏõêÎ≥∏: a slanted mountain bicycle on the road in front of a buildin...
ÌÉÄÍ≤ü: a slanted rusty mountain motorcycle in front of a fence...
Source Pass (Íµ¨Ï°∞ Ï†ÄÏû•)...


  0%|          | 0/25 [00:00<?, ?it/s]

   ‚úÖ Source Ï†ÄÏû•: masactrl_source_0.jpg
Target Pass (Ìé∏Ïßë ÏÉùÏÑ±)...


  0%|          | 0/35 [00:00<?, ?it/s]

masactrl_sdxl_0.jpg ÏÉùÏÑ±

üîÑ [2/11] Ïù∏Îç±Ïä§ 2 Ï≤òÎ¶¨...
ÏõêÎ≥∏: a cat sitting on a wooden chair...
ÌÉÄÍ≤ü: a red dog with flowers in mouth standing on a metal chair...
Source Pass (Íµ¨Ï°∞ Ï†ÄÏû•)...


  0%|          | 0/25 [00:00<?, ?it/s]

   ‚úÖ Source Ï†ÄÏû•: masactrl_source_2.jpg
Target Pass (Ìé∏Ïßë ÏÉùÏÑ±)...


  0%|          | 0/35 [00:00<?, ?it/s]

masactrl_sdxl_2.jpg ÏÉùÏÑ±

üîÑ [3/11] Ïù∏Îç±Ïä§ 3 Ï≤òÎ¶¨...
ÏõêÎ≥∏: blue light, a black and white cat is playing with a flower...
ÌÉÄÍ≤ü: blue light, a black and white dog is playing with a yellow b...
Source Pass (Íµ¨Ï°∞ Ï†ÄÏû•)...


  0%|          | 0/25 [00:00<?, ?it/s]

   ‚úÖ Source Ï†ÄÏû•: masactrl_source_3.jpg
Target Pass (Ìé∏Ïßë ÏÉùÏÑ±)...


  0%|          | 0/35 [00:00<?, ?it/s]

masactrl_sdxl_3.jpg ÏÉùÏÑ±

üîÑ [4/11] Ïù∏Îç±Ïä§ 13 Ï≤òÎ¶¨...
ÏõêÎ≥∏: three white dumplings on brown wooden bowl...
ÌÉÄÍ≤ü: three white cupcakes on black metal wooden bowl...
Source Pass (Íµ¨Ï°∞ Ï†ÄÏû•)...


  0%|          | 0/25 [00:00<?, ?it/s]

   ‚úÖ Source Ï†ÄÏû•: masactrl_source_13.jpg
Target Pass (Ìé∏Ïßë ÏÉùÏÑ±)...


  0%|          | 0/35 [00:00<?, ?it/s]

masactrl_sdxl_13.jpg ÏÉùÏÑ±

üîÑ [5/11] Ïù∏Îç±Ïä§ 25 Ï≤òÎ¶¨...
ÏõêÎ≥∏: a woman with blue hair wearing a shirt...
ÌÉÄÍ≤ü: a woman with red hair wearing a coat...
Source Pass (Íµ¨Ï°∞ Ï†ÄÏû•)...


  0%|          | 0/25 [00:00<?, ?it/s]

   ‚úÖ Source Ï†ÄÏû•: masactrl_source_25.jpg
Target Pass (Ìé∏Ïßë ÏÉùÏÑ±)...


  0%|          | 0/35 [00:00<?, ?it/s]

masactrl_sdxl_25.jpg ÏÉùÏÑ±

üîÑ [6/11] Ïù∏Îç±Ïä§ 40 Ï≤òÎ¶¨...
ÏõêÎ≥∏: a cat sitting with beads collar...
ÌÉÄÍ≤ü: a orange cat sitting with cotton collar putting hands up...
Source Pass (Íµ¨Ï°∞ Ï†ÄÏû•)...


  0%|          | 0/25 [00:00<?, ?it/s]

   ‚úÖ Source Ï†ÄÏû•: masactrl_source_40.jpg
Target Pass (Ìé∏Ïßë ÏÉùÏÑ±)...


  0%|          | 0/35 [00:00<?, ?it/s]

masactrl_sdxl_40.jpg ÏÉùÏÑ±

üîÑ [7/11] Ïù∏Îç±Ïä§ 48 Ï≤òÎ¶¨...
ÏõêÎ≥∏: sunset over a field with clouds and a bright sky...
ÌÉÄÍ≤ü: sunset over a field with pink stars and a dark sky...
Source Pass (Íµ¨Ï°∞ Ï†ÄÏû•)...


  0%|          | 0/25 [00:00<?, ?it/s]

   ‚úÖ Source Ï†ÄÏû•: masactrl_source_48.jpg
Target Pass (Ìé∏Ïßë ÏÉùÏÑ±)...


  0%|          | 0/35 [00:00<?, ?it/s]

masactrl_sdxl_48.jpg ÏÉùÏÑ±

üîÑ [8/11] Ïù∏Îç±Ïä§ 58 Ï≤òÎ¶¨...
ÏõêÎ≥∏: purple tulips in vase...
ÌÉÄÍ≤ü: yellow roses in blue vase...
Source Pass (Íµ¨Ï°∞ Ï†ÄÏû•)...


  0%|          | 0/25 [00:00<?, ?it/s]

   ‚úÖ Source Ï†ÄÏû•: masactrl_source_58.jpg
Target Pass (Ìé∏Ïßë ÏÉùÏÑ±)...


  0%|          | 0/35 [00:00<?, ?it/s]

masactrl_sdxl_58.jpg ÏÉùÏÑ±

üîÑ [9/11] Ïù∏Îç±Ïä§ 68 Ï≤òÎ¶¨...
ÏõêÎ≥∏: photo of horses running in the woods...
ÌÉÄÍ≤ü: painting of white unicorns standing in the woods...
Source Pass (Íµ¨Ï°∞ Ï†ÄÏû•)...


  0%|          | 0/25 [00:00<?, ?it/s]

   ‚úÖ Source Ï†ÄÏû•: masactrl_source_68.jpg
Target Pass (Ìé∏Ïßë ÏÉùÏÑ±)...


  0%|          | 0/35 [00:00<?, ?it/s]

masactrl_sdxl_68.jpg ÏÉùÏÑ±

üîÑ [10/11] Ïù∏Îç±Ïä§ 102 Ï≤òÎ¶¨...
ÏõêÎ≥∏: a rose on the rail...
ÌÉÄÍ≤ü: a white tulip on the rail...
Source Pass (Íµ¨Ï°∞ Ï†ÄÏû•)...


  0%|          | 0/25 [00:00<?, ?it/s]

   ‚úÖ Source Ï†ÄÏû•: masactrl_source_102.jpg
Target Pass (Ìé∏Ïßë ÏÉùÏÑ±)...


  0%|          | 0/35 [00:00<?, ?it/s]

masactrl_sdxl_102.jpg ÏÉùÏÑ±

üîÑ [11/11] Ïù∏Îç±Ïä§ 135 Ï≤òÎ¶¨...
ÏõêÎ≥∏: the two people are standing on rocks with a fish...
ÌÉÄÍ≤ü: the two old people are sitting on boat with a fish...
Source Pass (Íµ¨Ï°∞ Ï†ÄÏû•)...


  0%|          | 0/25 [00:00<?, ?it/s]

   ‚úÖ Source Ï†ÄÏû•: masactrl_source_135.jpg
Target Pass (Ìé∏Ïßë ÏÉùÏÑ±)...


  0%|          | 0/35 [00:00<?, ?it/s]

masactrl_sdxl_135.jpg ÏÉùÏÑ±
