In [None]:
import os
import random
from pathlib import Path

import cv2
import numpy as np
import torch
from PIL import Image
from diffusers import FluxKontextInpaintPipeline
from nunchaku import NunchakuFluxTransformer2DModelV2
from nunchaku.utils import get_precision
from tqdm.notebook import tqdm
from ultralytics.utils.ops import xywhn2xyxy, xyxy2xywhn
from ultralytics.utils.plotting import Annotator

In [None]:
prompts = [
    "Add a large black suitcase on the ground",
    "Add a red backpack on the ground",
    "Add a big box on the ground",
    "Add a large sack on the ground",
    "Add a large hard-shell suitcase on the ground",
    "Add a battered leather suitcase on the ground",
    "Add a metallic silver briefcase on the ground",
    "Add a blue gym bag on the ground",
    "Add a suspicious duffle bag on the ground",
    "Add a military tactical backpack on the ground",
    "Add a bulky hiking backpack on the ground",
    "Add a dirty canvas bag on the ground"
    "Add a colorful school bag on the ground"
]

In [None]:
# Prepare output directories
data_dir = '../data/final'
image_dir = f"{data_dir}/images"
label_dir = f"{data_dir}/labels"
annotated_dir = f"{data_dir}/annotated"

In [None]:
# Skip to continue previous final
if os.path.exists(data_dir):
    candidate = data_dir
    i = 0
    while os.path.exists(candidate):
        i += 1
        candidate = f"{data_dir}_{i}"
    data_dir = candidate
    image_dir = f"{data_dir}/images"
    label_dir = f"{data_dir}/labels"
    annotated_dir = f"{data_dir}/annotated"

os.makedirs(data_dir, exist_ok=False)
os.makedirs(image_dir, exist_ok=False)
os.makedirs(label_dir, exist_ok=False)
os.makedirs(annotated_dir, exist_ok=False)

In [None]:
def load_flux_kontext_pipeline() -> FluxKontextInpaintPipeline:
    """
    :return: A FluxKontextInpaintPipeline with Nunchaku FLUX.1-Kontext-Dev transformer
    :rtype: FluxKontextInpaintPipeline
    """
    transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
        f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors")

    transformer.set_attention_backend('sage')

    pipe = FluxKontextInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer,
                                                      torch_dtype=torch.bfloat16).to('cuda')

    pipe._exclude_from_cpu_offload.append("transformer")
    pipe.enable_model_cpu_offload()

    return pipe


def flux_kontext_inpaint(pipe: FluxKontextInpaintPipeline, img: Image.Image, bbox: np.ndarray,
                         prompt: str) -> Image.Image:
    """
    Inpaints an image region defined by bbox using Flux.1 Kontext.
    :param pipe: A Flux.1 Kontext inpaint pipeline
    :type pipe: FluxKontextInpaintPipeline
    :param img: Input image
    :type img: Image.Image
    :param bbox: Bounding box [x1, y1, x2, y2]
    :type bbox: np.ndarray
    :param prompt: Inpainting prompt
    :type prompt: str
    :return: Inpainted image
    :rtype: Image.Image
    """
    img_np = np.array(img)
    h, w = img_np.shape[:2]

    mask = np.zeros((h, w), dtype=np.uint8)
    x1, y1, x2, y2 = bbox.astype(int)
    mask[y1:y2, x1:x2] = 255  # White rectangle for inpainting area

    mask_pil = Image.fromarray(mask)

    seed = random.randint(0, 2 ** 31 - 1)
    generator = torch.Generator(device='cuda').manual_seed(seed)

    inpainted_img = pipe(
        prompt=prompt,
        image=img,
        mask_image=mask_pil,
        guidance_scale=2.5,
        generator=generator,
        strength=1.0).images[0]

    return inpainted_img

In [None]:
def read_yolo_labels(txt_path: str | Path) -> np.ndarray:
    """
    Reads YOLO format labels from a text file.
    :param txt_path: Path to the YOLO label text file.
    :type txt_path: str | Path
    :return: Nx4 array of bounding boxes in normalized xywh format.
    """
    txt_path = Path(txt_path)
    if not txt_path.exists() or txt_path.stat().st_size == 0:
        return np.zeros((0, 5), dtype=np.float32)

    rows = []
    for line in txt_path.read_text(encoding="utf-8").splitlines():
        line = line.strip()
        if not line:
            continue
        parts = line.split()
        if len(parts) != 5:
            raise ValueError(f"Bad label line in {txt_path}: {line}")
        rows.append([float(p) for p in parts])
    return np.asarray(rows, dtype=np.float32)[:, 1:]


def top_k_by_area_xywh(xywh: np.ndarray, k: int = 3) -> np.ndarray:
    """
    Selects the top k bounding boxes by area from an array of xywh boxes.
    :param xywh: Array of bounding boxes in xywh format.
    :type xywh: np.ndarray
    :param k: Number of top boxes to select.
    :type k: int
    :return: Kx4 array of top k bounding boxes by area.
    :rtype: np.ndarray
    """
    xywh = np.asarray(xywh)
    if xywh.size == 0:
        return xywh.reshape(0, 4)

    areas = xywh[:, 2] * xywh[:, 3]
    k = min(k, len(xywh))
    idx = np.argsort(areas)[::-1][:k]
    return xywh[idx]

