In [None]:
# ============================================
# Cell 1 – Imports & Basic Config
# ============================================
import os
import glob
import json
from typing import Optional, List, Tuple

import numpy as np
from PIL import Image, ImageOps

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from facenet_pytorch import InceptionResnetV1
import mediapipe as mp
import pandas as pd

# ---- Paths ----
MODEL_PATH = "models/best_facenet_arcface_kfold5.pth"
LABEL_MAP_PATH = "models/label_map.json"
# Path to your test images directory.
# You can place the "Test" folder in your project root or adjust the path accordingly.
TEST_DIR = "./Test"
OUTPUT_CSV = "predictions.csv"

# ---- Image / crop config ----
IMG_SIZE = 160
TARGET_SIZE = (384, 384)
MARGIN_RATIO = 0.15
MAX_DIM = 1600
MIN_DIM = 256
USE_CENTER_FALLBACK = True

# ---- Normalize  ----
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
# ============================================
# Cell 2 – Load Label Map
# ============================================
with open(LABEL_MAP_PATH, "r", encoding="utf-8") as f:
    label_map = json.load(f)   # {"NamaOrang": idx}

idx_to_name = {v: k for k, v in label_map.items()}
num_classes = len(label_map)

print("num_classes:", num_classes)
print("example mapping:", list(label_map.items())[:5])


In [None]:
# ============================================
# Cell 3 – MediaPipe Face Detection + Robust Crop
# ============================================
mp_face_detection = mp.solutions.face_detection


def run_face_detection(img_rgb: np.ndarray):
    # 1st pass: model_selection=1
    with mp_face_detection.FaceDetection(
        model_selection=1, min_detection_confidence=0.5
    ) as fd:
        res = fd.process(img_rgb)
        if res.detections:
            return res, img_rgb.shape[1], img_rgb.shape[0]

    # 2nd pass: fallback
    with mp_face_detection.FaceDetection(
        model_selection=0, min_detection_confidence=0.3
    ) as fd:
        res = fd.process(img_rgb)
        return res, img_rgb.shape[1], img_rgb.shape[0]


def pick_best_detection(detections, w: int, h: int):
    best = None
    best_score = -1.0
    for det in detections:
        bbox = det.location_data.relative_bounding_box
        score = det.score[0]
        area = max(bbox.width * bbox.height, 1e-6)
        combined = score * area
        if combined > best_score:
            best_score = combined
            best = det
    return best


def safe_center_crop(img: np.ndarray) -> Optional[Image.Image]:
    """Fallback crop yang dijamin tidak 0x0."""
    h, w = img.shape[:2]
    if h == 0 or w == 0:
        return None
    side = max(1, int(0.8 * min(h, w)))
    cx, cy = w // 2, h // 2
    x1 = max(0, cx - side // 2)
    y1 = max(0, cy - side // 2)
    x2 = min(w, x1 + side)
    y2 = min(h, y1 + side)
    if x2 <= x1 or y2 <= y1:
        return None
    crop = img[y1:y2, x1:x2]
    if crop.size == 0:
        return None
    face = Image.fromarray(crop)
    return ImageOps.fit(face, TARGET_SIZE, Image.BICUBIC)


def detect_and_crop_from_pil(pil_img: Image.Image) -> Optional[Image.Image]:
    """
    Crop wajah utama, aman dari ZeroDivisionError.
    """
    # Fix orientasi, ke RGB
    pil_img = ImageOps.exif_transpose(pil_img).convert("RGB")
    img = np.array(pil_img)
    orig = img.copy()
    oh, ow = orig.shape[:2]

    # Scale up
    short_side = min(oh, ow)
    if short_side < MIN_DIM:
        scale = MIN_DIM / float(short_side)
        new_w = int(ow * scale)
        new_h = int(oh * scale)
        pil_img = pil_img.resize((new_w, new_h), Image.BICUBIC)
        img = np.array(pil_img)

    # Scale down
    h, w = img.shape[:2]
    if max(h, w) > MAX_DIM:
        scale = MAX_DIM / float(max(h, w))
        new_w = int(w * scale)
        new_h = int(h * scale)
        pil_img = pil_img.resize((new_w, new_h), Image.BICUBIC)
        img = np.array(pil_img)
        h, w = new_h, new_w

    # Detection @ resized
    results, w, h = run_face_detection(img)

    # Coba di resolusi awal kalau gagal
    if not results or not results.detections:
        img = orig
        h, w = img.shape[:2]
        results, w, h = run_face_detection(img)

    if not results or not results.detections:
        if USE_CENTER_FALLBACK:
            return safe_center_crop(img)
        return None

    best_det = pick_best_detection(results.detections, w, h)
    bbox = best_det.location_data.relative_bounding_box

    x = int(bbox.xmin * w)
    y = int(bbox.ymin * h)
    bw = int(bbox.width * w)
    bh = int(bbox.height * h)

    x = max(0, x)
    y = max(0, y)
    bw = max(1, bw)
    bh = max(1, bh)

    margin_x = int(bw * MARGIN_RATIO)
    margin_y = int(bh * MARGIN_RATIO)

    x1 = max(0, x - margin_x)
    y1 = max(0, y - margin_y)
    x2 = min(w, x + bw + margin_x)
    y2 = min(h, y + bh + margin_y)

    if x2 <= x1 or y2 <= y1:
        # fallback kalau bounding box aneh
        return safe_center_crop(img)

    crop = img[y1:y2, x1:x2]
    if crop.size == 0:
        return safe_center_crop(img)

    face = Image.fromarray(crop)
    return ImageOps.fit(face, TARGET_SIZE, Image.BICUBIC)


In [None]:
# ============================================
# Cell 4 – ArcMarginProduct & FaceNetArcFace
# ============================================
import math

class ArcMarginProduct(nn.Module):
    """
    ArcFace: cos(theta + m) dengan scaling s.
    input: (B, in_features) -> embedding
    label: (B,) -> class index
    output: (B, out_features) -> logits untuk CrossEntropy
    """
    def __init__(
        self,
        in_features: int,
        out_features: int,
        s: float = 25.0,
        m: float = 0.30,
        easy_margin: bool = False,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
        self.easy_margin = easy_margin

        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.clamp(cosine.pow(2), 0.0, 1.0))
        phi = cosine * self.cos_m - sine * self.sin_m

        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)

        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1), 1.0)

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output


