In [2]:
import numpy
import torch
from tqdm.notebook import tqdm
from diffusers import FluxKontextInpaintPipeline
from nunchaku import NunchakuFluxTransformer2DModelV2
from nunchaku.utils import get_precision, get_gpu_memory
import numpy as np
from PIL import Image
from pathlib import Path
import random
from ultralytics.utils.ops import xywhn2xyxy, xyxy2xywh
from ultralytics.utils.plotting import Annotator, colors

In [37]:
prompts = [
    "Add a large black suitcase on the ground",
    "Add a red backpack on the ground",
    "Add a big box on the ground",
    "Add a large sack on the ground",
]

In [21]:
def load_flux_kontext_pipeline():
    transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
        f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors")


    pipe = FluxKontextInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer,
                                                      torch_dtype=torch.bfloat16).to('cuda')

    if get_gpu_memory() > 18:
        pipe.enable_model_cpu_offload()
    else:
        pipe._exclude_from_cpu_offload.append("transformer")
        pipe.enable_sequential_cpu_offload()

    return pipe


def flux_kontext_inpaint(pipe: FluxKontextInpaintPipeline, img: Image.Image, bbox: numpy.ndarray,
                         prompt: str) -> Image.Image:
    img_np = np.array(img)
    h, w = img_np.shape[:2]

    mask = np.zeros((h, w), dtype=np.uint8)
    x1, y1, x2, y2 = bbox.astype(int)
    mask[y1:y2, x1:x2] = 255  # White rectangle for inpainting area

    mask_pil = Image.fromarray(mask)

    seed = random.randint(0, 2 ** 31 - 1)
    generator = torch.Generator(device='cuda').manual_seed(seed)

    inpainted_img = pipe(
        prompt=prompt,
        image=img,
        mask_image=mask_pil,
        guidance_scale=2.5,
        generator=generator,
        strength=1.0).images[0]

    return inpainted_img

In [22]:
def read_yolo_labels(txt_path: str | Path) -> np.ndarray:
    """
    Returns Nx5 float array: [cls, xc, yc, w, h] (normalized).
    Empty file -> shape (0, 5).
    """
    txt_path = Path(txt_path)
    if not txt_path.exists() or txt_path.stat().st_size == 0:
        return np.zeros((0, 5), dtype=np.float32)

    rows = []
    for line in txt_path.read_text(encoding="utf-8").splitlines():
        line = line.strip()
        if not line:
            continue
        parts = line.split()
        if len(parts) != 5:
            raise ValueError(f"Bad label line in {txt_path}: {line}")
        rows.append([float(p) for p in parts])
    return np.asarray(rows, dtype=np.float32)[:, 1:]


def top_k_by_area_xywh(xywh: np.ndarray, k: int = 3) -> np.ndarray:
    """
    xywh: Nx4 array [x, y, w, h] (absolute or normalized; area ranking is the same).
    Returns up to k rows, largest areas first.
    """
    xywh = np.asarray(xywh)
    if xywh.size == 0:
        return xywh.reshape(0, 4)

    areas = xywh[:, 2] * xywh[:, 3]
    k = min(k, len(xywh))
    idx = np.argsort(areas)[::-1][:k]
    return xywh[idx]

In [None]:
pipe = load_flux_kontext_pipeline()

In [None]:
for i in range(194):
    image = mage.open(f'dataset/backgrounds/{image_id}.png').convert("RGB")
    labels = read_yolo_labels(f"dataset/people-box/labels/{image_id}.txt")
    labels[:, 2] *= 1.05
    labels[:, 3] *= 1.05
    top3 = top_k_by_area_xywh(labels, k=3)
    xyxy = xywhn2xyxy(top3, w=1280, h=720)
    box = random.choice(xyxy)
    remove_people = flux_kontext_inpaint(pipe, image, box, "Remove all people from the image")
    x1, y1, x2, y2 = box
    short_box = np.array([x1, y1 + 0.4 * abs(y2 - y1), x2, y2])
    prompt = random.choice(prompts)
    add_object = flux_kontext_inpaint(pipe, remove_people, short_box, prompt)
    add_object.save(f"dataset/images/{image_id}.png")
    labels = f'0 {short_box[0]} {short_box[1]} {short_box[2]} {short_box[3]}'
