In [1]:
#V16 changes:
    # section #2 -> added stronger augmentation via make_transform_function
        # added is_uncertain function
    # section #3 -> updated to pull the hightest version model rather than
        # the most recently modified model folder
    # section #8 -> updated to have more aggressive augmentation
    # section #11 -> now avoids capping majority and upsamples minority
    #- section #19 -> added explicit prints for clustering, audit, and review
    # section #20 -> merged old section #21 org by class with review conf mining
    # added sections #21 and #22

In [8]:
# --------------------------
# 0. Imports
# --------------------------
# Standard Library Imports
import csv
import gc
import glob
import multiprocessing as mp
import os
import random
import re
import shutil
import subprocess
import sys
import time

# Third-Party Imports
import accelerate
import dill
import face_recognition
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
#import tensorflow as tf
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import transformers

# From Imports
from collections import Counter
from datasets import ClassLabel, Dataset, Features, Image as DatasetsImage, concatenate_datasets, load_dataset
from datetime import datetime
from functools import partial
from imagehash import phash, hex_to_hash
from io import BytesIO
from pathlib import Path
from PIL import Image, ImageOps, ExifTags, UnidentifiedImageError
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix, log_loss, precision_recall_fscore_support
from torch import nn
from torch.nn import functional as F
from torch.optim import AdamW, LBFGS
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torchvision import transforms
from torchvision.transforms import (
    GaussianBlur,
    RandAugment,
    RandomAffine,
    RandomApply,
    ToPILImage,
    ToTensor
)
from tqdm import tqdm
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    EarlyStoppingCallback,
    TrainingArguments,
    Trainer,
)

In [3]:
# --------------------------
# 1. Global Configurations
# --------------------------
RUN_INFERENCE = True  # Toggle this off to disable running inference
IMAGE_DIR = "/Users/natalyagrokh/AI/ml_expressions/img_datasets/ferckjalfag_dataset"
BASE_PATH = IMAGE_DIR

LABEL_NAMES = [
    'anger', 'disgust', 'fear', 'happiness', 'neutral',
    'questioning', 'sadness', 'surprise', 'contempt', 'unknown'
]
id2label = dict(enumerate(LABEL_NAMES))
label2id = {v: k for k, v in id2label.items()}

HARD_CLASS_NAMES = ['contempt', 'disgust', 'fear', 'questioning']
hard_class_ids = [label2id[n] for n in HARD_CLASS_NAMES]

VALID_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff")

def is_valid_image(filename):
    return filename.lower().endswith(VALID_EXTENSIONS) and not filename.startswith("._")

label_mapping = {name.lower(): name for name in LABEL_NAMES}

# üî¢ Dynamically determine the next version
def get_next_version(base_dir):

    # Use glob to find all entries matching the pattern
    all_entries = glob.glob(os.path.join(base_dir, "V*_*"))
    
    # Filter to include only directories
    existing = [
        os.path.basename(d) for d in all_entries if os.path.isdir(d)
    ]

    # Extract version numbers from the directory names
    versions = [
        int(d[1:].split("_")[0]) for d in existing
        if d.startswith("V") and "_" in d and d[1:].split("_")[0].isdigit()
    ]
    
    # Determine the next version number
    next_version = max(versions, default=0) + 1
    return f"V{next_version}"

# Automatically create a versioned output folder
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
VERSION = get_next_version("/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training")
VERSION_TAG = VERSION + "_" + timestamp
SAVE_DIR = os.path.join("/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training", VERSION_TAG)
LOGITS_PATH = os.path.join(SAVE_DIR, f"logits_eval_{VERSION}.npy")
LABELS_PATH = os.path.join(SAVE_DIR, f"labels_eval_{VERSION}.npy")
os.makedirs(SAVE_DIR, exist_ok=True)
print(f"üìÅ Output directory created: {SAVE_DIR}")

üìÅ Output directory created: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335


In [4]:
# --------------------------
# 2. Utility Functions (Metrics & Calibration)
# --------------------------

# üîç Compute perceptual hash for image similarity clustering (used in REVIEW and Disgust curation)
def compute_hash(image_path):
    try:
        img = Image.open(image_path).convert("L").resize((64, 64))
        return str(phash(img))
    except Exception:
        return None

# üîÑ Smoothed Cross Entropy Loss (Œµ = 0.05)
class SmoothedCrossEntropyLoss(nn.Module):
    def __init__(self, smoothing=0.05):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, logits, target):
        num_classes = logits.size(1)
        with torch.no_grad():
            smooth_labels = torch.full_like(logits, self.smoothing / (num_classes - 1))
            smooth_labels.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
        log_probs = F.log_softmax(logits, dim=1)
        loss = -(smooth_labels * log_probs).sum(dim=1).mean()
        return loss

# ‚ö†Ô∏è Confidence Penalty to Reduce Overconfidence
def confidence_penalty(logits, beta=0.05):
    probs = F.softmax(logits, dim=1)
    log_probs = F.log_softmax(logits, dim=1)
    entropy = -torch.sum(probs * log_probs, dim=1)
    return beta * entropy.mean()

# üìä Compute Metrics with Confusion Matrix Logging
def compute_metrics_with_confusion(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)

    # Print classification report
    print("\nClassification Report:")
    report = classification_report(labels, preds, target_names=LABEL_NAMES, output_dict=True)
    print(classification_report(labels, preds, target_names=LABEL_NAMES))

    # Save raw logits/labels for calibration or further analysis
    np.save(os.path.join(SAVE_DIR, f"logits_eval_{VERSION}.npy"), logits)
    np.save(os.path.join(SAVE_DIR, f"labels_eval_{VERSION}.npy"), labels)

    # Save per-class F1/precision/recall/entropy to CSV (append per epoch)
    f1s = [report[name]["f1-score"] for name in LABEL_NAMES]
    recalls = [report[name]["recall"] for name in LABEL_NAMES]
    precisions = [report[name]["precision"] for name in LABEL_NAMES]

    # Entropy per class (sorted by entropy)
    softmax_probs = F.softmax(torch.tensor(logits), dim=-1)
    entropies = -torch.sum(softmax_probs * torch.log(softmax_probs + 1e-12), dim=-1)
    entropy_per_class = []
    for idx, class_name in enumerate(LABEL_NAMES):
        mask = (np.array(labels) == idx)
        if mask.any():
            class_entropy = entropies[mask].mean().item()
            entropy_per_class.append((class_name, class_entropy))
        else:
            entropy_per_class.append((class_name, 0.0))
    # Sort for display only; CSV row stays in canonical label order
    sorted_entropy = sorted(entropy_per_class, key=lambda x: x[1], reverse=True)

    # CSV logging
    epoch_metrics_path = os.path.join(SAVE_DIR, "per_class_metrics.csv")
    epoch = getattr(trainer.state, "epoch", None) if "trainer" in globals() else None
    df_row = pd.DataFrame({
        "epoch": [epoch],
        **{f"f1_{n}": [f] for n, f in zip(LABEL_NAMES, f1s)},
        **{f"recall_{n}": [r] for n, r in zip(LABEL_NAMES, recalls)},
        **{f"precision_{n}": [p] for n, p in zip(LABEL_NAMES, precisions)},
        **{f"entropy_{n}": [e] for n, e in entropy_per_class}
    })
    if os.path.exists(epoch_metrics_path):
        df_row.to_csv(epoch_metrics_path, mode="a", header=False, index=False)
    else:
        df_row.to_csv(epoch_metrics_path, mode="w", header=True, index=False)

    # Generate and print confusion matrix heatmap
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm, annot=True, fmt="d", cmap="Blues",
        xticklabels=LABEL_NAMES,
        yticklabels=LABEL_NAMES
    )
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, f"confusion_matrix_epoch_{VERSION}.png"))
    plt.close()

    # Top confused pairs
    confusion_pairs = [
        ((LABEL_NAMES[i], LABEL_NAMES[j]), cm[i][j])
        for i in range(len(LABEL_NAMES))
        for j in range(len(LABEL_NAMES)) if i != j
    ]
    top_confusions = sorted(confusion_pairs, key=lambda x: x[1], reverse=True)[:3]
    print("\nTop 3 confused class pairs:")
    for (true_label, pred_label), count in top_confusions:
        print(f"  - {true_label} ‚Üí {pred_label}: {count} instances")

    # Compute average prediction entropy
    avg_entropy = entropies.mean().item()
    print(f"\nüß† Avg prediction entropy: {avg_entropy:.4f}")

    print("\nüîç Class entropies (sorted):")
    for class_name, entropy in sorted_entropy:
        print(f"  - {class_name}: entropy = {entropy:.4f}")

    accuracy = (preds == labels).mean()
    return {"accuracy": accuracy}


