In [None]:
from transformers import (
    AutoProcessor,
    SiglipImageProcessor,
    SiglipVisionModel,
    T5EncoderModel,
    BitsAndBytesConfig,
)
from univa.models.qwen2p5vl.modeling_univa_qwen2p5vl import (
    UnivaQwen2p5VLForConditionalGeneration,
)
from univa.utils.flux_pipeline import FluxPipeline
from univa.utils.get_ocr import get_ocr_result
from univa.utils.denoiser_prompt_embedding_flux import encode_prompt
from qwen_vl_utils import process_vision_info
from univa.utils.anyres_util import dynamic_resize, concat_images_adaptive
import torch
from torch import nn
import os
import uuid
import base64
from typing import Dict
from PIL import Image, ImageDraw, ImageFont
from pathlib import Path
from diffusers.utils import make_image_grid

import json

In [None]:


def load_siglip(siglip_path, device="cuda"):
    siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path)
    siglip_model = SiglipVisionModel.from_pretrained(
        siglip_path,
        torch_dtype=torch.bfloat16,
    ).to(device)
    return siglip_processor, siglip_model


def load_lvlm(pretrained_lvlm_path, denoise_projector_path=None, siglip_projector_path=None, device="cuda"):
    lvlm_model = UnivaQwen2p5VLForConditionalGeneration.from_pretrained(
        pretrained_lvlm_path,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
    ).to(device)

    if denoise_projector_path:
        denoise_projector = torch.load(denoise_projector_path, map_location=device)
        msg = lvlm_model.load_state_dict(denoise_projector, strict=False)
        assert len(msg[1]) == 0, f"Missing keys in denoise projector: {msg[1]}"

    if siglip_projector_path:
        siglip_projector = torch.load(siglip_projector_path, map_location=device)
        msg = lvlm_model.load_state_dict(siglip_projector, strict=False)
        assert len(msg[1]) == 0, f"Missing keys in siglip projector: {msg[1]}"

    lvlm_processor = AutoProcessor.from_pretrained(
        pretrained_lvlm_path,
        min_pixels=1024 * 1024,
        max_pixels=1024 * 1024,
    )

    return lvlm_processor, lvlm_model


def load_flux(flux_path, lvlm_model, device="cuda"):
    return FluxPipeline.from_pretrained(
        flux_path,
        transformer=lvlm_model.denoise_tower.denoiser,
        torch_dtype=torch.bfloat16,
    ).to(device)


def build_conversation(prompt, image, resolution=1024):
    return [{
        "role": "user",
        "content": [
            {"type": "text", "text": prompt},
            {
                "type": "image",
                "image": image,
                "min_pixels": resolution * resolution,
                "max_pixels": resolution * resolution,
            },
        ],
    }]


def preprocess_lvlm_inputs(lvlm_processor, convo, image, device="cuda"):
    chat_text = lvlm_processor.apply_chat_template(
        convo, tokenize=False, add_generation_prompt=True
    )
    chat_text = "<|im_end|>\n".join(chat_text.split("<|im_end|>\n")[1:])

    return lvlm_processor(
        text=[chat_text],
        images=[image],
        padding=True,
        return_tensors="pt",
    ).to(device)


def get_siglip_hidden_states(siglip_processor, siglip_model, image, device="cuda"):
    tensor = siglip_processor.preprocess(
        image.convert("RGB"),
        do_resize=True,
        do_convert_rgb=True,
        return_tensors="pt",
    ).pixel_values.to(device)

    with torch.no_grad():
        return siglip_model(tensor).last_hidden_state



In [None]:
# def get_img_and_clean_prompt(dataset_path: Path, idx: int):
#     json_path = dataset_path / "annotation.json"


#     with open(json_path, "r") as f:
#         data = json.load(f)

#     img_meta = data[idx]


#     img_path = dataset_path / img_meta["image"][0]

#     src_img = Image.open(img_path)
#     clean_prompt = img_meta["conversations"][0]["value"]
#     return src_img, clean_prompt


def get_img_and_clean_prompt(dataset_path: Path):
    img_paths = list(dataset_path.glob("*.jpg"))

    json_path = dataset_path / "clean_prompt_meta.json"

    with open(json_path, "r") as f:
        data = json.load(f)

    img_prompt_pairs = []
    for img_path in img_paths:
        clean_prompt = data[img_path.name]
        img = Image.open(img_path)
        pair = (img,clean_prompt)
        img_prompt_pairs.append(pair)
    return img_prompt_pairs


In [None]:

LVLM_PATH = "checkpoints/training-dataset.utils.aniweave.ai/uniworld-stage-2-cp-10000-1024_start_with_text_mpl_in_512/univa"
DENOISE_PROJECTOR_WEIGHT_PATH = "checkpoints/denoise_projector.bin"
SIGLIP_PROJECTOR_WEIGHT_PATH = "checkpoints/training-dataset.utils.aniweave.ai/uniworld-stage-2-cp-10000-1024_start_with_text_mpl_in_512/siglip_projector.bin"

