In [1]:
#V29 changes:
    # overview: selectively turn off label smoothing for contempt,disgust 
    # section #2 - updated CustomTrainer,TargetedSmoothedCrossEntropyLoss
    # section #3 - loading V28 checkpoint

In [2]:
# --------------------------
# 0. Imports
# --------------------------
# Standard Library Imports
import datasets
import csv
import gc
import glob
import multiprocessing as mp
import os
import random
import re
import shutil
import subprocess
import sys
import time

# Third-Party Imports
import accelerate
import dill
import face_recognition
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
#import tensorflow as tf
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import transformers

# From Imports
from collections import Counter
from datasets import ClassLabel, Dataset, Features, Image as DatasetsImage, concatenate_datasets, load_dataset
from datetime import datetime
from functools import partial
from imagehash import phash, hex_to_hash
from io import BytesIO
from pathlib import Path
from PIL import Image, ImageOps, ExifTags, UnidentifiedImageError
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix, log_loss, precision_recall_fscore_support
from torch import nn
from torch.nn import functional as F
from torch.optim import AdamW, LBFGS
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torchvision import transforms
from torchvision.transforms import (
    GaussianBlur,
    RandAugment,
    RandomAffine,
    RandomApply,
    RandomPerspective,
    RandomAdjustSharpness,
    ToPILImage,
    ToTensor
)
from tqdm import tqdm
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    EarlyStoppingCallback,
    TrainingArguments,
    Trainer,
)

In [3]:
# --------------------------
# 1. Global Configurations
# --------------------------
RUN_INFERENCE = True  # Toggle this off to disable running inference
IMAGE_DIR = "/Users/natalyagrokh/AI/ml_expressions/img_datasets/ferckjalfaga_dataset"
BASE_PATH = IMAGE_DIR
MODEL_ROOT = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training"

LABEL_NAMES = [
    'anger', 'disgust', 'fear', 'happiness', 'neutral',
    'questioning', 'sadness', 'surprise', 'contempt', 'unknown'
]
id2label = dict(enumerate(LABEL_NAMES))
label2id = {v: k for k, v in id2label.items()}

HARD_CLASS_NAMES = ['contempt', 'disgust', 'fear', 'questioning']
hard_class_ids = [label2id[n] for n in HARD_CLASS_NAMES]

VALID_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff")

def is_valid_image(filename):
    return filename.lower().endswith(VALID_EXTENSIONS) and not filename.startswith("._")

label_mapping = {name.lower(): name for name in LABEL_NAMES}

# üî¢ Dynamically determine the next version
def get_next_version(base_dir):

    # Use glob to find all entries matching the pattern
    all_entries = glob.glob(os.path.join(base_dir, "V*_*"))
    
    # Filter to include only directories
    existing = [
        os.path.basename(d) for d in all_entries if os.path.isdir(d)
    ]

    # Extract version numbers from the directory names
    versions = [
        int(d[1:].split("_")[0]) for d in existing
        if d.startswith("V") and "_" in d and d[1:].split("_")[0].isdigit()
    ]
    
    # Determine the next version number
    next_version = max(versions, default=0) + 1
    return f"V{next_version}"

# Automatically create a versioned output folder
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
VERSION = get_next_version("/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training")
VERSION_TAG = VERSION + "_" + timestamp
SAVE_DIR = os.path.join("/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training", VERSION_TAG)
LOGITS_PATH = os.path.join(SAVE_DIR, f"logits_eval_{VERSION}.npy")
LABELS_PATH = os.path.join(SAVE_DIR, f"labels_eval_{VERSION}.npy")
os.makedirs(SAVE_DIR, exist_ok=True)
print(f"üìÅ Output directory created: {SAVE_DIR}")

üìÅ Output directory created: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807


In [4]:
# --------------------------
# 2. Utility Functions (Metrics & Calibration)
# --------------------------

# ------------------------------------------
# Part A: Data Preparation & Augmentation
# ------------------------------------------

# üó∫Ô∏è Injects 'image_path' to dataset BEFORE any map/filter
def add_image_path(example):
    # Handle DatasetsImage and PIL.Image types robustly.
    img_obj = example["image"]
    path = getattr(img_obj, "filename", None)
    if path is None:
        # Fallback for rare cases where the path is not in the image object.
        if "file" in example:
            path = os.path.join(BASE_PATH, example["file"])
        else:
            path = ""
    example["image_path"] = path
    return example

# üè∑Ô∏è Standardizes labels from various sources (int, str, filepath) to a consistent integer ID.
def reconcile_labels(example):
    label = example.get("label", None)
    # Determine the original label string from different possible input formats.
    if isinstance(label, int):
        original_label = dataset.features["label"].int2str(label).strip().lower()
    elif isinstance(label, str):
        original_label = label.strip().lower()
    else:
        file_path = example["image_path"]
        original_label = os.path.basename(os.path.dirname(file_path)).lower() if file_path else None
    
    # Map the string label to its corresponding integer ID.
    pretrain_label = label_mapping.get(original_label)
    example["label"] = label2id[pretrain_label] if pretrain_label is not None else -1
    return example

# üîç Computes a perceptual hash (pHash) for an image to find visually similar duplicates.
def compute_hash(image_path):
    try:
        img = Image.open(image_path).convert("L").resize((64, 64))
        return str(phash(img))
    except Exception:
        return None

# üì¶ Applies augmentations and processes images on-the-fly for each batch.
# This is a more robust approach than pre-processing the entire dataset.
class DataCollatorWithAugmentation:
    def __init__(self, processor, augment_dict):
        self.processor = processor
        self.augment_dict = augment_dict

    def __call__(self, features):
        # Apply augmentations and process images
        processed_images = []
        for x in features:
            label = x["label"]
            # Select the correct augmentation pipeline
            aug_pipeline = self.augment_dict.get(label, data_augment)
            # Ensure image is in RGB format
            rgb_image = x["image"].convert("RGB")
            augmented_image = aug_pipeline(rgb_image)
            processed_images.append(augmented_image)

        # Create the 'pixel_values' batch.
        # Padding and truncation arguments are removed as they are not needed when
        # augmentations already resize the images to a uniform size.
        batch = self.processor(
            images=processed_images,
            return_tensors="pt"
        )
        
        # Add labels to the batch
        batch["labels"] = torch.tensor([x["label"] for x in features], dtype=torch.long)
        return batch

        
# ------------------------------------------
# Part B: Model & Training Components
# ------------------------------------------
      
# üèãÔ∏è Defines a custom Trainer that uses targeted loss function
class CustomLossTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Initialize our new targeted loss function.
        self.loss_fct = TargetedSmoothedCrossEntropyLoss(smoothing=0.05)

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss = self.loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss
        

# üîÑ Implements Cross-Entropy Loss with Targeted Label Smoothing for multiple classes
# Smoothing turned OFF for 'contempt' & 'disgust' to encourage confident predictions
class TargetedSmoothedCrossEntropyLoss(nn.Module):
    def __init__(self, smoothing=0.05):
        super().__init__()
        self.smoothing = smoothing
        # Define the list of classes for which smoothing will be turned OFF.
        self.target_class_ids = [label2id['contempt'], label2id['disgust']]

    def forward(self, logits, target):
        num_classes = logits.size(1)
        with torch.no_grad():
            # 1. Start with standard smoothed labels for all samples.
            smooth_labels = torch.full_like(logits, self.smoothing / (num_classes - 1))
            smooth_labels.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
            
            # 2. Create a mask to find all samples belonging to our target classes.
            # torch.isin is an efficient way to check for multiple IDs at once.
            target_mask = torch.isin(target, torch.tensor(self.target_class_ids, device=target.device))
            
            # 3. For the targeted samples, overwrite with a "sharp" one-hot encoding.
            if target_mask.any():
                sharp_labels = F.one_hot(target[target_mask], num_classes=num_classes).float()
                smooth_labels[target_mask] = sharp_labels
        
        log_probs = F.log_softmax(logits, dim=1)
        loss = -(smooth_labels * log_probs).sum(dim=1).mean()
        return loss

# ‚ö†Ô∏è Confidence Penalty to Reduce Overconfidence
def confidence_penalty(logits, beta=0.05):
    probs = F.softmax(logits, dim=1)
    log_probs = F.log_softmax(logits, dim=1)
    # Entropy is a measure of uncertainty; penalizing low entropy encourages less confident predictions.
    entropy = -torch.sum(probs * log_probs, dim=1)
    return beta * entropy.mean()

    