# üå°Ô∏è Apply Temperature Scaling for Calibration
def apply_temperature_scaling(logits_path, labels_path):
    if not (os.path.exists(logits_path) and os.path.exists(labels_path)):
        print(f"‚ùå Missing files:\n  - {logits_path if not os.path.exists(logits_path) else ''}\n - {labels_path if not os.path.exists(labels_path) else ''}")
        return None

    print(f"üìÇ Loading logits from: {logits_path}")
    print(f"üìÇ Loading labels from: {labels_path}")

    logits = torch.tensor(np.load(logits_path), dtype=torch.float32).to(device)
    labels = torch.tensor(np.load(labels_path), dtype=torch.long).to(device)

    class TemperatureScaler(nn.Module):
        def __init__(self):
            super().__init__()
            self.temperature = nn.Parameter(torch.ones(1) * 1.5)

        def forward(self, logits):
            return logits / self.temperature

    model = TemperatureScaler().to(device)
    optimizer = LBFGS([model.temperature], lr=0.01, max_iter=50)

    def eval_fn():
        optimizer.zero_grad()
        loss = F.cross_entropy(model(logits), labels)
        loss.backward()
        return loss

    optimizer.step(eval_fn)
    calibrated_logits = model(logits)
    probs = F.softmax(calibrated_logits, dim=1).detach().cpu().numpy()
    logloss = log_loss(labels.cpu().numpy(), probs)

    # Save optimal temperature
    temperature_value = model.temperature.item()
    torch.save(
        torch.tensor([temperature_value]),
        os.path.join(SAVE_DIR, f"{VERSION}_calibrated_temperature.pt")
    )
    print(f"‚úÖ Optimal temperature: {temperature_value:.4f}")
    print(f"‚úÖ Calibrated Log Loss: {logloss:.4f}")
    return temperature_value, logits.cpu(), labels.cpu()


# üìà Plot Reliability Diagram (Calibration Curve)
def plot_reliability_diagram(logits, labels, temperature, n_bins=15):
    probs = F.softmax(logits / temperature, dim=1)
    confidences, predictions = torch.max(probs, 1)
    accuracies = predictions.eq(labels)

    bins = torch.linspace(0, 1, n_bins + 1)
    bin_lowers, bin_uppers = bins[:-1], bins[1:]

    bin_accuracies, bin_confidences = [], []
    for lower, upper in zip(bin_lowers, bin_uppers):
        mask = (confidences > lower) & (confidences <= upper)
        if mask.any():
            bin_accuracies.append(accuracies[mask].float().mean())
            bin_confidences.append(confidences[mask].mean())

    plt.figure(figsize=(6, 6))
    plt.plot(bin_confidences, bin_accuracies, marker='o', label='Model')
    plt.plot([0, 1], [0, 1], linestyle='--', label='Perfect Calibration')
    plt.title("Reliability Diagram (After Temperature Scaling)")
    plt.xlabel("Confidence")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    output_path = os.path.join(SAVE_DIR, f"{VERSION}_reliability_diagram_calibrated.png")
    plt.savefig(output_path)
    plt.close()
    print(f"üìä Saved reliability diagram to {output_path}")

# saving model and processor
def save_model_and_processor(model, processor, save_dir, trainer=None):
    print(f"Saving model and processor to: {save_dir}")
    
    model = model.to("cpu")

    # Save processor
    processor.save_pretrained(save_dir)
    print(f"‚úÖ Processor saved to: {SAVE_DIR}")
    
    # Save full model
    model.save_pretrained(SAVE_DIR, safe_serialization=True)
    print(f"‚úÖ Full model saved to: {SAVE_DIR}")

    # Save state dict
    final_model_path = os.path.join(SAVE_DIR, 'final_model.pth')
    torch.save(model.state_dict(), final_model_path)
    print(f"‚úÖ State dict saved to: {final_model_path}")

    # Save trainer state
    if trainer is not None:
        try:
            trainer.save_model(os.path.join(save_dir, "backup_trainer_model"))
            print("‚úÖ Trainer backup saved.")
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to save trainer backup: {e}")

    # Memory cleanup
    del model
    gc.collect()
    try:
        torch.cuda.empty_cache()
    except Exception:
        pass  # Not all systems have CUDA
    print("‚úÖ Memory cleanup complete after save.")

# üö¶ Prints label distribution for a dataset
    #only calling for ad hoc debugging, experiments, sanity checks 
def check_label_integrity(dataset, LABEL_NAMES, label2id):
    # Count all mapped labels
    label_counts = Counter(dataset['label'])
    print("\nüö® Label distribution after mapping (before split):")
    for label_id in range(len(LABEL_NAMES)):
        label_name = LABEL_NAMES[label_id]
        print(f"  {label_name:12}: {label_counts.get(label_id, 0)}")

    # Specifically highlight 'surprise'
    surprise_id = label2id['surprise']
    if label_counts.get(surprise_id, 0) == 0:
        print("‚ùóWARNING: No 'surprise' images found after mapping!")
    elif label_counts[surprise_id] < 50:  # arbitrary threshold
        print(f"‚ö†Ô∏è Only {label_counts[surprise_id]} 'surprise' images found! Check curation or mapping.")

# üö¶ Prints label distribution for multiple datasets
def check_all_label_integrity(datasets_dict, LABEL_NAMES, label2id):
    for name, dataset in datasets_dict.items():
        print(f"\nüö® Label distribution for: {name}")
        label_counts = Counter(dataset['label'])
        for label_id in range(len(LABEL_NAMES)):
            label_name = LABEL_NAMES[label_id]
            print(f"  {label_name:12}: {label_counts.get(label_id, 0)}")
        surprise_id = label2id['surprise']
        if label_counts.get(surprise_id, 0) == 0:
            print("‚ùóWARNING: No 'surprise' images found in this split!")
        elif label_counts[surprise_id] < 50:
            print(f"‚ö†Ô∏è Only {label_counts[surprise_id]} 'surprise' images in {name}! Check curation or mapping.")

# --- Stronger Augmentation Utility ---
def make_transform_function(processor, hard_class_ids):
    def transform_function(example):
        label = example["label"]
        aug_pipeline = strong_aug if label in hard_class_ids else data_augment
        if example["image"].mode != "RGB":
            example["image"] = example["image"].convert("RGB")
        augmented_image = aug_pipeline(example["image"])
        inputs = processor(augmented_image, return_tensors="pt")
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        inputs["labels"] = example["label"]
        return inputs
    return transform_function  

# Returns a boolean tensor: True if the prediction is low-confidence
def is_uncertain(probs, threshold=0.85, entropy_thresh=1.3):
    conf, _ = torch.max(probs, dim=-1)
    entropy = -torch.sum(probs * torch.log(probs + 1e-12), dim=-1)
    return (conf < threshold) | (entropy > entropy_thresh)

In [9]:
# --------------------------
# 3. Auto-Load Latest Pretrained Model and Processor
# --------------------------

