In [None]:
# ============================================================
# INFERENCIA EN VIDEO:
#   - YOLOv8 detecta personas
#   - Se recorta upper/lower por persona
#   - ResNet101 (tu clasificador) predice color para cada crop
#   - Se genera un video anotado con "Sup:" y "Inf:"
# ============================================================

!pip install -q timm albumentations==1.4.3 ultralytics opencv-python tqdm

import os
import cv2
import time
import torch
import timm
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from ultralytics import YOLO
from google.colab import drive
from pathlib import Path
from tqdm import tqdm

# ------------------------------------------------------------
# 1) Montar Drive y configurar rutas
# ------------------------------------------------------------
drive.mount('/content/drive')

CKPT_PATH  = "/content/drive/MyDrive/color_classifier/resnet101_color_best_20250725_075421.pth"  # <- ajusta
VIDEO_PATH = "/content/test_video_v1.mp4"                                          # <- ajusta
OUT_PATH   = "/content/test_video_result.mp4"

YOLO_WEIGHTS = "yolov8s.pt"
UPPER_END_RATIO   = 0.55
LOWER_START_RATIO = 0.45
CONF_PERSON = 0.3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ------------------------------------------------------------
# 2) Cargar el clasificador
# ------------------------------------------------------------
def load_color_classifier(ckpt_path: str):
    ckpt = torch.load(ckpt_path, map_location=device)

    model_name   = ckpt["model_name"]
    num_classes  = ckpt["num_classes"]
    class_names  = ckpt["class_names"]
    img_size     = ckpt["img_size"]
    mean         = ckpt.get("mean", (0.485, 0.456, 0.406))
    std          = ckpt.get("std",  (0.229, 0.224, 0.225))

    model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
    model.load_state_dict(ckpt["state_dict"])
    model.eval().to(device)

    tfm = A.Compose([
        A.Resize(img_size, img_size),
        A.Normalize(mean=mean, std=std),
        ToTensorV2(),
    ])
    return model, class_names, tfm

model, CLASS_NAMES, infer_tf = load_color_classifier(CKPT_PATH)
print("Clases:", CLASS_NAMES)

def classify_color(img_bgr):
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    x = infer_tf(image=img_rgb)['image'].unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
    idx = int(np.argmax(probs))
    return CLASS_NAMES[idx], float(probs[idx])

# ------------------------------------------------------------
# 3) YOLO personas
# ------------------------------------------------------------
yolo = YOLO(YOLO_WEIGHTS)

# ------------------------------------------------------------
# 4) Video
# ------------------------------------------------------------
cap = cv2.VideoCapture(VIDEO_PATH)
if not cap.isOpened():
    raise FileNotFoundError(f"No puedo abrir el video: {VIDEO_PATH}")

width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps    = cap.get(cv2.CAP_PROP_FPS)
total  = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out    = cv2.VideoWriter(OUT_PATH, fourcc, fps if fps > 0 else 25, (width, height))

print(f"Video: {width}x{height} @ {fps:.2f}fps | frames: {total}")

pbar = tqdm(total=total, desc="Procesando video", unit="frame")
while True:
    ret, frame = cap.read()
    if not ret:
        break

    results = yolo.predict(frame, conf=CONF_PERSON, classes=[0], verbose=False)
    r = results[0]
    boxes = r.boxes.xyxy.cpu().numpy().astype(int) if r.boxes is not None else []

    for box in boxes:
        x1, y1, x2, y2 = box.tolist()
        w, h = x2 - x1, y2 - y1
        if w <= 0 or h <= 0:
            continue

        upper_end   = y1 + int(h * UPPER_END_RATIO)
        lower_start = y1 + int(h * LOWER_START_RATIO)

        upper_crop = frame[max(y1,0):min(upper_end,height), max(x1,0):min(x2,width)]
        lower_crop = frame[max(lower_start,0):min(y2,height), max(x1,0):min(x2,width)]

        sup_label, sup_prob = ("NA", 0.0)
        inf_label, inf_prob = ("NA", 0.0)

        if upper_crop.size > 0:
            sup_label, sup_prob = classify_color(upper_crop)
        if lower_crop.size > 0:
            inf_label, inf_prob = classify_color(lower_crop)

        cv2.rectangle(frame, (x1, y1), (x2, y2), (0,255,0), 2)

        txt_sup = f"Sup: {sup_label} ({sup_prob:.2f})"
        txt_inf = f"Inf: {inf_label} ({inf_prob:.2f})"
        y_txt = max(y1 - 10, 0)
        cv2.putText(frame, txt_sup, (x1, y_txt), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,255), 2, cv2.LINE_AA)
        cv2.putText(frame, txt_inf, (x1, y_txt - 22 if y_txt-22>0 else y_txt + 22),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,255), 2, cv2.LINE_AA)

    out.write(frame)
    pbar.update(1)

cap.release()
out.release()
pbar.close()
print("✅ Video procesado y guardado en:", OUT_PATH)


Mounted at /content/drive
Device: cuda
Clases: ['amarillo', 'azul', 'beige', 'blanco', 'gris', 'marron', 'morado', 'naranja', 'negro', 'rojo', 'rosa', 'verde']
Downloading https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8s.pt to 'yolov8s.pt'...


100%|██████████| 21.5M/21.5M [00:00<00:00, 164MB/s]


Video: 1920x1080 @ 25.00fps | frames: 221


Procesando video: 100%|██████████| 221/221 [01:28<00:00,  2.49frame/s]

✅ Video procesado y guardado en: /content/test_video_result.mp4



