In [4]:
from __future__ import annotations
from typing import Any
import numpy as np
import cv2

from facefusion.types import VisionFrame

# ---------------------------
# Internal helpers
# ---------------------------

def _as_numpy(arr: Any) -> np.ndarray:
    """
    Accepts a NumPy array or a PyTorch tensor. If it's a tensor, detach+cpu+numpy.
    """
    if isinstance(arr, np.ndarray):
        return arr
    # PyTorch-like tensor
    if hasattr(arr, "detach") and hasattr(arr, "cpu") and hasattr(arr, "numpy"):
        return arr.detach().cpu().numpy()
    raise TypeError("Expected a NumPy array or a PyTorch tensor.")

def _to_uint8_0_255(x: np.ndarray) -> np.ndarray:
    """
    Normalize to [0,255] uint8. Handles common cases:
    - [-1, 1] float -> scale
    - [0, 1] float  -> scale
    - otherwise assumes already in [0,255]
    """
    x = x.astype(np.float32, copy=False)
    vmin, vmax = float(x.min()), float(x.max())
    if vmin >= -1.01 and vmax <= 1.01:
        x = x * 127.5 + 127.5
    elif vmin >= 0.0 and vmax <= 1.0:
        x = x * 255.0
    x = np.clip(x, 0.0, 255.0)
    return x.astype(np.uint8)

def _hwc3_bgr(img: np.ndarray) -> VisionFrame:
    """
    Ensure output is HWC, 3 channels, BGR, uint8, contiguous.
    """
    if img.ndim != 3:
        raise ValueError(f"Expected 3D array (HWC or CHW), got shape {img.shape}")

    # If CHW, convert to HWC
    if img.shape[0] in (1, 3, 4) and (img.shape[2] not in (1, 3, 4)):
        img = np.transpose(img, (1, 2, 0))

    # Drop alpha if present
    if img.shape[2] == 4:
        img = img[:, :, :3]

    # If grayscale, expand to 3 channels
    if img.shape[2] == 1:
        img = cv2.cvtColor(_to_uint8_0_255(img), cv2.COLOR_GRAY2BGR)
        return np.ascontiguousarray(img, dtype=np.uint8)

    # Assume RGB â†’ convert to BGR
    img = _to_uint8_0_255(img)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    return np.ascontiguousarray(img, dtype=np.uint8)

# ---------------------------
# 1) From a StyleGAN generator output tensor
# ---------------------------

def visionframe_from_stylegan(tensor: Any, select: int = 0) -> VisionFrame:
    """
    Convert a StyleGAN(2) output to VisionFrame.
    Accepts:
      - torch.Tensor or np.ndarray
      - shape (C,H,W) or (N,C,H,W) or (H,W,C) or (N,H,W,C)
    Values can be in [-1,1], [0,1], or [0,255].
    """
    arr = _as_numpy(tensor)

    # Handle batch dimension
    if arr.ndim == 4:
        arr = arr[select]  # pick one sample

    if arr.ndim != 3:
        raise ValueError(f"Expected 3D tensor after batching, got shape {arr.shape}")

    return _hwc3_bgr(arr)

# ---------------------------
# 2) From a PIL.Image.Image
# ---------------------------

def visionframe_from_pil(pil_image: "Image.Image") -> VisionFrame:
    """
    Convert a PIL image to VisionFrame, respecting EXIF orientation.
    """
    try:
        from PIL import ImageOps
    except ImportError as e:
        raise ImportError("Pillow is required for visionframe_from_pil.") from e

    pil_image = ImageOps.exif_transpose(pil_image).convert("RGB")
    rgb = np.array(pil_image, dtype=np.uint8)  # HWC, RGB
    bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
    return np.ascontiguousarray(bgr, dtype=np.uint8)

# ---------------------------
# 3) From a path to an image
# ---------------------------

def visionframe_from_path(path: str) -> VisionFrame:
    """
    Load an image from disk and return a VisionFrame (BGR, uint8).
    Uses PIL to handle EXIF orientation correctly.
    """
    try:
        from PIL import Image, ImageOps
    except ImportError as e:
        raise ImportError("Pillow is required for visionframe_from_path.") from e

    with Image.open(path) as im:
        im = ImageOps.exif_transpose(im).convert("RGB")
        rgb = np.array(im, dtype=np.uint8)
    bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
    return np.ascontiguousarray(bgr, dtype=np.uint8)

In [5]:
from quality_metrics import *

In [6]:
from PIL import Image
from facefusion_wrapper import init_facefusion_state, ensure_facefusion_models
from quality_metrics import *          # the metric we wrote

# 1) Initialize state like the CLI would
init_facefusion_state(
    detector_model="retinaface",
    detector_size="640x640",
    detector_score=0.75,
    detector_angles=(0,),         # add 90/-90 if you need rotated detection
    use_landmarker_68=True,      # set True if you want real 68-pt landmarks
    landmarker_score=0.5,
    download_scope="full",
)

# 2) Ensure models are downloaded & inference pools are ready
ensure_facefusion_models(use_landmarker_68=True)


In [7]:
# # 3) Run your pipeline
# import os
# for im in os.listdir("/Users/adamsobieszek/PycharmProjects/_manipy/age_oems_img"):
#     if im.endswith(".pt") or im.startswith("."):
#         continue

    
#     print(im)
#     vf = visionframe_from_pil(Image.open(f"/Users/adamsobieszek/PycharmProjects/_manipy/age_oems_img/{im}"))
#     reject, score, subs, reasons = should_reject(vf)
#     if reject:
#         os.remove(f"/Users/adamsobieszek/PycharmProjects/_manipy/age_oems_img/{im}")
#     print(reject, score, reasons)
#     print(subs)



import os
rej = 0
to_remove = []
all_reasons = []
for im in os.listdir("/Users/adamsobieszek/PycharmProjects/_manipy/content/out_filtered5"):
    if im.endswith(".pt") or im.startswith("."):
        continue


    vf = visionframe_from_pil(Image.open(f"/Users/adamsobieszek/PycharmProjects/_manipy/content/out_filtered5/{im}"))
    reject, score, subs, reasons = should_reject(vf)
    if reject:
        if rej%100==0:
            print(rej)
        rej +=1
        to_remove.append(im)
        all_reasons.append(reasons)
        os.remove(f"/Users/adamsobieszek/PycharmProjects/_manipy/content/out_filtered5/{im}")
        # if reasons == ["landmark_geometry_implausible"]:
print(reject, score, reasons)
print(subs)
print(all_reasons)