# üìä Compute Metrics with Confusion Matrix Logging
def compute_metrics_with_confusion(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)

    # Print classification report
    print("\nClassification Report:")
    report = classification_report(labels, preds, target_names=LABEL_NAMES, output_dict=True)
    print(classification_report(labels, preds, target_names=LABEL_NAMES))

    # Save raw logits/labels for calibration or further analysis
    np.save(os.path.join(SAVE_DIR, f"logits_eval_{VERSION}.npy"), logits)
    np.save(os.path.join(SAVE_DIR, f"labels_eval_{VERSION}.npy"), labels)

    # Save per-class F1/precision/recall/entropy to CSV (append per epoch)
    f1s = [report[name]["f1-score"] for name in LABEL_NAMES]
    recalls = [report[name]["recall"] for name in LABEL_NAMES]
    precisions = [report[name]["precision"] for name in LABEL_NAMES]

    # Entropy per class (sorted by entropy)
    softmax_probs = F.softmax(torch.tensor(logits), dim=-1)
    entropies = -torch.sum(softmax_probs * torch.log(softmax_probs + 1e-12), dim=-1)
    entropy_per_class = []
    for idx, class_name in enumerate(LABEL_NAMES):
        mask = (np.array(labels) == idx)
        if mask.any():
            class_entropy = entropies[mask].mean().item()
            entropy_per_class.append((class_name, class_entropy))
        else:
            entropy_per_class.append((class_name, 0.0))
    # Sort for display only; CSV row stays in canonical label order
    sorted_entropy = sorted(entropy_per_class, key=lambda x: x[1], reverse=True)

    # CSV logging
    epoch_metrics_path = os.path.join(SAVE_DIR, "per_class_metrics.csv")
    epoch = getattr(trainer.state, "epoch", None) if "trainer" in globals() else None
    df_row = pd.DataFrame({
        "epoch": [epoch],
        **{f"f1_{n}": [f] for n, f in zip(LABEL_NAMES, f1s)},
        **{f"recall_{n}": [r] for n, r in zip(LABEL_NAMES, recalls)},
        **{f"precision_{n}": [p] for n, p in zip(LABEL_NAMES, precisions)},
        **{f"entropy_{n}": [e] for n, e in entropy_per_class}
    })
    if os.path.exists(epoch_metrics_path):
        df_row.to_csv(epoch_metrics_path, mode="a", header=False, index=False)
    else:
        df_row.to_csv(epoch_metrics_path, mode="w", header=True, index=False)

    # Generate and print confusion matrix heatmap
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm, annot=True, fmt="d", cmap="Blues",
        xticklabels=LABEL_NAMES,
        yticklabels=LABEL_NAMES
    )
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, f"confusion_matrix_epoch_{VERSION}.png"))
    plt.close()

    # Top confused pairs
    confusion_pairs = [
        ((LABEL_NAMES[i], LABEL_NAMES[j]), cm[i][j])
        for i in range(len(LABEL_NAMES))
        for j in range(len(LABEL_NAMES)) if i != j
    ]
    top_confusions = sorted(confusion_pairs, key=lambda x: x[1], reverse=True)[:3]
    print("\nTop 3 confused class pairs:")
    for (true_label, pred_label), count in top_confusions:
        print(f"  - {true_label} ‚Üí {pred_label}: {count} instances")

    # Compute average prediction entropy
    avg_entropy = entropies.mean().item()
    print(f"\nüß† Avg prediction entropy: {avg_entropy:.4f}")

    print("\nüîç Class entropies (sorted):")
    for class_name, entropy in sorted_entropy:
        print(f"  - {class_name}: entropy = {entropy:.4f}")

    accuracy = (preds == labels).mean()
    return {"accuracy": accuracy}


# üíæ Saves the model, processor, and trainer state 
def save_model_and_processor(model, processor, save_dir, trainer=None):
    print(f"Saving model and processor to: {save_dir}")
    
    model = model.to("cpu")

    # Save processor
    processor.save_pretrained(save_dir)
    print(f"‚úÖ Processor saved to: {SAVE_DIR}")
    
    # Save full model
    model.save_pretrained(SAVE_DIR, safe_serialization=True)
    print(f"‚úÖ Full model saved to: {SAVE_DIR}")

    # Save state dict
    final_model_path = os.path.join(SAVE_DIR, 'final_model.pth')
    torch.save(model.state_dict(), final_model_path)
    print(f"‚úÖ State dict saved to: {final_model_path}")

    # Save trainer state
    if trainer is not None:
        try:
            trainer.save_model(os.path.join(save_dir, "backup_trainer_model"))
            print("‚úÖ Trainer backup saved.")
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to save trainer backup: {e}")

    # Memory cleanup
    del model
    gc.collect()
    try:
        torch.cuda.empty_cache()
    except Exception:
        pass  # Not all systems have CUDA
    print("‚úÖ Memory cleanup complete after save.")


# ------------------------------------------
# Part C: General & Debugging Utilities
# ------------------------------------------

def analyze_dataset_structure(dataset_to_analyze, id2label, base_path):
    # Print label schema from the dataset being analyzed.
    print("Label schema (from dataset):", dataset_to_analyze.features["label"])

    # Label distribution from the dataset object.
    label_counts = Counter(dataset_to_analyze["label"])
    print("\nüìä Full dataset label distribution (from Dataset object):")
    for label_id, count in sorted(label_counts.items()):
        print(f"  {id2label[label_id]}: {count} examples")

    # Dynamically detect and print minority classes for informational purposes.
    N = 3
    minority_classes_ids = set(
        label for label, _ in sorted(label_counts.items(), key=lambda x: x[1])[:N]
    )
    minority_names = [id2label[i] for i in minority_classes_ids]
    print(f"\n‚ö†Ô∏è  Dynamically identified minority classes (for info): {minority_names}")

    # Count images per directory from the source folder for reference.
    folder_image_counts = {}
    print("\nüìÇ Image count per source folder:")
    for label in sorted(os.listdir(base_path)):
        label_path = os.path.join(base_path, label)
        if os.path.isdir(label_path):
            valid_images = [img for img in os.listdir(label_path) if is_valid_image(img)]
            folder_image_counts[label] = len(valid_images)
            print(f"  {label}: {len(valid_images)} images")

    # The function now only returns the dictionary that is used later.
    return folder_image_counts

# üìÇ Finds the most recent file (e.g., 'audit.csv') from the latest V* run directory.
def find_latest_run_artifact(root_dir, filename):
    # Find all previous version directories.
    all_run_dirs = [
        d for d in os.listdir(root_dir)
        if d.startswith("V") and os.path.isdir(os.path.join(root_dir, d))
    ]
    if not all_run_dirs:
        return None

    # Sort by version number to reliably find the latest run.
    try:
        latest_run_dir_name = sorted(all_run_dirs, key=extract_version_from_path, reverse=True)[0]
        artifact_path = os.path.join(root_dir, latest_run_dir_name, filename)

        if os.path.exists(artifact_path):
            print(f"‚úÖ Found artifact '{filename}' in latest run: {latest_run_dir_name}")
            return artifact_path
        else:
            print(f"‚ö†Ô∏è Artifact '{filename}' not found in {latest_run_dir_name}.")
            return None
    except Exception as e:
        print(f"‚ö†Ô∏è Could not process previous run directories to find '{filename}'. Error: {e}")
        return None

# üîë Creates a unique 'label/filename.jpg' key from a full image path for reliable matching.
def get_relative_path_key(path):
    try:
        return os.path.join(os.path.basename(os.path.dirname(path)), os.path.basename(path))
    except Exception:
        return ""

# üî¢ Extracts version number from a directory name
def extract_version_from_path(path):
    import re
    match = re.search(r"V(\d+)", os.path.basename(path))
    return int(match.group(1)) if match else -1

# üö¶ Prints the label distribution for a given dataset, 
    #useful for debugging and sanity checks.
def check_label_integrity(dataset, LABEL_NAMES, label2id):
    # Count all mapped labels in the dataset.
    label_counts = Counter(dataset['label'])
    print("\nüö® Label distribution after mapping (before split):")
    for label_id in range(len(LABEL_NAMES)):
        label_name = LABEL_NAMES[label_id]
        print(f"  {label_name:12}: {label_counts.get(label_id, 0)}")

    # Specifically highlight 'surprise'
    surprise_id = label2id['surprise']
    if label_counts.get(surprise_id, 0) == 0:
        print("‚ùóWARNING: No 'surprise' images found after mapping!")
    elif label_counts[surprise_id] < 50:  # arbitrary threshold
        print(f"‚ö†Ô∏è Only {label_counts[surprise_id]} 'surprise' images found! Check curation or mapping.")

