In [None]:
#V10 changes:
    # renamed variables for clarity
    # section #4 
        # perceptual clustering + relabeling for disgust
        # AU-based filtering
    # section #6 added targeted augmenation 
        # manual questioning and disgust to dynamic augmentation targeting
    # section #9 label distribution audit
    # section #12 per-class confidence entropy tracking

In [1]:
# --------------------------
# 0. Imports
# --------------------------
import accelerate
import dill
import gc
import glob
import torch
import matplotlib.pyplot as plt
import multiprocessing as mp
import numpy as np
import os
import pandas as pd
import random
import seaborn as sns
import shutil
import subprocess
import sys
import tensorflow as tf
import time
import torchvision.transforms as T
import transformers

from collections import Counter
from datasets import ClassLabel, Dataset, Features, Image as DatasetsImage, concatenate_datasets, load_dataset
from datetime import datetime
from functools import partial
from imagehash import phash
from io import BytesIO
from pathlib import Path
from PIL import Image, ImageOps, ExifTags, UnidentifiedImageError
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, classification_report, log_loss
from torch import nn
from torch.nn import functional as F
from torch.optim import LBFGS
from torchvision.transforms import ToPILImage
from tqdm import tqdm
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    EarlyStoppingCallback,
    TrainingArguments,
    Trainer
)

In [2]:
# --------------------------
# 1. Global Configurations
# --------------------------
RUN_INFERENCE = True  # Toggle this off to disable running inference
IMAGE_DIR = "/Users/natalyagrokh/AI/ml_expressions/img_datasets/ferckjalf_dataset"
BASE_PATH = IMAGE_DIR

LABEL_NAMES = [
    "anger", "disgust", "fear", "happiness",
    "sadness", "surprise", "neutral", "questioning"
]
id2label = dict(enumerate(LABEL_NAMES))
label2id = {v: k for k, v in id2label.items()}

VALID_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff")

def is_valid_image(filename):
    return filename.lower().endswith(VALID_EXTENSIONS) and not filename.startswith("._")

label_mapping = {name.lower(): name for name in LABEL_NAMES}

# üî¢ Dynamically determine the next version
def get_next_version(base_dir):

    # Use glob to find all entries matching the pattern
    all_entries = glob.glob(os.path.join(base_dir, "V*_*"))
    
    # Filter to include only directories
    existing = [
        os.path.basename(d) for d in all_entries if os.path.isdir(d)
    ]

    # Extract version numbers from the directory names
    versions = [
        int(d[1:].split("_")[0]) for d in existing
        if d.startswith("V") and "_" in d and d[1:].split("_")[0].isdigit()
    ]
    
    # Determine the next version number
    next_version = max(versions, default=0) + 1
    return f"V{next_version}"

# Automatically create a versioned output folder
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
VERSION = get_next_version("/Users/natalyagrokh/AI/ml_expressions/img_expressions")
VERSION_TAG = VERSION + "_" + timestamp
SAVE_DIR = os.path.join("/Users/natalyagrokh/AI/ml_expressions/img_expressions", VERSION_TAG)
os.makedirs(SAVE_DIR, exist_ok=True)
print(f"üìÅ Output directory created: {SAVE_DIR}")

üìÅ Output directory created: /Users/natalyagrokh/AI/ml_expressions/img_expressions/V10_20250515_150651


In [3]:
# --------------------------
# 2. Auto-Load Latest Pretrained Model and Processor
# --------------------------

# Automatically load latest model path
MODEL_ROOT = "/Users/natalyagrokh/AI/ml_expressions/img_expressions"
# List all version folders in descending order
model_dirs = sorted(
    [os.path.join(MODEL_ROOT, d) for d in os.listdir(MODEL_ROOT)
     if d.startswith("V") and os.path.isdir(os.path.join(MODEL_ROOT, d))],
    key=lambda x: os.path.getmtime(x),
    reverse=True
)

# Remove the current output version (to avoid loading from empty target)
model_dirs = [d for d in model_dirs if VERSION in d or not d.startswith(VERSION)]
model_dirs = [d for d in model_dirs if os.path.basename(d).startswith("V") and d != SAVE_DIR]

# Pick the most recent complete model (not current output)
if len(model_dirs) < 1:
    raise FileNotFoundError("‚ùå No earlier model folders found.")
model_path = model_dirs[0]
print(f"‚úÖ Auto-loaded model from: {model_path}")

# Load base model and processor
model = AutoModelForImageClassification.from_pretrained(model_path)
processor = AutoImageProcessor.from_pretrained(model_path)

