In [1]:
import glob
import os
import shutil
import torch
import json
import matplotlib.pyplot as plt
from facenet_pytorch import MTCNN
from PIL import Image, UnidentifiedImageError
from tqdm import tqdm
from collections import Counter
from pathlib import Path
from transformers import AutoModelForImageClassification, AutoImageProcessor
from torchvision import transforms as T
import torch.nn.functional as F
import cv2
import numpy as np

In [2]:
# --------------------------
# 1. MTCNN Face Alignment
# --------------------------
mtcnn = MTCNN(image_size=224, post_process=True)

def align_face(image):
    aligned = mtcnn(image)
    if aligned is None:
        return image
    return T.ToPILImage()(aligned)

In [4]:
# --------------------------
# 2. Load and Apply Temperature Scaling
# --------------------------

mtcnn = MTCNN(image_size=224, post_process=True)

def align_face(image):
    aligned = mtcnn(image)
    if aligned is None:
        return image
    return T.ToPILImage()(aligned)

# Dynamically locate the most recent V*-tagged model directory
MODEL_ROOT = "/Users/natalyagrokh/AI/ml_expressions/img_expressions"
model_dirs = sorted(
    [os.path.join(MODEL_ROOT, d) for d in os.listdir(MODEL_ROOT)
     if d.startswith("V") and os.path.isdir(os.path.join(MODEL_ROOT, d))],
    key=lambda x: os.path.getmtime(x),
    reverse=True
)

if not model_dirs:
    raise FileNotFoundError("‚ùå No model directories found under MODEL_ROOT.")

latest_model_dir = model_dirs[0]
print(f"üìÅ Using model directory: {latest_model_dir}")

# Dynamically find logits and labels files inside latest directory
logits_path = os.path.join(latest_model_dir, "logits_eval_V14.npy")
labels_path = os.path.join(latest_model_dir, "labels_eval_V14.npy")

try:
    with open(os.path.join(latest_model_dir, "temperature_V14.txt")) as f:
        TEMPERATURE = float(f.read().strip())
    print(f"üå°Ô∏è Loaded precomputed temperature: {TEMPERATURE:.4f}")
except Exception as e:
    TEMPERATURE = 1.5  # üëà Recommended fallback for now
    print(f"‚ö†Ô∏è Could not load temperature. Using fallback: {TEMPERATURE}")

üìÅ Using model directory: /Users/natalyagrokh/AI/ml_expressions/img_expressions/V14_20250531_160421
üå°Ô∏è Loaded precomputed temperature: 1.3240


In [5]:
# --------------------------
# 3. Configuration
# --------------------------

model_path = model_dirs[0]
print(f"‚úÖ Auto-loaded model from: {model_path}")

# Load model and processor
model = AutoModelForImageClassification.from_pretrained(model_path)
processor = AutoImageProcessor.from_pretrained(model_path)

BASE_CONFIDENCE_THRESHOLD = 0.35
ENTROPY_THRESHOLD = 1.9
MINORITY_CLASSES = {'disgust', 'fear', 'sadness'}
MINORITY_CLASS_THRESHOLD = 0.38

IMAGE_DIR = "/Volumes/JavaAOT/Documents/AI/ml_expressions/img_datasets/celeba_dataset"
SORTED_OUTPUT_DIR = "/Users/natalyagrokh/AI/ml_expressions/img_datasets/celeba_dataset_sorted"
REVIEW_DIR = "/Users/natalyagrokh/AI/ml_expressions/img_datasets/celeba_dataset_tosort"

