In [None]:
from transformers import (
    AutoProcessor,
    SiglipImageProcessor,
    SiglipVisionModel,
    T5EncoderModel,
    BitsAndBytesConfig,
)
from univa.models.qwen2p5vl.modeling_univa_qwen2p5vl import (
    UnivaQwen2p5VLForConditionalGeneration,
)
from univa.utils.flux_pipeline import FluxPipeline
from univa.utils.get_ocr import get_ocr_result
from univa.utils.denoiser_prompt_embedding_flux import encode_prompt
from qwen_vl_utils import process_vision_info
from univa.utils.anyres_util import dynamic_resize, concat_images_adaptive
import torch
from torch import nn
import os
import uuid
import base64
from typing import Dict
from PIL import Image, ImageDraw, ImageFont
from pathlib import Path
from diffusers.utils import make_image_grid


In [2]:
FLUX_PATH = Path("model_weight/FLUX.1-dev")
SIGLIP_PATH = Path("model_weight/siglip2-so400m-patch16-512")
MODEL_PATH = Path("model_weight/UniWorld-V1")

In [3]:
def initialize_models(nf4=True, offload=True):
    os.makedirs("tmp", exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
    )

    # Load main model and task head
    model = UnivaQwen2p5VLForConditionalGeneration.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        quantization_config=quantization_config if nf4 else None,
    ).to(device)

    task_head = nn.Sequential(
        nn.Linear(3584, 10240), nn.SiLU(), nn.Dropout(0.3), nn.Linear(10240, 2)
    ).to(device)
    task_head.load_state_dict(torch.load(MODEL_PATH / "task_head_final.pt"))
    task_head.eval()

    processor = AutoProcessor.from_pretrained(
        MODEL_PATH,
        min_pixels=448 * 448,
        max_pixels=448 * 448,
    )

    if nf4:
        text_encoder_2 = T5EncoderModel.from_pretrained(
            FLUX_PATH,
            subfolder="text_encoder_2",
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
        )
        pipe = FluxPipeline.from_pretrained(
            FLUX_PATH,
            transformer=model.denoise_tower.denoiser,
            text_encoder_2=text_encoder_2,
            torch_dtype=torch.bfloat16,
        ).to(device)
    else:
        pipe = FluxPipeline.from_pretrained(
            FLUX_PATH,
            transformer=model.denoise_tower.denoiser,
            torch_dtype=torch.bfloat16,
        ).to(device)

    if offload:
        pipe.enable_model_cpu_offload()
        pipe.enable_vae_slicing()

    tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
    text_encoders = [pipe.text_encoder, pipe.text_encoder_2]

    siglip_processor = SiglipImageProcessor.from_pretrained(SIGLIP_PATH)
    siglip_model = SiglipVisionModel.from_pretrained(
        SIGLIP_PATH,
        torch_dtype=torch.bfloat16,
    ).to(device)

    return {
        "model": model,
        "task_head": task_head,
        "processor": processor,
        "pipe": pipe,
        "tokenizers": tokenizers,
        "text_encoders": text_encoders,
        "siglip_processor": siglip_processor,
        "siglip_model": siglip_model,
        "device": device,
    }

In [None]:
state = initialize_models()


In [6]:
from PIL import Image
import torch
from univa.utils.denoiser_prompt_embedding_flux import encode_prompt
from qwen_vl_utils import process_vision_info