MODEL_ROOT = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training"

def extract_version(dirname):
    # Extracts the version number as an integer for sorting, e.g., V15_... ‚Üí 15
    match = re.match(r"V(\d+)", os.path.basename(dirname))
    return int(match.group(1)) if match else -1

model_dirs = [
    os.path.join(MODEL_ROOT, d)
    for d in os.listdir(MODEL_ROOT)
    if d.startswith("V") and os.path.isdir(os.path.join(MODEL_ROOT, d))
]

# Exclude SAVE_DIR (current output) by absolute path
model_dirs = [d for d in model_dirs if os.path.abspath(d) != os.path.abspath(SAVE_DIR)]

# Sort by version number, descending (highest first)
model_dirs = sorted(model_dirs, key=extract_version, reverse=True)

print("Available model directories (sorted by version):")
for d in model_dirs:
    print(" -", d)
    print("   Files:", os.listdir(d))

if len(model_dirs) < 1:
    raise FileNotFoundError("‚ùå No earlier model folders found.")

model_path = model_dirs[0]
print(f"‚úÖ Auto-loaded model from: {model_path}")

# Load model and processor as before
model = AutoModelForImageClassification.from_pretrained(model_path)
processor = AutoImageProcessor.from_pretrained(model_path)

# Load base model and processor
model = AutoModelForImageClassification.from_pretrained(model_path)
processor = AutoImageProcessor.from_pretrained(model_path)

# Modify classification head with Dropout for regularization
model.classifier = nn.Sequential(
    nn.Dropout(p=0.1),
    nn.Linear(model.classifier.in_features, len(id2label))
)

# Replace classification head to match current label schema
model.config.id2label = id2label
model.config.label2id = label2id
model.config.num_labels = len(LABEL_NAMES)

# Define device and push model to device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("üñ•Ô∏è Using device:", device)
model.to(device).eval()

Some weights of the model checkpoint at /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V15_20250616_154815 were not used when initializing ViTForImageClassification: ['classifier.1.bias', 'classifier.1.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V15_20250616_154815 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model 

Available model directories (sorted by version):
 - /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V15_20250616_154815
   Files: ['model.safetensors', 'V15_distribution_plot_20250616_154815.png', 'review_predictions_clustered', 'confusion_matrix_epoch_V15.png', '.DS_Store', 'V15_calibrated_temperature.pt', 'contempt_clusters', 'label_snapshots', 'logits_eval_V15.npy', 'checkpoint-10245', 'fear_clusters', 'config.json', 'sadness_clusters', 'checkpoint-17075', 'V15_ferckjalf_2025_06_16.ipynb', 'disgust_clusters', 'review_predictions_by_class', 'V15_reliability_diagram_calibrated.png', 'V15_review_candidates.txt', 'logs', 'per_class_metrics.csv', 'labels_eval_V15.npy', 'V15_augmentation_snapshot.csv', 'V15_review_predictions_with_preds.csv', '.ipynb_checkpoints', 'questioning_clusters', 'final_model.pth', 'preprocessor_config.json']
 - /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V14_20250614_190959
   Files: ['model.safetensors', 'V14_review_cand

Some weights of the model checkpoint at /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V15_20250616_154815 were not used when initializing ViTForImageClassification: ['classifier.1.bias', 'classifier.1.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V15_20250616_154815 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model 

üñ•Ô∏è Using device: mps


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [10]:
# --------------------------
# 4. Load and Prepare Dataset
# --------------------------
dataset = load_dataset(
    "imagefolder",
    data_dir="/Users/natalyagrokh/AI/ml_expressions/img_datasets/ferckjalfag_dataset",
    split="train",
    cache_dir="/tmp/hf_cache"
)

counter = {"n": 0}

def reconcile_labels(example):
    counter["n"] += 1
    if counter["n"] % 1000 == 0:
        print(f"Processed {counter['n']} images...")

    label = example.get("label", None)

    if isinstance(label, int):
        original_label = dataset.features["label"].int2str(label).strip().lower()
    elif isinstance(label, str):
        original_label = label.strip().lower()
    else:
        file_path = getattr(example["image"], "filename", None)
        original_label = os.path.basename(os.path.dirname(file_path)).lower() if file_path else None

    pretrain_label = label_mapping.get(original_label)
    example["label"] = label2id[pretrain_label] if pretrain_label is not None else -1
    return example

# Single-threaded labeling to preserve .filename
dataset = dataset.map(reconcile_labels, desc="Re-labeling dataset")
dataset = dataset.filter(lambda x: x["label"] != -1)

print(f"‚úÖ Total examples after filtering: {len(dataset)}")

Resolving data files:   0%|          | 0/17514 [00:00<?, ?it/s]

‚úÖ Total examples after filtering: 17504


In [11]:
# --------------------------
# 5. Dataset Label Overview and Folder Stats
# --------------------------
def analyze_dataset_structure(dataset, id2label, base_path):
    # Print label schema from the dataset
    print("Label schema (from dataset):", dataset.features["label"])

    # Label distribution from the dataset object
    label_counts = Counter(dataset["label"])
    print("\nüìä Full dataset label distribution (from Dataset object):")
    for label_id, count in sorted(label_counts.items()):
        print(f"  {id2label[label_id]}: {count} examples")

    # Dynamically detect minority classes (lowest 3 frequencies)
    N = 3
    minority_classes = set(
        label for label, _ in sorted(label_counts.items(), key=lambda x: x[1])[:N]
    )
    print(f"\n‚ö†Ô∏è  Dynamically identified minority classes: {[id2label[i] for i in minority_classes]}")

    # Count images per directory, and store for later validation
    folder_image_counts = {}
    print("\nüìÇ Image count per label folder:")
    for label in sorted(os.listdir(base_path)):
        label_path = os.path.join(base_path, label)
        if os.path.isdir(label_path):
            valid_images = [img for img in os.listdir(label_path) if is_valid_image(img)]
            folder_image_counts[label] = len(valid_images)
            print(f"  {label}: {len(valid_images)} images")

    return minority_classes, folder_image_counts

# Example usage right after dataset loading
minority_classes, folder_image_counts = analyze_dataset_structure(dataset, id2label, BASE_PATH)

Label schema (from dataset): ClassLabel(names=['anger', 'contempt', 'disgust', 'fear', 'happiness', 'neutral', 'questioning', 'sadness', 'surprise', 'unknown'], id=None)

üìä Full dataset label distribution (from Dataset object):
  anger: 2302 examples
  disgust: 309 examples
  fear: 1432 examples
  happiness: 2892 examples
  neutral: 3334 examples
  questioning: 1943 examples
  sadness: 1706 examples
  surprise: 2779 examples
  contempt: 421 examples
  unknown: 386 examples

‚ö†Ô∏è  Dynamically identified minority classes: ['contempt', 'disgust', 'unknown']

üìÇ Image count per label folder:
  anger: 2302 images
  contempt: 421 images
  disgust: 309 images
  fear: 1432 images
  happiness: 2892 images
  neutral: 3334 images
  questioning: 1943 images
  sadness: 1706 images
  surprise: 2779 images
  unknown: 386 images


In [12]:
# --------------------------
# 6. Perceptual Clustering for Ambiguous/Confused Classes
# --------------------------

CLUSTER_TARGETS = ["disgust", "sadness", "fear", "questioning", "contempt"]

for class_name in CLUSTER_TARGETS:
    class_dir = os.path.join(BASE_PATH, class_name)
    if not os.path.isdir(class_dir):
        print(f"‚ö†Ô∏è Class dir not found: {class_dir} (skipping)")
        continue

    class_images = [
        os.path.join(class_dir, f) for f in os.listdir(class_dir)
        if is_valid_image(f)
    ]
    hash_map = {}
    for path in class_images:
        h = compute_hash(path)
        if h:
            hash_map.setdefault(h, []).append(path)

    cluster_dir = os.path.join(SAVE_DIR, f"{class_name}_clusters")
    os.makedirs(cluster_dir, exist_ok=True)

    print(f"üîç {class_name.capitalize()} hash clusters with more than 1 image:")
    for h, paths in hash_map.items():
        if len(paths) > 1:
            cluster_path = os.path.join(cluster_dir, h)
            os.makedirs(cluster_path, exist_ok=True)
            for p in paths:
                shutil.copy(p, cluster_path)
            print(f"  - Cluster {h[:8]}: {len(paths)} images copied for review")

üîç Disgust hash clusters with more than 1 image:
üîç Sadness hash clusters with more than 1 image:
  - Cluster 958c52e1: 2 images copied for review
  - Cluster ee9a8d33: 2 images copied for review
  - Cluster d0890396: 2 images copied for review
  - Cluster bb0d06f2: 2 images copied for review
  - Cluster d7f00fa2: 2 images copied for review
üîç Fear hash clusters with more than 1 image:
  - Cluster 9ae56592: 2 images copied for review
  - Cluster 91c8ee81: 2 images copied for review
  - Cluster dae5a596: 2 images copied for review
üîç Questioning hash clusters with more than 1 image:
  - Cluster da014886: 2 images copied for review
  - Cluster 9db42783: 2 images copied for review
üîç Contempt hash clusters with more than 1 image:


In [13]:
# --------------------------
# 7. Class Frequency-Aware Augmentation Targeting
# --------------------------

# Compute label frequencies from train split (post filtering)
label_freqs = Counter(dataset["label"])
label_id2name = {v: k for k, v in label2id.items()}
label_name2id = {v: k for k, v in label_id2name.items()}

# Get lowest-count classes dynamically
minority_by_count = sorted(label_freqs, key=label_freqs.get)[:3]
minority_by_name = [label_id2name[i] for i in minority_by_count]
minority_by_name = [n for n in minority_by_name if n != "unknown"]

# Manually include known confused or underperforming classes
manual_focus_classes = ['disgust', 'questioning', 'contempt']

# Merge and deduplicate
minority_class_names = list(set(minority_by_name + manual_focus_classes))

# Final list as label indices
minority_classes = [label_name2id[name] for name in minority_class_names]

print(f"üéØ Targeted minority augmentation will apply to: {minority_class_names}")

üéØ Targeted minority augmentation will apply to: ['contempt', 'questioning', 'disgust']


In [15]:
# --------------------------
# 8. Define Data Augmentation and Preprocessing Transformation
# --------------------------

# --- Stronger Augmentation for Hard Classes ---
data_augment = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomRotation(10),
    T.ColorJitter(brightness=0.1, contrast=0.1)
])

strong_aug = T.Compose([
    RandAugment(num_ops=3, magnitude=12),
    T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
    T.RandomResizedCrop(224, scale=(0.5, 1.0)),
    T.ColorJitter(0.4, 0.4, 0.4, 0.1),
    T.RandomApply([T.GaussianBlur(5)], p=0.25),
    T.ToTensor(),  # <-- Add this before RandomErasing
    T.RandomApply([T.RandomErasing()], p=0.1),
    T.ToPILImage(),  # <-- Convert back to PIL if your processor needs PIL Images
])

# Augmentation counter tracking
aug_count = Counter()

# After mapping finishes use the new make_transform_function
dataset = dataset.map(make_transform_function(processor, hard_class_ids))
formatted_counts = {LABEL_NAMES[k]: v for k, v in aug_count.items()}
print(f"‚úÖ Augmentation counts: {formatted_counts}")

# Explicitly log dataset snapshots (class distribution) to a 
# CSV or JSON after each run for easy future diffing and tracking
snapshot_path = os.path.join(SAVE_DIR, f"{VERSION}_augmentation_snapshot.csv")
aug_snapshot = pd.DataFrame.from_dict(dict(aug_count), orient='index', columns=['count'])
aug_snapshot.to_csv(snapshot_path)

print(f"‚úÖ Saved augmentation snapshot to {snapshot_path}")

Map:   0%|          | 0/17504 [00:00<?, ? examples/s]

‚úÖ Augmentation counts: {}
‚úÖ Saved augmentation snapshot to /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335/V16_augmentation_snapshot.csv


In [16]:
# --------------------------
# 9. Train-Validation Split
# --------------------------
split_dataset = dataset.train_test_split(test_size=0.2)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]