# Replace classification head to match current label schema
model.classifier = torch.nn.Linear(model.classifier.in_features, len(id2label))
model.config.id2label = id2label
model.config.label2id = label2id
model.config.num_labels = len(LABEL_NAMES)

# Define device and push model to device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("üñ•Ô∏è Using device:", device)
model.to(device).eval()

‚úÖ Auto-loaded model from: /Users/natalyagrokh/AI/ml_expressions/img_expressions/V9_20250513_152634
üñ•Ô∏è Using device: mps


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [4]:
# --------------------------
# 3. Load and Prepare Dataset
# --------------------------
dataset = load_dataset(
    "imagefolder",
    data_dir="/Users/natalyagrokh/AI/ml_expressions/img_datasets/ferckjalf_dataset",
    split="train",
    cache_dir="/tmp/hf_cache"
)

counter = {"n": 0}

def reconcile_labels(example):
    counter["n"] += 1
    if counter["n"] % 1000 == 0:
        print(f"Processed {counter['n']} images...")

    label = example.get("label", None)

    if isinstance(label, int):
        original_label = dataset.features["label"].int2str(label).strip().lower()
    elif isinstance(label, str):
        original_label = label.strip().lower()
    else:
        file_path = getattr(example["image"], "filename", None)
        original_label = os.path.basename(os.path.dirname(file_path)).lower() if file_path else None

    pretrain_label = label_mapping.get(original_label)
    example["label"] = label2id[pretrain_label] if pretrain_label is not None else -1
    return example

# Single-threaded labeling to preserve .filename
dataset = dataset.map(reconcile_labels, desc="Re-labeling dataset")
dataset = dataset.filter(lambda x: x["label"] != -1)

print(f"‚úÖ Total examples after filtering: {len(dataset)}")

Resolving data files:   0%|          | 0/22055 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/22055 [00:00<?, ?files/s]

Generating train split: 0 examples [00:00, ? examples/s]

Re-labeling dataset:   0%|          | 0/22055 [00:00<?, ? examples/s]

Processed 1000 images...
Processed 2000 images...
Processed 3000 images...
Processed 4000 images...
Processed 5000 images...
Processed 6000 images...
Processed 7000 images...
Processed 8000 images...
Processed 9000 images...
Processed 10000 images...
Processed 11000 images...
Processed 12000 images...
Processed 13000 images...
Processed 14000 images...
Processed 15000 images...
Processed 16000 images...
Processed 17000 images...
Processed 18000 images...
Processed 19000 images...
Processed 20000 images...
Processed 21000 images...
Processed 22000 images...


Filter:   0%|          | 0/22055 [00:00<?, ? examples/s]

‚úÖ Total examples after filtering: 22055


In [5]:
# --------------------------
# 4. Dataset Label Overview and Folder Stats
# --------------------------
def analyze_dataset_structure(dataset, id2label, base_path):
    # Print label schema from the dataset
    print("Label schema (from dataset):", dataset.features["label"])

    # Label distribution from the dataset object
    label_counts = Counter(dataset["label"])
    print("\nüìä Full dataset label distribution (from Dataset object):")
    for label_id, count in sorted(label_counts.items()):
        print(f"  {id2label[label_id]}: {count} examples")

    # Dynamically detect minority classes (lowest 3 frequencies)
    N = 3
    minority_classes = set(
        label for label, _ in sorted(label_counts.items(), key=lambda x: x[1])[:N]
    )
    print(f"\n‚ö†Ô∏è  Dynamically identified minority classes: {[id2label[i] for i in minority_classes]}")

    # Count images per directory, and store for later validation
    folder_image_counts = {}
    print("\nüìÇ Image count per label folder:")
    for label in sorted(os.listdir(base_path)):
        label_path = os.path.join(base_path, label)
        if os.path.isdir(label_path):
            valid_images = [img for img in os.listdir(label_path) if is_valid_image(img)]
            folder_image_counts[label] = len(valid_images)
            print(f"  {label}: {len(valid_images)} images")

    return minority_classes, folder_image_counts

# Example usage right after dataset loading
minority_classes, folder_image_counts = analyze_dataset_structure(dataset, id2label, BASE_PATH)

Label schema (from dataset): ClassLabel(names=['anger', 'disgust', 'fear', 'happiness', 'neutral', 'questioning', 'sadness', 'surprise'], id=None)

üìä Full dataset label distribution (from Dataset object):
  anger: 2196 examples
  disgust: 251 examples
  fear: 1314 examples
  happiness: 9172 examples
  sadness: 1548 examples
  surprise: 2571 examples
  neutral: 3153 examples
  questioning: 1850 examples