{'have_68': True, 'len_5': 5, 'lap_var': 1133.169704551504, 'mean': 95.28914027149321, 'std': 47.944034300971396, 'yaw_pitch_roll': (4.702521492927184, 27.438954382612998, -1.2347140312194824)}
{'have_68': True, 'len_5': 5, 'lap_var': 1365.3906652795317, 'mean': 123.90750988142292, 'std': 47.94432971879368, 'yaw_pitch_roll': (-4.755723286303828, 0.0680892477033477, -0.3775702118873596)}
{'have_68': True, 'len_5': 5, 'lap_var': 1680.1370016521253, 'mean': 98.69434989788972, 'std': 44.527370240644295, 'yaw_pitch_roll': (5.045251006138251, 16.99693822017773, 0.17360253632068634)}
{'have_68': True, 'len_5': 5, 'lap_var': 1381.96360081259, 'mean': 107.91742654508612, 'std': 69.37871466738997, 'yaw_pitch_roll': (-0.9415193882452082, -12.173628296698004, -0.5343043208122253)}
{'have_68': True, 'len_5': 5, 'lap_var': 1941.8792182339355, 'mean': 102.66990182954038, 'std': 63.570641715501516, 'yaw_pitch_roll': (-0.978396924475629, 11.902220704232212, 0.10300638526678085)}
{'have_68': True, 'len_

In [8]:
# ---------- Edge partial-face detection ----------

def _skin_mask_bgr(img_bgr: np.ndarray) -> np.ndarray:
    """Race-agnostic skin mask using YCrCb + HSV union, then cleaned."""
    ycrcb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2YCrCb)
    Y, Cr, Cb = cv2.split(ycrcb)
    # classic YCrCb box (wide)
    m1 = (Cr >= 133) & (Cr <= 180) & (Cb >= 77) & (Cb <= 135) & (Y >= 40)

    hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
    H, S, V = cv2.split(hsv)
    # wide HSV skin zone
    m2 = ((H <= 25) | (H >= 160)) & (S >= 30) & (V >= 40)

    m = (m1 | m2).astype(np.uint8) * 255
    # clean
    m = cv2.medianBlur(m, 5)
    m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, np.ones((5,5), np.uint8))
    m = cv2.morphologyEx(m, cv2.MORPH_OPEN,  np.ones((3,3), np.uint8))
    return m
# ---------- replace in your code ----------

def _strip_boxes(h: int, w: int, frac: float = 0.18):
    t = max(1, int(round(frac * min(h, w))))  # ensure at least 1 px
    return {
        "left":   (slice(0, h), slice(0, t)),
        "right":  (slice(0, h), slice(w - t, w)),
        "top":    (slice(0, t), slice(0, w)),
        "bottom": (slice(h - t, h), slice(0, w)),
    }