# üö¶ Check and print label distributions across all important splits
check_all_label_integrity(
    {
        "full dataset (post-aug)": dataset,
        "train set": train_dataset,
        "val set": eval_dataset,
        # "post-balance train": train_dataset_balanced,
    },
    LABEL_NAMES, label2id
)


üö® Label distribution for: full dataset (post-aug)
  anger       : 2302
  disgust     : 309
  fear        : 1432
  happiness   : 2892
  neutral     : 3334
  questioning : 1943
  sadness     : 1706
  surprise    : 2779
  contempt    : 421
  unknown     : 386

üö® Label distribution for: train set
  anger       : 1888
  disgust     : 241
  fear        : 1151
  happiness   : 2314
  neutral     : 2661
  questioning : 1556
  sadness     : 1337
  surprise    : 2237
  contempt    : 327
  unknown     : 291

üö® Label distribution for: val set
  anger       : 414
  disgust     : 68
  fear        : 281
  happiness   : 578
  neutral     : 673
  questioning : 387
  sadness     : 369
  surprise    : 542
  contempt    : 94
  unknown     : 95


In [17]:
# --------------------------
# 10. Label Distribution Snapshot and Drift Monitor
# --------------------------
snapshot_dir = os.path.join(SAVE_DIR, "label_snapshots")
os.makedirs(snapshot_dir, exist_ok=True)

# Count current training labels
train_label_names = [LABEL_NAMES[i] for i in train_dataset['label']]
label_counts = pd.Series(train_label_names).value_counts().sort_index()
label_counts.name = VERSION

# Save snapshot CSV
snapshot_path = os.path.join(snapshot_dir, f"{VERSION}_label_distribution.csv")
label_counts.to_csv(snapshot_path)
print(f"üìä Saved label distribution snapshot: {snapshot_path}")

# Optionally compare to previous version
previous_versions = sorted([
    f for f in os.listdir(snapshot_dir) if f.endswith(".csv") and not f.startswith(VERSION)
])
if previous_versions:
    latest_prev = previous_versions[-1]
    prev_df = pd.read_csv(os.path.join(snapshot_dir, latest_prev), index_col=0)
    diff = label_counts.subtract(prev_df.iloc[:, 0], fill_value=0)
    print("üîç Label count change since last snapshot:")
    print(diff)

# üö¶ Check and print label distributions across all important splits
check_all_label_integrity(
    {
        "full dataset (post-aug)": dataset,
        "train set": train_dataset,
        "val set": eval_dataset,
        # "post-balance train": train_dataset_balanced,
    },
    LABEL_NAMES, label2id
)

üìä Saved label distribution snapshot: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335/label_snapshots/V16_label_distribution.csv

üö® Label distribution for: full dataset (post-aug)
  anger       : 2302
  disgust     : 309
  fear        : 1432
  happiness   : 2892
  neutral     : 3334
  questioning : 1943
  sadness     : 1706
  surprise    : 2779
  contempt    : 421
  unknown     : 386

üö® Label distribution for: train set
  anger       : 1888
  disgust     : 241
  fear        : 1151
  happiness   : 2314
  neutral     : 2661
  questioning : 1556
  sadness     : 1337
  surprise    : 2237
  contempt    : 327
  unknown     : 291

üö® Label distribution for: val set
  anger       : 414
  disgust     : 68
  fear        : 281
  happiness   : 578
  neutral     : 673
  questioning : 387
  sadness     : 369
  surprise    : 542
  contempt    : 94
  unknown     : 95