‚ö†Ô∏è  Dynamically identified minority classes: ['disgust', 'fear', 'sadness']

üìÇ Image count per label folder:
  anger: 2196 images
  disgust: 251 images
  fear: 1314 images
  happiness: 9172 images
  neutral: 3153 images
  questioning: 1850 images
  sadness: 1548 images
  surprise: 2571 images


In [6]:
# --------------------------
# 5. Disgust Curation: Perceptual Clustering & AU Filtering (Optional)
# --------------------------
def compute_hash(image_path):
    try:
        img = Image.open(image_path).convert("L").resize((64, 64))
        return str(phash(img))
    except Exception:
        return None

# Directory to inspect for disgust images
disgust_dir = os.path.join(BASE_PATH, "disgust")
disgust_images = [
    os.path.join(disgust_dir, f) for f in os.listdir(disgust_dir)
    if is_valid_image(f)
]

# Compute perceptual hashes for clustering
hash_map = {}
for path in disgust_images:
    h = compute_hash(path)
    if h:
        hash_map.setdefault(h, []).append(path)

# Identify clusters with >1 similar image (potential duplicates or mislabels)
cluster_dir = os.path.join(SAVE_DIR, "disgust_clusters")
os.makedirs(cluster_dir, exist_ok=True)

print("üîç Disgust hash clusters with more than 1 image:")
for h, paths in hash_map.items():
    if len(paths) > 1:
        cluster_path = os.path.join(cluster_dir, h)
        os.makedirs(cluster_path, exist_ok=True)
        for p in paths:
            shutil.copy(p, cluster_path)
        print(f"  - Cluster {h[:8]}: {len(paths)} images copied for review")

# Optionally filter by AU-related cues in filenames or external labels (if available)
# e.g., filenames tagged with AU9_AU10 or meta file with AU info
AU_keywords = ["au9", "au10", "AU9", "AU10"]
au_filtered_paths = [p for p in disgust_images if any(k in p for k in AU_keywords)]

print(f"\nüß† AU-filtered disgust examples (based on filename cues): {len(au_filtered_paths)}")
# Optionally copy for inspection
au_dir = os.path.join(SAVE_DIR, "disgust_with_AUs")
os.makedirs(au_dir, exist_ok=True)
for p in au_filtered_paths:
    shutil.copy(p, au_dir)

üîç Disgust hash clusters with more than 1 image:

üß† AU-filtered disgust examples (based on filename cues): 0


In [7]:
# --------------------------
# 6. Class Frequency-Aware Augmentation Targeting
# --------------------------

# Compute label frequencies from train split (post filtering)
label_freqs = Counter(dataset["label"])
label_id2name = {v: k for k, v in label2id.items()}
label_name2id = {v: k for k, v in label_id2name.items()}

# Get lowest-count classes dynamically
minority_by_count = sorted(label_freqs, key=label_freqs.get)[:3]
minority_by_name = [label_id2name[i] for i in minority_by_count]

# Manually include known confused or underperforming classes
manual_focus_classes = ['disgust', 'questioning']

# Merge and deduplicate
minority_class_names = list(set(minority_by_name + manual_focus_classes))

# Final list as label indices
minority_classes = [label_name2id[name] for name in minority_class_names]

print(f"üéØ Targeted minority augmentation will apply to: {minority_class_names}")

üéØ Targeted minority augmentation will apply to: ['fear', 'disgust', 'questioning', 'sadness']


In [9]:
# --------------------------
# 7. Define Data Augmentation and Preprocessing Transformation
# --------------------------

# Baseline augmentation
data_augment = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomRotation(10),
    T.ColorJitter(brightness=0.1, contrast=0.1)
])

# Stronger augmentation for minority classes
minority_aug = T.Compose([
    T.RandomResizedCrop(224, scale=(0.6, 1.0)),
    T.RandomHorizontalFlip(p=0.8),
    T.RandomRotation(30),
    T.ColorJitter(0.4, 0.4, 0.4, 0.2),
    T.RandomGrayscale(p=0.3),
    T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)),
])

#factory function that returns another function -> tranform_function
def make_transform_function(processor, minority_classes):
    def transform_function(example):
        label = example["label"]
        aug_pipeline = minority_aug if label in minority_classes else data_augment

        if example["image"].mode != "RGB":
            example["image"] = example["image"].convert("RGB")

        augmented_image = aug_pipeline(example["image"])
        inputs = processor(augmented_image, return_tensors="pt")
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        inputs["labels"] = example["label"]
        return inputs
    return transform_function

