In [None]:
from typing import Any

import diffusers
import matplotlib.pyplot as plt
import scienceplots  # noqa B401
import seaborn as sns
import torch
from matplotlib.ticker import MaxNLocator
from PIL import Image
from tqdm import trange

from pipeline import (
    DNAEditFluxPipeline,
    FireFlowEditFluxPipeline,
    FTEditFluxPipeline,
    RFInversionEditFluxPipeline,
    RFSolverEditFluxPipeline,
    RFVanillaFluxPipeline,
)
from pipeline.common import RecordInvForCallback

diffusers.utils.logging.set_verbosity_error()

plt.style.use("science")
colors = sns.color_palette()
plt.rcParams["font.family"] = "Times New Roman"

In [None]:
target_pipe = [
    ("RF-Vanilla", RFVanillaFluxPipeline, 25),
    ("RF-Inversion", RFInversionEditFluxPipeline, 25),
    ("RF-Solver", RFSolverEditFluxPipeline, 25),
    ("FireFlow", FireFlowEditFluxPipeline, 25),
    ("FTEdit", FTEditFluxPipeline, 25),
    ("DNAEdit", DNAEditFluxPipeline, 25),
]
stage_colors = {n: colors[n] for n in range(len(target_pipe))}
total_diff_list = []
for pipe_name, pipe, num_inference_steps in target_pipe:
    source_img = "data/PIE-Bench_v1/annotation_images/0_random_140/000000000001.jpg"
    source_prompt = "a round cake with orange frosting on a wooden plate"
    img_list = [Image.open(source_img)]
    num_recon = 1
    input_img = source_img

    print(f"Running {pipe_name}...")
    pipe = pipe.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
    pipe.to("cuda")
    pipe.set_progress_bar_config(disable=True)

    callback = RecordInvForCallback()
    kwargs = {}
    if pipe_name in ["RF-Solver", "FireFlow"]:
        kwargs["with_second_order"] = True
    elif pipe_name in ["RF-Inversion"]:  # follow the paper settings in Table 4
        kwargs["start_timestep"] = 8
        kwargs["stop_timestep"] = 1.0  # ratio
    elif pipe_name in ["FTEdit"]:
        kwargs["fixed_point_steps"] = 3
    elif pipe_name in ["DNAEdit"]:
        kwargs["start_timestep"] = 0

    for _ in trange(num_recon, ncols=0, leave=False):
        image = pipe.reconstruction(
            input_img,
            source_prompt,
            guidance_scale=1,
            num_inference_steps=num_inference_steps,
            callback_on_step_end=callback,
            callback_on_step_end_tensor_inputs=callback.tensor_inputs,
            **kwargs,
        ).images[0]
        img_list.append(image)
        input_img = image

    target_timestep = sorted(
        list(callback.record["inverse"].keys() & callback.record["foward"].keys()), reverse=True
    )
    diff_list = [
        (callback.record["inverse"][t] - callback.record["foward"][t]).pow(2).mean().item()
        for t in target_timestep
    ]
    total_diff_list.append((pipe_name, diff_list))

    plt.figure(figsize=(10, 4))
    for i in range(len(img_list)):
        plt.subplot(1, len(img_list), i + 1)
        plt.imshow(img_list[i])
        plt.axis("off")
    plt.show()

    del pipe
    del callback
    torch.cuda.empty_cache()

In [None]:
fig = plt.figure(figsize=(8, 4))
ax = fig.add_subplot(1, 1, 1)
for pipe_idx, (pipe_name, diff_list) in enumerate(total_diff_list):
    ax.plot(range(len(diff_list)), diff_list)
    ax.scatter(
        range(len(diff_list)),
        diff_list,
        color=stage_colors[pipe_idx],
        alpha=0.7,
        label=pipe_name,
        s=6,
    )
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.set_xticks(list(range(0, len(diff_list))))
every_step = 3
ax.set_xticklabels(
    [
        (
            len(diff_list) - label
            if (label_idx % every_step == 0) or (label_idx in [len(diff_list) - 1, 0])
            else ""
        )
        for label_idx, label in enumerate(range(len(diff_list)))
    ]
)
plt.legend()
plt.show()