In [18]:
# --------------------------
# 11. Balance Dataset (with NO oversampling for 'unknown')
# --------------------------
MINORITY_CAP = 1500
balanced_subsets = []
label_counts = Counter(dataset["label"])

for label, count in label_counts.items():
    subset = dataset.filter(lambda x: x['label'] == label, num_proc=1)
    class_name = LABEL_NAMES[label]
    if class_name == "unknown":
        balanced_subsets.append(subset)
    elif count < MINORITY_CAP:
        multiplier = MINORITY_CAP // len(subset)
        remainder = MINORITY_CAP % len(subset)
        subset = concatenate_datasets([subset] * multiplier + [subset.select(range(remainder))])
        balanced_subsets.append(subset)
    else:
        # Append full set (no downsampling for majority classes)
        balanced_subsets.append(subset)

train_dataset_balanced = concatenate_datasets(balanced_subsets).shuffle(seed=42)

train_dataset_balanced = concatenate_datasets(balanced_subsets).shuffle(seed=42)
print("After balancing:", Counter(train_dataset_balanced["label"]))

# Calculate weights: Give hard classes 2x, others 1x
weights = [2.0 if l in hard_class_ids else 1.0 for l in train_dataset_balanced["label"]]
weights = torch.DoubleTensor(weights)
sampler = torch.utils.data.WeightedRandomSampler(
    weights=weights, num_samples=len(weights), replacement=True
)

# üö¶ Check and print label distributions across all important splits
check_all_label_integrity(
    {
        "full dataset (post-aug)": dataset,
        "train set": train_dataset,
        "val set": eval_dataset,
        # "post-balance train": train_dataset_balanced,
    },
    LABEL_NAMES, label2id
)

Filter:   0%|          | 0/17504 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17504 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17504 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17504 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17504 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17504 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17504 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17504 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17504 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17504 [00:00<?, ? examples/s]

After balancing: Counter({4: 3334, 3: 2892, 7: 2779, 0: 2302, 5: 1943, 6: 1706, 1: 1500, 8: 1500, 2: 1500, 9: 386})

üö® Label distribution for: full dataset (post-aug)
  anger       : 2302
  disgust     : 309
  fear        : 1432
  happiness   : 2892
  neutral     : 3334
  questioning : 1943
  sadness     : 1706
  surprise    : 2779
  contempt    : 421
  unknown     : 386

üö® Label distribution for: train set
  anger       : 1888
  disgust     : 241
  fear        : 1151
  happiness   : 2314
  neutral     : 2661
  questioning : 1556
  sadness     : 1337
  surprise    : 2237
  contempt    : 327
  unknown     : 291

üö® Label distribution for: val set
  anger       : 414
  disgust     : 68
  fear        : 281
  happiness   : 578
  neutral     : 673
  questioning : 387
  sadness     : 369
  surprise    : 542
  contempt    : 94
  unknown     : 95


In [19]:
# --------------------------
# 12. Define Training Arguments for Robust Fine-Tuning
# --------------------------
training_args = TrainingArguments(
    output_dir=SAVE_DIR,                   # Directory to save checkpoints and the final model
    eval_strategy="epoch",                 # Evaluate at the end of each epoch
    save_strategy="epoch",                 # Save checkpoint at each epoch
    save_total_limit=2,                    # ‚úÖ (optional) Keep only last 2 checkpoints to save space
    learning_rate=4e-5,                    # A conservative learning rate for fine-tuning
    per_device_train_batch_size=8,         # Adjust based on your CPU memory limits
    per_device_eval_batch_size=8,
    num_train_epochs=5,                    # Fine-tune for a few epochs (adjust as needed)
    load_best_model_at_end=True,           # Automatically load the best model when training finishes
    metric_for_best_model="accuracy",      # Monitor accuracy for best model selection
    logging_dir=os.path.join(SAVE_DIR, "logs"),  # ‚úÖ Save logs inside versioned folder
    logging_strategy="epoch",                 # ‚úÖ Log once per epoch
    save_safetensors=True                  # ‚úÖ Optional: saves model weights in `.safetensors` (safe format)
)

In [20]:
# --------------------------
# 13. Define Compute Metrics
# --------------------------
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = (predictions == labels).mean()
    return {"accuracy": accuracy}

In [21]:
# --------------------------
# 14. Trainer with Class-Weighted Loss
# --------------------------

# Define custom Trainer to inject class weights
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        # Use smoothed CE + confidence penalty
        smooth_ce_loss = SmoothedCrossEntropyLoss(smoothing=0.05)
        loss = smooth_ce_loss(logits, labels) + confidence_penalty(logits, beta=0.05)
        return (loss, outputs) if return_outputs else loss

# Modify training args for learning rate scheduling and early stopping
training_args.load_best_model_at_end = True
training_args.metric_for_best_model = "eval_loss"
training_args.evaluation_strategy = "epoch"
training_args.save_strategy = "epoch"

# Add EarlyStoppingCallback
early_stop_callback = EarlyStoppingCallback(
    early_stopping_patience=3,
    early_stopping_threshold=0.001
)

# Initialize WeightedTrainer with focal loss, confidence penalty, and label smoothing
trainer = WeightedTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics_with_confusion,
    optimizers=(
        AdamW(model.parameters(), lr=training_args.learning_rate, weight_decay=0.01),
        None
    ),
    callbacks=[early_stop_callback]
)

# T_0 = epochs before first restart, T_mult = restart multiplier
scheduler = CosineAnnealingWarmRestarts(trainer.optimizer, T_0=2, T_mult=2)

# Add scheduler step logic inside the training loop:
original_train = trainer.train

def modified_train(*args, **kwargs):
    result = original_train(*args, **kwargs)
    scheduler.step(trainer.state.epoch)  # instead of eval_loss
    return result

# Fine-tune model
trainer.train()

# Save model
save_model_and_processor(model, processor, SAVE_DIR)



Epoch,Training Loss,Validation Loss,Accuracy
1,0.5595,0.532953,0.927164
2,0.4049,0.509697,0.943445
3,0.352,0.521151,0.943445
4,0.3308,0.518269,0.946015



Classification Report:
              precision    recall  f1-score   support

       anger       0.96      0.98      0.97       414
     disgust       0.69      0.78      0.73        68
        fear       0.93      0.65      0.77       281
   happiness       0.98      1.00      0.99       578
     neutral       0.99      0.96      0.98       673
 questioning       0.73      0.93      0.82       387
     sadness       0.99      0.93      0.96       369
    surprise       0.99      0.99      0.99       542
    contempt       0.52      0.47      0.49        94
     unknown       0.99      1.00      0.99        95

    accuracy                           0.93      3501
   macro avg       0.88      0.87      0.87      3501
weighted avg       0.93      0.93      0.93      3501


Top 3 confused class pairs:
  - fear ‚Üí questioning: 76 instances
  - contempt ‚Üí questioning: 44 instances
  - questioning ‚Üí fear: 11 instances

üß† Avg prediction entropy: 0.3843

üîç Class entropies (sorted)




Classification Report:
              precision    recall  f1-score   support

       anger       0.96      0.99      0.98       414
     disgust       0.86      0.54      0.67        68
        fear       0.83      0.86      0.84       281
   happiness       0.99      0.99      0.99       578
     neutral       0.99      0.97      0.98       673
 questioning       0.83      0.89      0.86       387
     sadness       0.96      0.98      0.97       369
    surprise       1.00      0.99      0.99       542
    contempt       0.61      0.50      0.55        94
     unknown       0.99      1.00      0.99        95

    accuracy                           0.94      3501
   macro avg       0.90      0.87      0.88      3501
weighted avg       0.94      0.94      0.94      3501


Top 3 confused class pairs:
  - contempt ‚Üí questioning: 34 instances
  - questioning ‚Üí fear: 30 instances
  - fear ‚Üí questioning: 27 instances