#returned transform_function is applied to each dataset example inside .map()
#each time the dataset runs transform_function(example), it receives:
#inputs = {
#     'pixel_values': tensor,
#     'labels': label_int
# }
dataset = dataset.map(make_transform_function(processor, minority_classes))

Map:   0%|          | 0/22055 [00:00<?, ? examples/s]

In [15]:
# --------------------------
# 8. Train-Validation Split
# --------------------------
split_dataset = dataset.train_test_split(test_size=0.2)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]

In [16]:
# --------------------------
# 9. Label Distribution Snapshot and Drift Monitor
# --------------------------
snapshot_dir = os.path.join(SAVE_DIR, "label_snapshots")
os.makedirs(snapshot_dir, exist_ok=True)

# Count current training labels
train_label_names = [LABEL_NAMES[i] for i in train_dataset['label']]
label_counts = pd.Series(train_label_names).value_counts().sort_index()
label_counts.name = VERSION

# Save snapshot CSV
snapshot_path = os.path.join(snapshot_dir, f"{VERSION}_label_distribution.csv")
label_counts.to_csv(snapshot_path)
print(f"üìä Saved label distribution snapshot: {snapshot_path}")

# Optionally compare to previous version
previous_versions = sorted([
    f for f in os.listdir(snapshot_dir) if f.endswith(".csv") and not f.startswith(VERSION)
])
if previous_versions:
    latest_prev = previous_versions[-1]
    prev_df = pd.read_csv(os.path.join(snapshot_dir, latest_prev), index_col=0)
    diff = label_counts.subtract(prev_df.iloc[:, 0], fill_value=0)
    print("üîç Label count change since last snapshot:")
    print(diff)

üìä Saved label distribution snapshot: /Users/natalyagrokh/AI/ml_expressions/img_expressions/V10_20250515_150651/label_snapshots/V10_label_distribution.csv


In [17]:
# --------------------------
# 10. Balance Dataset
# --------------------------
mp.set_start_method('fork', force=True)

label_target = 3000
balanced_subsets = []

# Dynamically calculate label counts
label_counts = Counter(train_dataset["label"])
print("Original label distribution:", label_counts)

for label, count in label_counts.items():
    subset = train_dataset.filter(lambda x: x['label'] == label, num_proc=1)
    if count > label_target:
        subset = subset.select(random.sample(range(len(subset)), label_target))
    elif count < label_target:
        multiplier = label_target // len(subset)
        remainder = label_target % len(subset)
        subset = concatenate_datasets([subset] * multiplier + [subset.select(range(remainder))])
    balanced_subsets.append(subset)

train_dataset = concatenate_datasets(balanced_subsets).shuffle(seed=42)
print("After balancing:", Counter(train_dataset['label']))

Original label distribution: Counter({3: 7370, 6: 2507, 5: 2061, 0: 1745, 7: 1475, 4: 1246, 2: 1034, 1: 206})


Filter:   0%|          | 0/17644 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17644 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17644 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17644 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17644 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17644 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17644 [00:00<?, ? examples/s]

Filter:   0%|          | 0/17644 [00:00<?, ? examples/s]

After balancing: Counter({7: 3000, 3: 3000, 2: 3000, 6: 3000, 5: 3000, 4: 3000, 1: 3000, 0: 3000})


In [18]:
# --------------------------
# 11. Define Training Arguments for Robust Fine-Tuning
# --------------------------
training_args = TrainingArguments(
    output_dir=SAVE_DIR,                   # Directory to save checkpoints and the final model
    eval_strategy="epoch",                 # Evaluate at the end of each epoch
    save_strategy="epoch",                 # Save checkpoint at each epoch
    save_total_limit=2,                    # ‚úÖ (optional) Keep only last 2 checkpoints to save space
    learning_rate=4e-5,                    # A conservative learning rate for fine-tuning
    per_device_train_batch_size=8,         # Adjust based on your CPU memory limits
    per_device_eval_batch_size=8,
    num_train_epochs=5,                    # Fine-tune for a few epochs (adjust as needed)
    load_best_model_at_end=True,           # Automatically load the best model when training finishes
    metric_for_best_model="accuracy",      # Monitor accuracy for best model selection
    logging_dir=os.path.join(SAVE_DIR, "logs"),  # ‚úÖ Save logs inside versioned folder
    logging_strategy="epoch",                 # ‚úÖ Log once per epoch
    save_safetensors=True                  # ‚úÖ Optional: saves model weights in `.safetensors` (safe format)
)