In [None]:
def box_from_diff(img1, img2, threshold=25, area_threshold=1000) -> tuple | None:
    """
    Extract bounding box around new object in difference image.
    :param img1: First image
    :type img1: Image.Image
    :param img2: Second image
    :type img2: Image.Image
    :param threshold: Pixel difference threshold
    :type threshold: int
    :param area_threshold: Minimum area of detected object
    :type area_threshold: int
    :return: Bounding box (x1, y1, x2, y2) or None if no object detected
    :rtype: tuple | None
    """

    diff_image = cv2.absdiff(np.array(img1), np.array(img2))

    # Convert to grayscale if needed
    if len(diff_image.shape) == 3:
        gray = cv2.cvtColor(diff_image, cv2.COLOR_BGR2GRAY)
    else:
        gray = diff_image

    # Apply threshold
    _, binary = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY)

    # Find contours
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    if not contours:
        return None

    # Get bounding box of largest contour
    largest_contour = max(contours, key=cv2.contourArea)

    if cv2.contourArea(largest_contour) < area_threshold:
        return None

    x, y, w, h = cv2.boundingRect(largest_contour)

    # Convert to xyxy format
    bbox = (x, y, x + w, y + h)

    return bbox

In [None]:
pipe = load_flux_kontext_pipeline()

In [None]:
num_images = 374
failed_removals = 0
failed_adds = 0
with tqdm(total=num_images, position=1, leave=True) as pbar:
    for i in range(num_images):

        if os.path.exists(f"{image_dir}/{i}.png"):
            pbar.update(1)
            print(f"Image {i} already processed, skipping.")
            continue

        # Load image and label
        image = Image.open(f'data/backgrounds/{i}.png').convert("RGB")
        labels = read_yolo_labels(f"data/people-box/labels/{i}.txt")

        # Grow boxes by 5% and select a random person from the top 3 largest
        labels[:, 2] *= 1.05
        labels[:, 3] *= 1.05
        top3 = top_k_by_area_xywh(labels, k=3)
        xyxy = xywhn2xyxy(top3, w=1248, h=832)
        box = random.choice(xyxy)
        x1, y1, x2, y2 = box

        # Remove the selected person from the image
        remove_people = flux_kontext_inpaint(pipe, image, box, "Remove all people")

        # Validate person removal
        if not box_from_diff(image, remove_people,
                             area_threshold=420):  # According to EDA, 423 px is the smallest person box in all the data
            # Retry with a different person, if available
            if len(xyxy) > 1:
                box2 = box
                while np.array_equal(box2, box):
                    box2 = random.choice(xyxy)
                box = box2
                remove_people = flux_kontext_inpaint(pipe, image, box, "Remove all people")
            else:
                print(f"Person removal failed in image {i}, skipping.")
                failed_removals += 1
                pbar.update(1)
                continue

            # Skip if still fails
            if not box_from_diff(image, remove_people, area_threshold=420):
                print(f"Person removal failed in image {i}, skipping.")
                failed_removals += 1
                pbar.update(1)
                continue

        # Add an object in place of the removed person
        short_box = np.array([x1, y1 + 0.4 * abs(y2 - y1), x2, y2])  # Lower the box by 40%
        prompt = random.choice(prompts)
        add_object = flux_kontext_inpaint(pipe, remove_people, short_box, prompt)
        pbar.refresh()

        # Extract bounding box of the added object
        object_box = box_from_diff(remove_people, add_object)  # Calculate difference between inpainted images
        pbar.refresh()

        # Validate the output
        if not object_box:
            print(f"No object detected in image {i}, skipping.")
            failed_adds += 1
            pbar.update(1)
            continue

        xywhn = xyxy2xywhn(np.asarray(object_box, dtype=np.float32), w=1248, h=832)
        labels = f'0 {xywhn[0]} {xywhn[1]} {xywhn[2]} {xywhn[3]}'  # YOLO format

        # Save outputs
        add_object.save(f"{image_dir}/{i}.png")
        with open(f"{label_dir}/{i}.txt", 'w') as f:
            f.write(labels)

        # Save annotated_dataset image for visualization
        ann = Annotator(add_object.copy())
        ann.box_label(object_box, color=(255, 0, 0))
        final = Image.fromarray(ann.result())
        final.save(f"{annotated_dir}/{i}.png")

        pbar.update(1)
        pbar.refresh()
        print(f"Processed image {i}")


In [None]:
failed = failed_removals + failed_adds
failure_rate = failed / num_images
successful = num_images - failed
print(f"Successful images: {successful}, Failed images: {failed}")
print(f"Processing completed with failure rate: {failure_rate:.2%}")
print(f"Failed person removals: {failed_removals}, Failed object additions: {failed_adds}")