# üö¶ Prints the label distribution for a dictionary of datasets (e.g., train, eval).
def check_all_label_integrity(datasets_dict, LABEL_NAMES, label2id):
    for name, dataset in datasets_dict.items():
        print(f"\nüö® Label distribution for: {name}")
        label_counts = Counter(dataset['label'])
        
        for label_id in range(len(LABEL_NAMES)):
            label_name = LABEL_NAMES[label_id]
            print(f"  {label_name:12}: {label_counts.get(label_id, 0)}")
        
        surprise_id = label2id['surprise']
        if label_counts.get(surprise_id, 0) == 0:
            print("‚ùóWARNING: No 'surprise' images found in this split!")
        elif label_counts[surprise_id] < 50:
            print(f"‚ö†Ô∏è Only {label_counts[surprise_id]} 'surprise' images in {name}! Check curation or mapping.")

In [5]:
# --------------------------
# 3. Auto-Load V20 Golden Checkpoint
# --------------------------
# Manually set the path to your best-performing model
model_path = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V28_20250709_153248"
print(f"‚úÖ Explicitly loading V28checkpoint from: {model_path}")

# Load model and processor
model = AutoModelForImageClassification.from_pretrained(model_path)
processor = AutoImageProcessor.from_pretrained(model_path)

# Reset the classifier head for new training
model.classifier = nn.Linear(model.config.hidden_size, len(LABEL_NAMES))
model.config.id2label = id2label
model.config.label2id = label2id
model.config.num_labels = len(LABEL_NAMES)
print("‚úÖ Classifier head reset for new training.")

# Define device and push model to device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("\nüñ•Ô∏è Using device:", device)
model.to(device).eval()

‚úÖ Explicitly loading V28checkpoint from: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V28_20250709_153248
‚úÖ Classifier head reset for new training.

üñ•Ô∏è Using device: mps


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [6]:
# ==============================
# 4. Load and Prepare Dataset (with filename preservation)
# ==============================

# --- Dynamic File Count ---
print("üîç Counting valid image files on disk for verification...")
# This will recursively find all valid image files in your dataset directory
expected_file_count = len(
    [p for p in Path(BASE_PATH).rglob("*") if is_valid_image(p.name)]
)
print(f"‚úÖ Found {expected_file_count} image files in {BASE_PATH}")

# Disable caching BEFORE loading
datasets.disable_caching()
print("‚úÖ Datasets caching disabled for this run to ensure fresh data load.")

# Step 1: Load dataset and capture filepaths
dataset = load_dataset(
    "imagefolder",
    data_dir=BASE_PATH,
    split="train" # No need to specify cache_dir when caching is off
)

# Only run ONCE and only here, so "image_path" is never dropped later!
dataset = dataset.map(add_image_path, desc="Add file path to each record")
dataset = dataset.map(reconcile_labels, desc="Re-labeling dataset (preserving image_path)")
dataset = dataset.filter(lambda x: x["label"] != -1)

# ** Robust Verification **
final_count = len(dataset)
print(f"‚úÖ Total examples after filtering: {final_count}")
print("Sample with path:", dataset[0]["image_path"])

# Assertion checks whether loaded count is very close to the disk count
# Small tolerance accounts for any files that fail to load or be filtered
assert abs(final_count - expected_file_count) < 10, \
    f"Dataset size mismatch! Found {expected_file_count} files but loaded {final_count}."

assert dataset[0].get("image_path", None), "image_path missing from first record"

üîç Counting valid image files on disk for verification...
‚úÖ Found 17452 image files in /Users/natalyagrokh/AI/ml_expressions/img_datasets/ferckjalfaga_dataset
‚úÖ Datasets caching disabled for this run to ensure fresh data load.


Resolving data files:   0%|          | 0/17461 [00:00<?, ?it/s]

Add file path to each record:   0%|          | 0/17451 [00:00<?, ? examples/s]

Re-labeling dataset (preserving image_path):   0%|          | 0/17451 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17451 [00:00<?, ? examples/s]

‚úÖ Total examples after filtering: 17451
Sample with path: /Users/natalyagrokh/AI/ml_expressions/img_datasets/ferckjalfaga_dataset/anger/Abel_Pacheco_0002.jpg


In [7]:
# ==============================================================================
# 5. V25 Data Curation: Remove Hard Negatives
# ==============================================================================

# 1. Define the path to the hard negatives file generated by the V24 run.
V24_RUN_DIR = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V24_20250704_081028"
hard_negatives_path = os.path.join(V24_RUN_DIR, 'review_hardneg_contempt_questioning.txt')

if not os.path.exists(hard_negatives_path):
    raise FileNotFoundError(f"CRITICAL: The hard negatives file is missing: {hard_negatives_path}")

# 2. Load the list of image paths to be excluded into a set for fast lookup.
with open(hard_negatives_path, 'r') as f:
    # os.path.normpath ensures path formats are consistent (e.g., handles slashes).
    exclusion_list = {os.path.normpath(line.strip()) for line in f if line.strip()}
print(f"‚úÖ Loaded {len(exclusion_list)} hard-negative images to be excluded from training.")

# 3. Filter the main dataset.
#    The 'dataset' variable holds the data loaded in Section #4.
initial_count = len(dataset)
curated_dataset = dataset.filter(
    lambda example: os.path.normpath(example['image_path']) not in exclusion_list,
    desc="Filtering out hard negatives"
)
removed_count = initial_count - len(curated_dataset)
print(f"‚úÖ Curated dataset created for V25. Removed {removed_count} images.")
print(f"   - Original size: {initial_count} -> New size: {len(curated_dataset)}")

# All subsequent sections must now use 'curated_dataset'.

‚úÖ Loaded 15 hard-negative images to be excluded from training.


Filtering out hard negatives:   0%|          | 0/17451 [00:00<?, ? examples/s]

‚úÖ Curated dataset created for V25. Removed 4 images.
   - Original size: 17451 -> New size: 17447


In [8]:
# --------------------------
# 6. Dataset Label Overview and Folder Stats
# --------------------------

# Call the function and assign its single, useful return value.
folder_image_counts = analyze_dataset_structure(curated_dataset, id2label, BASE_PATH)

Label schema (from dataset): ClassLabel(names=['anger', 'contempt', 'disgust', 'fear', 'happiness', 'neutral', 'questioning', 'sadness', 'surprise', 'unknown'], id=None)

üìä Full dataset label distribution (from Dataset object):
  anger: 2302 examples
  disgust: 309 examples
  fear: 1432 examples
  happiness: 2892 examples
  neutral: 3333 examples
  questioning: 1895 examples
  sadness: 1706 examples
  surprise: 2783 examples
  contempt: 409 examples
  unknown: 386 examples

‚ö†Ô∏è  Dynamically identified minority classes (for info): ['contempt', 'disgust', 'unknown']

üìÇ Image count per source folder:
  anger: 2302 images
  contempt: 412 images
  disgust: 309 images
  fear: 1432 images
  happiness: 2892 images
  neutral: 3333 images
  questioning: 1896 images
  sadness: 1706 images
  surprise: 2783 images
  unknown: 386 images


In [9]:
# --------------------------
# 7. Perceptual Clustering for Ambiguous/Confused Classes
# --------------------------

CLUSTER_TARGETS = ["disgust", "sadness", "fear", "questioning", "contempt"]

for class_name in CLUSTER_TARGETS:
    class_dir = os.path.join(BASE_PATH, class_name)
    if not os.path.isdir(class_dir):
        print(f"‚ö†Ô∏è Class dir not found: {class_dir} (skipping)")
        continue

    class_images = [
        os.path.join(class_dir, f) for f in os.listdir(class_dir)
        if is_valid_image(f)
    ]
    hash_map = {}
    for path in class_images:
        h = compute_hash(path)
        if h:
            hash_map.setdefault(h, []).append(path)

    cluster_dir = os.path.join(SAVE_DIR, f"{class_name}_clusters")
    os.makedirs(cluster_dir, exist_ok=True)

    print(f"üîç {class_name.capitalize()} hash clusters with more than 1 image:")
    for h, paths in hash_map.items():
        if len(paths) > 1:
            cluster_path = os.path.join(cluster_dir, h)
            os.makedirs(cluster_path, exist_ok=True)
            for p in paths:
                shutil.copy(p, cluster_path)
            print(f"  - Cluster {h[:8]}: {len(paths)} images copied for review")

