In [None]:
# ==============================
# Qwen-VL en COCO val2017 (vehículos)
# ==============================

import os
import re
import json
import torch
from tqdm import tqdm
from PIL import Image, ImageDraw
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"

# ----- 1️⃣ COCO dataset -----
dataDir = "./COCO"
dataType = "val2017"
annFile = os.path.join(dataDir, "annotations", f"instances_{dataType}.json")
coco_gt = COCO(annFile)

vehicle_classes = ['car', 'bus', 'truck', 'bicycle', 'motorcycle']
cat_ids = coco_gt.getCatIds(catNms=vehicle_classes)

# Obtener imágenes con vehículos
img_ids_set = set()
for cat_id in cat_ids:
    ids = coco_gt.getImgIds(catIds=[cat_id])
    img_ids_set.update(ids)
img_ids = list(img_ids_set)
images = coco_gt.loadImgs(img_ids)
print(f"{len(images)} imágenes con al menos un vehículo encontradas.")

# ----- 2️⃣ Cargar Qwen-VL -----
model_id = "Qwen/Qwen-VL"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True, fp16=True).eval()

# ----- 3️⃣ Inferencia y parsing -----
coco_results = []

for img in tqdm(images, desc="Procesando imágenes"):
    img_path = os.path.join(dataDir, "images", dataType, img['file_name'])
    width, height = Image.open(img_path).size

    # Prompt pidiendo grounding de vehículos
    query = tokenizer.from_list_format([
        {'image': img_path},
        {'text': 'Detect and ground cars, buses, trucks, bicycles, and motorcycles in this image:'},
    ])
    inputs = tokenizer(query, return_tensors='pt').to(model.device)

    with torch.no_grad():
        pred = model.generate(**inputs, max_new_tokens=128)

    response = tokenizer.decode(pred[0].cpu(), skip_special_tokens=False)

    # Extraer pares <ref> objeto </ref> y <box>(x1,y1),(x2,y2)</box>
    objects = re.findall(r"<ref>(.*?)</ref>\s*<box>\((\d+),(\d+)\),\((\d+),(\d+)\)</box>", response)

    for obj, x1, y1, x2, y2 in objects:
        obj = obj.lower()
        if any(cls in obj for cls in vehicle_classes):
            x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
            w, h = x2 - x1, y2 - y1

            # Buscar categoría COCO
            matched_cls = next((cls for cls in vehicle_classes if cls in obj), None)
            if matched_cls:
                coco_results.append({
                    "image_id": img['id'],
                    "category_id": coco_gt.getCatIds(catNms=[matched_cls])[0],
                    "bbox": [x1, y1, w, h],
                    "score": 1.0  # Qwen no devuelve score explícito
                })

# ----- 4️⃣ Guardar resultados -----
with open("qwen_vl_vehicles.json", "w") as f:
    json.dump(coco_results, f)
print(f"Resultados guardados: {len(coco_results)} detecciones")

# ----- 5️⃣ Evaluación COCO -----
if len(coco_results) > 0:
    coco_dt = coco_gt.loadRes("qwen_vl_vehicles.json")
    coco_eval = COCOeval(coco_gt, coco_dt, iouType='bbox')
    coco_eval.params.catIds = cat_ids
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()
else:
    print("No se detectaron objetos, revisa el prompt o parsing.")