SIGLIP_PATH = "/workspace/UniWorld-V1/model_weight/siglip2-so400m-patch16-512"
FLUX_PATH = "/workspace/UniWorld-V1/model_weight/FLUX.1-dev"

DATASET_FOLDER  = Path("/workspace/UniWorld-V1/xinyi_test_data")

In [None]:

# ------------------- Part 1: Model Loading (Run Only Once) -------------------

class ModelLoader:
    def __init__(self, device="cuda"):
        self.device = device
        self.siglip_processor = None
        self.siglip_model = None
        self.lvlm_processor = None
        self.lvlm_model = None
        self.pipe = None
        self.empty_pooled_prompt_embeds = None

    def load_all_models(self):
        """Loads all the necessary models onto the specified device."""
        print("Loading all models onto the GPU... This may take a moment.")

        # Load your models using your provided functions
        self.siglip_processor, self.siglip_model = load_siglip(SIGLIP_PATH, device=self.device)
        self.lvlm_processor, self.lvlm_model = load_lvlm(
            LVLM_PATH,
            DENOISE_PROJECTOR_WEIGHT_PATH,
            SIGLIP_PROJECTOR_WEIGHT_PATH,
            device=self.device,
        )
        self.pipe = load_flux(FLUX_PATH, self.lvlm_model, device=self.device)

        # Pre-compute the empty prompt embeds once, as it's constant
        self._precompute_empty_prompt()
        print("Models loaded successfully.")

    def _precompute_empty_prompt(self):
        """Encodes an empty prompt for classifier-free guidance."""
        tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
        text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]

        _, self.empty_pooled_prompt_embeds = encode_prompt(
            text_encoders,
            tokenizers,
            prompt="",
            max_sequence_length=256,
            device=self.device,
            num_images_per_prompt=1,
        )


In [None]:

# --- Initial Setup ---
# Create an instance of the loader and load the models.
# This is the section you only run one time.
device = "cuda"
model_manager = ModelLoader(device=device)
model_manager.load_all_models()


# ------------------- Part 2: Image Generation (Run as many times as you want) -------------------

def generate_image(models, prompt, input_image, num_images_per_prompt=4, num_steps=30, guidance_scale=4.0):
    """
    Generates an image using the pre-loaded models.
    This function can be called repeatedly without reloading models.
    """
    print(f"Generating image for prompt: '{prompt}'")
    # Ensure image is in the correct format and size
    processed_image = input_image.resize((1024, 1024))

    # Build conversation
    convo = build_conversation(prompt, processed_image)

    # Encode inputs
    lvlm_inputs = preprocess_lvlm_inputs(models.lvlm_processor, convo, processed_image, device=models.device)
    siglip_hs = get_siglip_hidden_states(models.siglip_processor, models.siglip_model, processed_image, device=models.device)

    # Forward pass to get embeddings
    with torch.no_grad():
        lvlm_embeds = models.lvlm_model(
            lvlm_inputs.input_ids,
            pixel_values=getattr(lvlm_inputs, "pixel_values", None),
            attention_mask=lvlm_inputs.attention_mask,
            image_grid_thw=getattr(lvlm_inputs, "image_grid_thw", None),
            siglip_hidden_states=siglip_hs,
            output_type="denoise_embeds",
        )

    # Generate image using the pre-loaded diffusion pipe
    generator = torch.Generator(device=models.device)
    images = models.pipe(
        prompt_embeds=lvlm_embeds,
        pooled_prompt_embeds=models.empty_pooled_prompt_embeds, # Use the pre-computed empty embeds
        height=1024,
        width=1024,
        num_inference_steps=num_steps,
        guidance_scale=guidance_scale,
        generator=generator,
        num_images_per_prompt=num_images_per_prompt,
    ).images
    
    # It's good practice to clear cache after a generation step
    torch.cuda.empty_cache()
    
    print("Image generation complete.")
    return images


In [None]:

# ------------------- Example Usage -------------------

# --- First Generation ---
# In this cell, you can now call the generation function.



img_clean_prompt_pairs = get_img_and_clean_prompt(DATASET_FOLDER)

for i, (ori_img, prompt) in enumerate(img_clean_prompt_pairs):
    cleaned_imgs = generate_image(model_manager, prompt, ori_img)
    cleaned_imgs = [img.resize(ori_img.size) for img in cleaned_imgs]

    img_list = [ori_img] + cleaned_imgs
    img_grid = make_image_grid(img_list, cols=5, rows=1)
    img_grid.save(f"/workspace/UniWorld-V1/xinyi_test_data/1024_comparison/{i}.jpg")
