In [1]:
from werkzeug.utils import secure_filename
import os
import re
from PIL import Image, ImageDraw, ImageFont
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
import matplotlib.pyplot as plt
import matplotlib.patches as patches

model_id = "google/paligemma-3b-mix-224"
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to("cuda" if torch.cuda.is_available() else "cpu").eval()


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
def process_bbox(image_path, bbox_string, output_path):
    image = Image.open(image_path).convert("RGB")
    draw = ImageDraw.Draw(image)
    width, height = image.size

    # You can set a font if you have one, else use default
    try:
        font = ImageFont.truetype("arial.ttf", size=16)
    except:
        font = ImageFont.load_default()

    pattern = r"<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})>\s*([a-zA-Z]+)"
    matches = re.findall(pattern, bbox_string)

    for ymin_str, xmin_str, ymax_str, xmax_str, label in matches:
        ymin = int(ymin_str) / 1000 * height
        xmin = int(xmin_str) / 1000 * width
        ymax = int(ymax_str) / 1000 * height
        xmax = int(xmax_str) / 1000 * width

        # Draw rectangle
        draw.rectangle([(xmin, ymin), (xmax, ymax)], outline='red', width=2)

        # Draw label
        text_position = (xmin, ymin - 15 if ymin > 15 else ymin + 5)
        draw.text(text_position, label, fill='yellow', font=font)

    image.save(output_path)

In [3]:

def process_image(input_path, output_path):
    image = Image.open(input_path).convert("RGB")
    # Add <image> token prefix per model requirement
    prompt = "<image> detect chair ; table ; bed ; sofa ; shelf\n"

    inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
    input_len = inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        outputs = model.generate(**inputs, max_new_tokens=200, do_sample=False)
        outputs = outputs[0][input_len:]

    result = processor.decode(outputs, skip_special_tokens=True)
    print(f"Model detection output: {result}")

    process_bbox(input_path, result, output_path)


In [7]:

process_image(input_path='static/uploads/room.jpg',output_path='static/processed/processed_room.jpg')

Model detection output: <loc0596><loc0000><loc0911><loc0233> shelf ; <loc0596><loc0350><loc0905><loc0920> sofa ; <loc0720><loc0181><loc0913><loc0350> chair ; <loc0745><loc0397><loc0871><loc0612> table
