In [6]:
import os
import sys
from typing import List, Tuple
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import torch
from torchvision.transforms.functional import to_tensor
import accelerate
from pathlib import Path

# Add OmniGen2 to path
root_dir = Path().resolve()
omnigen_path = root_dir / "OmniGen2-main"
sys.path.append(str(omnigen_path))

from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
from omnigen2.models.transformers.transformer_omnigen2 import OmniGen2Transformer2DModel
from omnigen2.utils.img_util import create_collage

In [7]:
import os
from typing import List, Union
from PIL import Image, ImageOps

def preprocess(input_image_path: Union[str, List[str], None] = None) -> List[Image.Image]:
    """
    Preprocess the input images by:
    - Accepting a single path, list of paths, or a directory
    - Loading only common image files
    - Correcting orientation via EXIF
    - Converting to 3‑channel RGB (drops alpha)
    """
    if input_image_path is None:
        return []

    # Normalize to a list of paths
    if isinstance(input_image_path, str):
        paths = [input_image_path]
    else:
        paths = input_image_path

    images: List[Image.Image] = []
    for p in paths:
        if os.path.isdir(p):
            for fname in os.listdir(p):
                if fname.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".gif")):
                    img = Image.open(os.path.join(p, fname))
                    images.append(img)
        else:
            img = Image.open(p)
            images.append(img)

    # EXIF transpose + strip alpha channel
    processed = []
    for img in images:
        img = ImageOps.exif_transpose(img).convert("RGB")
        processed.append(img)

    return processed


**Pipeline Initialization**

In [8]:
accelerator = accelerate.Accelerator()

# Initialize the pipeline from Hugging Face with CPU offloading
model_path = "OmniGen2/OmniGen2"
pipeline = OmniGen2Pipeline.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    enable_model_cpu_offload=True  # Reduces VRAM usage by ~50% with minimal speed impact
)
pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained(
    model_path,
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
)
pipeline = pipeline.to(accelerator.device, dtype=torch.bfloat16)

Keyword arguments {'trust_remote_code': True, 'enable_model_cpu_offload': True} are not expected by OmniGen2Pipeline and will be ignored.
Loading checkpoint shards: 100%|██████████| 4/4 [00:11<00:00,  2.80s/it]
Loading pipeline components...:  40%|████      | 2/5 [00:11<00:14,  4.92s/it]Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Loading checkpoint shards: 100%|██████████| 2/2 [00:14<00:00,  7.10s/it]
Loading pipeline components...: 100%|██████████| 5/5 [00:27<00:00,  5.47s/it]
Expected types for transformer: (<class 'omnigen2.models.transformers.transformer_omnigen2.OmniGen2Transformer2DModel'>,), got <class 'diffusers_modules.local.transformer_omnigen2.OmniGen2Transformer2DModel'>.
Fetching 2 files: 100%|████

**Editing with instruction**

In [9]:
# Example of editing an image
def edit_image(image_path: str, prompt: str, negative_prompt: str = None):
    """
    Edit an image using OmniGen2
    Args:
        image_path: Path to the input image
        prompt: Instruction for editing
        negative_prompt: What to avoid in generation
    """
    if negative_prompt is None:
        negative_prompt = "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs"
    
    # Load and preprocess image
    input_imgs = preprocess(image_path)
    
    # Generate
    gen = torch.Generator(device=accelerator.device).manual_seed(0)
    result = pipeline(
        prompt=prompt,
        input_images=input_imgs,
        num_inference_steps=50,
        max_sequence_length=1024,
        text_guidance_scale=5.0,
        image_guidance_scale=2.0,
        negative_prompt=negative_prompt,
        num_images_per_prompt=1,
        generator=gen,
        output_type="pil",
    )
    
    # Display results
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
    ax1.imshow(input_imgs[0])
    ax1.set_title("Input Image")
    ax1.axis("off")
    
    ax2.imshow(result.images[0])
    ax2.set_title("Edited Image")
    ax2.axis("off")
    
    plt.show()
    return result.images[0]