def _eye_cue(gray: np.ndarray, skin_mask: np.ndarray) -> float:
    """Return [0,1] eye-likeliness in a region (dark circle in skin)."""
    # Skip tiny/empty ROIs
    if gray.size == 0 or gray.shape[0] < 8 or gray.shape[1] < 8:
        return 0.0

    # emphasize dark blobs on skin
    g = cv2.GaussianBlur(gray, (0,0), 1.2)
    dog = cv2.Laplacian(g, cv2.CV_32F) * -1.0  # dark-on-light positive
    if skin_mask is not None and skin_mask.size == gray.size:
        dog = np.where(skin_mask.astype(bool), dog, 0.0)

    # HoughCircles needs CV_8UC1
    dog8 = cv2.normalize(dog, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    dog8 = cv2.equalizeHist(dog8)

    h, w = gray.shape
    rmin = max(2, int(0.01 * min(h, w)))
    rmax = max(rmin + 1, int(0.04 * min(h, w)))

    circles = cv2.HoughCircles(
        dog8, cv2.HOUGH_GRADIENT, dp=1.2,
        minDist=max(4, int(0.08 * min(h, w))),
        param1=80, param2=8,
        minRadius=rmin, maxRadius=rmax
    )
    if circles is None:
        return 0.0
    c = min(3, circles.shape[1])
    return float(min(1.0, c / 2.0))

def _mouth_cue(img_bgr: np.ndarray, skin_mask: np.ndarray) -> float:
    """Return [0,1] mouth-likeliness (reddish elongated blob in skin)."""
    b,g,r = cv2.split(img_bgr.astype(np.int16))
    lipness = np.clip((r - (g+ b)//2), 0, 255).astype(np.uint8)
    lipness[skin_mask == 0] = 0
    lipness = cv2.GaussianBlur(lipness, (0,0), 1.0)
    _, th = cv2.threshold(lipness, 0, 255, cv2.THRESH_OTSU)
    th = cv2.morphologyEx(th, cv2.MORPH_OPEN, np.ones((3,3), np.uint8))
    cnts,_ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    best = 0.0
    for c in cnts:
        area = cv2.contourArea(c)
        if area < 25: 
            continue
        x,y,w,h = cv2.boundingRect(c)
        ar = w/(h+1e-6)
        if ar >= 1.4:  # elongated horizontally
            best = max(best, min(1.0, (area / (img_bgr.shape[0]*img_bgr.shape[1]))*80.0))
    return best

def _cheek_arc_cue(skin_mask: np.ndarray, border_side: str) -> float:
    """
    Look for a large skin component touching the border that fits an inward-facing ellipse arc.
    Returns [0,1].
    """
    h,w = skin_mask.shape
    # components
    cnts,_ = cv2.findContours(skin_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    best = 0.0
    for c in cnts:
        if len(c) < 10: 
            continue
        area = cv2.contourArea(c)
        if area < 0.003 * h * w:
            continue
        x,y,bw,bh = cv2.boundingRect(c)
        touches = {
            "left":   x <= 1,
            "right":  x + bw >= w-2,
            "top":    y <= 1,
            "bottom": y + bh >= h-2
        }[border_side]
        if not touches:
            continue
        # fit ellipse for curvature check
        if len(c) >= 5:
            (cx, cy), (ma, mi), angle = cv2.fitEllipse(c)
            major = max(ma, mi)
            # center must be inside image and offset inward from that border
            inward = {
                "left":   cx > 0.25*major,
                "right":  (w - cx) > 0.25*major,
                "top":    cy > 0.25*major,
                "bottom": (h - cy) > 0.25*major
            }[border_side]
            if inward:
                # strength grows with area and curvature smoothness
                perim = cv2.arcLength(c, True) + 1e-6
                smooth = min(1.0, (area/perim) / 4.0)  # crude convex smoothness
                score = min(1.0, (area/(h*w))*150.0) * 0.6 + 0.4*smooth
                best = max(best, score)
    return best


def detect_partial_face_edges(img_bgr: np.ndarray, main_bbox: Optional[np.ndarray]=None,
                              debug: bool=False) -> Tuple[bool, dict]:
    """
    Return (has_partial_face, details_dict).
    Heavily biased toward catching *any* face fragment touching the image border,
    while trying to ignore neck/shoulder skin.
    """
    h, w = img_bgr.shape[:2]
    # Optionally blank out the central main face bbox so it's not counted as 'edge'
    mask_exclude = np.zeros((h,w), np.uint8)
    if main_bbox is not None:
        x1,y1,x2,y2 = np.array(main_bbox, int)
        cv2.rectangle(mask_exclude, (x1,y1), (x2,y2), 255, -1)

    strips = _strip_boxes(h, w, frac=0.18)
    skin = _skin_mask_bgr(img_bgr)
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)

    details = {}
    fired_any = False
    for side, (ys, xs) in strips.items():
        roi_skin = skin[ys, xs].copy()
        roi_skin[mask_exclude[ys, xs] > 0] = 0  # exclude main face region if provided
        if roi_skin.sum() < 100:  # almost no skin on this edge
            details[side] = {"skin_frac": 0.0, "eye": 0.0, "mouth": 0.0, "cheek": 0.0, "score": 0.0}
            continue

        roi_bgr = img_bgr[ys, xs]
        roi_gray = gray[ys, xs]

        skin_frac = float(roi_skin.mean()/255.0)

        eye_s   = _eye_cue(roi_gray, roi_skin)
        mouth_s = _mouth_cue(roi_bgr, roi_skin)
        cheek_s = _cheek_arc_cue(roi_skin, side)

        # Combine: require skin + (eye OR mouth OR cheek_arc)
        score = (0.4*eye_s + 0.3*mouth_s + 0.3*cheek_s) * (0.5 + 0.5*skin_frac)

        # Heavy-handed thresholds:
        # - fire if strong cue OR moderate cue with lots of skin on the edge
        fire = (score >= 0.35) or (skin_frac >= 0.25 and (eye_s >= 0.25 or mouth_s >= 0.25 or cheek_s >= 0.25))

        details[side] = {
            "skin_frac": skin_frac,
            "eye": eye_s,
            "mouth": mouth_s,
            "cheek": cheek_s,
            "score": score,
            "fire": bool(fire)
        }
        fired_any |= fire

    if debug:
        print(details)
    return fired_any, details



In [9]:
from __future__ import annotations
from typing import Any, Callable, Dict, List, Mapping, Tuple, Optional
import numpy as np
import cv2

# type for metric functions: receive a read-only context, return either a float
# or a (name, value) pair. If only a float is returned, the metric's __name__ is used.
MetricFn = Callable[[Mapping[str, Any]], float | Tuple[str, float]]
# ---- helpers ----
def _coalesce_ndarray(*vals):
    """
    Return the first value that is not None and, if ndarray-like, has size > 0.
    """
    for v in vals:
        if v is None:
            continue
        if isinstance(v, np.ndarray):
            if v.size == 0:
                continue
        return v
    return None

# ---- replace your _build_metric_context with this version ----
def _build_metric_context(vision_frame: VisionFrame) -> Dict[str, Any]:
    ctx: Dict[str, Any] = {}
    h, w = vision_frame.shape[:2]
    ctx["vision_frame"] = vision_frame
    ctx["H"], ctx["W"] = h, w

    faces = get_many_faces([vision_frame])
    ctx["faces"] = faces
    ctx["num_faces"] = len(faces)

    if not faces:
        ctx.update({
            "face": None, "bbox": None, "bbox_w": None,
            "lmk68": None, "lmk5": None,
            "crop": None, "gray": None,
            "full_gray": cv2.cvtColor(vision_frame, cv2.COLOR_BGR2GRAY),
            "det_raw": None, "emb_norm": None
        })
        return ctx

    face: Face = max(faces, key=lambda f: f.score_set.get('detector', 0.0))
    x1, y1, x2, y2 = face.bounding_box
    bbox = np.array([x1, y1, x2, y2], dtype=float)
    bbox_w = max(1.0, x2 - x1)

    lmk68 = face.landmark_set.get('68')
    # SAFE coalesce instead of "or"
    lmk5  = _coalesce_ndarray(face.landmark_set.get('5/68'),
                              face.landmark_set.get('5'))

    crop = _crop_from_bbox(vision_frame, bbox)
    if crop.size == 0:
        crop = vision_frame
    gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
    full_gray = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2GRAY)

    ctx.update({
        "face": face,
        "bbox": bbox,
        "bbox_w": bbox_w,
        "lmk68": lmk68,
        "lmk5":  lmk5,
        "crop": crop,
        "gray": gray,
        "full_gray": full_gray,
        "det_raw": float(face.score_set.get('detector', 0.0)),
        "emb_norm": getattr(face, "embedding_norm", None),
    })
    return ctx

def evaluate_metrics(
    vision_frame: VisionFrame,
    metric_fns: List[MetricFn],
    *,
    include_meta: bool = True
) -> Dict[str, float]:
    """
    Run a list of metric functions on a single image.
    Each metric_fn(ctx) may return either a float or ('name', float).
    Returns {name: value}. Adds a few useful meta entries when include_meta=True.
    """
    ctx = _build_metric_context(vision_frame)
    out: Dict[str, float] = {}

    # Optional meta so you can filter in tests
    if include_meta:
        out["meta_num_faces"] = float(ctx["num_faces"])
        out["meta_has_faces"] = 1.0 if ctx["num_faces"] > 0 else 0.0

    for fn in metric_fns:
        name = getattr(fn, "__name__", "metric")
        val = fn(ctx)
        if isinstance(val, tuple) and len(val) == 2:
            name, score = val  # ('custom_name', value)
        else:
            score = float(val)  # use fn.__name__ as key
        out[name] = float(score)
        # keep testing even if a metric fails
    return out

def m_det(ctx):  # detector confidence â†’ [0,1]
    det_raw = ctx["det_raw"]
    return ("det", _scale01(det_raw, DET_SCORE_MIN, DET_SCORE_MAX)) if det_raw is not None else ("det", 0.0)

def m_geom(ctx):
    l68, bw = ctx["lmk68"], ctx["bbox_w"]
    return ("geom", _geom_symmetry_score(l68, bw) if _has_68(l68) else 0.7)

def m_pose(ctx):
    l68, l5 = ctx["lmk68"], ctx["lmk5"]
    if _has_68(l68):
        yaw, pitch, roll = _estimate_pose_from_68(l68)
    elif l5 is not None:
        yaw, pitch, roll = _pose_from_5(l5)
    else:
        return ("pose", 0.5)
    yaw_p = _clamp01(1.0 - abs(yaw)/POSE_MAX_YAW)
    pit_p = _clamp01(1.0 - abs(pitch)/POSE_MAX_PITCH)
    rol_p = _clamp01(1.0 - abs(roll)/POSE_MAX_ROLL)
    return ("pose", 0.4*yaw_p + 0.3*pit_p + 0.3*rol_p)

def m_sharp(ctx):
    if ctx["gray"] is None: return ("sharp", 0.0)
    lapv = _laplacian_var(ctx["gray"])
    return ("sharp", _clamp01((lapv - LAPLACIAN_BAD) / (LAPLACIAN_GOOD - LAPLACIAN_BAD)))

def m_exposure(ctx):
    if ctx["gray"] is None: return ("exposure", 0.0)
    g = ctx["gray"]; mean, std = float(g.mean()), float(g.std())
    if mean < EXPO_LOW: band = _scale01(mean, 0.0, EXPO_LOW)
    elif mean > EXPO_HIGH: band = _scale01(255.0 - mean, 0.0, 255.0-EXPO_HIGH)
    else: band = 1.0
    spread = _scale01(std, EXPO_STD_MIN, 90.0)
    return ("exposure", 0.7*band + 0.3*spread)

def m_center(ctx):
    if ctx["bbox"] is None: return ("center", 0.0)
    return ("center", _centering_score(ctx["bbox"], ctx["W"], ctx["H"]))

def m_occl(ctx):
    if not _has_68(ctx["lmk68"]): return ("occl", 0.5)
    return ("occl", _occlusion_score(ctx["full_gray"], ctx["lmk68"], ctx["bbox"]))
# === Edge-partial metrics (uses detect_partial_face_edges) ===

def _edge_info(ctx):
    """
    Compute (or reuse cached) edge-partial result for this image.
    Returns (has_partial: bool, details: dict).
    """
    key = "_edge_info_cached"
    if key in ctx:
        return ctx[key]
    bbox = ctx.get("bbox")
    main_bbox = None if bbox is None else np.array(bbox, dtype=int)
    has_partial, details = detect_partial_face_edges(
        ctx["vision_frame"], main_bbox=main_bbox, debug=False
    )
    ctx[key] = (has_partial, details)  # cache for sibling metrics
    return ctx[key]

def m_edge_partial(ctx):
    has, _ = _edge_info(ctx)
    return ("edge_partial", 1.0 if has else 0.0)

def make_edge_fire(side: str):
    def _m(ctx):
        _, det = _edge_info(ctx)
        d = det.get(side, {})
        return (f"edge_{side}_fire", 1.0 if d.get("fire", False) else 0.0)
    _m.__name__ = f"m_edge_{side}_fire"
    return _m

def make_edge_score(side: str):
    def _m(ctx):
        _, det = _edge_info(ctx)
        d = det.get(side, {})
        return (f"edge_{side}_score", float(d.get("score", 0.0)))
    _m.__name__ = f"m_edge_{side}_score"
    return _m

def make_edge_skin_frac(side: str):
    def _m(ctx):
        _, det = _edge_info(ctx)
        d = det.get(side, {})
        return (f"edge_{side}_skin_frac", float(d.get("skin_frac", 0.0)))
    _m.__name__ = f"m_edge_{side}_skin"
    return _m

In [12]:
import numpy as np
import cv2

# if you already defined these earlier, keep your versions:
def _strip_boxes(h: int, w: int, frac: float = 0.18):
    t = max(1, int(round(frac * min(h, w))))
    # widen at very low resolutions
    if max(h, w) <= 256:
        t = max(t, int(0.25 * min(h, w)))
    return {
        "left":   (slice(0, h), slice(0, t)),
        "right":  (slice(0, h), slice(w - t, w)),
        "top":    (slice(0, t), slice(0, w)),
        "bottom": (slice(h - t, h), slice(0, w)),
    }

def _skin_mask_bgr(img_bgr: np.ndarray) -> np.ndarray:
    ycrcb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2YCrCb)
    Y, Cr, Cb = cv2.split(ycrcb)
    m1 = (Cr >= 133) & (Cr <= 180) & (Cb >= 77) & (Cb <= 135) & (Y >= 40)
    hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
    H, S, V = cv2.split(hsv)
    m2 = ((H <= 25) | (H >= 160)) & (S >= 30) & (V >= 40)
    m = (m1 | m2).astype(np.uint8) * 255
    m = cv2.medianBlur(m, 5)
    m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, np.ones((5,5), np.uint8))
    m = cv2.morphologyEx(m, cv2.MORPH_OPEN,  np.ones((3,3), np.uint8))
    return m