In [19]:
# --------------------------
# 12. Define Compute and Confusion Metrics
# --------------------------
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = (predictions == labels).mean()
    return {"accuracy": accuracy}

# Define a compute_metrics function w/ confusion matrix logging
def compute_metrics_with_confusion(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)

    print("\nClassification Report:")
    print(classification_report(labels, preds, target_names=LABEL_NAMES))

    # Save raw values for further use if needed
    np.save(os.path.join(SAVE_DIR, f"logits_eval_{VERSION}.npy"), logits)
    np.save(os.path.join(SAVE_DIR, f"labels_eval_{VERSION}.npy"), labels)    

    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm, annot=True, fmt="d", cmap="Blues",
        xticklabels=LABEL_NAMES,
        yticklabels=LABEL_NAMES
    )
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, f"confusion_matrix_epoch_{VERSION}.png"))
    plt.close()

    # Identify top 3 confused class pairs (excluding diagonal)
    confusion_pairs = [
        ((LABEL_NAMES[i], LABEL_NAMES[j]), cm[i][j])
        for i in range(len(LABEL_NAMES))
        for j in range(len(LABEL_NAMES)) if i != j
    ]
    top_confusions = sorted(confusion_pairs, key=lambda x: x[1], reverse=True)[:3]
    print("\nTop 3 confused class pairs:")
    for (true_label, pred_label), count in top_confusions:
        print(f"  - {true_label} ‚Üí {pred_label}: {count} instances")

    # Compute average prediction entropy
    softmax_probs = F.softmax(torch.tensor(logits), dim=-1)
    entropies = -torch.sum(softmax_probs * torch.log(softmax_probs + 1e-12), dim=-1)
    avg_entropy = entropies.mean().item()
    print(f"\nüß† Avg prediction entropy: {avg_entropy:.4f}")

    class_entropies = {}
    for idx, class_name in enumerate(LABEL_NAMES):
        mask = (np.array(labels) == idx)
        if mask.any():
            class_entropy = entropies[mask].mean().item()
            class_entropies[class_name] = class_entropy
    
    for class_name, entropy_val in sorted(class_entropies.items(), key=lambda x: x[1], reverse=True):
        print(f"  - {class_name}: entropy = {entropy_val:.4f}")


    accuracy = (preds == labels).mean()
    return {"accuracy": accuracy}

In [20]:
# --------------------------
# 13. Trainer with Class-Weighted Loss
# --------------------------

# ‚öñÔ∏è Compute dynamic class weights from training labels/set
label_freqs = Counter(train_dataset['label'])
total = sum(label_freqs.values())
class_weights = torch.tensor([total / label_freqs[i] for i in range(len(label_freqs))], dtype=torch.float).to(device)

# üî• Define Focal Loss to focus on hard-to-classify examples
class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal.mean() if self.reduction == 'mean' else focal.sum()

# ‚ö†Ô∏è Confidence penalty discourages overconfident incorrect predictions
def confidence_penalty(logits, beta=0.05):
    probs = F.softmax(logits, dim=1)
    log_probs = F.log_softmax(logits, dim=1)
    entropy = -torch.sum(probs * log_probs, dim=1)
    return beta * entropy.mean()

# Define custom Trainer to inject class weights
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        # Use focal loss instead of standard cross-entropy
        focal_loss_fn = FocalLoss(alpha=1.0, gamma=2.0)
        # Combine with confidence penalty (entropy-based)
        loss = focal_loss_fn(logits, labels) + confidence_penalty(logits, beta=0.05)
        return (loss, outputs) if return_outputs else loss

# trainer initialization
trainer = WeightedTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics_with_confusion,
)

# Fine-tune model
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.1434,0.203124,0.927454
2,0.0238,0.300202,0.929268
3,0.0054,0.423543,0.923147
4,0.0035,0.4362,0.926321
5,0.0003,0.448196,0.930401



Classification Report:
              precision    recall  f1-score   support

       anger       0.94      0.89      0.91       451
     disgust       0.97      0.73      0.84        45
        fear       0.81      0.75      0.78       280
   happiness       0.98      0.96      0.97      1802
     sadness       0.92      0.85      0.89       302
    surprise       0.89      0.98      0.93       510
     neutral       0.93      0.95      0.94       646
 questioning       0.81      0.91      0.86       375

    accuracy                           0.93      4411
   macro avg       0.91      0.88      0.89      4411
weighted avg       0.93      0.93      0.93      4411


Top 3 confused class pairs:
  - fear ‚Üí questioning: 55 instances
  - happiness ‚Üí surprise: 34 instances
  - happiness ‚Üí neutral: 24 instances