@torch.inference_mode()
def run_inference(
    text,
    image1: Image.Image = None,
    image2: Image.Image = None,
    height=1024,
    width=1024,
    steps=30,
    guidance=4.0,
    joint_with_t5=True,
    num_imgs=1,
    seed=-1,
):
    convo = []
    content = []
    image_list: list[Image.Image] = []

    if text:
        content.append({"type": "text", "text": text})

    for image in [image1, image2]:
        if image:
            content.append(
                {
                    "type": "image",
                    "image": image,
                    "min_pixels": 448 * 448,
                    "max_pixels": 448 * 448,
                }
            )
            image_list.append(image)

    convo.append({"role": "user", "content": content})

    # === Tokenize input ===
    chat_text = state["processor"].apply_chat_template(
        convo, tokenize=False, add_generation_prompt=True
    )
    chat_text = "<|im_end|>\n".join(chat_text.split("<|im_end|>\n")[1:])
    image_inputs, video_inputs = process_vision_info(convo)

    inputs = state["processor"](
        text=[chat_text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to(state["device"])

    # === SigLIP embedding (optional)
    siglip_hs = None
    if state["siglip_processor"] and image_list:
        tensors = [
            state["siglip_processor"]
            .preprocess(
                image.convert("RGB"),
                do_resize=True,
                do_convert_rgb=True,
                return_tensors="pt",
            )
            .pixel_values.to(state["device"])
            for image in image_list
        ]
        siglip_hs = state["siglip_model"](torch.cat(tensors)).last_hidden_state

    # === Main model forward ===
    lvlm = state["model"](
        inputs.input_ids,
        pixel_values=getattr(inputs, "pixel_values", None),
        attention_mask=inputs.attention_mask,
        image_grid_thw=getattr(inputs, "image_grid_thw", None),
        siglip_hidden_states=siglip_hs,
        output_type="denoise_embeds",
    )

    # === Add T5 if needed ===
    prm_embeds, pooled = encode_prompt(
        state["text_encoders"],
        state["tokenizers"],
        text if joint_with_t5 else "",
        256,
        state["device"],
        1,
    )
    prompt_embeds = torch.cat([lvlm, prm_embeds], dim=1) if joint_with_t5 else lvlm

    # === Generator and seed ===
    if seed == -1:
        seed = torch.seed()
    generator = torch.Generator(device=state["device"]).manual_seed(seed)

    # === Generate image ===
    images = state["pipe"](
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=pooled,
        height=height,
        width=width,
        num_inference_steps=steps,
        guidance_scale=guidance,
        generator=generator,
        num_images_per_prompt=num_imgs,
    ).images

    return images


In [7]:
from gemini.gemini_clean import GeminiImageCleaner

In [12]:
GEMINI_KEY = "AIzaSyD9faZ4OGTkuHBmFaU-F6gz_6NNYB32IQQ"


gemini_cleaner = GeminiImageCleaner(
    api_key=GEMINI_KEY, config_path="gemini/gemini_config.yml"
)

In [13]:
TEST_IMG_FOLDER = Path("/root/UniWorld-V1/empty_room")

In [14]:
img_paths = list(TEST_IMG_FOLDER.glob("*.jpg"))

In [15]:
target_path = TEST_IMG_FOLDER / "35.jpg"
img = Image.open(target_path).resize((1024, 1024))


In [None]:
# img.thumbnail((1024, 1024))
decor_prompt = "Generate a prompt to add multiple furniture items to an empty room in the format: Add [a/an modifier furniture item orientation/location], [a/an modifier furniture item orientation/location], etc. Return prompt only."
furnishing_prompt = gemini_cleaner.talk_with_gemini([img], decor_prompt)

uniworld_img = run_inference(furnishing_prompt, img)[0]

gemini_img = gemini_cleaner.clean_image(img, furnishing_prompt)


In [None]:
from PIL import Image, ImageDraw, ImageFont
from PIL import Image, ImageDraw, ImageFont


def annotate(img, text, position="right"):
    annotated = img.copy()
    draw = ImageDraw.Draw(annotated)
    font_size = 40
    try:
        font = ImageFont.truetype("arial.ttf", font_size)
    except IOError:
        font = ImageFont.load_default(size=font_size)

    # Calculate text size using textbbox
    bbox = draw.textbbox((0, 0), text, font=font)
    text_width = bbox[2] - bbox[0]
    text_height = bbox[3] - bbox[1]

    padding = 10
    img_width, img_height = img.size

    if position == "right":
        x = img_width - text_width - 2 * padding
    else:  # default to "left"
        x = padding

    y = padding

    # Draw background rectangle
    draw.rectangle(
        [x - padding, y - padding, x + text_width + padding, y + text_height + padding],
        fill=(0, 0, 0, 180),
    )

    # Draw text
    draw.text((x, y), text, font=font, fill=(255, 255, 255))
    return annotated


annotate(img, "restset")

In [None]:
make_image_grid([img, uniworld_img, gemini_img], cols=3, rows=1)

In [None]:
img = Image.open(img_paths[0])
furnishing_img_list = []


for img_path in img_paths:
    img = Image.open(img_path)
    furnishing_prompt = gemini_cleaner.talk_with_gemini(
        [img],
        "give me a image editing prompt for adding furnture to the given empty room, return prompt only, format Add furniture with description",
    )
    furnishing_img = run_inference(furnishing_prompt, img)[0]
    furnishing_img_list.append(furnishing_img)

In [31]:
img_grid_list = []

for i, image_path in enumerate(img_paths):
    img = Image.open(image_path)
    img.thumbnail((1024, 1024))
    furnished_img = furnishing_img_list[i].resize(img.size)
    img_grid = make_image_grid([img, furnished_img], cols=2, rows=1)
    img_grid_list.append(img_grid)


In [38]:
for i, img_grid in enumerate(img_grid_list):
    img_grid.save(f"/root/app/UniWorld-V1/comparison/{i}.jpg")