üß† Avg prediction entropy: 0.3051

üîç Class entropies (sorted)




Classification Report:
              precision    recall  f1-score   support

       anger       0.99      0.98      0.98       414
     disgust       0.78      0.72      0.75        68
        fear       0.84      0.81      0.83       281
   happiness       0.98      1.00      0.99       578
     neutral       0.99      0.99      0.99       673
 questioning       0.80      0.90      0.85       387
     sadness       0.98      0.98      0.98       369
    surprise       0.99      0.99      0.99       542
    contempt       0.70      0.41      0.52        94
     unknown       0.98      1.00      0.99        95

    accuracy                           0.94      3501
   macro avg       0.90      0.88      0.89      3501
weighted avg       0.94      0.94      0.94      3501


Top 3 confused class pairs:
  - contempt ‚Üí questioning: 41 instances
  - fear ‚Üí questioning: 38 instances
  - questioning ‚Üí fear: 25 instances

üß† Avg prediction entropy: 0.2922

üîç Class entropies (sorted)




Classification Report:
              precision    recall  f1-score   support

       anger       0.99      0.99      0.99       414
     disgust       0.82      0.68      0.74        68
        fear       0.78      0.86      0.82       281
   happiness       0.99      1.00      1.00       578
     neutral       0.99      0.99      0.99       673
 questioning       0.83      0.85      0.84       387
     sadness       0.98      0.99      0.98       369
    surprise       1.00      1.00      1.00       542
    contempt       0.68      0.48      0.56        94
     unknown       0.99      1.00      0.99        95

    accuracy                           0.95      3501
   macro avg       0.91      0.88      0.89      3501
weighted avg       0.95      0.95      0.94      3501


Top 3 confused class pairs:
  - questioning ‚Üí fear: 43 instances
  - contempt ‚Üí questioning: 33 instances
  - fear ‚Üí questioning: 27 instances

üß† Avg prediction entropy: 0.2833

üîç Class entropies (sorted)

In [None]:
# # --------------------------
# # 15. Rescue & Save from Last Checkpoint (after training)
# # --------------------------
# #in case model save fails, resume from latest checkpoint
# processor.save_pretrained(SAVE_DIR)
# print("‚úÖ Processor manually re-saved.")

# # Use parent directory of SAVE_DIR to locate latest V* folder
# parent_dir = os.path.dirname(SAVE_DIR)
# v_folders = [
#     d for d in os.listdir(parent_dir)
#     if os.path.isdir(os.path.join(parent_dir, d)) and d.startswith("V")
# ]

# def extract_timestamp(name):
#     try:
#         _, date_str, time_str = name.split("_")
#         return datetime.strptime(f"{date_str}_{time_str}", "%Y%m%d_%H%M%S")
#     except Exception:
#         return datetime.min

# latest_version_folder = max(v_folders, key=extract_timestamp)
# latest_version_path = os.path.join(parent_dir, latest_version_folder)
# print(f"üóÇÔ∏è Using latest version folder: {latest_version_path}")

# # Locate latest checkpoint within that version folder
# checkpoint_dirs = [
#     os.path.join(latest_version_path, d)
#     for d in os.listdir(latest_version_path)
#     if d.startswith("checkpoint-") and os.path.isdir(os.path.join(latest_version_path, d))
# ]
# if not checkpoint_dirs:
#     raise ValueError("‚ùå No checkpoint found in latest version folder.")

# latest_checkpoint = max(checkpoint_dirs, key=os.path.getmtime)
# print(f"‚úÖ Found latest checkpoint: {latest_checkpoint}")

# # Load model and processor from latest checkpoint and save them
# model = AutoModelForImageClassification.from_pretrained(latest_checkpoint)
# processor = AutoImageProcessor.from_pretrained(latest_version_path)
# model = model.to("cpu")

In [22]:
# --------------------------
# 16. Inference Utilities
# --------------------------

# Reload Model for Inference
model = AutoModelForImageClassification.from_pretrained(SAVE_DIR).to(device).eval()
print("‚úÖ Model reloaded for inference.")

# Single image prediction (unbatched)
def predict_label(image_path, threshold=0.85):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(image, return_tensors="pt").to(device)
    with torch.no_grad():
        logits = model(**inputs).logits
        probs = F.softmax(logits, dim=-1)
        conf, pred_idx = torch.max(probs, dim=-1)
    return (id2label[pred_idx.item()], conf.item()) if conf.item() >= threshold else ("REVIEW", conf.item())

# Batched prediction
def batch_predict(image_folder, batch_size=64, threshold=0.85):
    all_preds = []
    error_count = 0
    image_paths = [
        p for p in Path(image_folder).rglob("*")
        if is_valid_image(p.name)
    ]

    for i in tqdm(range(0, len(image_paths), batch_size), desc="Running inference in batches"):
        batch_paths = image_paths[i:i + batch_size]
        images, valid_paths = [], []

        for path in batch_paths:
            try:
                img = Image.open(path).convert("RGB")
                images.append(img)
                valid_paths.append(str(path))
            except Exception:
                error_count += 1
                continue

        if not images:
            continue

        inputs = processor(images=images, return_tensors="pt").to(device)
        with torch.no_grad():
            logits = model(**inputs).logits
            probs = F.softmax(logits, dim=-1)
            confs, preds = torch.max(probs, dim=-1)

        for pred, conf, path in zip(preds.tolist(), confs.tolist(), valid_paths):
            all_preds.append(LABEL_NAMES[pred] if conf >= threshold else "REVIEW")

    print(f"‚úÖ Inference complete. Skipped {error_count} invalid image(s).")
    return all_preds

# Distribution plot
def plot_distribution(predictions, output_path):
    label_counts = Counter(predictions)
    labels = sorted(label_counts.keys())
    counts = [label_counts[label] for label in labels]

    plt.figure(figsize=(10, 5))
    plt.bar(labels, counts)
    plt.title("Predicted Expression Distribution")
    plt.xlabel("Expression")
    plt.ylabel("Count")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