üß† Avg prediction entropy: 0.1355
  - fear: entropy = 0.2843
  - questioning: entropy = 0.2411
  - anger: entropy = 0.1973
  - sadness: entropy = 0.1899
  - disgust: entropy

TrainOutput(global_step=15000, training_loss=0.03527554010152817, metrics={'train_runtime': 14831.2143, 'train_samples_per_second': 8.091, 'train_steps_per_second': 1.011, 'total_flos': 9.29953881980928e+18, 'train_loss': 0.03527554010152817, 'epoch': 5.0})

In [None]:
# # --------------------------
# # Rescue & Save from Last Checkpoint (after training)
# # --------------------------
# #in case model save fails, resume from latest checkpoint
# processor.save_pretrained(SAVE_DIR)
# print("‚úÖ Processor manually re-saved.")

# # Use parent directory of SAVE_DIR to locate latest V* folder
# parent_dir = os.path.dirname(SAVE_DIR)
# v_folders = [
#     d for d in os.listdir(parent_dir)
#     if os.path.isdir(os.path.join(parent_dir, d)) and d.startswith("V")
# ]

# def extract_timestamp(name):
#     try:
#         _, date_str, time_str = name.split("_")
#         return datetime.strptime(f"{date_str}_{time_str}", "%Y%m%d_%H%M%S")
#     except Exception:
#         return datetime.min

# latest_version_folder = max(v_folders, key=extract_timestamp)
# latest_version_path = os.path.join(parent_dir, latest_version_folder)
# print(f"üóÇÔ∏è Using latest version folder: {latest_version_path}")

# # Locate latest checkpoint within that version folder
# checkpoint_dirs = [
#     os.path.join(latest_version_path, d)
#     for d in os.listdir(latest_version_path)
#     if d.startswith("checkpoint-") and os.path.isdir(os.path.join(latest_version_path, d))
# ]
# if not checkpoint_dirs:
#     raise ValueError("‚ùå No checkpoint found in latest version folder.")

# latest_checkpoint = max(checkpoint_dirs, key=os.path.getmtime)
# print(f"‚úÖ Found latest checkpoint: {latest_checkpoint}")

# # Load model and processor from latest checkpoint and save them
# model = AutoModelForImageClassification.from_pretrained(latest_checkpoint)
# processor = AutoImageProcessor.from_pretrained(latest_version_path)
# model = model.to("cpu")

In [21]:
# --------------------------
# 14. Save Final Independent Model (Safe Save Mode)
# --------------------------

model = model.to("cpu")  # move to CPU first

# Save processor
processor.save_pretrained(SAVE_DIR)
print(f"‚úÖ Processor saved to: {SAVE_DIR}")

# Save state dict
final_model_path = os.path.join(SAVE_DIR, 'final_model.pth')
torch.save(model.state_dict(), final_model_path)
print(f"‚úÖ State dict saved to: {final_model_path}")

# Save full model
model.save_pretrained(SAVE_DIR, safe_serialization=True)
print(f"‚úÖ Full model saved to: {SAVE_DIR}")

# Save trainer state (if defined)
if 'trainer' in globals():
    try:
        trainer.save_model(os.path.join(SAVE_DIR, "backup_trainer_model"))
        print("‚úÖ Trainer backup saved.")
    except Exception as e:
        print(f"‚ö†Ô∏è Failed to save trainer backup: {e}")
    
# Free memory
del model
gc.collect()
torch.cuda.empty_cache()
print("‚úÖ Memory cleanup complete after save.")

‚úÖ Processor saved to: /Users/natalyagrokh/AI/ml_expressions/img_expressions/V10_20250515_150651
‚úÖ State dict saved to: /Users/natalyagrokh/AI/ml_expressions/img_expressions/V10_20250515_150651/final_model.pth
‚úÖ Full model saved to: /Users/natalyagrokh/AI/ml_expressions/img_expressions/V10_20250515_150651
‚úÖ Trainer backup saved.
‚úÖ Memory cleanup complete after save.


In [24]:
# --------------------------
# 15. Inference Utilities
# --------------------------

# Reload Model for Inference
model = AutoModelForImageClassification.from_pretrained(SAVE_DIR).to(device).eval()
print("‚úÖ Model reloaded for inference.")

# Single image prediction (unbatched)
def predict_label(image_path, threshold=0.85):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(image, return_tensors="pt").to(device)
    with torch.no_grad():
        logits = model(**inputs).logits
        probs = F.softmax(logits, dim=-1)
        conf, pred_idx = torch.max(probs, dim=-1)
    return (id2label[pred_idx.item()], conf.item()) if conf.item() >= threshold else ("REVIEW", conf.item())

