In [None]:
# ==============================
# Kosmos-2 en COCO val2017 (vehículos)
# ==============================

import os
import torch
import json
from tqdm import tqdm
from PIL import Image, ImageDraw
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from transformers import AutoProcessor, AutoModelForVision2Seq
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"

# ----- 1️⃣ COCO dataset -----
dataDir = "./COCO"
dataType = "val2017"
annFile = os.path.join(dataDir, "annotations", f"instances_{dataType}.json")
coco_gt = COCO(annFile)

vehicle_classes = ['car', 'bus', 'truck', 'bicycle', 'motorcycle']
cat_ids = coco_gt.getCatIds(catNms=vehicle_classes)

# Obtener imágenes con al menos un vehículo
img_ids_set = set()
for cat_id in cat_ids:
    ids = coco_gt.getImgIds(catIds=[cat_id])
    img_ids_set.update(ids)

img_ids = list(img_ids_set)
images = coco_gt.loadImgs(img_ids)
print(f"{len(images)} imágenes con al menos un vehículo encontradas.")

# ----- 2️⃣ Cargar Kosmos-2 -----
model_id = "microsoft/kosmos-2-patch14-224"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id).to(device)
model.eval()

# ----- 3️⃣ Inferencia y conversión a formato COCO -----
coco_results = []

for img in tqdm(images, desc="Procesando imágenes"):
    img_path = os.path.join(dataDir, "images", dataType, img['file_name'])
    image = Image.open(img_path).convert("RGB")
    width, height = image.size

    # Prompt: usar grounding + lista de clases
    prompt = "<grounding> An image of " + ", ".join(vehicle_classes)

    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)

    with torch.no_grad():
        generated_ids = model.generate(
            pixel_values=inputs["pixel_values"],
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            image_embeds=None,
            image_embeds_position_mask=inputs["image_embeds_position_mask"],
            use_cache=True,
            max_new_tokens=128,
        )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    # Postprocesar: texto + entidades (frases + cajas)
    _, entities = processor.post_process_generation(generated_text)

    # Dibujar detecciones y guardar en formato COCO
    draw = ImageDraw.Draw(image)
    for phrase, _, boxes in entities:
        phrase = phrase.lower()
        if any(cls in phrase for cls in vehicle_classes):
            for box in boxes:
                # boxes devueltos están normalizados (x0, y0, x1, y1) en [0,1]
                x0, y0, x1, y1 = box
                x0, y0, x1, y1 = x0*width, y0*height, x1*width, y1*height
                w, h = x1 - x0, y1 - y0

                # Buscar categoría de COCO
                matched_cls = next((cls for cls in vehicle_classes if cls in phrase), None)
                if matched_cls:
                    coco_results.append({
                        "image_id": img['id'],
                        "category_id": coco_gt.getCatIds(catNms=[matched_cls])[0],
                        "bbox": [x0, y0, w, h],
                        "score": 1.0  # Kosmos-2 no devuelve score explícito
                    })

                    draw.rectangle([x0, y0, x1, y1], outline="red", width=2)
                    draw.text((x0, y0-10), f"{matched_cls}", fill="red")

# ----- 4️⃣ Guardar resultados -----
with open("kosmos2_vehicles.json", "w") as f:
    json.dump(coco_results, f)
print(f"Resultados guardados: {len(coco_results)} detecciones")

# ----- 5️⃣ Evaluación COCO -----
if len(coco_results) > 0:
    coco_dt = coco_gt.loadRes("kosmos2_vehicles.json")
    coco_eval = COCOeval(coco_gt, coco_dt, iouType='bbox')
    coco_eval.params.catIds = cat_ids
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()
else:
    print("No se detectaron objetos, ajusta prompt o revisa post-procesado.")

# Mostrar última imagen procesada
plt.figure(figsize=(10, 8))
plt.imshow(image)
plt.axis("off")
plt.show()