Some weights of the model checkpoint at /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335 were not used when initializing ViTForImageClassification: ['classifier.1.bias', 'classifier.1.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model 

‚úÖ Model reloaded for inference.


In [23]:
# --------------------------
# 17. Entry Point for Inference
# --------------------------
if __name__ == "__main__" and RUN_INFERENCE:

    # Auto-locate latest model directory
    OUTPUT_PATH = os.path.join(SAVE_DIR, f"{VERSION}_distribution_plot_{timestamp}.png")

    predictions = batch_predict(IMAGE_DIR)
    reviewed_paths = []
    image_paths = [str(p) for p in Path(IMAGE_DIR).rglob("*") if is_valid_image(p.name)]

    for path, label in zip(image_paths, predictions):
        if label == "REVIEW":
            reviewed_paths.append(path)

    # Save paths to inspect manually
    with open(os.path.join(SAVE_DIR, f"{VERSION}_review_candidates.txt"), "w") as f:
        f.write("\n".join(reviewed_paths))
    print(f"üìù Saved REVIEW file paths to {VERSION}_review_candidates.txt")

    plot_distribution(predictions, OUTPUT_PATH)
    print(f"Distribution plot saved to: {OUTPUT_PATH}")

Running inference in batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 274/274 [05:53<00:00,  1.29s/it]

‚úÖ Inference complete. Skipped 0 invalid image(s).
üìù Saved REVIEW file paths to V16_review_candidates.txt
Distribution plot saved to: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335/V16_distribution_plot_20250618_144335.png





In [24]:
# --------------------------
# 18. Temperature Scaling Calibration 
# --------------------------

# Wrapper model for calibrated inference
class ModelWithTemperature(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)

    def forward(self, input_ids=None, pixel_values=None, **kwargs):
        logits = self.model(pixel_values=pixel_values).logits
        return logits / self.temperature

    def set_temperature(self, logits, labels):
        nll_criterion = nn.CrossEntropyLoss()
        optimizer = LBFGS([self.temperature], lr=0.01, max_iter=50)

        def eval_fn():
            optimizer.zero_grad()
            loss = nll_criterion(logits / self.temperature, labels)
            loss.backward()
            return loss

        optimizer.step(eval_fn)
        print(f"Optimal temperature (wrapped): {self.temperature.item():.4f}")
        return self

# Dynamically locate the most recent V* folder that contains logits/labels
base_dir = os.path.dirname(SAVE_DIR)
v_folders = sorted([
    d for d in os.listdir(base_dir)
    if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("V")
], key=lambda d: os.path.getmtime(os.path.join(base_dir, d)), reverse=True)

logits_path, labels_path = None, None
for v in v_folders:
    version_tag = v.split('_')[0]
    folder_path = os.path.join(base_dir, v)
    logits_candidate = os.path.join(folder_path, f"logits_eval_{version_tag}.npy")
    labels_candidate = os.path.join(folder_path, f"labels_eval_{version_tag}.npy")
    if os.path.exists(logits_candidate) and os.path.exists(labels_candidate):
        INFER_SAVE_DIR = folder_path
        INFER_VERSION = version_tag
        print(f"üìÅ Using calibration files from: {SAVE_DIR}")
        logits_path = logits_candidate
        labels_path = labels_candidate
        break

# --------------------------
# Run calibration
# --------------------------
if logits_path and labels_path:
    result = apply_temperature_scaling(logits_path, labels_path)
    if result is not None:
        temperature, logits, labels = result
        plot_reliability_diagram(logits, labels, temperature)
else:
    print(f"‚ö†Ô∏è Skipping temperature scaling and diagram (missing logits or labels in {SAVE_DIR})")

üìÅ Using calibration files from: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335
üìÇ Loading logits from: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335/logits_eval_V16.npy
üìÇ Loading labels from: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335/labels_eval_V16.npy
‚úÖ Optimal temperature: 1.2230
‚úÖ Calibrated Log Loss: 0.2804
üìä Saved reliability diagram to /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335/V16_reliability_diagram_calibrated.png


In [25]:
# --------------------------
# 19. Review & Relabel 'REVIEW' Predictions (with Audit Logging & Clustering)
# --------------------------

REVIEW_THRESHOLD = 0.85
REVIEW_BY_CLASS_DIR = os.path.join(SAVE_DIR, "review_predictions_by_class")
REVIEW_CSV_LOG = os.path.join(SAVE_DIR, f"{VERSION}_review_predictions_with_preds.csv")
REVIEW_CLUSTER_DIR = os.path.join(SAVE_DIR, "review_predictions_clustered")

os.makedirs(REVIEW_BY_CLASS_DIR, exist_ok=True)
os.makedirs(REVIEW_CLUSTER_DIR, exist_ok=True)

# ---- If you HAVE NOT already generated review CSV (inference stage) ----
if not os.path.exists(REVIEW_CSV_LOG):
    review_log = []
    image_paths = [
        p for p in Path(IMAGE_DIR).rglob("*")
        if p.is_file() and p.suffix.lower() in [".jpg", ".jpeg", ".png"]
    ]
    for img_path in image_paths:
        try:
            image = Image.open(img_path).convert("RGB")
            inputs = processor(image, return_tensors="pt").to(device)
            with torch.no_grad():
                logits = model(**inputs).logits
                probs = F.softmax(logits, dim=-1)
                conf, pred_idx = torch.max(probs, dim=-1)
            conf_val = conf.item()
            pred_label = id2label[pred_idx.item()]
            tag = "REVIEW" if conf_val < REVIEW_THRESHOLD else pred_label
            review_log.append({
                "image_path": str(img_path),
                "predicted_label": pred_label,
                "confidence": round(conf_val, 4),
                "tag": tag
            })
            if tag == "REVIEW":
                target_dir = os.path.join(REVIEW_BY_CLASS_DIR, pred_label)
                os.makedirs(target_dir, exist_ok=True)
                shutil.copy(str(img_path), target_dir)
        except Exception as e:
            print(f"‚ö†Ô∏è Error with image: {img_path} | {e}")
    pd.DataFrame(review_log).to_csv(REVIEW_CSV_LOG, index=False)
    print(f"‚úÖ Completed tagging + copying REVIEW predictions to: {REVIEW_BY_CLASS_DIR}")
    print(f"üìÑ CSV log saved to: {REVIEW_CSV_LOG}")

# ---- If you already HAVE a review CSV (assignment/audit stage) ----
df = pd.read_csv(REVIEW_CSV_LOG)
review_assignment_log = []
for _, row in df.iterrows():
    path = row["image_path"]
    pred_label = row["predicted_label"]
    conf = float(row["confidence"])
    true_label = os.path.basename(os.path.dirname(path))
    assigned = "unknown" if conf < REVIEW_THRESHOLD else pred_label
    dest_dir = os.path.join(REVIEW_BY_CLASS_DIR, assigned)
    os.makedirs(dest_dir, exist_ok=True)
    shutil.copy(path, dest_dir)
    review_assignment_log.append([path, true_label, pred_label, conf, assigned])

log_df = pd.DataFrame(
    review_assignment_log,
    columns=["image_path", "true_label", "pred_label", "confidence", "assigned_folder"]
)
log_df.to_csv(os.path.join(SAVE_DIR, "review_assignment_audit.csv"), index=False)
print("‚úÖ Review assignments (with audit) complete.")

# ---- Perceptual hash clustering of review pool ----
def phash_distance(hash1, hash2):
    return hash1 - hash2

PHASH_CLUSTER_THRESHOLD = 6
image_paths = [row[0] for row in review_assignment_log if row[4] != "unknown"]  # assigned to a class

hashes = []
for img_path in image_paths:
    try:
        img = Image.open(img_path).convert("L").resize((64, 64))
        hashes.append(hex_to_hash(str(phash(img))))
    except Exception as e:
        print(f"phash error: {img_path} | {e}")

clusters = []
used = set()
for i, h1 in enumerate(hashes):
    if i in used:
        continue
    cluster = [image_paths[i]]
    used.add(i)
    for j, h2 in enumerate(hashes):
        if j <= i or j in used:
            continue
        if phash_distance(h1, h2) <= PHASH_CLUSTER_THRESHOLD:
            cluster.append(image_paths[j])
            used.add(j)
    if len(cluster) > 1:
        clusters.append(cluster)

if not clusters:
    print(f"‚ö†Ô∏è No clusters found for review. {REVIEW_CLUSTER_DIR} will remain empty.")
else:
    for idx, cluster in enumerate(clusters):
        out_dir = os.path.join(REVIEW_CLUSTER_DIR, f"cluster_{idx}")
        os.makedirs(out_dir, exist_ok=True)
        for p in cluster:
            shutil.copy(p, out_dir)
    print(f"‚úÖ Saved {len(clusters)} clusters to {REVIEW_CLUSTER_DIR}")

‚úÖ Completed tagging + copying REVIEW predictions to: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335/review_predictions_by_class
üìÑ CSV log saved to: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335/V16_review_predictions_with_preds.csv
‚úÖ Review assignments (with audit) complete.
‚ö†Ô∏è No clusters found for review. /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335/review_predictions_clustered will remain empty.


In [28]:
# --------------------------
# 20. REVIEW Pool Diagnostics & Hard Confusion Mining
# --------------------------

# A. Flag hard confusion pairs for manual review
REVIEW_CONFUSION_PAIRS = [("contempt", "questioning"), ("fear", "surprise")]

def parse_review_confusions(csv_path, confusion_pairs):
    import csv
    flagged_imgs = {pair: [] for pair in confusion_pairs}
    with open(csv_path) as f:
        reader = csv.DictReader(f)
        for row in reader:
            pred = row["predicted_label"]
            true = os.path.basename(os.path.dirname(row["image_path"]))
            conf = float(row["confidence"])
            for a, b in confusion_pairs:
                if ((pred == a and true == b) or (pred == b and true == a)) and conf < 0.8:
                    flagged_imgs[(a, b)].append(row["image_path"])
    return flagged_imgs

confusion_candidates = parse_review_confusions(REVIEW_CSV_LOG, REVIEW_CONFUSION_PAIRS)
for pair, imgs in confusion_candidates.items():
    print(f"\nFlagged {len(imgs)} hard negatives for {pair}:")
    out_path = os.path.join(SAVE_DIR, f"review_hardneg_{pair[0]}_{pair[1]}.txt")
    with open(out_path, "w") as f:
        f.write("\n".join(imgs))
    print(f"  Saved list: {out_path}")

# B. Organize REVIEW-tagged images by predicted class (for curation)
REVIEW_SORT_DIR = os.path.join(SAVE_DIR, "review_predictions_by_class")
os.makedirs(REVIEW_SORT_DIR, exist_ok=True)
review_txt_path = os.path.join(SAVE_DIR, f"{VERSION}_review_candidates.txt")
csv_path = os.path.join(SAVE_DIR, f"{VERSION}_review_predictions_with_preds.csv")

if os.path.exists(review_txt_path) and os.path.exists(csv_path):
    with open(review_txt_path, "r") as f:
        review_paths = {line.strip() for line in f.readlines()}

    df = pd.read_csv(csv_path)
    count = 0

    print(f"üîç Found {len(df)} total predictions (CSV) and {len(review_paths)} REVIEW-tagged paths.")

    for _, row in df.iterrows():
        path = row["image_path"]
        label = row["predicted_label"]
        conf = row["confidence"]

        if path in review_paths and label != "REVIEW":
            dest_dir = os.path.join(REVIEW_SORT_DIR, label)
            os.makedirs(dest_dir, exist_ok=True)
            shutil.copy(path, dest_dir)
            count += 1

    print(f"üìÇ Grouped {count} REVIEW images into folders by predicted label in: {REVIEW_SORT_DIR}")
else:
    print("‚ö†Ô∏è Missing review candidates file or prediction CSV.")


Flagged 181 hard negatives for ('contempt', 'questioning'):
  Saved list: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335/review_hardneg_contempt_questioning.txt

Flagged 40 hard negatives for ('fear', 'surprise'):
  Saved list: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335/review_hardneg_fear_surprise.txt
üîç Found 17364 total predictions (CSV) and 17505 REVIEW-tagged paths.
üìÇ Grouped 17364 REVIEW images into folders by predicted label in: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V16_20250618_144335/review_predictions_by_class


In [29]:
# --------------------------
# 21. Visualization & Error Tracking
# --------------------------

print("Label name/id mapping:")
for idx, name in enumerate(LABEL_NAMES):
    print(f"{idx}: {name}")

# Defensive: Check that metrics file exists before plotting
per_class_csv = os.path.join(SAVE_DIR, "per_class_metrics.csv")
if not os.path.exists(per_class_csv):
    print(f"‚ö†Ô∏è Metrics file {per_class_csv} not found.")
else:
    metrics_df = pd.read_csv(per_class_csv)
    last_row = metrics_df.iloc[-1]

    fig, ax = plt.subplots(figsize=(10,6))
    f1s = [last_row[f"f1_{n}"] for n in LABEL_NAMES]
    ax.bar(LABEL_NAMES, f1s)
    ax.set_title("Per-Class F1 (Last Epoch)")
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, "per_class_f1.png"))
    plt.close()

    # Bar plot of per-class entropy
    entropies = [last_row[f"entropy_{n}"] for n in LABEL_NAMES]
    fig, ax = plt.subplots(figsize=(10,6))
    ax.bar(LABEL_NAMES, entropies)
    ax.set_title("Per-Class Mean Entropy (Last Epoch)")
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, "per_class_entropy.png"))
    plt.close()

    # Histogram for REVIEW pool
    review_counts = Counter()
    if os.path.exists(REVIEW_SORT_DIR):
        for label_dir in os.listdir(REVIEW_SORT_DIR):
            count = len(os.listdir(os.path.join(REVIEW_SORT_DIR, label_dir)))
            review_counts[label_dir] = count
        plt.bar(review_counts.keys(), review_counts.values())
        plt.title("REVIEW Pool Distribution by Predicted Class")
        plt.savefig(os.path.join(SAVE_DIR, "review_pool_distribution.png"))
        plt.close()
        # Flag if >70% in one class
        total = sum(review_counts.values())
        for label, count in review_counts.items():
            if total > 0 and count / total > 0.7:
                print(f"‚ö†Ô∏è REVIEW pool highly imbalanced: {count/total:.1%} in '{label}'")

    # Audit print block (as before)
    print("Sample review predictions (audit):")
    if 'log_df' in locals():
        print(log_df[["image_path", "true_label", "pred_label", "confidence", "assigned_folder"]].head())
    elif 'df' in locals():
        print(df[["image_path", "true_label", "predicted_label", "confidence"]].head())
    else:
        print("No review/audit DataFrame found for printing.")

