In [1]:
from pathlib import Path
import json
from collections import Counter

import ipywidgets as widgets
from IPython.display import display
from PIL import Image, ImageOps
import io

DATASET_PATH = Path("dataset_qwen_pe_test.json")
ASSET_ROOT = Path("pico-banana-400k-subject_driven/openimages")
DEFAULT_OUTPUT = Path("dataset_qwen_pe_test_filtered.json")

with open(DATASET_PATH, "r") as f:
    DATA = json.load(f)

print(f"Loaded {len(DATA)} items from {DATASET_PATH}")
counts = Counter(item.get("edit_type", "") for item in DATA)
for etype, count in counts.most_common():
    print(f"{etype}: {count}")

Loaded 1444 items from dataset_qwen_pe_test.json
Remove an existing object: 482
Replace one object category with another: 482
Add a new object to the scene: 480


In [2]:
decisions = {}
status = widgets.HTML()
output = widgets.Output()

index_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=max(len(DATA) - 1, 0),
    step=1,
    description="Index",
    continuous_update=False,
)

alpha_slider = widgets.FloatSlider(
    value=0.5,
    min=0.0,
    max=1.0,
    step=0.05,
    description="Mask alpha",
    readout_format=".2f",
    continuous_update=False,
)

def load_image(rel_path: str):
    path = ASSET_ROOT / rel_path
    if not path.exists():
        return None, f"Missing: {path}"
    try:
        img = Image.open(path).convert("RGB")
        buf = io.BytesIO()
        img.save(buf, format="PNG")
        return buf.getvalue(), None
    except Exception as exc:  # noqa: BLE001
        return None, f"Error reading {path}: {exc}"

def load_pil(rel_path: str):
    path = ASSET_ROOT / rel_path
    if not path.exists():
        return None, f"Missing: {path}"
    try:
        return Image.open(path).convert("RGBA"), None
    except Exception as exc:  # noqa: BLE001
        return None, f"Error reading {path}: {exc}"

def overlay_with_mask(base_img: Image.Image, mask_img: Image.Image, alpha: float):
    if base_img.size != mask_img.size:
        mask_img = mask_img.resize(base_img.size)
    mask_gray = mask_img.convert("L")
 
    mask_scaled = mask_gray.point(lambda p: int(p * alpha))
    overlay = Image.new("RGBA", base_img.size, (255, 0, 0, 0))
    overlay.putalpha(mask_scaled)
    blended = Image.alpha_composite(base_img.convert("RGBA"), overlay)
    buf = io.BytesIO()
    blended.save(buf, format="PNG")
    return buf.getvalue()

def summarize_decisions():
    keep = sum(1 for d in decisions.values() if d == "keep")
    drop = sum(1 for d in decisions.values() if d == "drop")
    undecided = len(DATA) - keep - drop
    return keep, drop, undecided

def render(idx: int):
    output.clear_output()
    item = DATA[idx]
    keep, drop, undecided = summarize_decisions()
    decision = decisions.get(idx, "undecided")
    status.value = (
        f"Current decision: <b>{decision}</b>. "
        f"Keep {keep} / Drop {drop} / Undecided {undecided}."
    )

    mask_img = None
    mask_err = None
    mask_path = item.get("back_mask")
    if mask_path:
        mask_img, mask_err = load_pil(mask_path)

    with output:
        print(f"#{idx} edit_type={item.get('edit_type', '')}")
        print(item.get("prompt", ""))
        print()

        tiles = []
        missing = []
        if mask_err:
            missing.append(mask_err)

        frames = []
        if item.get("image"):
            frames.append(("image+mask", item.get("image")))
        for j, rel in enumerate(item.get("edit_image", []) or []):
            frames.append((f"edit_image[{j}]+mask", rel))

        for label, rel in frames:
            data, err = load_image(rel)
            if not data:
                missing.append(err or f"{label}: could not load")
                continue
            try:
                base_img = Image.open(io.BytesIO(data)).convert("RGBA")
                if mask_img:
                    overlaid = overlay_with_mask(base_img, mask_img, alpha_slider.value)
                    img_widget = widgets.Image(value=overlaid, format="png")
                else:
                    img_widget = widgets.Image(value=data, format="png")
                tiles.append(widgets.VBox([widgets.Label(label), img_widget]))
            except Exception as exc:  # noqa: BLE001
                missing.append(f"{label}: render failed ({exc})")

        if tiles:
            grid = widgets.GridBox(
                tiles,
                layout=widgets.Layout(
                    grid_template_columns="repeat(2, minmax(320px, 1fr))",
                    grid_gap="12px",
                ),
            )
            display(grid)
        if missing:
            print("\n".join(missing))

def mark(decision: str):
    decisions[index_slider.value] = decision
    render(index_slider.value)

def clear_mark():
    decisions.pop(index_slider.value, None)
    render(index_slider.value)

keep_btn = widgets.Button(description="Keep", button_style="success")
keep_btn.on_click(lambda _btn: mark("keep"))

drop_btn = widgets.Button(description="Drop", button_style="danger")
drop_btn.on_click(lambda _btn: mark("drop"))

clear_btn = widgets.Button(description="Clear", button_style="warning")
clear_btn.on_click(lambda _btn: clear_mark())

def on_index_change(change):
    if change.get("name") == "value":
        render(change["new"])

def on_alpha_change(change):
    if change.get("name") == "value":
        render(index_slider.value)

index_slider.observe(on_index_change, names="value")
alpha_slider.observe(on_alpha_change, names="value")

controls = widgets.HBox([keep_btn, drop_btn, clear_btn, alpha_slider])
display(index_slider, controls, status, output)

render(index_slider.value)

IntSlider(value=0, continuous_update=False, description='Index', max=1443)

HBox(children=(Button(button_style='success', description='Keep', style=ButtonStyle()), Button(button_style='dâ€¦

HTML(value='')

Output()

In [None]:
save_path = widgets.Text(value=str(DEFAULT_OUTPUT), description="Save to")
save_btn = widgets.Button(description="Save filtered", button_style="success")
save_output = widgets.Output()

# Keep by default; only dropped items are removed unless you mark keep explicitly.
KEEP_BY_DEFAULT = True

def on_save(_btn):
    if KEEP_BY_DEFAULT:
        filtered = [item for i, item in enumerate(DATA) if decisions.get(i) != "drop"]
    else:
        filtered = [item for i, item in enumerate(DATA) if decisions.get(i) == "keep"]

    path = Path(save_path.value)
    path.write_text(json.dumps(filtered, indent=2))

    with save_output:
        save_output.clear_output()
        print(f"Saved {len(filtered)} items to {path}")
        print(f"Dropped {len(DATA) - len(filtered)} items")

save_btn.on_click(on_save)
display(widgets.HBox([save_path, save_btn]), save_output)