# Realtime Webcam Caption + Classification

In [None]:
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
import requests
import time

from transformers import BlipProcessor, BlipForConditionalGeneration

# ----------------------------
# Load models
# ----------------------------
CAPTION_MODEL_ID = "Salesforce/blip-image-captioning-base"

caption_processor = BlipProcessor.from_pretrained(CAPTION_MODEL_ID)
caption_model = BlipForConditionalGeneration.from_pretrained(CAPTION_MODEL_ID)

# ResNet-50 (as you already switched)
clf_model = torch.hub.load("pytorch/vision:v0.6.0", "resnet50", pretrained=True)
clf_model.eval()

device = "cuda" if torch.cuda.is_available() else "cpu"
caption_model = caption_model.to(device).eval()
clf_model = clf_model.to(device).eval()

# ----------------------------
# Labels
# ----------------------------
labels = requests.get("https://git.io/JJkYN", timeout=30).text.strip().split("\n")
if len(labels) < 1000:
    raise RuntimeError(f"Expected 1000 ImageNet labels, got {len(labels)}")

# ----------------------------
# Classifier preprocessing (ImageNet standard)
# ----------------------------
clf_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

def center_zoom_crop(pil_img: Image.Image, zoom: float) -> Image.Image:
    """Zoom > 1 crops tighter around center to bias classifier to closer object."""
    if zoom <= 1.0:
        return pil_img
    w, h = pil_img.size
    new_w = max(1, int(w / zoom))
    new_h = max(1, int(h / zoom))
    left = (w - new_w) // 2
    top = (h - new_h) // 2
    return pil_img.crop((left, top, left + new_w, top + new_h))

# ----------------------------
# Lightweight realtime throttle cache (NEW)
# ----------------------------
_frame_count = 0
_last_caption = ""
_last_confidences = {}

# ----------------------------
# Inference
# ----------------------------
@torch.no_grad()
def caption_and_classify(frame_rgb, zoom, topk, frame_stride, mode):
    """
    frame_rgb: numpy array RGB from gr.Image(type="numpy")
    zoom: float
    topk: int
    frame_stride: int (run inference every N frames)
    mode: "Both" | "Caption only" | "Classify only"
    returns: caption(str), confidences(dict), status(str)
    """
    global _frame_count, _last_caption, _last_confidences

    try:
        t0 = time.time()

        if frame_rgb is None:
            return "", {}, "No frame received yet."

        # Throttle: only run heavy inference every N frames
        _frame_count += 1
        stride = max(1, int(frame_stride))
        run_models = (_frame_count % stride == 0)

        if not run_models:
            ms = (time.time() - t0) * 1000
            status = f"OK (cached) | device={device} | stride={stride} | {ms:.0f} ms"
            # Return cached results
            return _last_caption, _last_confidences, status

        # Ensure uint8 for PIL
        if frame_rgb.dtype != "uint8":
            frame_rgb = frame_rgb.astype("uint8")

        img = Image.fromarray(frame_rgb).convert("RGB")

        caption = _last_caption
        confidences = _last_confidences

        # Enable autocast only on CUDA (NEW)
        autocast_ctx = torch.cuda.amp.autocast if device == "cuda" else None

        # ----- Caption on full frame -----
        if mode in ("Both", "Caption only"):
            cap_inputs = caption_processor(images=img, return_tensors="pt")
            cap_inputs = {k: v.to(device) for k, v in cap_inputs.items()}

            if autocast_ctx:
                with autocast_ctx():
                    cap_out = caption_model.generate(
                        **cap_inputs,
                        max_new_tokens=25,  # slightly shorter for realtime
                        num_beams=3,
                        do_sample=False
                    )
            else:
                cap_out = caption_model.generate(
                    **cap_inputs,
                    max_new_tokens=25,
                    num_beams=3,
                    do_sample=False
                )

            caption = caption_processor.decode(cap_out[0], skip_special_tokens=True)

        # ----- Classify on center-zoom crop -----
        if mode in ("Both", "Classify only"):
            cls_img = center_zoom_crop(img, float(zoom))
            x = clf_transform(cls_img).unsqueeze(0).to(device)

            if autocast_ctx:
                with autocast_ctx():
                    logits = clf_model(x)[0]
            else:
                logits = clf_model(x)[0]

            probs = torch.softmax(logits, dim=0)

            k = max(1, min(int(topk), 20))  # keep it sane for UI
            vals, idxs = torch.topk(probs, k=k)

            vals = vals.detach().cpu()
            idxs = idxs.detach().cpu()
            confidences = {labels[int(i)]: float(v) for v, i in zip(vals, idxs)}

        # Update cache (NEW)
        _last_caption = caption
        _last_confidences = confidences

        ms = (time.time() - t0) * 1000
        status = f"OK (inference) | device={device} | stride={stride} | {ms:.0f} ms"
        return caption, confidences, status

    except Exception as e:
        return "", {}, f"ERROR: {type(e).__name__}: {e}"

# ----------------------------
# Gradio
# ----------------------------
demo = gr.Interface(
    fn=caption_and_classify,
    inputs=[
        gr.Image(sources=["webcam"], streaming=True, type="numpy", label="Webcam"),
        gr.Slider(1.0, 3.0, value=1.8, step=0.1, label="Classifier center-zoom (higher = more foreground)"),
        gr.Slider(1, 10, value=5, step=1, label="Top-K classes"),
        gr.Slider(1, 6, value=2, step=1, label="Frame stride (run inference every N frames)"),
        gr.Radio(["Both", "Caption only", "Classify only"], value="Both", label="Mode"),
    ],
    outputs=[
        gr.Textbox(label="Caption (BLIP)", lines=3),
        gr.Label(label="Classification (ResNet50)", num_top_classes=5),
        gr.Textbox(label="Status / Debug", lines=2),
    ],
    title="Realtime Webcam: Caption + Foreground-Biased Classification",
    description=(
        "Caption uses the full frame. Classification uses a center-zoom crop. "
        "Use Frame stride to reduce load and improve realtime stability."
    ),
    live=True,
)

demo.launch(share=True, debug=True)