üîç Disgust hash clusters with more than 1 image:
üîç Sadness hash clusters with more than 1 image:
  - Cluster 958c52e1: 2 images copied for review
  - Cluster ee9a8d33: 2 images copied for review
  - Cluster d0890396: 2 images copied for review
  - Cluster bb0d06f2: 2 images copied for review
  - Cluster d7f00fa2: 2 images copied for review
üîç Fear hash clusters with more than 1 image:
  - Cluster 9ae56592: 2 images copied for review
  - Cluster 91c8ee81: 2 images copied for review
  - Cluster dae5a596: 2 images copied for review
üîç Questioning hash clusters with more than 1 image:
  - Cluster da014886: 2 images copied for review
  - Cluster 9db42783: 2 images copied for review
üîç Contempt hash clusters with more than 1 image:


In [10]:
# --------------------------
# 8. Class Frequency-Aware Augmentation Targeting
# --------------------------

# Compute label frequencies from train split (post filtering)
label_freqs = Counter(curated_dataset["label"])
label_id2name = {v: k for k, v in label2id.items()}
label_name2id = {v: k for k, v in label_id2name.items()}

# Get lowest-count classes dynamically
minority_by_count = sorted(label_freqs, key=label_freqs.get)[:3]
minority_by_name = [label_id2name[i] for i in minority_by_count]
minority_by_name = [n for n in minority_by_name if n != "unknown"]

# Manually include known confused or underperforming classes
manual_focus_classes = ['disgust', 'questioning', 'contempt']

# Merge and deduplicate
minority_class_names = list(set(minority_by_name + manual_focus_classes))

# Final list as label indices
minority_classes = [label_name2id[name] for name in minority_class_names]

print(f"üéØ Targeted minority augmentation will apply to: {minority_class_names}")

üéØ Targeted minority augmentation will apply to: ['contempt', 'questioning', 'disgust']


In [11]:
# --------------------------
# 9. Define Data Augmentation and Preprocessing Transformation
# --------------------------

# V25 Strategy: We define the augmentation pipelines here, but they will be
# applied on-the-fly by the DataCollator during training. This is a more
# robust method that avoids issues with the data balancing steps.

# Baseline augmentation for majority classes.
data_augment = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomRotation(10),
    T.ColorJitter(brightness=0.1, contrast=0.1)
])

# Stronger RandAugment for targeted minority classes.
minority_aug = T.Compose([
    RandAugment(num_ops=2, magnitude=9),
    T.RandomResizedCrop(224, scale=(0.7, 1.0)),
    T.ColorJitter(0.3, 0.3, 0.3, 0.1),
])

print("‚úÖ Augmentation pipelines defined. They will be applied by the data collator during training.")

‚úÖ Augmentation pipelines defined. They will be applied by the data collator during training.


In [12]:
# --------------------------
# 10. Balance Dataset (with NO oversampling for 'unknown')
# --------------------------
MINORITY_CAP = 2250
balanced_subsets = []
label_counts = Counter(curated_dataset["label"])
print("Original label distribution:", label_counts)

for label, count in label_counts.items():
    subset = curated_dataset.filter(lambda x: x['label'] == label, num_proc=1)
    class_name = LABEL_NAMES[label]
    if class_name == "unknown":
        balanced_subsets.append(subset)
    elif count < MINORITY_CAP:
        multiplier = MINORITY_CAP // len(subset)
        remainder = MINORITY_CAP % len(subset)
        subset = concatenate_datasets([subset] * multiplier + [subset.select(range(remainder))])
        balanced_subsets.append(subset)
    else:
        # Append full set (no downsampling for majority classes)
        balanced_subsets.append(subset)

train_dataset = concatenate_datasets(balanced_subsets).shuffle(seed=42)
print("After balancing:", Counter(train_dataset['label']))

hard_classes = ['contempt', 'disgust', 'questioning', 'surprise', 'fear']
hard_class_ids = [label2id[c] for c in hard_classes]

# Calculate weights: Give hard classes 2x, others 1x
weights = [2.0 if l in hard_class_ids else 1.0 for l in train_dataset["label"]]
weights = torch.DoubleTensor(weights)
sampler = torch.utils.data.WeightedRandomSampler(
    weights=weights,
    num_samples=len(weights),
    replacement=True
)

# üö¶ Check and print label distributions across all important splits
check_all_label_integrity(
    {
        "full curated dataset": curated_dataset,
        "train set (post-balance)": train_dataset,
    },
    LABEL_NAMES, label2id
)

# cleaning memory
del balanced_subsets
del subset
gc.collect()

Original label distribution: Counter({4: 3333, 3: 2892, 7: 2783, 0: 2302, 5: 1895, 6: 1706, 2: 1432, 8: 409, 9: 386, 1: 309})


Filter:   0%|          | 0/17447 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17447 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17447 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17447 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17447 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17447 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17447 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17447 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17447 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17447 [00:00<?, ? examples/s]

After balancing: Counter({4: 3333, 3: 2892, 7: 2783, 0: 2302, 6: 2250, 1: 2250, 5: 2250, 2: 2250, 8: 2250, 9: 386})

üö® Label distribution for: full curated dataset
  anger       : 2302
  disgust     : 309
  fear        : 1432
  happiness   : 2892
  neutral     : 3333
  questioning : 1895
  sadness     : 1706
  surprise    : 2783
  contempt    : 409
  unknown     : 386

üö® Label distribution for: train set (post-balance)
  anger       : 2302
  disgust     : 2250
  fear        : 2250
  happiness   : 2892
  neutral     : 3333
  questioning : 2250
  sadness     : 2250
  surprise    : 2783
  contempt    : 2250
  unknown     : 386


0

In [13]:
# --------------------------
# 11. Optimizer, Scheduler, and Training
# --------------------------

# --- Part A: Train/Validation Split on CURATED Data ---
split_dataset = curated_dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]
print(f"‚úÖ Curated data split into {len(train_dataset)} training and {len(eval_dataset)} validation samples.")


# --- Part B: Define Training Arguments ---
# Reusing arguments from your script, ensuring best model is loaded.
training_args = TrainingArguments(
    output_dir=SAVE_DIR,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    learning_rate=4e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    logging_dir=os.path.join(SAVE_DIR, "logs"),
    logging_strategy="epoch",
    remove_unused_columns=False  # <-- Add this line
)

# This part is also correct.
early_stop_callback = EarlyStoppingCallback(
    early_stopping_patience=3,
    early_stopping_threshold=0.001
)

# --- Part C: Instantiate the Data Collator ---
# This uses augmentation pipelines defined in Section #8.
minority_augment_map = {label_id: minority_aug for label_id in minority_classes}
data_collator = DataCollatorWithAugmentation(
    processor=processor,
    augment_dict=minority_augment_map
)

# --- Part D: Discriminative Learning Rate Optimizer Setup ---
# Define different learning rates for the head and the backbone
head_lr = 5e-5      # High learning rate for the classifier head
backbone_lr = 2e-7  # Very low learning rate for the fine-tuned backbone layers

# First, ensure all layers are frozen by default
for param in model.parameters():
    param.requires_grad = False
# Unfreeze the classifier head to be trained    
for param in model.classifier.parameters():
    param.requires_grad = True
# This gives the model more capacity to adapt its feature extraction.
for name, param in model.vit.encoder.layer[-4:].named_parameters():
    param.requires_grad = True

# Create parameter groups for the optimizer
optimizer_grouped_parameters = [
    {'params': model.classifier.parameters(), 'lr': head_lr},
    {'params': model.vit.encoder.layer[-4:].parameters(), 'lr': backbone_lr}
]

# Create the AdamW optimizer with the specified parameter groups.
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, weight_decay=0.01)

# --- Part E: Trainer Initialization and Execution ---
# Initialize your custom CustomLossTrainer with all components.
trainer = CustomLossTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics_with_confusion,
    optimizers=(optimizer, None),
    data_collator=data_collator,
    callbacks=[early_stop_callback]
)

# --- Part F: Train the Model and Finalize ---
print(f"\n--- Starting {VERSION} Training with on-the-fly processing ---")
trainer.train()
print("--- Training Finished ---")

