In [1]:
import os
import sys
from typing import List, Tuple
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import torch
from torchvision.transforms.functional import to_tensor
import accelerate
from pathlib import Path
root_dir = Path().resolve()
sys.path.append(root_dir)
from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
from omnigen2.models.transformers.transformer_omnigen2 import OmniGen2Transformer2DModel
from omnigen2.utils.img_util import create_collage

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
from typing import List, Union
from PIL import Image, ImageOps

def preprocess(input_image_path: Union[str, List[str], None] = None) -> List[Image.Image]:
    """
    Preprocess the input images by:
    - Accepting a single path, list of paths, or a directory
    - Loading only common image files
    - Correcting orientation via EXIF
    - Converting to 3‑channel RGB (drops alpha)
    """
    if input_image_path is None:
        return []

    # Normalize to a list of paths
    if isinstance(input_image_path, str):
        paths = [input_image_path]
    else:
        paths = input_image_path

    images: List[Image.Image] = []
    for p in paths:
        if os.path.isdir(p):
            for fname in os.listdir(p):
                if fname.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".gif")):
                    img = Image.open(os.path.join(p, fname))
                    images.append(img)
        else:
            img = Image.open(p)
            images.append(img)

    # EXIF transpose + strip alpha channel
    processed = []
    for img in images:
        img = ImageOps.exif_transpose(img).convert("RGB")
        processed.append(img)

    return processed


**Pipeline Initialization**

In [3]:
accelerator = accelerate.Accelerator()

model_path="OmniGen2/OmniGen2"
pipeline = OmniGen2Pipeline.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    token="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",
)
pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained(
    model_path,
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
)
pipeline = pipeline.to(accelerator.device, dtype=torch.bfloat16)

Couldn't connect to the Hub: 401 Client Error: Unauthorized for url: https://huggingface.co/api/models/OmniGen2/OmniGen2 (Request ID: Root=1-6884c7d2-4da896402310b7302e2b0e2c;b5d1a93e-1672-4854-ae4a-c0b7c02ad7aa)