Some weights of the model checkpoint at /Users/natalyagrokh/AI/ml_expressions/img_expressions/V14_20250531_160421 were not used when initializing ViTForImageClassification: ['classifier.1.bias', 'classifier.1.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at /Users/natalyagrokh/AI/ml_expressions/img_expressions/V14_20250531_160421 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to b

‚úÖ Auto-loaded model from: /Users/natalyagrokh/AI/ml_expressions/img_expressions/V14_20250531_160421


In [6]:
# # --------------------------
# # 1.5. GPU Environment Setup
# # --------------------------
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Change this per parallel job (e.g., "1", "2", ...)
# print("Process restricted to GPUs:", os.environ["CUDA_VISIBLE_DEVICES"])

# # Optional: Monitor GPU usage
# gpu_usage = subprocess.check_output(["nvidia-smi"]).decode("utf-8")
# print("Current GPU usage:\n", gpu_usage)

In [7]:
# --------------------------
# 4. Device Setup
# --------------------------
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("‚úÖ Using Apple M-series GPU (MPS backend).")
else:
    device = torch.device("cpu")
    print("‚ö†Ô∏è MPS not available. Using CPU.")

model.to(device).eval()

# Get label mapping
id2label = model.config.id2label

‚úÖ Using Apple M-series GPU (MPS backend).


In [8]:
# --------------------------
# 5. Prediction Function
# --------------------------
# def predict_label_with_confidence(image_path, topk=3):
#     image = align_face(Image.open(image_path).convert("RGB"))
#     inputs = processor(image, return_tensors="pt").to(model.device)
#     with torch.no_grad():
#         logits = model(**inputs).logits
#         probs = F.softmax(logits / TEMPERATURE, dim=-1).squeeze()
#         entropy = -torch.sum(probs * torch.log(probs + 1e-8)).item()
#         top_probs, top_idxs = torch.topk(probs, topk)
#     return [(id2label[i.item()], top_probs[i].item()) for i in range(topk)], entropy
def predict_label_with_confidence(image_path, num_aug=3):
    image = align_face(Image.open(image_path).convert("RGB"))
    aug = T.Compose([T.RandomHorizontalFlip(), T.ColorJitter(0.2, 0.2, 0.2)])

    probs_all = []
    for _ in range(num_aug):
        aug_img = aug(image)
        inputs = processor(aug_img, return_tensors="pt").to(model.device)
        logits = model(**inputs).logits
        probs_all.append(F.softmax(logits / TEMPERATURE, dim=-1))

    probs = torch.mean(torch.stack(probs_all), dim=0).squeeze()
    entropy = -torch.sum(probs * torch.log(probs + 1e-8)).item()
    conf, pred_idx = torch.max(probs, dim=-1)

    return id2label[pred_idx.item()], conf.item(), entropy, probs

# function aggregates predictions over multiple augmentations for each input
def tta_inference(image, model, processor, device, N=5):
    import torch.nn.functional as F
    from torchvision import transforms as T

    tta_transforms = [
        T.Compose([]),  # original
        T.Compose([T.RandomHorizontalFlip(p=1.0)]),
        T.Compose([T.ColorJitter(0.3, 0.3, 0.3)]),
        T.Compose([T.RandomRotation(10)]),
        T.Compose([T.RandomResizedCrop(224, scale=(0.9, 1.0))]),
    ]

    probs_list = []
    for t in tta_transforms:
        img_aug = t(image)
        inputs = processor(img_aug, return_tensors="pt").to(device)
        with torch.no_grad():
            logits = model(**inputs).logits
            probs = F.softmax(logits / TEMPERATURE, dim=-1).cpu()
        probs_list.append(probs)
    probs_stack = torch.stack(probs_list, dim=0)
    probs_mean = probs_stack.mean(dim=0).squeeze()
    conf, pred_idx = torch.max(probs_mean, dim=-1)
    entropy = -torch.sum(probs_mean * torch.log(probs_mean + 1e-8)).item()
    return conf.item(), pred_idx.item(), entropy, probs_mean

In [9]:
# --------------------------
# 6. Sorting Pipeline
# --------------------------
total_conf = 0.0
total_entropy = 0.0
num_samples = 0
log_lines = []

def sort_images():
    global total_conf, total_entropy, num_samples  # Ensure global scope for tracking

    os.makedirs(SORTED_OUTPUT_DIR, exist_ok=True)
    os.makedirs(REVIEW_DIR, exist_ok=True)

    valid_exts = {".jpg", ".jpeg", ".png", ".tif", ".tiff"}
    image_paths = [p for p in Path(IMAGE_DIR).rglob("*") if p.suffix.lower() in valid_exts]

    review_manifest, counts = [], Counter()

    for img_path in tqdm(image_paths, desc="üîç Sorting images"):
        try:
            label, conf, entropy, probs = predict_label_with_confidence(img_path)
            top3 = [(id2label[i.item()], round(probs[i].item(), 3)) for i in torch.topk(probs, 3).indices]
            log_line = f"{str(img_path)} | pred={label} | conf={conf:.3f} | entropy={entropy:.3f} | top3={top3}\n"
            log_lines.append(log_line)

            num_samples += 1
            total_conf += conf
            total_entropy += entropy

            threshold = MINORITY_CLASS_THRESHOLD if label in MINORITY_CLASSES else BASE_CONFIDENCE_THRESHOLD

            if conf < 0.20 and entropy > 2.0:
                reason = "ood"
                label = "unknown"
                dest_dir = os.path.join(REVIEW_DIR, "unknown")
            elif conf < threshold or entropy > ENTROPY_THRESHOLD:
                reason = "thresholds"
                label = "unknown"
                dest_dir = os.path.join(REVIEW_DIR, "unknown")
            else:
                reason = "passed"
                dest_dir = os.path.join(SORTED_OUTPUT_DIR, label)

            if reason != "passed":
                review_manifest.append({
                    "file": str(img_path),
                    "top3": top3,
                    "confidence": round(conf, 4),
                    "entropy": round(entropy, 4),
                    "reason": reason
                })

            os.makedirs(dest_dir, exist_ok=True)
            shutil.copy2(img_path, os.path.join(dest_dir, os.path.basename(img_path)))
            counts[label] += 1

        except (UnidentifiedImageError, Exception) as e:
            print(f"‚ö†Ô∏è Error: {img_path} | {e}")

    with open(os.path.join(REVIEW_DIR, "review_manifest.json"), "w") as f:
        json.dump(review_manifest, f, indent=2)

    if num_samples > 0:
        print(f"\nüìä Mean confidence: {total_conf / num_samples:.4f}")
        print(f"üìä Mean entropy   : {total_entropy / num_samples:.4f}")
    else:
        print("\n‚ö†Ô∏è No samples processed.")

    print("\n‚úÖ Sorting complete.")
    print("üìä Image counts per class:")
    for label in sorted(counts):
        print(f"  {label:10s} : {counts[label]}")

    log_path = os.path.join(REVIEW_DIR, "sorting_log.txt")
    with open(log_path, "w") as f:
        f.writelines(log_lines)
    print(f"üìù Sorting log saved to: {log_path}")

In [10]:
# --------------------------
# 7. Run
# --------------------------
if __name__ == "__main__":
    sort_images()

üîç Sorting images:  71%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè   | 72223/101819 [3:48:00<1:50:36,  4.46it/s]

‚ö†Ô∏è Error: /Volumes/JavaAOT/Documents/AI/ml_expressions/img_datasets/celeba_dataset/172544.png | [Errno 28] No space left on device: '/Volumes/JavaAOT/Documents/AI/ml_expressions/img_datasets/celeba_dataset/172544.png' -> '/Users/natalyagrokh/AI/ml_expressions/img_datasets/celeba_dataset_tosort/unknown/172544.png'


üîç Sorting images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 101819/101819 [5:28:12<00:00,  5.17it/s]



üìä Mean confidence: 0.1534
üìä Mean entropy   : 2.1809

‚úÖ Sorting complete.
üìä Image counts per class:
  unknown    : 101818
üìù Sorting log saved to: /Users/natalyagrokh/AI/ml_expressions/img_datasets/celeba_dataset_tosort/sorting_log.txt


In [13]:
# --------------------------
# 8. Visualization of Rejected Images (Post-Sorting)
# --------------------------
def visualize_rejected_review_images(manifest_path=None, max_images=25):
    import matplotlib.pyplot as plt

    if manifest_path is None:
        manifest_path = os.path.join(REVIEW_DIR, "review_manifest.json")

    if not os.path.exists(manifest_path):
        print(f"‚ö†Ô∏è No review manifest found at: {manifest_path}")
        return

    with open(manifest_path) as f:
        manifest = json.load(f)

    print(f"üñºÔ∏è Visualizing {min(len(manifest), max_images)} of {len(manifest)} rejected samples")

    for idx, entry in enumerate(manifest[:max_images]):
        try:
            img = Image.open(entry["file"]).convert("RGB")
            plt.subplot(5, 5, idx + 1)
            plt.imshow(img)
            plt.axis("off")
            plt.title(f"{entry['top3'][0][0]}\n{entry['confidence']:.2f}, {entry['entropy']:.2f}")
        except Exception as e:
            print(f"‚ö†Ô∏è Could not open image {entry['file']}: {e}")
    plt.tight_layout()
    save_path = os.path.join(REVIEW_DIR, "rejected_grid.png")
    plt.savefig(save_path, dpi=150)
    print(f"üñºÔ∏è Saved rejected image grid to: {save_path}")
    plt.close()

In [14]:
# --------------------------
# 9. Auto-Pseudo-Labeling Export
# --------------------------
def export_pseudo_labeled_images():
    print("\nüìÅ Exporting high-confidence CelebA images as pseudo-labels")
    pseudo_root = os.path.join(SORTED_OUTPUT_DIR, "celeba_pseudo_labels")
    os.makedirs(pseudo_root, exist_ok=True)

    for label in os.listdir(SORTED_OUTPUT_DIR):
        label_dir = os.path.join(SORTED_OUTPUT_DIR, label)
        if not os.path.isdir(label_dir) or label == "celeba_pseudo_labels":
            continue
        dest_dir = os.path.join(pseudo_root, label)
        os.makedirs(dest_dir, exist_ok=True)
        for img_file in os.listdir(label_dir):
            if img_file.lower().endswith((".jpg", ".jpeg", ".png", ".tif")):
                src = os.path.join(label_dir, img_file)
                dst = os.path.join(dest_dir, img_file)
                try:
                    shutil.copy2(src, dst)
                except Exception as e:
                    print(f"‚ö†Ô∏è Skipping {src}: {e}")
    print("‚úÖ Pseudo-labeled CelebA export complete at:", pseudo_root)

In [15]:
# --------------------------
# 10. Post-run Hook
# --------------------------
if __name__ == "__main__":
    sort_images()
    visualize_rejected_review_images()
    export_pseudo_labeled_images()

üîç Sorting images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 101819/101819 [5:46:17<00:00,  4.90it/s]



üìä Mean confidence: 0.1534
üìä Mean entropy   : 2.1809

‚úÖ Sorting complete.
üìä Image counts per class:
  unknown    : 101819
üìù Sorting log saved to: /Users/natalyagrokh/AI/ml_expressions/img_datasets/celeba_dataset_tosort/sorting_log.txt
üñºÔ∏è Visualizing 25 of 101819 rejected samples
üñºÔ∏è Saved rejected image grid to: /Users/natalyagrokh/AI/ml_expressions/img_datasets/celeba_dataset_tosort/rejected_grid.png

üìÅ Exporting high-confidence CelebA images as pseudo-labels
‚úÖ Pseudo-labeled CelebA export complete at: /Users/natalyagrokh/AI/ml_expressions/img_datasets/celeba_dataset_sorted/celeba_pseudo_labels
