In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import diffusers
import matplotlib.pyplot as plt
import torch
from PIL import Image

from pipeline import (
    FireFlowEditFluxPipeline,
    FlowEditFluxPipeline,
    FTEditFluxPipeline,
    MultiTurnEditFluxPipeline,
    RFInversionEditFluxPipeline,
    RFSolverEditFluxPipeline,
)
from processor.ft_editing_attn_processor import FluxAttentionReplace, P2PFlux_JointAttnProcessor2_0

diffusers.utils.logging.set_verbosity_error()

prompt_sequence = [
    {
        "prompt": "a square cake with orange frosting on a wooden plate", 
        "editing_type_id": "0"
    },
    {
        "prompt": "a square cake with orange frosting on a glass plate",
        "editing_type_id": "7"
    },
    {
        "prompt": "a square cake with orange frosting and chocolate sprinkles on a glass plate",
        "editing_type_id": "2",
    },
    {
        "prompt": "a square cake with green frosting and chocolate sprinkles on a glass plate",
        "editing_type_id": "6",
    },
    {
        "prompt": "a square cake with green frosting and chocolate sprinkles on a marble plate",
        "editing_type_id": "7",
    },
    {
        "prompt": "a square cake with green frosting and strawberry slices on a marble plate",
        "editing_type_id": "4",
    },
]

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
target_pipe = [
    ("RFInversion", RFInversionEditFluxPipeline, 15),
    ("RFSolver", RFSolverEditFluxPipeline, 15),
    ("FireFlow", FireFlowEditFluxPipeline, 15),
    ("FlowEdit", FlowEditFluxPipeline, 15),
    ("FTEdit", FTEditFluxPipeline, 15),
    ("MultiTurn", MultiTurnEditFluxPipeline, 15),
    
]

prompt_sequence = [prompt["prompt"] for prompt in prompt_sequence]
source_img = "data/PIE-Bench_v1/annotation_images/0_random_140/000000000001.jpg"
source_prompt = "a round cake with orange frosting on a wooden plate"

for pipe_name, pipe, num_inference_steps in target_pipe:
    print(f"Running {pipe_name}...")
    pipe = pipe.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
    pipe.to("cuda")
    pipe.set_progress_bar_config(disable=True)
    
    if pipe_name in ["RFSolver", "FireFlow"]:
        pipe.add_processor(after_layer=0, before_layer=37, filter_name="single_transformer_blocks")
    elif pipe_name == "MultiTurn":
        pipe.add_processor(after_layer=0, before_layer=37, filter_name=["single_transformer_blocks", "transformer_blocks"])
    elif pipe_name == "FTEdit":
        controller = FluxAttentionReplace(
            prompts=["", ""],  # dummy prompts
            num_steps=num_inference_steps,
            attn_ratio=0.15,
            num_att_layers=37,
        )
        pipe.add_processor(after_layer=37, before_layer=0, filter_name="transformer_blocks", target_processor=P2PFlux_JointAttnProcessor2_0, controller=controller)

    kwargs = {"guidance_scale": 3.5}
    if pipe_name == "RFInversion":
        kwargs["stop_timestep"] = 0.25
    elif pipe_name in ["RFSolver", "FireFlow"]:
        kwargs["with_second_order"] = True
        kwargs["inject_step"] = 2
    elif pipe_name == "MultiTurn":
        kwargs["stop_timestep"] = 0.25
        kwargs["with_second_order"] = True
        kwargs["inject_step"] = 0
        kwargs["attn_guidance_start_block"] = 11
    elif pipe_name == "FlowEdit":
        del kwargs["guidance_scale"]
        kwargs["interpolate_start_step"] = 0
        kwargs["interpolate_end_step"] = 15
        kwargs["source_guidance_scale"] = 1.5
        kwargs["target_guidance_scale"] = 5.5
    elif pipe_name == "FTEdit":
        kwargs["fixed_point_steps"] = 3
        kwargs["ly_ratio"] = 1.0

    image_list = pipe.multiturn(
        source_img,
        source_prompt,
        prompt_sequence,
        num_inference_steps=num_inference_steps,
        **kwargs,
    )
    image_list = [Image.open(source_img)] + image_list
    plt.figure(figsize=(10, 4))
    for image_idx, image in enumerate(image_list):
        plt.subplot(1, len(image_list), image_idx + 1)
        plt.imshow(image)
        plt.axis("off")
    plt.show()
        
    del pipe
    torch.cuda.empty_cache()

Running RFInversion...


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 69.66it/s]
Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 31.73it/s]it/s]
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00,  9.86it/s]