Invalid credentials in Authorization header.
Will try to load from local cache.
Keyword arguments {'trust_remote_code': True} are not expected by OmniGen2Pipeline and will be ignored.
Loading pipeline components...:   0%|                                                                                                                                          | 0/5 [00:00<?, ?it/s]
Loading checkpoint shards:   0%|                                                                                                                                               | 0/2 [00:00<?, ?it/s][A
Loading checkpoint shards:  50%|███████████████████████████████████████████████████████████████████▌                                                                   | 1/2 [00:00<00:00,  2.86it/s][A
Lo

**Editing with instruction**

In [None]:

#!/usr/bin/env python3
import csv, torch
from pathlib import Path

# --- User‑level config ----------------------------------------
COLORS = ["red","green","blue","yellow","orange",
          "purple","pink","brown","black","gray"]

NEG_PROMPT = (
    "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, "
    "mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, "
    "messy drawing, broken legs, censor, censored, censor_bar"
)

# Root of your synthetic_dataset folder
SYNTHETIC_DIR = Path(
    "/home/patrick/ssd/discover-hidden-visual-concepts/PatrickProject/"
    "ImageEditing/third_party/OmniGen2/synthetic_dataset"
)

# CSV listing which objects to process (header: "class")
OBJECTS_CSV = SYNTHETIC_DIR / "objectstorun.csv"

# --- Read list of object names from CSV ----------------------
object_names = []
with OBJECTS_CSV.open("r", newline="") as f:
    reader = csv.DictReader(f)
    for row in reader:
        name = row.get("class", "").strip()
        if not name or name.lower() == "class":
            continue
        # strip literal "_bases" if present
        if name.endswith("_bases"):
            name = name[:-len("_bases")]
        object_names.append(name)
print("Will process objects:", object_names)

# --- Main loop: one pass per object ---------------------------
for OBJECT_NAME in object_names:
    print(f"\n=== Processing '{OBJECT_NAME}' ===")

    # 1) Locate bases
    BASE_DIR = SYNTHETIC_DIR / f"{OBJECT_NAME}_bases"
    pngs = sorted(BASE_DIR.glob("*.png"))
    if not pngs:
        print(f"⚠️ No bases in {BASE_DIR}; skipping.")
        continue

    # 2) Prepare color output folder
    COLOR_DIR = SYNTHETIC_DIR / f"{OBJECT_NAME}_color"
    COLOR_DIR.mkdir(exist_ok=True)

    # 3) Prepare CSV
    CSV_PATH = COLOR_DIR / "labels.csv"
    write_header = not CSV_PATH.exists()
    with CSV_PATH.open("a", newline="") as csv_file:
        writer = csv.writer(csv_file)
        if write_header:
            writer.writerow(["filename","size","texture","variant","colour","class"])

        # 4) Process each base
        for base_png in pngs:
            stem = base_png.stem
            if not stem.startswith("base_"):
                print("⚠️ skipping unexpected file:", stem)
                continue

            # remove prefix and split
            parts = stem[len("base_"):].split("_")
            # drop last segment (class tag)
            parts = parts[:-1]
            # now parts = [size, *(texture parts), variant]
            size = parts[0]
            variant = parts[-1]
            texture = "_".join(parts[1:-1]) if len(parts) > 2 else ""

            # preprocess and generate
            input_imgs = preprocess(str(base_png))
            for colour in COLORS:
                out_name = f"{OBJECT_NAME}_{size}_{texture}_{variant}_{colour}.png"
                out_path = COLOR_DIR / out_name
                if out_path.exists():
                    continue
                # generate
                prompt = f"Change the object to {colour} and make the background white"
                gen = torch.Generator(device=accelerator.device).manual_seed(0)
                result = pipeline(
                    prompt              = prompt,
                    input_images        = input_imgs,
                    num_inference_steps = 50,
                    max_sequence_length = 1024,
                    text_guidance_scale = 5.0,
                    image_guidance_scale= 2.0,
                    negative_prompt     = NEG_PROMPT,
                    num_images_per_prompt=1,
                    generator           = gen,
                    output_type         = "pil",
                )
                result.images[0].save(out_path)
                writer.writerow([out_name, size, texture, variant, colour, OBJECT_NAME])
                print("✔", out_name)

    # 5) Check completeness
    expected = set()
    for base_png in pngs:
        parts = base_png.stem[len("base_"):].split("_")[:-1]
        size, variant = parts[0], parts[-1]
        texture = "_".join(parts[1:-1]) if len(parts) > 2 else ""
        for colour in COLORS:
            expected.add(f"{OBJECT_NAME}_{size}_{texture}_{variant}_{colour}.png")
    existing = {p.name for p in COLOR_DIR.glob("*.png")}
    missing = sorted(expected - existing)
    if missing:
        print("🚨 Missing combos:", missing)
    else:
        print(f"✅ All {len(expected)} combos present for '{OBJECT_NAME}'.")

print("\n🎉 All objects processed!")


Will process objects: ['doll', 'pipe', 'telescope', 'suitcase', 'christmastreeornamentball']

=== Processing 'doll' ===
✅ All 120 combos present for 'doll'.

=== Processing 'pipe' ===
✅ All 120 combos present for 'pipe'.

=== Processing 'telescope' ===
✅ All 120 combos present for 'telescope'.

=== Processing 'suitcase' ===
✅ All 120 combos present for 'suitcase'.

=== Processing 'christmastreeornamentball' ===


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:35<00:00,  1.91s/it]


✔ christmastreeornamentball_large_bumpy_01_red.png


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:35<00:00,  1.92s/it]


✔ christmastreeornamentball_large_bumpy_01_green.png


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:37<00:00,  1.94s/it]


✔ christmastreeornamentball_large_bumpy_01_blue.png


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:35<00:00,  1.91s/it]


✔ christmastreeornamentball_large_bumpy_01_yellow.png


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:35<00:00,  1.91s/it]


✔ christmastreeornamentball_large_bumpy_01_orange.png


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:37<00:00,  1.95s/it]


✔ christmastreeornamentball_large_bumpy_01_purple.png


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:37<00:00,  1.94s/it]


✔ christmastreeornamentball_large_bumpy_01_pink.png


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:35<00:00,  1.91s/it]


✔ christmastreeornamentball_large_bumpy_01_brown.png


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:35<00:00,  1.91s/it]


✔ christmastreeornamentball_large_bumpy_01_black.png


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:36<00:00,  1.94s/it]


✔ christmastreeornamentball_large_bumpy_01_gray.png


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:36<00:00,  1.94s/it]


✔ christmastreeornamentball_large_bumpy_02_red.png


 72%|██████████████████████████████████████████████████████████████████████████████████████████████████████▉                                        | 36/50 [01:09<00:28,  2.03s/it]