In [None]:
# =============================
# 0. IMPORTS & DEPENDENCIES
# =============================

# ---- STANDARD IMPORTS ----
import os
import cv2
import shutil
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
from collections import Counter

import torch
from torchvision import transforms as T

# ---- FACE DETECTION (RetinaFace or fallback to MTCNN) ----
try:
    from retinaface import RetinaFace
    USE_RETINAFACE = True
except ImportError:
    from facenet_pytorch import MTCNN
    USE_RETINAFACE = False

# ---- FACE DEDUPLICATION (ArcFace) ----
try:
    from insightface.app import FaceAnalysis
    USE_ARCFACE = True
except ImportError:
    from imagehash import phash
    USE_ARCFACE = False

# ---- LOAD YOUR V18 MODEL ----
from transformers import AutoModelForImageClassification, AutoImageProcessor

In [None]:
# =============================
# 1. CONFIGURATION & PARAMETERS
# =============================
INPUT_DIR = "/path/to/your/input_images"    # Folder of images to process
SAVE_DIR = "/path/to/output/sorted_faces"   # Output root folder
MODEL_PATH = "/path/to/your/V18_model"      # Your V18 supervised model checkpoint
BATCH_SIZE = 32                             # Adjust for speed/memory
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"

# Thresholds
CONFIDENCE_THRESHOLDS = {    # Per-class thresholds; fallback to global if class not in dict
    "disgust": 0.90,
    "contempt": 0.90,
    "fear": 0.90,
    "questioning": 0.90,
    "surprise": 0.90,
    "anger": 0.90,
    "happiness": 0.92,
    "sadness": 0.90,
    "neutral": 0.92,
    "unknown": 0.98,
}
GLOBAL_CONF_THRESH = 0.90
GLOBAL_ENTROPY_THRESH = 0.45

os.makedirs(SAVE_DIR, exist_ok=True)

In [None]:
# =============================
# 2. UTILITY FUNCTIONS
# =============================

# FACE DETECTION & CROPPING
def detect_and_crop(image_path):
    """Detect and crop the largest face in the image."""
    img = np.array(Image.open(image_path).convert("RGB"))
    if USE_RETINAFACE:
        faces = RetinaFace.detect_faces(image_path)
        if not faces:
            return None
        largest = max(faces.values(), key=lambda f: f['facial_area'][2]*f['facial_area'][3])
        x1, y1, x2, y2 = largest['facial_area']
        face = img[y1:y2, x1:x2]
    else:
        mtcnn = MTCNN(margin=10, keep_all=True, device=DEVICE)
        faces, _ = mtcnn.detect(Image.fromarray(img))
        if faces is None or len(faces) == 0:
            return None
        x1, y1, x2, y2 = [int(v) for v in faces[0]]
        face = img[y1:y2, x1:x2]
    pil_face = Image.fromarray(face)
    # QUALITY CHECK HERE
    if not is_quality_ok(pil_face):
        return None
    return pil_face


# FACE DEDUPLICATION (ArcFace/phash)
if USE_ARCFACE:
    arcface_app = FaceAnalysis(name="antelopev2", providers=['CPUExecutionProvider'])
    arcface_app.prepare(ctx_id=0)
else:
    arcface_app = None

def compute_face_embedding(face_img):
    """Return a 512-dim ArcFace embedding (InsightFace)."""
    if not USE_ARCFACE or arcface_app is None:
        raise RuntimeError("InsightFace/ArcFace not installed")
    face_np = np.array(face_img)
    res = arcface_app.get(face_np)
    if not res or len(res) == 0:
        return None
    return res[0].embedding

def compute_phash(face_img):
    return str(phash(face_img))

def deduplicate(images, embeddings=None, hashes=None, threshold=0.7):
    """Remove duplicate or near-duplicate faces."""
    if embeddings is not None:
        # Deduplicate by cosine similarity
        from sklearn.metrics.pairwise import cosine_similarity
        keep = []
        used = set()
        sims = cosine_similarity(embeddings)
        for i in range(len(images)):
            if i in used: continue
            keep.append(i)
            for j in range(i+1, len(images)):
                if sims[i, j] > threshold:
                    used.add(j)
        return [images[i] for i in keep]
    elif hashes is not None:
        seen = set()
        keep = []
        for i, h in enumerate(hashes):
            if h not in seen:
                seen.add(h)
                keep.append(i)
        return [images[i] for i in keep]
    else:
        return images


# IMAGE QUALITY FILTERING
def is_quality_ok(face_img, min_size=48, blur_thresh=80):
    """
    Check if image is of sufficient quality: size and sharpness.
    """
    if min(face_img.size) < min_size:
        return False
    # Quick blur check using variance of Laplacian
    face_np = np.array(face_img.convert("L"))
    lap_var = cv2.Laplacian(face_np, cv2.CV_64F).var()
    if lap_var < blur_thresh:
        return False
    return True


