# 04 - Inference & API
Run inference and demonstrate FastAPI endpoint usage.

In [None]:
print('Notebook scaffold - add inference & API usage here')

# Inference and Visualization
This embeds the inference function and the drawing utility to visualize predictions, pulled directly from `src/infer.py` and `newspaper_yolo_api/app/draw.py`.

In [None]:
# src/infer.py content
from typing import List, Union
from pathlib import Path
from ultralytics import YOLO

try:
    from src.config import InferenceConfig
except Exception:
    import sys
    sys.path.append('.')
    from src.config import InferenceConfig


def infer(weights: str, sources: Union[str, List[str]], cfg: InferenceConfig = InferenceConfig()):
    model = YOLO(weights)
    results = model.predict(source=sources, conf=cfg.conf, iou=cfg.iou, imgsz=cfg.imgsz)
    return results

In [None]:
# app/draw.py content (visualization)
from typing import List, Dict, Tuple
import cv2
import numpy as np
import colorsys

FIXED_PALETTE = {
    "Header": (255, 0, 0),
    "Title": (0, 0, 255),
    "Text": (0, 170, 0),
    "Table": (0, 255, 255),
    "Image": (255, 0, 255),
    "Footer": (255, 255, 0),
    "Stamp or Signature": (0, 140, 255),
    "Caption": (147, 20, 255),
    "Keyvalue": (0, 255, 127),
    "List-item": (255, 128, 0),
    "Check-box": (128, 128, 128),
}

def _hash_color_bgr(name: str):
    h = (abs(hash(name)) % 360) / 360.0
    s, v = 0.75, 1.0
    r, g, b = colorsys.hsv_to_rgb(h, s, v)
    return (int(b * 255), int(g * 255), int(r * 255))

def _get_color(cls_name: str):
    return FIXED_PALETTE.get(cls_name, _hash_color_bgr(cls_name))

def _text_color_for(bgr):
    b, g, r = bgr
    y = 0.299 * r + 0.587 * g + 0.114 * b
    return (0, 0, 0) if y > 170 else (255, 255, 255)

def _draw_label(img, x1, y1, text, color):
    font = cv2.FONT_HERSHEY_SIMPLEX
    fs = 0.6
    th = 2
    (tw, th_text), _ = cv2.getTextSize(text, font, fs, th)
    x2 = x1 + tw + 8
    y2 = y1 - th_text - 8
    if y2 < 0:
        y2 = y1 + th_text + 8
        y1_lab = y1
    else:
        y1_lab = y1 - 2
    cv2.rectangle(img, (x1, y2), (x2, y1_lab), color, thickness=-1)
    tcolor = _text_color_for(color)
    ty = y2 + th_text + 3 if y2 >= 0 else y1 + th_text + 3
    cv2.putText(img, text, (x1 + 4, ty), font, fs, tcolor, th, cv2.LINE_AA)

def draw_detections(img: np.ndarray, dets: List[dict], class_palette: Dict[str, Tuple[int, int, int]] = None, fill_alpha: float = 0.15, thickness: int = 2, font_scale: float = 0.6, show_legend: bool = True) -> np.ndarray:
    if class_palette is None:
        class_palette = FIXED_PALETTE
    overlay = img.copy()
    for d in dets:
        x1, y1, x2, y2 = map(int, d["xyxy"])
        cls_name = d.get("cls_name", "unknown")
        color = class_palette.get(cls_name, _get_color(cls_name))
        cv2.rectangle(overlay, (x1, y1), (x2, y2), color, thickness=-1)
    img = cv2.addWeighted(overlay, fill_alpha, img, 1 - fill_alpha, 0)
    for d in dets:
        x1, y1, x2, y2 = map(int, d["xyxy"])
        cls_name = d.get("cls_name", "unknown")
        conf = float(d.get("conf", 0.0))
        color = class_palette.get(cls_name, _get_color(cls_name))
        cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness=thickness, lineType=cv2.LINE_AA)
        label = f"{cls_name} {conf:.2f}"
        _draw_label(img, x1, y1, label, color)
    if show_legend and len(dets) > 0:
        present, seen = [], set()
        for d in dets:
            name = d.get("cls_name", "unknown")
            if name not in seen:
                present.append(name)
                seen.add(name)
        x0, y0 = 12, 12
        h_line = 22
        for i, name in enumerate(present[:20]):
            color = class_palette.get(name, _get_color(name))
            y = y0 + i * h_line
            cv2.rectangle(img, (x0, y), (x0 + 16, y + 16), color, thickness=-1)
            tcolor = _text_color_for(color)
            cv2.putText(img, name, (x0 + 22, y + 14), cv2.FONT_HERSHEY_SIMPLEX, 0.5, tcolor, 1, cv2.LINE_AA)
    return img