# Save the final, best-performing model from the run.
print("\n--- Saving Final Model ---")
save_model_and_processor(trainer.model, processor, SAVE_DIR)
print(f"--- Model saved to {SAVE_DIR} ---")

‚úÖ Curated data split into 13957 training and 3490 validation samples.

--- Starting V29 Training with on-the-fly processing ---




Epoch,Training Loss,Validation Loss,Accuracy
1,0.5353,0.37402,0.975358
2,0.3703,0.377649,0.974212
3,0.3769,0.376376,0.976504
4,0.3758,0.358693,0.979943
5,0.3656,0.363032,0.978223



Classification Report:
              precision    recall  f1-score   support

       anger       0.98      0.99      0.98       471
     disgust       0.99      0.93      0.96        71
        fear       0.95      0.94      0.94       289
   happiness       0.99      0.99      0.99       542
     neutral       0.99      0.98      0.98       697
 questioning       0.95      0.98      0.96       367
     sadness       0.98      0.98      0.98       350
    surprise       0.98      0.99      0.98       542
    contempt       0.84      0.83      0.84        84
     unknown       1.00      1.00      1.00        77

    accuracy                           0.98      3490
   macro avg       0.96      0.96      0.96      3490
weighted avg       0.98      0.98      0.98      3490


Top 3 confused class pairs:
  - contempt ‚Üí questioning: 14 instances
  - fear ‚Üí surprise: 7 instances
  - disgust ‚Üí contempt: 5 instances

üß† Avg prediction entropy: 0.3440

üîç Class entropies (sorted):
  -




Classification Report:
              precision    recall  f1-score   support

       anger       0.97      0.98      0.98       471
     disgust       0.94      0.92      0.93        71
        fear       0.95      0.93      0.94       289
   happiness       0.99      1.00      0.99       542
     neutral       1.00      0.98      0.99       697
 questioning       0.96      0.96      0.96       367
     sadness       0.98      0.97      0.98       350
    surprise       0.97      0.99      0.98       542
    contempt       0.82      0.81      0.81        84
     unknown       1.00      1.00      1.00        77

    accuracy                           0.97      3490
   macro avg       0.96      0.95      0.96      3490
weighted avg       0.97      0.97      0.97      3490


Top 3 confused class pairs:
  - contempt ‚Üí questioning: 11 instances
  - fear ‚Üí surprise: 9 instances
  - questioning ‚Üí contempt: 7 instances

üß† Avg prediction entropy: 0.3333

üîç Class entropies (sorted):




Classification Report:
              precision    recall  f1-score   support

       anger       0.98      0.99      0.98       471
     disgust       0.93      0.89      0.91        71
        fear       0.97      0.94      0.95       289
   happiness       0.99      1.00      0.99       542
     neutral       0.99      0.99      0.99       697
 questioning       0.95      0.96      0.96       367
     sadness       0.99      0.99      0.99       350
    surprise       0.98      0.99      0.99       542
    contempt       0.79      0.77      0.78        84
     unknown       1.00      1.00      1.00        77

    accuracy                           0.98      3490
   macro avg       0.96      0.95      0.95      3490
weighted avg       0.98      0.98      0.98      3490


Top 3 confused class pairs:
  - contempt ‚Üí questioning: 15 instances
  - disgust ‚Üí contempt: 7 instances
  - fear ‚Üí surprise: 7 instances

üß† Avg prediction entropy: 0.3319

üîç Class entropies (sorted):
  -




Classification Report:
              precision    recall  f1-score   support

       anger       0.98      0.99      0.99       471
     disgust       0.99      0.93      0.96        71
        fear       0.96      0.94      0.95       289
   happiness       0.99      1.00      0.99       542
     neutral       0.99      0.99      0.99       697
 questioning       0.96      0.98      0.97       367
     sadness       0.98      0.98      0.98       350
    surprise       0.98      0.99      0.98       542
    contempt       0.88      0.88      0.88        84
     unknown       1.00      1.00      1.00        77

    accuracy                           0.98      3490
   macro avg       0.97      0.97      0.97      3490
weighted avg       0.98      0.98      0.98      3490


Top 3 confused class pairs:
  - contempt ‚Üí questioning: 9 instances
  - fear ‚Üí surprise: 8 instances
  - surprise ‚Üí fear: 5 instances

üß† Avg prediction entropy: 0.3340

üîç Class entropies (sorted):
  - fea




Classification Report:
              precision    recall  f1-score   support

       anger       0.99      0.98      0.99       471
     disgust       0.96      0.93      0.94        71
        fear       0.97      0.93      0.95       289
   happiness       0.99      1.00      0.99       542
     neutral       0.99      0.99      0.99       697
 questioning       0.95      0.98      0.97       367
     sadness       0.98      0.99      0.98       350
    surprise       0.98      0.99      0.99       542
    contempt       0.89      0.79      0.84        84
     unknown       0.99      1.00      0.99        77

    accuracy                           0.98      3490
   macro avg       0.97      0.96      0.96      3490
weighted avg       0.98      0.98      0.98      3490


Top 3 confused class pairs:
  - contempt ‚Üí questioning: 12 instances
  - fear ‚Üí surprise: 9 instances
  - fear ‚Üí anger: 4 instances

üß† Avg prediction entropy: 0.3346

üîç Class entropies (sorted):
  - fear:

In [14]:
# ===================================
# 12. Inference & Analysis Utilities 
# ===================================

# ------------------------------------------
# Part A: Core Prediction Functions
# ------------------------------------------

# üöÄ Runs inference on a folder of images in batches and assigns labels or 'REVIEW'.
def batch_predict(image_folder, batch_size=64, threshold=0.85):
    all_preds = []
    error_count = 0
    image_paths = [
        p for p in Path(image_folder).rglob("*")
        if is_valid_image(p.name)
    ]

    for i in tqdm(range(0, len(image_paths), batch_size), desc="Running inference in batches"):
        batch_paths = image_paths[i:i + batch_size]
        images, valid_paths = [], []

        for path in batch_paths:
            try:
                img = Image.open(path).convert("RGB")
                images.append(img)
                valid_paths.append(str(path)) # <-- THIS LINE WAS MISSING
            except Exception:
                error_count += 1
                continue

        if not images:
            continue

        inputs = processor(images=images, return_tensors="pt").to(device)
        with torch.no_grad():
            logits = model(**inputs).logits
            probs = F.softmax(logits, dim=-1)
            confs, preds = torch.max(probs, dim=-1)

        for pred, conf, path in zip(preds.tolist(), confs.tolist(), valid_paths):
            all_preds.append(LABEL_NAMES[pred] if conf >= threshold else "REVIEW")

    print(f"‚úÖ Inference complete. Skipped {error_count} invalid image(s).")
    return all_preds

# ------------------------------------------
# Part B: Post-Inference Analysis & Visualization
# ------------------------------------------

# ‚õèÔ∏è Mines the prediction CSV to find "hard negative" images from specific confusing class pairs.
def parse_review_confusions(csv_path, confusion_pairs):
    import csv
    flagged_imgs = {pair: [] for pair in confusion_pairs}
    with open(csv_path) as f:
        reader = csv.DictReader(f)
        for row in reader:
            pred = row["predicted_label"]
            true = os.path.basename(os.path.dirname(row["image_path"]))
            conf = float(row["confidence"])
            for a, b in confusion_pairs:
                if ((pred == a and true == b) or (pred == b and true == a)) and conf < 0.8:
                    flagged_imgs[(a, b)].append(row["image_path"])
    return flagged_imgs

# üìä Creates distribution bar plot of predicted label.
def plot_distribution(predictions, output_path):
    label_counts = Counter(predictions)
    labels = sorted(label_counts.keys())
    counts = [label_counts[label] for label in labels]
    
    plt.figure(figsize=(10, 5))
    plt.bar(labels, counts)
    plt.title("Predicted Expression Distribution")
    plt.xlabel("Expression")
    plt.ylabel("Count")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

# üìà Plot Reliability Diagram (Calibration Curve)
    # visualizes how well model confidence matches actual accuracy