class FaceNetArcFace(nn.Module):
    def __init__(self, num_classes: int, embedding_dim: int = 512,
                 s: float = 25.0, m: float = 0.30):
        super().__init__()
        self.backbone = InceptionResnetV1(
            pretrained='vggface2',
            classify=False
        )
        in_features = 512
        self.embedding = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features, embedding_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
        )
        self.arc_margin = ArcMarginProduct(
            in_features=embedding_dim,
            out_features=num_classes,
            s=s,
            m=m,
            easy_margin=False,
        )

    def forward(self, x, labels=None):
        feat = self.backbone(x)          # (B, 512)
        emb = self.embedding(feat)       # (B, 512)
        emb = F.normalize(emb, dim=1)

        if labels is None:
            # inference: pure cosine logits (tanpa margin)
            logits = F.linear(
                F.normalize(emb),
                F.normalize(self.arc_margin.weight)
            )
        else:
            # training: pakai ArcFace margin
            logits = self.arc_margin(emb, labels)

        return logits, emb


In [None]:
# ============================================
# Cell 5 – Load Model & Preprocess
# ============================================
def load_model():
    model = FaceNetArcFace(
        num_classes=num_classes,
        embedding_dim=512,
        s=25.0,
        m=0.30,
    )

    ckpt = torch.load(MODEL_PATH, map_location=device)

    if isinstance(ckpt, dict) and "model" in ckpt:
        state_dict = ckpt["model"]
    else:
        state_dict = ckpt

    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    return model


model = load_model()
print("Model loaded.")

preprocess = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])


def prepare_image(pil_img: Image.Image) -> torch.Tensor:
    pil_img = pil_img.convert("RGB")
    x = preprocess(pil_img)
    x = x.unsqueeze(0)
    return x.to(device)


In [None]:
# ============================================
# Cell 6 – Prediction Function
# ============================================
def predict(pil_img: Image.Image):
    face = detect_and_crop_from_pil(pil_img)
    if face is None:
        return None, None

    x = prepare_image(face)
    with torch.no_grad():
        logits, emb = model(x)
        pred_idx = logits.argmax(1).item()

    return pred_idx, face


In [None]:
# ============================================
# Cell 7 – Scan Test Dataset (subdir per class)
# ============================================
def scan_test_dataset(root_dir: str):
    image_paths = []
    class_names = []

    classes = sorted(os.listdir(root_dir))
    classes = [c for c in classes if os.path.isdir(os.path.join(root_dir, c))]

    print("Found classes in Test:", len(classes))

    for cls in classes:
        cls_dir = os.path.join(root_dir, cls)
        for fname in os.listdir(cls_dir):
            if fname.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
                fpath = os.path.join(cls_dir, fname)
                image_paths.append(fpath)
                class_names.append(cls)

    print("Total test images:", len(image_paths))
    return image_paths, class_names


test_paths, test_folder_classes = scan_test_dataset(TEST_DIR)

In [None]:
# ============================================
# Cell 8 – Inference + CSV + Metrics
# ============================================
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from tqdm import tqdm

results = []
all_true = []
all_pred = []

skipped_io = 0
skipped_noface = 0

for path, true_class in tqdm(
    list(zip(test_paths, test_folder_classes)),
    total=len(test_paths),
    desc="Predicting on Test"
):
    # 1) Load image (handle file rusak / format aneh)
    try:
        img = Image.open(path)
    except Exception as e:
        # sama seperti sebelumnya: skip yang tidak bisa dibuka
        print("Error loading:", path, "|", e)
        skipped_io += 1
        continue

    # 2) Run prediction (pakai pipeline crop + model)
    pred_idx, _ = predict(img)

    # Kalau wajah tidak terdeteksi, skip dari evaluasi & csv
    if pred_idx is None:
        skipped_noface += 1
        continue

    pred_name = idx_to_name.get(pred_idx, "UNKNOWN_IDX")

    # 3) Simpan ke CSV (filename, label)
    results.append({
        "filename": os.path.basename(path),
        "label": pred_name,        # string nama kelas
    })

    # 4) Kumpulkan y_true dan y_pred untuk metric
    true_idx = label_map[true_class]   # true label dari nama folder
    all_true.append(true_idx)
    all_pred.append(pred_idx)

# ---- Save CSV ----
df = pd.DataFrame(results)
df.to_csv(OUTPUT_CSV, index=False)


In [None]:
# ---- Hitung Metric ----
acc = accuracy_score(all_true, all_pred)
cm = confusion_matrix(all_true, all_pred)
report = classification_report(
    all_true,
    all_pred,
    target_names=[idx_to_name[i] for i in range(num_classes)],
)

print("Saved CSV to:", OUTPUT_CSV)
print("Test Accuracy :", acc)
print("Confusion matrix shape:", cm.shape)
print(report)