In [None]:
"""
Rescue partially cropped faces by tracing cropped filenames back to originals,
re-detecting faces, and saving full, padded square crops with a `_rescued` suffix.

Example:
    python rescue_partial_faces.py \
        --bad_crops_dir /data/bad_crops \
        --originals_dir /data/originals \
        --output_dir /data/rescued \
        --yunet /models/face_detection_yunet_2023mar.onnx \
        --score_thresh 0.85 --margin 0.60
"""

In [None]:
import os, re, glob, argparse, json
from typing import Optional, Tuple, List, Dict

import cv2
import numpy as np

ALLOWED_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".tiff")

In [None]:
# ---------- Helpers: parsing & file lookup ----------

def parse_root_and_index(cropped_filename: str) -> Tuple[str, Optional[int]]:
    """
    Accepts names like:
      image_66752_face_1.jpg
      image_66752_face_1_padded.jpg
      image_66752_face_1_rescued.jpg
    Returns: ("image_66752", 1)
    If no index, returns index=None.
    """
    stem = os.path.splitext(os.path.basename(cropped_filename))[0]
    stem = re.sub(r'_(padded|rescued)$', '', stem)   # strip trailing suffixes
    m = re.match(r'^(.*)_face_(\d+)$', stem)
    if m:
        return m.group(1), int(m.group(2))
    # fallback: take everything before the first '_face_'
    p = stem.split("_face_")[0]
    return p, None


def find_original_by_root(root_id: str, originals_dir: str) -> Optional[str]:
    """Search recursively under originals_dir for an image whose basename is
    exactly root_id (best), otherwise a basename that startswith(root_id)."""
    # 1) exact basename match anywhere under originals_dir
    for dirpath, _, filenames in os.walk(originals_dir):
        for fn in filenames:
            ext = os.path.splitext(fn.lower())[1]
            if ext not in ALLOWED_EXTS:
                continue
            name = os.path.splitext(fn)[0]
            if name == root_id:
                return os.path.join(dirpath, fn)

    # 2) fallback: prefix match (handles variants like image_66752_v2.jpg)
    candidates = []
    for dirpath, _, filenames in os.walk(originals_dir):
        for fn in filenames:
            ext = os.path.splitext(fn.lower())[1]
            if ext not in ALLOWED_EXTS:
                continue
            name = os.path.splitext(fn)[0]
            if name.startswith(root_id):
                candidates.append(os.path.join(dirpath, fn))

    if candidates:
        # prefer the shortest basename (closest to exact), then lexicographic
        candidates.sort(key=lambda p: (len(os.path.splitext(os.path.basename(p))[0]), p))
        return candidates[0]

    return None


# ---------- Helpers: geometry & cropping ----------

def _expand_square_bbox_from_landmarks(landmarks, margin_ratio, img_w, img_h):
    # landmarks: [x_r_eye, y_r_eye, x_l_eye, y_l_eye, x_nose, y_nose, x_r_mouth, y_r_mouth, x_l_mouth, y_l_mouth]
    xs = landmarks[0::2]
    ys = landmarks[1::2]
    x_min, x_max = min(xs), max(xs)
    y_min, y_max = min(ys), max(ys)
    cx = (x_min + x_max) / 2.0
    cy = (y_min + y_max) / 2.0
    side = max(x_max - x_min, y_max - y_min)
    side = int(round(side * (1.0 + margin_ratio)))
    half = side // 2
    x0 = int(round(cx - half))
    y0 = int(round(cy - half))
    return x0, y0, side, side


def _safe_crop_with_padding(image, x, y, w, h, pad_mode=cv2.BORDER_REFLECT_101):
    H, W = image.shape[:2]
    left   = max(0, -x)
    top    = max(0, -y)
    right  = max(0, x + w - W)
    bottom = max(0, y + h - H)
    if any(v > 0 for v in (left, top, right, bottom)):
        padded = cv2.copyMakeBorder(image, top, bottom, left, right, pad_mode)
        x_p = x + left
        y_p = y + top
        return padded[y_p:y_p+h, x_p:x_p+w]
    else:
        return image[y:y+h, x:x+w]

# ---------- Face detection & rescue ----------

def make_yunet(model_path: str, score_thresh: float, nms_thresh: float, top_k: int):
    det = cv2.FaceDetectorYN.create(
        model=model_path,
        config="",
        input_size=(320, 320),     # will be reset per image
        score_threshold=score_thresh,
        nms_threshold=nms_thresh,
        top_k=top_k,
    )
    return det