def plot_reliability_diagram(logits, labels, temperature, n_bins=15):
    probs = F.softmax(logits / temperature, dim=1)
    confidences, predictions = torch.max(probs, 1)
    accuracies = predictions.eq(labels)

    bins = torch.linspace(0, 1, n_bins + 1)
    bin_lowers, bin_uppers = bins[:-1], bins[1:]

    bin_accuracies, bin_confidences = [], []
    for lower, upper in zip(bin_lowers, bin_uppers):
        mask = (confidences > lower) & (confidences <= upper)
        if mask.any():
            bin_accuracies.append(accuracies[mask].float().mean())
            bin_confidences.append(confidences[mask].mean())

    plt.figure(figsize=(6, 6))
    plt.plot(bin_confidences, bin_accuracies, marker='o', label='Model')
    plt.plot([0, 1], [0, 1], linestyle='--', label='Perfect Calibration')
    plt.title("Reliability Diagram (After Temperature Scaling)")
    plt.xlabel("Confidence")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    output_path = os.path.join(SAVE_DIR, f"{VERSION}_reliability_diagram_calibrated.png")
    plt.savefig(output_path)
    plt.close()
    print(f"üìä Saved reliability diagram to {output_path}")

# ü§ù Takes a list of models and predicts a label for a single image by averaging their softmax probabilities.
def ensemble_predict(models, processor, image_path, device="cpu"):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(image, return_tensors="pt").to(device)
    softmaxes = []
    individual_preds = []

    # Get predictions from each model in the ensemble.
    for m in models:
        with torch.no_grad():
            logits = m(**inputs).logits
            probs = F.softmax(logits, dim=-1)
            individual_preds.append(id2label[torch.argmax(probs, dim=-1).item()])
            softmaxes.append(probs.cpu().numpy())

    # Average the probabilities and return the final prediction.
    avg_probs = np.mean(softmaxes, axis=0)
    ensemble_pred_idx = np.argmax(avg_probs)
    ensemble_conf = avg_probs[0, ensemble_pred_idx]

    return id2label[ensemble_pred_idx], ensemble_conf, individual_preds
    
# ------------------------------------------
# Part C: Model Calibration Utilities
# ------------------------------------------

# üå°Ô∏è Apply Temperature Scaling for Calibration
def apply_temperature_scaling(logits_path, labels_path):
    if not (os.path.exists(logits_path) and os.path.exists(labels_path)):
        print(f"‚ùå Missing files:\n  - {logits_path if not os.path.exists(logits_path) else ''}\n - {labels_path if not os.path.exists(labels_path) else ''}")
        return None

    print(f"üìÇ Loading logits from: {logits_path}")
    print(f"üìÇ Loading labels from: {labels_path}")

    logits = torch.tensor(np.load(logits_path), dtype=torch.float32).to(device)
    labels = torch.tensor(np.load(labels_path), dtype=torch.long).to(device)

    class TemperatureScaler(nn.Module):
        def __init__(self):
            super().__init__()
            self.temperature = nn.Parameter(torch.ones(1) * 1.5)

        def forward(self, logits):
            return logits / self.temperature

    model = TemperatureScaler().to(device)
    optimizer = LBFGS([model.temperature], lr=0.01, max_iter=50)

    def eval_fn():
        optimizer.zero_grad()
        loss = F.cross_entropy(model(logits), labels)
        loss.backward()
        return loss

    optimizer.step(eval_fn)
    calibrated_logits = model(logits)
    probs = F.softmax(calibrated_logits, dim=1).detach().cpu().numpy()
    logloss = log_loss(labels.cpu().numpy(), probs)

    # Save optimal temperature
    temperature_value = model.temperature.item()
    torch.save(
        torch.tensor([temperature_value]),
        os.path.join(SAVE_DIR, f"{VERSION}_calibrated_temperature.pt")
    )
    print(f"‚úÖ Optimal temperature: {temperature_value:.4f}")
    print(f"‚úÖ Calibrated Log Loss: {logloss:.4f}")
    return temperature_value, logits.cpu(), labels.cpu()

In [15]:
# -----------------------------
# 13. Entry Point for Inference
# -----------------------------

# Reload Model for Inference
model = AutoModelForImageClassification.from_pretrained(SAVE_DIR).to(device).eval()
print("‚úÖ Model reloaded for inference.")

if __name__ == "__main__" and RUN_INFERENCE:

    # Auto-locate latest model directory
    OUTPUT_PATH = os.path.join(SAVE_DIR, f"{VERSION}_distribution_plot_{timestamp}.png")

    predictions = batch_predict(IMAGE_DIR)
    reviewed_paths = []
    image_paths = [str(p) for p in Path(IMAGE_DIR).rglob("*") if is_valid_image(p.name)]

    for path, label in zip(image_paths, predictions):
        if label == "REVIEW":
            reviewed_paths.append(path)

    # Save paths to inspect manually
    with open(os.path.join(SAVE_DIR, f"{VERSION}_review_candidates.txt"), "w") as f:
        f.write("\n".join(reviewed_paths))
    print(f"üìù Saved REVIEW file paths to {VERSION}_review_candidates.txt")

    plot_distribution(predictions, OUTPUT_PATH)
    print(f"Distribution plot saved to: {OUTPUT_PATH}")

‚úÖ Model reloaded for inference.


Running inference in batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 273/273 [06:17<00:00,  1.38s/it]

‚úÖ Inference complete. Skipped 0 invalid image(s).
üìù Saved REVIEW file paths to V29_review_candidates.txt
Distribution plot saved to: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807/V29_distribution_plot_20250710_082807.png





In [16]:
# --------------------------
# 14. Temperature Scaling Calibration 
# --------------------------

# Dynamically locate the most recent V* folder that contains logits/labels
base_dir = os.path.dirname(SAVE_DIR)
v_folders = sorted([
    d for d in os.listdir(base_dir)
    if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("V")
], key=lambda d: os.path.getmtime(os.path.join(base_dir, d)), reverse=True)

logits_path, labels_path = None, None
for v in v_folders:
    version_tag = v.split('_')[0]
    folder_path = os.path.join(base_dir, v)
    logits_candidate = os.path.join(folder_path, f"logits_eval_{version_tag}.npy")
    labels_candidate = os.path.join(folder_path, f"labels_eval_{version_tag}.npy")
    if os.path.exists(logits_candidate) and os.path.exists(labels_candidate):
        INFER_SAVE_DIR = folder_path
        INFER_VERSION = version_tag
        print(f"üìÅ Using calibration files from: {SAVE_DIR}")
        logits_path = logits_candidate
        labels_path = labels_candidate
        break

# --------------------------
# Run calibration
# --------------------------
if logits_path and labels_path:
    result = apply_temperature_scaling(logits_path, labels_path)
    if result is not None:
        temperature, logits, labels = result
        plot_reliability_diagram(logits, labels, temperature)
else:
    print(f"‚ö†Ô∏è Skipping temperature scaling and diagram (missing logits or labels in {SAVE_DIR})")

üìÅ Using calibration files from: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807
üìÇ Loading logits from: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807/logits_eval_V29.npy
üìÇ Loading labels from: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807/labels_eval_V29.npy
‚úÖ Optimal temperature: 1.1055
‚úÖ Calibrated Log Loss: 0.1494
üìä Saved reliability diagram to /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807/V29_reliability_diagram_calibrated.png


In [17]:
# --------------------------
# 15. Review & Relabel 'REVIEW' Predictions (with Audit Logging & Clustering)
# --------------------------
MINORITY_LABELS = ["disgust", "contempt", "fear", "questioning"]
MINORITY_ENTROPY_THRESH = 0.6
REVIEW_THRESHOLD = 0.85

REVIEW_BY_CLASS_DIR = os.path.join(SAVE_DIR, "review_predictions_by_class")
REVIEW_CSV_LOG = os.path.join(SAVE_DIR, f"{VERSION}_review_predictions_with_preds.csv")
REVIEW_CLUSTER_DIR = os.path.join(SAVE_DIR, "review_predictions_clustered")

os.makedirs(REVIEW_BY_CLASS_DIR, exist_ok=True)
os.makedirs(REVIEW_CLUSTER_DIR, exist_ok=True)