# ‚úÖ AUDIT BLOCK
print("Sample review predictions (audit):")
if 'log_df' in locals():
    print(log_df[["image_path", "true_label", "pred_label", "confidence", "assigned_folder"]].head())
elif 'df' in locals():
    print(df[["image_path", "true_label", "predicted_label", "confidence"]].head())
else:
    print("No review/audit DataFrame found for printing.")

Label name/id mapping:
0: anger
1: disgust
2: fear
3: happiness
4: neutral
5: questioning
6: sadness
7: surprise
8: contempt
9: unknown
Sample review predictions (audit):
                                          image_path true_label pred_label  \
0  /Users/natalyagrokh/AI/ml_expressions/img_data...   contempt       fear   
1  /Users/natalyagrokh/AI/ml_expressions/img_data...   contempt   surprise   
2  /Users/natalyagrokh/AI/ml_expressions/img_data...   contempt    unknown   
3  /Users/natalyagrokh/AI/ml_expressions/img_data...   contempt   surprise   
4  /Users/natalyagrokh/AI/ml_expressions/img_data...   contempt   surprise   

   confidence assigned_folder  
0      0.1819         unknown  
1      0.1790         unknown  
2      0.1499         unknown  
3      0.1673         unknown  
4      0.1715         unknown  
Sample review predictions (audit):
                                          image_path true_label pred_label  \
0  /Users/natalyagrokh/AI/ml_expressions/img_data...   

In [30]:
# --------------------------
# 22. Deployment Readiness Assertions and Flags
# --------------------------

# Load metrics
metrics_df = pd.read_csv(os.path.join(SAVE_DIR, "per_class_metrics.csv"))
last = metrics_df.iloc[-1]
warn = False

for cname in LABEL_NAMES:
    f1 = last[f"f1_{cname}"]
    entropy = last[f"entropy_{cname}"]
    if f1 < 0.8:
        print(f"üö® F1 < 0.8 for class '{cname}': {f1:.2f}")
        warn = True
    if entropy > 0.4:
        print(f"üö® Entropy > 0.4 for class '{cname}': {entropy:.2f}")
        warn = True

if not warn:
    print("‚úÖ All classes ready for deployment: F1 >= 0.8 and entropy <= 0.4")
else:
    print("‚ö†Ô∏è Some classes not deployment-ready! Address above issues before production.")


üö® F1 < 0.8 for class 'disgust': 0.74
üö® Entropy > 0.4 for class 'disgust': 0.45
üö® F1 < 0.8 for class 'contempt': 0.56
üö® Entropy > 0.4 for class 'contempt': 0.46
‚ö†Ô∏è Some classes not deployment-ready! Address above issues before production.
