In [None]:
import gc
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import diffusers
import matplotlib.pyplot as plt
import torch
from PIL import Image

from pipeline import (
    DNAEditFluxPipeline,
    FireFlowEditFluxPipeline,
    FlowEditFluxPipeline,
    FTEditFluxPipeline,
    MultiTurnEditFluxPipeline,
    RFInversionEditFluxPipeline,
    RFSolverEditFluxPipeline,
)
from processor.ft_editing_attn_processor import FluxAttentionReplace, P2PFlux_JointAttnProcessor2_0

diffusers.utils.logging.set_verbosity_error()

prompt_sequence = [
    {
        "prompt": "a round cake with frosting on a wooden plate",
        "editing_type_id": "0",
    },
    {
        "prompt": "a square cake with frosting on a wooden plate",
        "editing_type_id": "6",
    },
    {
        "prompt": "a square cake with frosting and chocolate sprinkles on a wooden plate",
        "editing_type_id": "6",
    },
    {
        "prompt": "a square cake with with frosting, chocolate sprinkles, and strawberry slices on a wooden plate",
        "editing_type_id": "4",
    },
]
prompt_sequence = [prompt["prompt"] for prompt in prompt_sequence]

In [None]:
def h_concat_pil_images(images):
    widths, heights = zip(*(i.size for i in images), strict=True)
    total_width = sum(widths)
    max_height = max(heights)

    new_image = Image.new("RGB", (total_width, max_height))

    x_offset = 0
    for img in images:
        new_image.paste(img, (x_offset, 0))
        x_offset += img.width
    return new_image

In [None]:
target_pipe = [
    ("RFInversion", RFInversionEditFluxPipeline, 28),
    ("RFSolver", RFSolverEditFluxPipeline, 28),
    ("FireFlow", FireFlowEditFluxPipeline, 28),
    ("FlowEdit", FlowEditFluxPipeline, 28),
    ("FTEdit", FTEditFluxPipeline, 28),
    ("DNAEdit", DNAEditFluxPipeline, 28),
    ("MultiTurn", MultiTurnEditFluxPipeline, 28),
]

source_img = "assets/sources/cake.jpg"
source_prompt = "a round cake with frosting on a wooden plate besides a cup"

for pipe_name, pipe, num_inference_steps in target_pipe:
    print(f"Running {pipe_name}...")
    pipe = pipe.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
    pipe.to("cuda")
    pipe.set_progress_bar_config(disable=True)

    if pipe_name in ["RFSolver", "FireFlow"]:
        pipe.add_processor(after_layer=0, filter_name="single_transformer_blocks")
    elif pipe_name == "MultiTurn":
        pipe.add_processor(
            after_layer=0, before_layer=37, filter_name=["single_transformer_blocks", "transformer_blocks"]
        )
    elif pipe_name == "FTEdit":
        controller = FluxAttentionReplace(
            prompts=["", ""],  # dummy prompts
            num_steps=num_inference_steps,
            attn_ratio=0.15,
            num_att_layers=37,
        )
        pipe.add_processor(
            after_layer=0,
            before_layer=36,
            filter_name="transformer_blocks",
            target_processor=P2PFlux_JointAttnProcessor2_0,
            controller=controller,
        )

    kwargs = {}
    if pipe_name == "RFInversion":
        kwargs["stop_timestep"] = 0.25
        kwargs["guidance_scale"] = 3.5
    elif pipe_name == "RFSolver":
        kwargs["with_second_order"] = True
        kwargs["inject_step"] = 3
        kwargs["guidance_scale"] = 2
    elif pipe_name == "FireFlow":
        kwargs["inject_step"] = 3
        kwargs["with_second_order"] = True
        kwargs["guidance_scale"] = 2
    elif pipe_name == "MultiTurn":
        kwargs["stop_timestep"] = 0.25
        kwargs["with_second_order"] = True
        kwargs["inject_step"] = 0
        kwargs["attn_guidance_start_block"] = 11
        kwargs["guidance_scale"] = 3.5
    elif pipe_name == "FlowEdit":
        kwargs["interpolate_start_step"] = 0
        kwargs["interpolate_end_step"] = 24
        kwargs["source_guidance_scale"] = 1.5
        kwargs["target_guidance_scale"] = 5.5
    elif pipe_name == "FTEdit":
        kwargs["fixed_point_steps"] = 3
        kwargs["ly_ratio"] = 1.0
        kwargs["guidance_scale"] = 2
    elif pipe_name == "DNAEdit":
        kwargs["start_timestep"] = 4
        kwargs["source_guidance_scale"] = 1
        kwargs["target_guidance_scale"] = 2.5

    image_list = pipe.multiturn(
        source_img,
        source_prompt,
        prompt_sequence,
        num_inference_steps=num_inference_steps,
        **kwargs,
    )
    image_list = [Image.open(source_img)] + image_list
    plt.figure(figsize=(10, 4))
    for image_idx, image in enumerate(image_list):
        plt.subplot(1, len(image_list), image_idx + 1)
        plt.imshow(image)
        plt.axis("off")
    plt.show()
    concat_image = h_concat_pil_images(image_list[1:])
    concat_image.save(f"assets/images/multi_turn_for_each_step_28_better/{pipe_name}.jpg")
    
    del pipe
    torch.cuda.empty_cache()
    gc.collect()