# ---- If you HAVE NOT already generated review CSV (inference stage) ----
if not os.path.exists(REVIEW_CSV_LOG):
    review_log = []
    image_paths = [
        p for p in Path(IMAGE_DIR).rglob("*")
        if p.is_file() and p.suffix.lower() in [".jpg", ".jpeg", ".png"]
    ]
    
    for img_path in image_paths:
        try:
            image = Image.open(img_path).convert("RGB")
            inputs = processor(image, return_tensors="pt").to(device)
            with torch.no_grad():
                logits = model(**inputs).logits
                probs = F.softmax(logits, dim=-1)
                conf, pred_idx = torch.max(probs, dim=-1)
            conf_val = conf.item()
            pred_label = id2label[pred_idx.item()]
            
            entropy = -torch.sum(probs * torch.log(probs + 1e-12), dim=-1).item()

            if pred_label in MINORITY_LABELS and entropy > MINORITY_ENTROPY_THRESH:
                tag = "unknown"
            elif conf_val < REVIEW_THRESHOLD:
                tag = "REVIEW"
            else:
                tag = pred_label
            
            review_log.append({
                "image_path": str(img_path),
                "predicted_label": pred_label,
                "confidence": round(conf_val, 4),
                "entropy": round(entropy, 4),
                "tag": tag
            })
            
            # For backward compatibility, still copy to REVIEW_BY_CLASS_DIR if tag is not "unknown"
            if tag not in ["unknown"]:
                target_dir = os.path.join(REVIEW_BY_CLASS_DIR, tag)
                os.makedirs(target_dir, exist_ok=True)
                shutil.copy(str(img_path), target_dir)
        except Exception as e:
            print(f"‚ö†Ô∏è Error with image: {img_path} | {e}")
            
    pd.DataFrame(review_log).to_csv(REVIEW_CSV_LOG, index=False)
    print(f"‚úÖ Completed tagging + copying REVIEW predictions to: {REVIEW_BY_CLASS_DIR}")
    print(f"üìÑ CSV log saved to: {REVIEW_CSV_LOG}")

# ---- If you already HAVE a review CSV (assignment/audit stage) ----
df = pd.read_csv(REVIEW_CSV_LOG)
review_assignment_log = []
for _, row in df.iterrows():
    path = row["image_path"]
    pred_label = row["predicted_label"]
    conf = float(row["confidence"])
    true_label = os.path.basename(os.path.dirname(path))
    
    entropy = float(row.get("entropy", 0))  # default to 0 if not present
    if pred_label in MINORITY_LABELS and entropy > MINORITY_ENTROPY_THRESH:
        assigned = "unknown"
    elif conf < REVIEW_THRESHOLD:
        assigned = "REVIEW"
    else:
        assigned = pred_label
    
    dest_dir = os.path.join(REVIEW_BY_CLASS_DIR, assigned)
    os.makedirs(dest_dir, exist_ok=True)
    shutil.copy(path, dest_dir)
    review_assignment_log.append([path, true_label, pred_label, conf, assigned, entropy])

log_df = pd.DataFrame(
    review_assignment_log,
    columns=["image_path", "true_label", "pred_label", "confidence", "assigned_folder", "entropy"]
)

log_df.to_csv(os.path.join(SAVE_DIR, "review_assignment_audit.csv"), index=False)
print("‚úÖ Review assignments (with audit) complete.")

print("Assignment summary:", Counter(log_df["assigned_folder"]))

# ---- Perceptual hash clustering of review pool ----
def phash_distance(hash1, hash2):
    return hash1 - hash2

PHASH_CLUSTER_THRESHOLD = 6
image_paths = [row[0] for row in review_assignment_log if row[4] != "unknown"]  # assigned to a class

hashes = []
for img_path in image_paths:
    try:
        img = Image.open(img_path).convert("L").resize((64, 64))
        hashes.append(hex_to_hash(str(phash(img))))
    except Exception as e:
        print(f"phash error: {img_path} | {e}")

clusters = []
used = set()
for i, h1 in enumerate(hashes):
    if i in used:
        continue
    cluster = [image_paths[i]]
    used.add(i)
    for j, h2 in enumerate(hashes):
        if j <= i or j in used:
            continue
        if phash_distance(h1, h2) <= PHASH_CLUSTER_THRESHOLD:
            cluster.append(image_paths[j])
            used.add(j)
    if len(cluster) > 1:
        clusters.append(cluster)

if not clusters:
    print(f"‚ö†Ô∏è No clusters found for review. {REVIEW_CLUSTER_DIR} will remain empty.")
else:
    for idx, cluster in enumerate(clusters):
        out_dir = os.path.join(REVIEW_CLUSTER_DIR, f"cluster_{idx}")
        os.makedirs(out_dir, exist_ok=True)
        for p in cluster:
            shutil.copy(p, out_dir)
    print(f"‚úÖ Saved {len(clusters)} clusters to {REVIEW_CLUSTER_DIR}")

‚úÖ Completed tagging + copying REVIEW predictions to: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807/review_predictions_by_class
üìÑ CSV log saved to: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807/V29_review_predictions_with_preds.csv
‚úÖ Review assignments (with audit) complete.
Assignment summary: Counter({'neutral': 3193, 'happiness': 2928, 'surprise': 2839, 'anger': 2344, 'sadness': 1628, 'fear': 1318, 'unknown': 1066, 'questioning': 863, 'REVIEW': 590, 'contempt': 281, 'disgust': 261})
‚úÖ Saved 265 clusters to /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807/review_predictions_clustered


In [18]:
# --------------------------
# 16. REVIEW Pool Diagnostics & Hard Confusion Mining
# --------------------------

# A. Flag hard confusion pairs for manual review
REVIEW_CONFUSION_PAIRS = [("contempt", "questioning"), ("fear", "surprise")]

confusion_candidates = parse_review_confusions(REVIEW_CSV_LOG, REVIEW_CONFUSION_PAIRS)
for pair, imgs in confusion_candidates.items():
    print(f"\nFlagged {len(imgs)} hard negatives for {pair}:")
    out_path = os.path.join(SAVE_DIR, f"review_hardneg_{pair[0]}_{pair[1]}.txt")
    with open(out_path, "w") as f:
        f.write("\n".join(imgs))
    print(f"  Saved list: {out_path}")

# B. Organize REVIEW-tagged images by predicted class (for curation)
REVIEW_SORT_DIR = os.path.join(SAVE_DIR, "review_predictions_by_class")
os.makedirs(REVIEW_SORT_DIR, exist_ok=True)
review_txt_path = os.path.join(SAVE_DIR, f"{VERSION}_review_candidates.txt")
csv_path = os.path.join(SAVE_DIR, f"{VERSION}_review_predictions_with_preds.csv")

if os.path.exists(review_txt_path) and os.path.exists(csv_path):
    with open(review_txt_path, "r") as f:
        review_paths = {line.strip() for line in f.readlines()}

    df = pd.read_csv(csv_path)
    count = 0

    print(f"üîç Found {len(df)} total predictions (CSV) and {len(review_paths)} REVIEW-tagged paths.")

    for _, row in df.iterrows():
        path = row["image_path"]
        label = row["predicted_label"]
        conf = row["confidence"]

        if path in review_paths and label != "REVIEW":
            dest_dir = os.path.join(REVIEW_SORT_DIR, label)
            os.makedirs(dest_dir, exist_ok=True)
            shutil.copy(path, dest_dir)
            count += 1

    print(f"üìÇ Grouped {count} REVIEW images into folders by predicted label in: {REVIEW_SORT_DIR}")
else:
    print("‚ö†Ô∏è Missing review candidates file or prediction CSV.")


Flagged 1 hard negatives for ('contempt', 'questioning'):
  Saved list: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807/review_hardneg_contempt_questioning.txt

Flagged 14 hard negatives for ('fear', 'surprise'):
  Saved list: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807/review_hardneg_fear_surprise.txt
üîç Found 17311 total predictions (CSV) and 1185 REVIEW-tagged paths.
üìÇ Grouped 1174 REVIEW images into folders by predicted label in: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807/review_predictions_by_class


In [19]:
# --------------------------
# 17. Visualization & Error Tracking
# --------------------------

print("Label name/id mapping:")
for idx, name in enumerate(LABEL_NAMES):
    print(f"{idx}: {name}")

# Defensive: Check that metrics file exists before plotting
per_class_csv = os.path.join(SAVE_DIR, "per_class_metrics.csv")
if not os.path.exists(per_class_csv):
    print(f"‚ö†Ô∏è Metrics file {per_class_csv} not found.")
