In [None]:
import gc
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import diffusers
import matplotlib.pyplot as plt
import scienceplots  # noqa B401
import seaborn as sns
import torch
from matplotlib.ticker import MaxNLocator

from pipeline import FireFlowEditFluxPipeline, RFVanillaFluxPipeline
from pipeline.common import RecordInvForCallback

diffusers.utils.logging.set_verbosity_error()

plt.style.use("science")
colors = sns.color_palette()
plt.rcParams["font.family"] = "Times New Roman"

In [None]:
target_pipe = [
    ("RF-Vanilla", RFVanillaFluxPipeline, 25),
    ("FireFlow", FireFlowEditFluxPipeline, 25),
]

for pipe_name, pipe, num_inference_steps in target_pipe:
    print(f"Running {pipe_name}...")
    pipe = FireFlowEditFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
    pipe.to("cuda")
    if pipe_name == "FireFlow":
        pipe.add_processor(after_layer=0, before_layer=37, filter_name="single_transformer_blocks")
    pipe.set_progress_bar_config(disable=True)
    source_img = "data/PIE-Bench_v1/annotation_images/0_random_140/000000000001.jpg"
    source_prompt = "a round cake with orange frosting on a wooden plate"
    
    if pipe_name == "FireFlow":
        callback = RecordInvForCallback(target_tensor_name="mid_noise_pred", target_key=["mid_noise_pred", "inverse"])
    else:
        callback = RecordInvForCallback(target_tensor_name="noise_pred", target_key=["noise_pred", "inverse"])
    kwargs = {}
    
    if pipe_name == "FireFlow":
        kwargs["with_second_order"] = True
        
    image = pipe.reconstruction(
        source_img,
        source_prompt,
        guidance_scale=1,
        num_inference_steps=num_inference_steps,
        callback_on_step_end=callback,
        callback_on_step_end_tensor_inputs=callback.tensor_inputs,
        **kwargs,
    ).images[0]

    target_timestep = sorted(
        list(callback.record["inverse"].keys() & callback.record["foward"].keys()), reverse=True
    )
    inv_for_diff_list = [
        (callback.record["inverse"][t] - callback.record["foward"][t]).pow(2).mean().item()
        for t in target_timestep
    ]

    inverse_time_keys = sorted(list(callback.record["inverse"].keys()), reverse=True)
    inv_inv_diff_list = [
        (callback.record["inverse"][inverse_time_keys[0]] - callback.record["inverse"][inverse_time_keys[time_idx]]).pow(2).mean().item()
        for time_idx in range(1, len(inverse_time_keys))
    ]
    forward_time_keys = sorted(list(callback.record["foward"].keys()), reverse=True)
    for_for_diff_list = [
        (callback.record["foward"][forward_time_keys[0]] - callback.record["foward"][forward_time_keys[time_idx]]).pow(2).mean().item()
        for time_idx in range(1, len(forward_time_keys))   
    ]
    fig = plt.figure(figsize=(8, 4))
    ax = fig.add_subplot(1, 1, 1)
    stage_colors = {n: colors[n] for n in range(3)}

    total_diff_list = [
        (inv_for_diff_list, r"$v^{i}_{t_i} - v^{f}_{t_i}$"),
        (inv_inv_diff_list, r"$v^{i}_{t_0} - v^{i}_{t_i}$"),
        (for_for_diff_list, r"$v^{f}_{t_0} - v^{f}_{t_i}$"),
    ]
    for diff_idx, (diff_list, name_of_diff_list) in enumerate(total_diff_list):
        ax.plot(range(len(diff_list)), diff_list, color=colors[diff_idx])
        ax.scatter(
            range(len(diff_list)),
            diff_list,
            alpha=0.7,
            color=stage_colors[diff_idx],
            label=name_of_diff_list,
            s=6,
        )
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax.set_xticks(list(range(0, len(diff_list))))
    every_step = 3
    ax.set_xticklabels(
        [
            (
                len(diff_list) - label
                if (label_idx % every_step == 0) or (label_idx in [len(diff_list) - 1, 0])
                else ""
            )
            for label_idx, label in enumerate(range(len(diff_list)))
        ]
    )
    plt.legend()
    plt.show()
    
    del pipe
    del callback
    torch.cuda.empty_cache()
    gc.collect()