In [None]:
import os
import shutil
import torch
from PIL import Image
from pathlib import Path
from collections import Counter
import subprocess
from transformers import AutoModelForImageClassification, AutoImageProcessor
import torch.nn.functional as F
from tqdm import tqdm
import json

In [None]:
# --------------------------
# 1. Configuration
# --------------------------
model_path = "/Volumes/JavaAOT/Documents/AI/ml_expressions/img_expressions/vit_final_independent_V6"
CONFIDENCE_THRESHOLD = 0.85  # Adjusted for more leniency
DISGUST_LABEL = 'disgust'

IMAGE_DIR = "/Volumes/JavaAOT/Documents/AI/ml_expressions/img_datasets/celeba_dataset"
SORTED_OUTPUT_DIR = "/Volumes/JavaAOT/Documents/AI/ml_expressions/img_datasets/celeba_dataset_sorted"
REVIEW_DIR = "/Volumes/JavaAOT/Documents/AI/ml_expressions/img_datasets/celeba_dataset_tosort"

In [None]:
# --------------------------
# 2. Move Model to GPU and Set Evaluation Mode
# --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForImageClassification.from_pretrained(model_path).to(device).eval()
processor = AutoImageProcessor.from_pretrained(model_path)
id2label = model.config.id2label

In [None]:
# --------------------------
# 3. Prediction Function
# --------------------------
def predict_label_with_confidence(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(image, return_tensors="pt").to(device)
    with torch.no_grad():
        logits = model(**inputs).logits
        probs = F.softmax(logits, dim=-1).squeeze()
        conf, pred_idx = torch.max(probs, dim=-1)
    return id2label[pred_idx.item()], conf.item()

In [None]:
# --------------------------
# 4.  Disgust-Only Sorting Pipeline
# --------------------------
def sort_disgust_images():
    os.makedirs(SORTED_OUTPUT_DIR, exist_ok=True)
    os.makedirs(REVIEW_DIR, exist_ok=True)
    image_paths = [p for p in Path(IMAGE_DIR).rglob("*") if p.suffix.lower() in [".jpg", ".jpeg", ".png", ".tif", ".tiff"]]

    review_manifest = []
    counts = Counter()
    disgust_dir = os.path.join(SORTED_OUTPUT_DIR, DISGUST_LABEL)

    for img_path in tqdm(image_paths, desc="Sorting images for Disgust"):
        try:
            label, confidence = predict_label_with_confidence(img_path)
            label = label.lower()  # Normalize case

            # Optional diagnostic
            # print(f"{img_path.name} â†’ {label} ({confidence:.2f})")

            if label == DISGUST_LABEL and confidence >= CONFIDENCE_THRESHOLD:
                dest_dir = disgust_dir
            else:
                dest_dir = os.path.join(REVIEW_DIR, label)
                os.makedirs(dest_dir, exist_ok=True)
                review_manifest.append({
                    "file": str(img_path),
                    "pred": label,
                    "confidence": confidence
                })

            os.makedirs(dest_dir, exist_ok=True)
            shutil.copy2(img_path, os.path.join(dest_dir, os.path.basename(img_path)))
            counts[label] += 1

        except Exception as e:
            print(f"Error processing {img_path}: {e}")

    with open(os.path.join(REVIEW_DIR, "review_manifest.json"), "w") as f:
        json.dump(review_manifest, f, indent=2)

    print("\nSorting complete.")
    print("Image counts:")
    for label, count in counts.items():
        print(f"{label:10s}: {count}")

In [None]:
# --------------------------
# 5. Run
# --------------------------
if __name__ == "__main__":
    sort_disgust_images()