In [7]:
import os
import random
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms as T
from facenet_pytorch import MTCNN
from transformers import AutoModelForImageClassification, AutoImageProcessor

In [23]:
# --------------------------
# Config
# --------------------------
model_path = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/V13_20250527_161430"
celeba_sample_dir = "/Volumes/JavaAOT/Documents/AI/ml_expressions/img_datasets/wider_face_dataset_curated"
TEMPERATURE_PATH = os.path.join(model_path, "temperature_V13.txt")
NUM_IMAGES = 25

In [24]:
# --------------------------
# Load model, processor, temp
# --------------------------
model = AutoModelForImageClassification.from_pretrained(model_path)
processor = AutoImageProcessor.from_pretrained(model_path)
id2label = model.config.id2label
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device).eval()

with open(TEMPERATURE_PATH) as f:
    TEMPERATURE = float(f.read().strip())

Some weights of the model checkpoint at /Users/natalyagrokh/AI/ml_expressions/img_expressions/V13_20250527_161430 were not used when initializing ViTForImageClassification: ['classifier.1.bias', 'classifier.1.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at /Users/natalyagrokh/AI/ml_expressions/img_expressions/V13_20250527_161430 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to b

In [25]:
# --------------------------
# Face alignment
# --------------------------
mtcnn = MTCNN(image_size=224, post_process=True)

def align_face(image):
    aligned = mtcnn(image)
    if aligned is None:
        return image
    return T.ToPILImage()(aligned)

In [26]:
# --------------------------
# Inference with TTA
# --------------------------
def predict_with_metadata(image_path, num_aug=3):
    image = align_face(Image.open(image_path).convert("RGB"))
    aug = T.Compose([T.RandomHorizontalFlip(), T.ColorJitter(0.2, 0.2, 0.2)])
    probs_all = []

    for _ in range(num_aug):
        aug_img = aug(image)
        inputs = processor(aug_img, return_tensors="pt").to(device)
        with torch.no_grad():
            logits = model(**inputs).logits
            probs = F.softmax(logits / TEMPERATURE, dim=-1)
            probs_all.append(probs)

    final_probs = torch.mean(torch.stack(probs_all), dim=0).squeeze()
    entropy = -torch.sum(final_probs * torch.log(final_probs + 1e-8)).item()
    conf, pred_idx = torch.max(final_probs, dim=-1)
    top3 = [(id2label[i.item()], round(final_probs[i].item(), 3)) for i in torch.topk(final_probs, 3).indices]
    return id2label[pred_idx.item()], conf.item(), entropy, top3

In [27]:
# --------------------------
# Run on random sample
# --------------------------
all_images = [os.path.join(dp, f) for dp, _, fn in os.walk(celeba_sample_dir) for f in fn if f.lower().endswith((".jpg", ".png", ".jpeg"))]
sample_images = random.sample(all_images, min(NUM_IMAGES, len(all_images)))

print(f"üß™ Inspecting {len(sample_images)} CelebA images with V13 model:\n")

for img_path in sample_images:
    try:
        label, conf, entropy, top3 = predict_with_metadata(img_path)
        print(f"{img_path}")
        print(f"‚Üí Pred: {label} | Conf: {conf:.3f} | Entropy: {entropy:.3f}")
        print(f"‚Üí Top3: {top3}")
        print("-" * 60)
    except Exception as e:
        print(f"‚ö†Ô∏è Error on {img_path}: {e}")

üß™ Inspecting 25 CelebA images with V13 model:

/Volumes/JavaAOT/Documents/AI/ml_expressions/img_datasets/wider_face_dataset_curated/28_Sports_Fan_Sports_Fan_28_875.jpg_face1.jpg
‚Üí Pred: surprise | Conf: 0.201 | Entropy: 2.048
‚Üí Top3: [('surprise', 0.201), ('questioning', 0.131), ('sadness', 0.13)]
------------------------------------------------------------
/Volumes/JavaAOT/Documents/AI/ml_expressions/img_datasets/wider_face_dataset_curated/51_Dresses_wearingdress_51_204.jpg_face1.jpg
‚Üí Pred: surprise | Conf: 0.196 | Entropy: 2.051
‚Üí Top3: [('surprise', 0.196), ('questioning', 0.147), ('sadness', 0.125)]
------------------------------------------------------------
/Volumes/JavaAOT/Documents/AI/ml_expressions/img_datasets/wider_face_dataset_curated/49_Greeting_peoplegreeting_49_381.jpg_face2.jpg
‚Üí Pred: surprise | Conf: 0.209 | Entropy: 2.038
‚Üí Top3: [('surprise', 0.209), ('sadness', 0.147), ('questioning', 0.126)]
---------------------------------------------------------