# Batched prediction
def batch_predict(image_folder, batch_size=64, threshold=0.85):
    all_preds = []
    error_count = 0
    image_paths = [
        p for p in Path(image_folder).rglob("*")
        if is_valid_image(p.name)
    ]

    for i in tqdm(range(0, len(image_paths), batch_size), desc="Running inference in batches"):
        batch_paths = image_paths[i:i + batch_size]
        images, valid_paths = [], []

        for path in batch_paths:
            try:
                img = Image.open(path).convert("RGB")
                images.append(img)
                valid_paths.append(str(path))
            except Exception:
                error_count += 1
                continue

        if not images:
            continue

        inputs = processor(images=images, return_tensors="pt").to(device)
        with torch.no_grad():
            logits = model(**inputs).logits
            probs = F.softmax(logits, dim=-1)
            confs, preds = torch.max(probs, dim=-1)

        for pred, conf, path in zip(preds.tolist(), confs.tolist(), valid_paths):
            all_preds.append(LABEL_NAMES[pred] if conf >= threshold else "REVIEW")

    print(f"‚úÖ Inference complete. Skipped {error_count} invalid image(s).")
    return all_preds

# Distribution plot
def plot_distribution(predictions, output_path):
    label_counts = Counter(predictions)
    labels = sorted(label_counts.keys())
    counts = [label_counts[label] for label in labels]

    plt.figure(figsize=(10, 5))
    plt.bar(labels, counts)
    plt.title("Predicted Expression Distribution")
    plt.xlabel("Expression")
    plt.ylabel("Count")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

‚úÖ Model reloaded for inference.


In [25]:
# --------------------------
# 16. Entry Point for Inference
# --------------------------
if __name__ == "__main__" and RUN_INFERENCE:

    # Auto-locate latest model directory
    OUTPUT_PATH = os.path.join(SAVE_DIR, f"{VERSION}_distribution_plot_{timestamp}.png")

    predictions = batch_predict(IMAGE_DIR)
    reviewed_paths = []
    image_paths = [str(p) for p in Path(IMAGE_DIR).rglob("*") if is_valid_image(p.name)]

    for path, label in zip(image_paths, predictions):
        if label == "REVIEW":
            reviewed_paths.append(path)

    # Save paths to inspect manually
    with open(os.path.join(SAVE_DIR, f"{VERSION}_review_candidates.txt"), "w") as f:
        f.write("\n".join(reviewed_paths))
    print(f"üìù Saved REVIEW file paths to {VERSION}_review_candidates.txt")

    plot_distribution(predictions, OUTPUT_PATH)
    print(f"Distribution plot saved to: {OUTPUT_PATH}")

Running inference in batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 345/345 [07:45<00:00,  1.35s/it]


‚úÖ Inference complete. Skipped 0 invalid image(s).
üìù Saved REVIEW file paths to V10_review_candidates.txt
Distribution plot saved to: /Users/natalyagrokh/AI/ml_expressions/img_expressions/V10_20250515_150651/V10_distribution_plot_20250515_150651.png


In [28]:
# --------------------------
# 17. Temperature Scaling Calibration 
# --------------------------

# Wrapper model for calibrated inference
class ModelWithTemperature(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)

    def forward(self, input_ids=None, pixel_values=None, **kwargs):
        logits = self.model(pixel_values=pixel_values).logits
        return logits / self.temperature

    def set_temperature(self, logits, labels):
        nll_criterion = nn.CrossEntropyLoss()
        optimizer = LBFGS([self.temperature], lr=0.01, max_iter=50)

        def eval_fn():
            optimizer.zero_grad()
            loss = nll_criterion(logits / self.temperature, labels)
            loss.backward()
            return loss

        optimizer.step(eval_fn)
        print(f"Optimal temperature (wrapped): {self.temperature.item():.4f}")
        return self

# Dynamically locate the most recent V* folder that contains logits/labels
base_dir = os.path.dirname(SAVE_DIR)
v_folders = sorted([
    d for d in os.listdir(base_dir)
    if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("V")
], key=lambda d: os.path.getmtime(os.path.join(base_dir, d)), reverse=True)