def redetect_and_save_all_faces_from_original(
    original_img_path: str,
    face_detector,
    output_dir: str,
    margin_ratio: float = 0.55,
    min_side: int = 64,
) -> List[str]:
    os.makedirs(output_dir, exist_ok=True)
    img = cv2.imread(original_img_path)
    if img is None:
        return []
    H, W = img.shape[:2]
    face_detector.setInputSize((W, H))
    _, faces = face_detector.detect(img)
    if faces is None or len(faces) == 0:
        return []

    saved = []
    root = os.path.splitext(os.path.basename(original_img_path))[0]

    for i, f in enumerate(faces):
        box = list(map(int, f[:4]))
        landmarks = list(map(int, f[4:14])) if len(f) >= 14 else []

        # Landmark-first square; fallback to box-based expansion
        if len(landmarks) >= 10:
            x_sq, y_sq, w_sq, h_sq = _expand_square_bbox_from_landmarks(landmarks, margin_ratio, W, H)
        else:
            side = int(round(max(box[2], box[3]) * (1.0 + margin_ratio)))
            cx = box[0] + box[2] // 2
            cy = box[1] + box[3] // 2
            x_sq = int(cx - side // 2); y_sq = int(cy - side // 2)
            w_sq = h_sq = side

        if min(w_sq, h_sq) < min_side:
            continue

        crop = _safe_crop_with_padding(img, x_sq, y_sq, w_sq, h_sq)
        out_path = os.path.join(output_dir, f"{root}_face_{i+1}_rescued.jpg")
        cv2.imwrite(out_path, crop)
        saved.append(out_path)

    return saved


def rescue_from_bad_crops(
    bad_crops_dir: str,
    originals_dir: str,
    output_dir: str,
    face_detector,
    margin_ratio: float = 0.55,
) -> Dict[str, List[str]]:
    """Process each bad crop once by its root; redetect & save ALL faces from the original."""
    os.makedirs(output_dir, exist_ok=True)
    bads = [os.path.join(bad_crops_dir, f) for f in os.listdir(bad_crops_dir)
            if os.path.splitext(f.lower())[1] in ALLOWED_EXTS]
    processed_roots = set()
    rescued_index: Dict[str, List[str]] = {}

    for path in sorted(bads):
        root, _ = parse_root_and_index(path)
        if root in processed_roots:
            continue
        orig = find_original_by_root(root, originals_dir)
        if not orig:
            # No original found; skip gracefully
            processed_roots.add(root)
            rescued_index[root] = []
            continue

        saved_paths = redetect_and_save_all_faces_from_original(
            original_img_path=orig,
            face_detector=face_detector,
            output_dir=output_dir,
            margin_ratio=margin_ratio,
        )
        processed_roots.add(root)
        rescued_index[root] = saved_paths

    return rescued_index

In [None]:
# ---------- JUPYTER NOTEBOOK ENTRYPOINT ----------

def run_rescue(
    bad_crops_dir: str,
    originals_dir: str,
    output_dir: str,
    yunet_path: str,
    score_thresh: float = 0.85,
    nms_thresh: float = 0.3,
    top_k: int = 5000,
    margin: float = 0.60,
    min_side: int = 64,
    index_json: Optional[str] = None,
):
    """
    Run the rescue process inside Jupyter (no CLI parsing).
    Returns a dict {root_id: [saved_paths]}.
    """
    det = make_yunet(yunet_path, score_thresh, nms_thresh, top_k)

    # patch in the min_side arg for convenience
    global redetect_and_save_all_faces_from_original
    orig_func = redetect_and_save_all_faces_from_original
    def _wrapped_redetect(*fa, **fk):
        fk.setdefault("min_side", min_side)
        return orig_func(*fa, **fk)
    redetect_and_save_all_faces_from_original = _wrapped_redetect

    rescued = rescue_from_bad_crops(
        bad_crops_dir=bad_crops_dir,
        originals_dir=originals_dir,
        output_dir=output_dir,
        face_detector=det,
        margin_ratio=margin,
    )

    if index_json:
        import json
        with open(index_json, "w") as f:
            json.dump(rescued, f, indent=2)

    total = sum(len(v) for v in rescued.values())
    print(f"[Rescue] processed {len(rescued)} roots | saved {total} crops")
    return rescued