# MODEL LOADING (V18 & PROCESSOR)
def get_v18_model():
    model = AutoModelForImageClassification.from_pretrained(MODEL_PATH)
    processor = AutoImageProcessor.from_pretrained(MODEL_PATH)
    model.to(DEVICE).eval()
    return model, processor
    

# BATCH INFERENCE
def batch_infer(model, processor, pil_imgs):
    # Preprocess
    inputs = processor(pil_imgs, return_tensors="pt", padding=True)
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
    with torch.no_grad():
        logits = model(**inputs).logits
        probs = torch.softmax(logits, dim=-1)
        entropies = -torch.sum(probs * torch.log(probs + 1e-12), dim=-1)
        confs, preds = probs.max(dim=-1)
    return preds.cpu().numpy(), confs.cpu().numpy(), entropies.cpu().numpy(), probs.cpu().numpy()


# SOFTMAX, ENTROPY, AND ASSIGNMENT LOGIC
def calculate_entropy(probs):
    """
    Calculate prediction entropy for a vector of softmax probabilities.
    """
    return -np.sum(probs * np.log(probs + 1e-12))
    

# EXPORT & AUDIT LOGGING
def assign_class(pred_idx, conf, entropy, id2label):
    pred_class = id2label[pred_idx]
    thresh = CONFIDENCE_THRESHOLDS.get(pred_class, GLOBAL_CONF_THRESH)
    if conf < thresh or entropy > GLOBAL_ENTROPY_THRESH:
        return "unknown"
    return pred_class


In [None]:
# =============================
# 3. MAIN PIPELINE EXECUTION
# =============================

def main():
    # Load model/processor
    model, processor = get_v18_model()
    id2label = model.config.id2label
    print("Loaded V18 model with labels:", id2label)

    # Discover images
    all_img_paths = list(Path(INPUT_DIR).rglob("*.[jp][pn]g"))
    print(f"Found {len(all_img_paths)} images in {INPUT_DIR}")

    # Step 1: Detect/crop faces
    cropped_imgs, cropped_paths = [], []
    for img_path in tqdm(all_img_paths, desc="Detect/crop faces"):
        try:
            face = detect_and_crop(str(img_path))
            if face is not None:
                cropped_imgs.append(face)
                cropped_paths.append(str(img_path))
        except Exception as e:
            print(f"Error cropping {img_path}: {e}")

    print(f"Detected {len(cropped_imgs)} faces.")

    # Step 2: Deduplicate
    if USE_ARCFACE:
        embeddings = [compute_face_embedding(img) for img in cropped_imgs]
        dedup_indices = deduplicate(list(range(len(cropped_imgs))), embeddings=embeddings)
    else:
        hashes = [compute_phash(img) for img in cropped_imgs]
        dedup_indices = deduplicate(list(range(len(cropped_imgs))), hashes=hashes)
    cropped_imgs = [cropped_imgs[i] for i in dedup_indices]
    cropped_paths = [cropped_paths[i] for i in dedup_indices]
    print(f"Deduplicated to {len(cropped_imgs)} unique faces.")

    # Step 3: Batch inference and sort
    audit_rows = []
    os.makedirs(os.path.join(SAVE_DIR, "unknown"), exist_ok=True)
    for i in tqdm(range(0, len(cropped_imgs), BATCH_SIZE), desc="Batch inference"):
        batch_imgs = cropped_imgs[i:i+BATCH_SIZE]
        batch_paths = cropped_paths[i:i+BATCH_SIZE]
        preds, confs, ents, probs = batch_infer(model, processor, batch_imgs)
        for pidx, conf, ent, path, probs_vec in zip(preds, confs, ents, batch_paths, probs):
            assigned = assign_class(pidx, conf, ent, id2label)
            out_dir = os.path.join(SAVE_DIR, assigned)
            os.makedirs(out_dir, exist_ok=True)
            img_name = os.path.basename(path)
            shutil.copy(path, os.path.join(out_dir, img_name))
            audit_rows.append({
                "image_path": path,
                "assigned_class": assigned,
                "predicted_label": id2label[pidx],
                "confidence": conf,
                "entropy": ent,
                **{f"prob_{id2label[i]}": p for i, p in enumerate(probs_vec)}
            })

    # Step 4: Save audit CSV and print stats
    audit_df = pd.DataFrame(audit_rows)
    audit_csv = os.path.join(SAVE_DIR, "sort_audit.csv")
    audit_df.to_csv(audit_csv, index=False)
    print(f"Audit CSV saved to {audit_csv}")

    print("Final sorted class counts:", Counter(audit_df["assigned_class"]))
    print("Mean/median confidence per class:")
    for c in sorted(audit_df["assigned_class"].unique()):
        subset = audit_df[audit_df["assigned_class"] == c]
        print(f"  {c:12} | n={len(subset)} | mean_conf={subset['confidence'].mean():.3f} | median_conf={subset['confidence'].median():.3f}")

if __name__ == "__main__":
    main()