logits_path, labels_path = None, None
for v in v_folders:
    version_tag = v.split('_')[0]
    folder_path = os.path.join(base_dir, v)
    logits_candidate = os.path.join(folder_path, f"logits_eval_{version_tag}.npy")
    labels_candidate = os.path.join(folder_path, f"labels_eval_{version_tag}.npy")
    if os.path.exists(logits_candidate) and os.path.exists(labels_candidate):
        INFER_SAVE_DIR = folder_path
        INFER_VERSION = version_tag
        print(f"üìÅ Using calibration files from: {SAVE_DIR}")
        logits_path = logits_candidate
        labels_path = labels_candidate
        break

# Apply calibration from saved logits and labels
def apply_temperature_scaling(logits_path, labels_path):
    if not (os.path.exists(logits_path) and os.path.exists(labels_path)):
        print(f"‚ùå Missing files:\n  - {logits_path if not os.path.exists(logits_path) else ''}\n  - {labels_path if not os.path.exists(labels_path) else ''}")
        return None

    print(f"üìÇ Loading logits from: {logits_path}")
    print(f"üìÇ Loading labels from: {labels_path}")
    
    logits = torch.tensor(np.load(logits_path), dtype=torch.float32).to(device)
    labels = torch.tensor(np.load(labels_path), dtype=torch.long).to(device)

    class TemperatureScaler(nn.Module):
        def __init__(self):
            super().__init__()
            self.temperature = nn.Parameter(torch.ones(1) * 1.5)

        def forward(self, logits):
            return logits / self.temperature

    model = TemperatureScaler().to(device)
    optimizer = LBFGS([model.temperature], lr=0.01, max_iter=50)

    def eval_fn():
        optimizer.zero_grad()
        loss = F.cross_entropy(model(logits), labels)
        loss.backward()
        return loss

    optimizer.step(eval_fn)
    calibrated_logits = model(logits)
    probs = F.softmax(calibrated_logits, dim=1).detach().cpu().numpy()
    logloss = log_loss(labels.cpu().numpy(), probs)

    # Save optimal temperature
    temperature_value = model.temperature.item()
    torch.save(
        torch.tensor([temperature_value]),
        os.path.join(SAVE_DIR, f"{VERSION}_calibrated_temperature.pt")
    )
    print(f"‚úÖ Optimal temperature: {temperature_value:.4f}")
    print(f"‚úÖ Calibrated Log Loss: {logloss:.4f}")
    return temperature_value, logits.cpu(), labels.cpu()

# Reliability diagram
def plot_reliability_diagram(logits, labels, temperature, n_bins=15):
    probs = F.softmax(logits / temperature, dim=1)
    confidences, predictions = torch.max(probs, 1)
    accuracies = predictions.eq(labels)

    bins = torch.linspace(0, 1, n_bins + 1)
    bin_lowers, bin_uppers = bins[:-1], bins[1:]

    bin_accuracies, bin_confidences = [], []
    for lower, upper in zip(bin_lowers, bin_uppers):
        mask = (confidences > lower) & (confidences <= upper)
        if mask.any():
            bin_accuracies.append(accuracies[mask].float().mean())
            bin_confidences.append(confidences[mask].mean())

    plt.figure(figsize=(6, 6))
    plt.plot(bin_confidences, bin_accuracies, marker='o', label='Model')
    plt.plot([0, 1], [0, 1], linestyle='--', label='Perfect Calibration')
    plt.title("Reliability Diagram (After Temperature Scaling)")
    plt.xlabel("Confidence")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    output_path = os.path.join(SAVE_DIR, f"{VERSION}_reliability_diagram_calibrated.png")
    plt.savefig(output_path)
    plt.close()
    print(f"üìä Saved reliability diagram to {output_path}")

# --------------------------
# Run calibration
# --------------------------
if logits_path and labels_path:
    result = apply_temperature_scaling(logits_path, labels_path)
    if result is not None:
        temperature, logits, labels = result
        plot_reliability_diagram(logits, labels, temperature)
else:
    print(f"‚ö†Ô∏è Skipping temperature scaling and diagram (missing logits or labels in {SAVE_DIR})")

üìÅ Using calibration files from: /Users/natalyagrokh/AI/ml_expressions/img_expressions/V10_20250515_150651
üìÇ Loading logits from: /Users/natalyagrokh/AI/ml_expressions/img_expressions/V10_20250515_150651/logits_eval_V10.npy
üìÇ Loading labels from: /Users/natalyagrokh/AI/ml_expressions/img_expressions/V10_20250515_150651/labels_eval_V10.npy
‚úÖ Optimal temperature: 1.7058
‚úÖ Calibrated Log Loss: 0.2930
üìä Saved reliability diagram to /Users/natalyagrokh/AI/ml_expressions/img_expressions/V10_20250515_150651/V10_reliability_diagram_calibrated.png