else:
    metrics_df = pd.read_csv(per_class_csv)
    last_row = metrics_df.iloc[-1]

    fig, ax = plt.subplots(figsize=(10,6))
    f1s = [last_row[f"f1_{n}"] for n in LABEL_NAMES]
    ax.bar(LABEL_NAMES, f1s)
    ax.set_title("Per-Class F1 (Last Epoch)")
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, "per_class_f1.png"))
    plt.close()

    # Bar plot of per-class entropy
    entropies = [last_row[f"entropy_{n}"] for n in LABEL_NAMES]
    fig, ax = plt.subplots(figsize=(10,6))
    ax.bar(LABEL_NAMES, entropies)
    ax.set_title("Per-Class Mean Entropy (Last Epoch)")
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, "per_class_entropy.png"))
    plt.close()

    # Histogram for REVIEW pool
    review_counts = Counter()
    if os.path.exists(REVIEW_SORT_DIR):
        for label_dir in os.listdir(REVIEW_SORT_DIR):
            count = len(os.listdir(os.path.join(REVIEW_SORT_DIR, label_dir)))
            review_counts[label_dir] = count
        plt.bar(review_counts.keys(), review_counts.values())
        plt.title("REVIEW Pool Distribution by Predicted Class")
        plt.savefig(os.path.join(SAVE_DIR, "review_pool_distribution.png"))
        plt.close()
        # Flag if >70% in one class
        total = sum(review_counts.values())
        for label, count in review_counts.items():
            if total > 0 and count / total > 0.7:
                print(f"‚ö†Ô∏è REVIEW pool highly imbalanced: {count/total:.1%} in '{label}'")

    # Audit print block (as before)
    print("Sample review predictions (audit):")
    if 'log_df' in locals():
        print(log_df[["image_path", "true_label", "pred_label", "confidence", "assigned_folder"]].head())
    elif 'df' in locals():
        print(df[["image_path", "true_label", "predicted_label", "confidence"]].head())
    else:
        print("No review/audit DataFrame found for printing.")

# ‚úÖ AUDIT BLOCK
print("Sample review predictions (audit):")
if 'log_df' in locals():
    print(log_df[["image_path", "true_label", "pred_label", "confidence", "assigned_folder"]].head())
elif 'df' in locals():
    print(df[["image_path", "true_label", "predicted_label", "confidence"]].head())
else:
    print("No review/audit DataFrame found for printing.")

Label name/id mapping:
0: anger
1: disgust
2: fear
3: happiness
4: neutral
5: questioning
6: sadness
7: surprise
8: contempt
9: unknown
Sample review predictions (audit):
                                          image_path true_label pred_label  \
0  /Users/natalyagrokh/AI/ml_expressions/img_data...   contempt   contempt   
1  /Users/natalyagrokh/AI/ml_expressions/img_data...   contempt   contempt   
2  /Users/natalyagrokh/AI/ml_expressions/img_data...   contempt    neutral   
3  /Users/natalyagrokh/AI/ml_expressions/img_data...   contempt   contempt   
4  /Users/natalyagrokh/AI/ml_expressions/img_data...   contempt   contempt   

   confidence assigned_folder  
0      0.9415        contempt  
1      0.9549        contempt  
2      0.9265         neutral  
3      0.9661        contempt  
4      0.9330        contempt  
Sample review predictions (audit):
                                          image_path true_label pred_label  \
0  /Users/natalyagrokh/AI/ml_expressions/img_data...   

In [20]:
# --------------------------
# 18. Deployment Readiness Assertions and Flags
# --------------------------

# Load metrics
metrics_df = pd.read_csv(os.path.join(SAVE_DIR, "per_class_metrics.csv"))
last = metrics_df.iloc[-1]
warn = False

for cname in LABEL_NAMES:
    f1 = last[f"f1_{cname}"]
    entropy = last[f"entropy_{cname}"]
    if f1 < 0.8:
        print(f"üö® F1 < 0.8 for class '{cname}': {f1:.2f}")
        warn = True
    if entropy > 0.4:
        print(f"üö® Entropy > 0.4 for class '{cname}': {entropy:.2f}")
        warn = True

if not warn:
    print("‚úÖ All classes ready for deployment: F1 >= 0.8 and entropy <= 0.4")
else:
    print("‚ö†Ô∏è Some classes not deployment-ready! Address above issues before production.")


‚úÖ All classes ready for deployment: F1 >= 0.8 and entropy <= 0.4


In [21]:
# --------------------------
# 19. Model Ensembling Analysis (Fully Dynamic)
# --------------------------

print("--- Setting up Model Ensembling Analysis ---")

# Find all valid model directories that contain saved model weights
all_model_dirs = [
    os.path.join(MODEL_ROOT, d)
    for d in os.listdir(MODEL_ROOT)
    if d.startswith("V") and os.path.isdir(os.path.join(MODEL_ROOT, d))
       and (os.path.exists(os.path.join(MODEL_ROOT, d, "model.safetensors")) or 
            os.path.exists(os.path.join(MODEL_ROOT, d, "pytorch_model.bin")))
]

# Sort by version number and select the two most recent, including the one just trained
# This now correctly includes the current SAVE_DIR as a candidate
sorted_models = sorted(all_model_dirs, key=extract_version_from_path, reverse=True)

if len(sorted_models) < 2:
    print("‚ö†Ô∏è Found fewer than two model versions (including current). Skipping ensembling.")
else:
    model_path_1 = sorted_models[0] # The model just trained
    model_path_2 = sorted_models[1] # The previous model

    print(f"‚úÖ Dynamically selected models for ensembling:")
    print(f"   - Model 1 (New):  {os.path.basename(model_path_1)}")
    print(f"   - Model 2 (Prev): {os.path.basename(model_path_2)}")

    # Load the models
    model_1 = AutoModelForImageClassification.from_pretrained(model_path_1).to(device).eval()
    model_2 = AutoModelForImageClassification.from_pretrained(model_path_2).to(device).eval()
    ensemble_models = [model_1, model_2]
    
    # --- Part B: Run Analysis on Misclassified Images ---
    audit_csv_path = os.path.join(SAVE_DIR, "review_assignment_audit.csv")
    if os.path.exists(audit_csv_path):
        print("\n--- Running Ensemble Analysis on Hard Cases ---")
        audit_df = pd.read_csv(audit_csv_path)
        
        # Filter for images that were misclassified by the newest model
        misclassified_df = audit_df[audit_df['true_label'] != audit_df['pred_label']]
        
        if misclassified_df.empty:
            print("‚úÖ No misclassified images found in the audit file. Nothing to analyze.")
        else:
            for _, row in misclassified_df.head(5).iterrows():
                image_path = row['image_path']
                true_label = row['true_label']
                
                # This now calls the utility function defined in Cell In[4]
                ensemble_pred, ensemble_conf, individual_preds = ensemble_predict(
                    ensemble_models, processor, image_path, device=device
                )
                
                print(f"\nImage: {os.path.basename(image_path)}")
                print(f"  - True Label:      {true_label}")
                print(f"  - Model 1 (New) Pred: {individual_preds[0]}")
                print(f"  - Model 2 (Prev) Pred: {individual_preds[1]}")
                print(f"  - ENSEMBLE Pred:   {ensemble_pred} (Confidence: {ensemble_conf:.2f})")
                
                if ensemble_pred == true_label:
                    print("  - ‚úÖ SUCCESS: Ensemble corrected the misclassification!")
                else:
                    print(f"  - ‚ùå FAILURE: Ensemble also misclassified as '{ensemble_pred}'.")
    else:
        print(f"‚ö†Ô∏è Could not find audit CSV at {audit_csv_path}. Skipping ensemble analysis.")

--- Setting up Model Ensembling Analysis ---
‚úÖ Dynamically selected models for ensembling:
   - Model 1 (New):  V29_20250710_082807
   - Model 2 (Prev): V28_20250709_153248

--- Running Ensemble Analysis on Hard Cases ---

Image: img_2617.jpg
  - True Label:      contempt
  - Model 1 (New) Pred: neutral
  - Model 2 (Prev) Pred: neutral
  - ENSEMBLE Pred:   neutral (Confidence: 0.93)
  - ‚ùå FAILURE: Ensemble also misclassified as 'neutral'.

Image: Iain_Richmond_0001.jpg
  - True Label:      contempt
  - Model 1 (New) Pred: neutral
  - Model 2 (Prev) Pred: neutral
  - ENSEMBLE Pred:   neutral (Confidence: 0.94)
  - ‚ùå FAILURE: Ensemble also misclassified as 'neutral'.

Image: Richard_Gephardt_0001.jpg
  - True Label:      contempt
  - Model 1 (New) Pred: happiness
  - Model 2 (Prev) Pred: happiness
  - ENSEMBLE Pred:   happiness (Confidence: 0.80)
  - ‚ùå FAILURE: Ensemble also misclassified as 'happiness'.

Image: img_1685.jpg
  - True Label:      contempt
  - Model 1 (New) Pred: a