def _bbox_iou(a, b) -> float:
    ax1, ay1, ax2, ay2 = a
    bx1, by1, bx2, by2 = b
    ix1, iy1 = max(ax1, bx1), max(ay1, by1)
    ix2, iy2 = min(ax2, bx2), min(ay2, by2)
    iw, ih = max(0, ix2 - ix1), max(0, iy2 - iy1)
    inter = iw * ih
    area_a = max(0, ax2-ax1) * max(0, ay2-ay1)
    area_b = max(0, bx2-bx1) * max(0, by2-by1)
    union = area_a + area_b - inter + 1e-6
    return inter / union

def detect_similar_skin_blob_outside_bbox(
    img_bgr: np.ndarray,
    main_bbox: np.ndarray,
    *,
    border_frac: float = 0.18,
    chi2_thresh: float = 9.0,           # ~99% for 2D Gaussian (Cr,Cb)
    min_area_frac: float = 0.0015,      # of full image; auto clamps below
    exclude_neck: bool = True
) -> tuple[bool, dict]:
    """
    Detects border-touching blobs outside main_bbox whose chroma (Cr,Cb) matches
    the main face's skin statistics. Returns (has_blob, details).
    """
    h, w = img_bgr.shape[:2]
    x1, y1, x2, y2 = np.array(main_bbox, int)
    x1 = np.clip(x1, 0, w-1); x2 = np.clip(x2, 0, w-1)
    y1 = np.clip(y1, 0, h-1); y2 = np.clip(y2, 0, h-1)

    ycrcb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2YCrCb)
    Y, Cr, Cb = cv2.split(ycrcb)

    # Skin mask inside main face (fallback to whole box if empty)
    main_skin = _skin_mask_bgr(img_bgr)[y1:y2, x1:x2]
    sel = main_skin > 0
    Cr_box = Cr[y1:y2, x1:x2][sel]
    Cb_box = Cb[y1:y2, x1:x2][sel]
    if Cr_box.size < 50:
        Cr_box = Cr[y1:y2, x1:x2].ravel()
        Cb_box = Cb[y1:y2, x1:x2].ravel()

    # robust chroma center & scale (median + MAD)
    mu = np.array([np.median(Cr_box), np.median(Cb_box)], dtype=np.float32)
    mad = 1.4826 * np.array([
        np.median(np.abs(Cr_box - mu[0])) + 1e-6,
        np.median(np.abs(Cb_box - mu[1])) + 1e-6
    ], dtype=np.float32)
    inv_var = 1.0 / (mad ** 2)

    # Mahalanobis distance (diag) for all pixels
    dcr = (Cr.astype(np.float32) - mu[0]) ** 2 * inv_var[0]
    dcb = (Cb.astype(np.float32) - mu[1]) ** 2 * inv_var[1]
    d2  = dcr + dcb

    # Candidate mask: similar chroma & skin-like & outside main bbox
    sim = (d2 < chi2_thresh).astype(np.uint8) * 255
    skin = _skin_mask_bgr(img_bgr)
    cand = cv2.bitwise_and(sim, skin)

    # Blank a slightly shrunken version of main bbox so neighbors remain
    shrink_x = int(0.08 * (x2 - x1))
    shrink_y = int(0.08 * (y2 - y1))
    ex1 = max(0, x1 + shrink_x); ex2 = min(w-1, x2 - shrink_x)
    ey1 = max(0, y1 + shrink_y); ey2 = min(h-1, y2 - shrink_y)
    cand[ey1:ey2, ex1:ex2] = 0

    # Keep only border strips
    strips = _strip_boxes(h, w, border_frac)
    mask_border = np.zeros((h, w), np.uint8)
    for s in strips.values():
        mask_border[s] = 255
    cand = cv2.bitwise_and(cand, mask_border)

    # Clean and find components
    cand = cv2.morphologyEx(cand, cv2.MORPH_CLOSE, np.ones((5,5), np.uint8))
    cand = cv2.morphologyEx(cand, cv2.MORPH_OPEN,  np.ones((3,3), np.uint8))

    min_area = max(40, int(min_area_frac * h * w))
    cnts, _ = cv2.findContours(cand, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    details = {"blobs": []}
    fired = False
    for c in cnts:
        area = cv2.contourArea(c)
        if area < min_area:
            continue
        x, y, bw, bh = cv2.boundingRect(c)

        # which side?
        touches_left   = x <= 1
        touches_right  = x + bw >= w - 2
        touches_top    = y <= 1
        touches_bottom = y + bh >= h - 2
        side = None
        if touches_left:   side = "left"
        if touches_right:  side = "right" if side is None else side + "+right"
        if touches_top:    side = "top" if side is None else side + "+top"
        if touches_bottom: side = "bottom" if side is None else side + "+bottom"

        # optional: ignore neck (tall, below face, overlaps horizontally with face)
        if exclude_neck:
            overlap_x = not (x2 < x or x + bw < x1)  # horizontal overlap with face
            is_below  = y > y2 - int(0.05 * (y2 - y1))
            tall      = bh / (bw + 1e-6) > 1.5
            if overlap_x and is_below and tall:
                continue

        # average chroma distance inside blob (stronger means more face-like)
        mask = np.zeros((h, w), np.uint8)
        cv2.drawContours(mask, [c], -1, 255, -1)
        mean_d2 = float(np.mean(d2[mask > 0])) 
        sim_score = float(np.clip(1.0 - (mean_d2 / (chi2_thresh + 1e-6)), 0.0, 1.0))

        details["blobs"].append({
            "bbox": [int(x), int(y), int(x + bw), int(y + bh)],
            "area": float(area),
            "side": side or "unknown",
            "mean_chi2": mean_d2,
            "similarity": sim_score
        })

        # Fire if it's a confident, sizable, border-touching blob
        if sim_score >= 0.35:
            fired = True

    details["fired"] = fired
    return fired, details
def _edge_color_info(ctx):
    key = "_edge_color_cached"
    if key in ctx:
        return ctx[key]
    bbox = ctx.get("bbox")
    if bbox is None:
        ctx[key] = (False, {"blobs": [], "fired": False})
    else:
        has, info = detect_similar_skin_blob_outside_bbox(ctx["vision_frame"], np.array(bbox, int))
        ctx[key] = (has, info)
    return ctx[key]

def m_edge_color_partial(ctx):
    has, _ = _edge_color_info(ctx)
    return ("edge_color_partial", 1.0 if has else 0.0)

def m_edge_color_max_sim(ctx):
    _, info = _edge_color_info(ctx)
    mx = max((b.get("similarity", 0.0) for b in info.get("blobs", [])), default=0.0)
    return ("edge_color_max_sim", float(mx))

In [None]:
# 3) Run your pipeline
from quality_metrics import _crop_from_bbox, _scale01, _has_68, _geom_symmetry_score, _estimate_pose_from_68, _clamp01, _laplacian_var, _centering_score, _occlusion_score
import os
import shutil


target_dir = "/Users/adamsobieszek/PycharmProjects/_manipy/content/out_filtered5"
out_dir = "/Users/adamsobieszek/PycharmProjects/_manipy/content/out_filtered6"
os.makedirs(out_dir, exist_ok=True)
for im in os.listdir(target_dir):
    if im.endswith(".pt") or im.startswith("."):
        continue

    metrics = [
    # new edge-partial metrics
    m_edge_partial,
    make_edge_fire("left"), make_edge_fire("right"),
    make_edge_fire("top"),  make_edge_fire("bottom"),
      m_edge_color_partial, m_edge_color_max_sim 
    # (optional debug)
    # make_edge_score("left"), make_edge_skin_frac("left"), ...
]

    vf = visionframe_from_pil(Image.open(f"{target_dir}/{im}"))
    values = evaluate_metrics(vf, metrics)
    print(values)
    # Example hard reject logic:
    if values.get("edge_partial", 0.0) >= 0.5:
        print("HARD REJECT: partial face at image edge")
        print(im)
        metrics = [m_det, m_geom, m_pose, m_sharp, m_exposure, m_center, m_occl,     m_edge_color_partial, m_edge_color_max_sim ] # new color-similarity check]x
        values = evaluate_metrics(vf, metrics)
        # values -> {'meta_num_faces': 1.0, 'meta_has_faces': 1.0, 'det': 0.83, ...}
        print(values)
    else:
        # Load the .webp image and save as .jpg in out_dir
        from PIL import Image
        in_path = f"{target_dir}/{im}"
        out_path = f"{out_dir}/{os.path.splitext(im)[0]}.jpg"
        img = Image.open(in_path).convert("RGB")
        img.save(out_path, "JPEG", quality=95)

{'meta_num_faces': 1.0, 'meta_has_faces': 1.0, 'edge_partial': 0.0, 'edge_left_fire': 0.0, 'edge_right_fire': 0.0, 'edge_top_fire': 0.0, 'edge_bottom_fire': 0.0, 'edge_color_partial': 1.0, 'edge_color_max_sim': 0.87662267731248}
{'meta_num_faces': 1.0, 'meta_has_faces': 1.0, 'edge_partial': 0.0, 'edge_left_fire': 0.0, 'edge_right_fire': 0.0, 'edge_top_fire': 0.0, 'edge_bottom_fire': 0.0, 'edge_color_partial': 1.0, 'edge_color_max_sim': 0.8290139891615791}
{'meta_num_faces': 1.0, 'meta_has_faces': 1.0, 'edge_partial': 0.0, 'edge_left_fire': 0.0, 'edge_right_fire': 0.0, 'edge_top_fire': 0.0, 'edge_bottom_fire': 0.0, 'edge_color_partial': 1.0, 'edge_color_max_sim': 0.7313054701218749}
{'meta_num_faces': 1.0, 'meta_has_faces': 1.0, 'edge_partial': 0.0, 'edge_left_fire': 0.0, 'edge_right_fire': 0.0, 'edge_top_fire': 0.0, 'edge_bottom_fire': 0.0, 'edge_color_partial': 1.0, 'edge_color_max_sim': 0.8325940953930552}
{'meta_num_faces': 1.0, 'meta_has_faces': 1.0, 'edge_partial': 0.0, 'edge_left

In [3]:
import os
print([m.split('.')[0] for m in os.listdir("/Users/adamsobieszek/PycharmProjects/_manipy/content/out_filtered6")])

['134_f_36', '69_f_36', '210_m_44', '123_m_11', '135_f_36', '33_f_12', '195_f_45', '41_f_19', '48_m_54', '146_f_15', '224_m_52', '128_f_27', '204_f_30', '153_m_49', '94_f_10', '4_m_45', '163_m_55', '203_f_55', '208_f_42', '14_f_34', '28_m_53', '187_f_49', '81_f_45', '7_f_23', '6_f_37', '199_f_7', '163_f_48', '110_f_57', '139_f_50', '162_f_48', '223_f_16', '154_f_25', '138_f_50', '178_f_28', '48_f_49', '167_f_56', '182_f_57', '153_f_40', '26_f_53', '180_f_10', '62_f_21', '1_f_52', '0_f_46', '86_f_20', '181_f_39', '152_f_41', '26_f_52', '166_f_43', '57_f_22', '206_m_57', '48_f_48', '115_f_48', '110_f_42', '64_f_45', '142_m_17', '21_f_23', '186_f_48', '124_f_54', '70_m_19', '208_f_43', '216_f_29', '106_f_44', '97_f_56', '44_f_12', '96_f_56', '192_f_21', '172_f_16', '101_f_35', '180_m_30', '100_f_35', '101_f_21', '30_f_54', '40_f_18', '195_f_50', '194_f_50', '61_m_52', '134_f_23', '170_f_51', '122_m_8', '159_f_40', '158_f_40', '218_f_36', '194_f_52', '144_f_45', '101_f_37', '173_f_28', '12

In [None]:
import torch
from sg_output_analysis import feature_sensitivity_at
import sys
import pickle
sys.path.append('/Users/adamsobieszek/PycharmProjects/_manipy/')
from manipy.stylegan.utils import sample_w
device = torch.device("mps")
# --- 2. Load Models ---
sys.path.append('/Users/adamsobieszek/PycharmProjects/psychGAN/content/psychGAN/stylegan3')
with open('/Users/adamsobieszek/PycharmProjects/psychGAN/stylegan2-ffhq-1024x1024.pkl', 'rb') as fp:
    G = pickle.load(fp)['G_ema'].to(device)
    G.eval()
# img_bchw: (1,3,H,W), RGB in [0,1] from your StyleGAN pipeline (no need to convert to BGR)
img_bchw = G(sample_w(1, G=G), None)  # or your already-rendered tensor
print(img_bchw.shape)
res = feature_sensitivity_at(
    image_bchw=img_bchw,
    center_xy=(256, 256),            # pixels; set coords_mode="normaliz    ed" if using [-1,1]
    extractor_name="resnet50",       # or "vgg16", "lpips", "clip-vitb32" if installed
    coords_mode="pixels",
    patch_px=128,
    out_res=224,
    reduction="l2_half"
)

print("grad (âˆ‚g/âˆ‚x, âˆ‚g/âˆ‚y):", res.grad_xy.cpu().numpy())
print("unit dir (fastest change):", res.unit_dir.cpu().numpy())
print("Hessian:\n", res.hess_xy.cpu().numpy())
print("eigvals:", res.eigvals.cpu().numpy())

torch.Size([1, 3, 1024, 1024])


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 3, 1, 224, 224]

In [None]:
# sg_output_analysis.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Tuple, Optional, Literal, Callable

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# ============================
# Feature extractors
# ============================

class _IdentityPool(nn.Module):
    def forward(self, x): return x

class FeatureExtractor(nn.Module):
    """
    Wraps a backbone and produces a single flat feature vector per image.
    Expect input: float tensor in [0,1], shape (N,3,H,W), RGB.
    """
    def __init__(self, trunk: nn.Module, pool: nn.Module | None = None, proj: nn.Module | None = None):
        super().__init__()
        self.trunk = trunk
        self.pool = pool if pool is not None else _IdentityPool()
        self.proj = proj if proj is not None else _IdentityPool()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.trunk(x)
        if isinstance(y, (list, tuple)):
            y = y[-1]
        # if feature map, global-average pool
        if y.ndim == 4:
            y = torch.flatten(F.adaptive_avg_pool2d(y, 1), 1)
        y = self.pool(y)
        y = self.proj(y)
        return y


def build_feature_extractor(name: str, device: torch.device | str = "cpu") -> FeatureExtractor:
    """
    name âˆˆ {"vgg16","resnet50","lpips","clip-vitb32", ...}
    (LPIPS and CLIP are optional; we fall back gracefully if not installed.)
    """
    n = name.lower()

    if n == "vgg16":
        from torchvision.models import vgg16, VGG16_Weights
        m = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features[:23]  # conv3_3
        return FeatureExtractor(nn.Sequential(m)).to(device).eval()

    if n == "resnet50":
        from torchvision.models import resnet50, ResNet50_Weights
        m = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        trunk = nn.Sequential(*(list(m.children())[:-1]))  # â†’ (N,2048,1,1)
        proj = nn.Flatten(1)
        return FeatureExtractor(trunk, proj=proj).to(device).eval()

    if n == "lpips":
        try:
            import lpips  # pip install lpips
        except Exception as e:
            raise RuntimeError("LPIPS not available; pip install lpips") from e
        net = lpips.LPIPS(net='vgg')  # returns (N,1,1,1) if called on image pairs
        # Wrap to behave like a feature extractor on single images:
        class LPIPSFeat(nn.Module):
            def __init__(self, net):
                super().__init__()
                self.net = net
            def forward(self, x):
                # Compare to black image to get a (pseudo) embedding; not a true embedding but useful for sensitivity
                z = torch.zeros_like(x)
                d = self.net(x*2-1, z*2-1)  # LPIPS expects [-1,1]
                return d.view(x.shape[0], -1)
        return FeatureExtractor(LPIPSFeat(net)).to(device).eval()

    if n in {"clip", "clip-vitb32"}:
        try:
            import clip  # pip install git+https://github.com/openai/CLIP.git
        except Exception as e:
            raise RuntimeError("CLIP not available; install openai-clip") from e
        model, _ = clip.load("ViT-B/32", device=device, jit=False)
        class CLIPImageFeat(nn.Module):
            def __init__(self, m): super().__init__(); self.m = m
            def forward(self, x):
                # CLIP expects normalized [-1,1] w/ mean/std; we accept [0,1] and convert
                mean = torch.tensor([0.48145466,0.4578275,0.40821073], device=x.device)[None,:,None,None]
                std  = torch.tensor([0.26862954,0.26130258,0.27577711], device=x.device)[None,:,None,None]
                x_n = (x - mean) / std
                return self.m.encode_image(x_n).float()
        return FeatureExtractor(CLIPImageFeat(model)).to(device).eval()

    raise ValueError(f"Unknown feature extractor: {name}")


# ============================
# Differentiable coords â†’ patch â†’ features
# ============================

@dataclass
class CoordsToFeaturesConfig:
    extractor_name: str = "resnet50"
    out_res: int = 224            # feature extractor input resolution (square)
    patch_px: int = 128           # physical patch width/height in pixels (on the source image)
    coords_mode: Literal["pixels","normalized"] = "pixels"  # input coords convention
    clamp: bool = True            # clamp sampling grid to [-1,1]


def _to_norm_xy(xy: torch.Tensor, H: int, W: int, mode: str) -> torch.Tensor:
    """
    Convert (x,y) in pixels or normalized to normalized coords in [-1,1] (align_corners=True).
    xy: (...,2)
    """
    if mode == "normalized":
        return xy
    # pixels â†’ normalized
    x, y = xy[..., 0], xy[..., 1]
    xn = 2.0 * x / max(W-1, 1) - 1.0
    yn = 2.0 * y / max(H-1, 1) - 1.0
    return torch.stack([xn, yn], dim=-1)


def _make_patch_grid(
    center_xy_norm: torch.Tensor,  # (2,), requires_grad=True
    H: int, W: int,
    out_res: int,
    patch_px: int,
    clamp: bool = True,
) -> torch.Tensor:
    """
    Create a sampling grid (1, out_res, out_res, 2) in normalized coords so that
    grid_sample(image, grid) returns a patch centered at center_xy_norm with physical
    size ~ patch_px Ã— patch_px (in source pixel units).
    """
    device = center_xy_norm.device
    # Base grid in [-1,1]
    yy, xx = torch.meshgrid(
        torch.linspace(-1, 1, out_res, device=device),
        torch.linspace(-1, 1, out_res, device=device),
        indexing="ij"
    )
    base = torch.stack([xx, yy], dim=-1)  # (H,W,2)

    # Half-size of patch in normalized coords: (patch_px/2) * (2/W or 2/H) = patch_px/W (or /H)
    sx = (patch_px / float(W))
    sy = (patch_px / float(H))
    scaled = torch.stack([sx * base[..., 0], sy * base[..., 1]], dim=-1)

    # Shift to center
    grid = scaled + center_xy_norm[None, None, :]
    if clamp:
        grid = torch.clamp(grid, -1.0, 1.0)
    return grid.unsqueeze(0)  # (1, out_res, out_res, 2)


def _bilinear_sample_bchw(
    image_bchw: torch.Tensor,         # (1,C,H,W)
    grid: torch.Tensor,               # (1, out_h, out_w, 2) in normalized [-1,1]
    padding_mode: Literal["border","zeros"] = "border",
    align_corners: bool = True,
) -> torch.Tensor:
    """
    Differentiable bilinear sampler equivalent to grid_sample for N=1 using
    primitive ops that are supported on MPS backward. Grad flows w.r.t. image and grid.
    """
    assert image_bchw.ndim == 4 and image_bchw.shape[0] == 1
    assert grid.ndim == 4 and grid.shape[0] == 1
    _, C, H, W = image_bchw.shape
    _, out_h, out_w, _ = grid.shape

    gx = grid[..., 0]
    gy = grid[..., 1]

    if align_corners:
        x = (gx + 1) * (W - 1) * 0.5
        y = (gy + 1) * (H - 1) * 0.5
    else:
        x = ((gx + 1) * W - 1) * 0.5
        y = ((gy + 1) * H - 1) * 0.5

    x0 = torch.floor(x)
    y0 = torch.floor(y)
    x1 = x0 + 1
    y1 = y0 + 1

    wa = (x1 - x) * (y1 - y)
    wb = (x - x0) * (y1 - y)
    wc = (x1 - x) * (y - y0)
    wd = (x - x0) * (y - y0)

    def sample_at(ix: torch.Tensor, iy: torch.Tensor) -> torch.Tensor:
        ix_l = ix.long()
        iy_l = iy.long()
        if padding_mode == "border":
            ix_l = ix_l.clamp(0, W - 1)
            iy_l = iy_l.clamp(0, H - 1)
            vals = image_bchw[:, :, iy_l, ix_l]  # (1,C,out_h,out_w)
            return vals
        elif padding_mode == "zeros":
            in_x = (ix_l >= 0) & (ix_l < W)
            in_y = (iy_l >= 0) & (iy_l < H)
            inb = (in_x & in_y).unsqueeze(1)  # (1,1,out_h,out_w)
            ix_c = ix_l.clamp(0, W - 1)
            iy_c = iy_l.clamp(0, H - 1)
            vals = image_bchw[:, :, iy_c, ix_c]
            return vals * inb
        else:
            raise ValueError("Unsupported padding_mode")

    Ia = sample_at(x0, y0)
    Ib = sample_at(x1, y0)
    Ic = sample_at(x0, y1)
    Id = sample_at(x1, y1)

    wa = wa.unsqueeze(1)
    wb = wb.unsqueeze(1)
    wc = wc.unsqueeze(1)
    wd = wd.unsqueeze(1)

    out = Ia * wa + Ib * wb + Ic * wc + Id * wd  # (1,C,out_h,out_w)
    return out


def coords_to_features(
    image_bchw: torch.Tensor,        # (1,3,H,W), RGB, float in [0,1]
    center_xy: torch.Tensor,         # (2,) x,y in pixels or normalized
    cfg: CoordsToFeaturesConfig,
    extractor: Optional[FeatureExtractor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns (features, patch_bchw). Differentiable w.r.t. center_xy (and image if you want).
    """
    assert image_bchw.ndim == 4 and image_bchw.shape[0] == 1 and image_bchw.shape[1] == 3
    _, _, H, W = image_bchw.shape
    device = image_bchw.device

    center_xy = center_xy.to(device).float()
    center_xy_norm = _to_norm_xy(center_xy, H, W, cfg.coords_mode)
    center_xy_norm.requires_grad_(True)

    # Build sampling grid
    is_mps = (device.type == "mps")
    effective_clamp = (cfg.clamp or is_mps)
    grid = _make_patch_grid(center_xy_norm, H, W, cfg.out_res, cfg.patch_px, effective_clamp)

    # Use a custom bilinear sampler on MPS to avoid grid_sample backward limitations
    if is_mps:
        patch = _bilinear_sample_bchw(image_bchw, grid, padding_mode="border", align_corners=True).squeeze(2)
    else:
        patch = F.grid_sample(image_bchw, grid, mode="bilinear", padding_mode="border", align_corners=True)

    if extractor is None:
        extractor = build_feature_extractor(cfg.extractor_name, device=device)
    
    with torch.set_grad_enabled(True):
        feats = extractor(patch)  # (1,D)
    return feats, patch


# ============================
# Sensitivity (grad/Hessian) at a coordinate
# ============================

@dataclass
class SensitivityResult:
    feats: torch.Tensor             # (D,)
    energy: torch.Tensor            # scalar g = 0.5||f||^2 or other reduction
    grad_xy: torch.Tensor           # (2,) âˆ‚g/âˆ‚(x,y)   in chosen coords_mode
    unit_dir: torch.Tensor          # (2,) normalized gradient direction (steepest ascent)
    hess_xy: torch.Tensor           # (2,2) Hessian of g wrt (x,y)
    eigvals: torch.Tensor           # (2,) principal curvatures
    eigvecs: torch.Tensor           # (2,2) columns = principal directions
    patch: torch.Tensor             # (1,3,out_res,out_res)

def feature_sensitivity_at(
    image_bchw: torch.Tensor,            # (1,3,H,W), RGB, [0,1]
    center_xy: Tuple[float,float] | torch.Tensor,  # (x,y) in pixels by default
    extractor_name: str = "resnet50",
    coords_mode: Literal["pixels","normalized"] = "pixels",
    patch_px: int = 128,
    out_res: int = 224,
    reduction: Literal["l2","l2_half","l1"] = "l2_half",
) -> SensitivityResult:
    """
    Compute âˆ‚g/âˆ‚(x,y) and âˆ‚Â²g/âˆ‚(x,y)Â² where g is a scalar energy of features f(patch(x,y)).
    - Default g = 0.5*||f||^2 (smooth and convenient): 'l2_half'
    - 'l2' uses ||f||^2, 'l1' uses ||f||_1 (subgradient-friendly).
    Returns gradient direction (steepest ascent), Hessian, and eigendecomposition.
    """
    if isinstance(center_xy, tuple):
        center_xy = torch.tensor(center_xy, dtype=torch.float32, device=image_bchw.device)

    cfg = CoordsToFeaturesConfig(
        extractor_name=extractor_name,
        out_res=out_res,
        patch_px=patch_px,
        coords_mode=coords_mode,
        clamp=True,
    )

    orig_device = image_bchw.device

    # Build extractor once on the original device
    extractor = build_feature_extractor(extractor_name, device=orig_device)

    # Prepare coordinates tensor on original device
    xy = center_xy.detach().clone().to(orig_device)
    xy.requires_grad_(True)

    # Keep image on original device
    image_in = image_bchw

    # Scalar energy helper
    def energy_from_xy(xy_in: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        feats, patch = coords_to_features(image_in, xy_in, cfg, extractor)
        feats = feats.view(-1)
        if reduction == "l2_half":
            g = 0.5 * torch.dot(feats, feats)
        elif reduction == "l2":
            g = torch.dot(feats, feats)
        elif reduction == "l1":
            g = feats.abs().sum()
        else:
            raise ValueError("Unknown reduction")
        return g, feats, patch

    # Autograd path (now MPS-safe via custom bilinear sampler)
    g, fvec, patch = energy_from_xy(xy)
    grad = torch.autograd.grad(g, xy, create_graph=True, retain_graph=True)[0]

    def _scalar_fn(xy_in: torch.Tensor) -> torch.Tensor:
        g2, _, _ = energy_from_xy(xy_in)
        return g2
    H = torch.autograd.functional.hessian(_scalar_fn, xy, create_graph=False)
    evals, evecs = torch.linalg.eigh(H)
    grad_dir = grad / (grad.norm() + 1e-8)
    return SensitivityResult(
        feats=fvec.detach(),
        energy=g.detach(),
        grad_xy=grad.detach(),
        unit_dir=grad_dir.detach(),
        hess_xy=H.detach(),
        eigvals=evals.detach(),
        eigvecs=evecs.detach(),
        patch=patch.detach()
    )
# img_bchw: (1,3,H,W), RGB in [0,1] from your StyleGAN pipeline (no need to convert to BGR)
img_bchw = G(sample_w(1, G=G), None)  # or your already-rendered tensor
print(img_bchw.shape)
res = feature_sensitivity_at(
    image_bchw=img_bchw,
    center_xy=(256, 256),            # pixels; set coords_mode="normalized" if using [-1,1]
    extractor_name="resnet50",       # or "vgg16", "lpips", "clip-vitb32" if installed
    coords_mode="pixels",
    patch_px=128,
    out_res=224,
    reduction="l2_half"
)

print("grad (âˆ‚g/âˆ‚x, âˆ‚g/âˆ‚y):", res.grad_xy.cpu().numpy())
print("unit dir (fastest change):", res.unit_dir.cpu().numpy())
print("Hessian:\n", res.hess_xy.cpu().numpy())
print("eigvals:", res.eigvals.cpu().numpy())

torch.Size([1, 3, 1024, 1024])


RuntimeError: max_pool2d with `return_indices=False` is not infinitely differentiable. If you want to calculate higher order derivatives, e.g. second order, set `return_indices=True`.