In [1]:
# ==============================================================================
#  V39  -  Stability and diagnostics run.
#  Summary: Implemented memory management to prevent crashes. Overhauled hard-negative
#           mining and inference logging to produce accurate curation data.
# ==============================================================================

# V38 to V39 changes:
    # overview: A stability and diagnostics run. Implemented memory management fixes to
    #           prevent OS crashes during training. Overhauled the hard-negative mining
    #           logic and inference logging to produce accurate data for the next
    #           curation cycle.

    # section #5 (TrainingArguments - Stability):
    #   - Reduced `per_device_train_batch_size` for both Stage 1 (16‚Üí8) and Stage 2 (8‚Üí4).
    #   - Added `gradient_accumulation_steps` to compensate for the smaller batch sizes
    #     and maintain effective batch throughput.
    #   - Rationale: Prevents excessive memory consumption that was causing the
    #     operating system (macOS jetsam) to terminate the script during long training runs.

    # section #7 (Hierarchical Inference):
    #   - Updated the `hierarchical_predict` function to log the model's raw guess as
    #     `top1_label` in addition to the final `prediction` (which is post-thresholding).
    #   - Rationale: The previous version overwrote predictions that fell below the
    #     confidence threshold, making it impossible to find all misclassifications. This
    #     change ensures the `full_inference_log.csv` contains the complete data needed
    #     for accurate mining.

    # section #8 (Post-Training ‚Äî Mining):
    #   - Overhauled the hard-negative mining logic.
    #   - The script now pre-filters the inference log to only mine for confusions
    #     among truly `RELEVANT_CLASSES`, preventing images that slipped past the S1
    #     filter from contaminating the results.
    #   - The mining logic now correctly uses the new `top1_label` column for analysis.
    #   - Updated the `confusion_pairs_to_mine` list to explicitly target the
    #     `neutral_speech` vs. `speech_action` and `sadness` vs. `speech_action` confusions.

In [2]:
# --------------------------
# 0. Imports
# --------------------------
# WORKAROUND for PyTorch MPS bug
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# Standard Library Imports
import datasets
import csv
import gc
import glob
import multiprocessing as mp
import torch
import random
import re
import shutil
import subprocess
import sys
import time
# Alias json to prevent scope conflicts with local variables.
import json as json_mod

# Third-Party Imports
import accelerate
import dill
import face_recognition
import matplotlib.pyplot as plt
import numpy as np, cv2
import pandas as pd
import seaborn as sns
import torch.nn.functional as F
import torchvision.transforms as T
import transformers

# From Imports
from collections import Counter
from dataclasses import dataclass
from datasets import ClassLabel, Dataset, Features, Image as DatasetsImage, concatenate_datasets, load_dataset
from datetime import datetime
from functools import partial
from imagehash import phash, hex_to_hash
from io import BytesIO
from pathlib import Path
from PIL import Image, ImageOps, ImageStat, ExifTags, UnidentifiedImageError
from sklearn.metrics import classification_report, confusion_matrix, log_loss
from sklearn.utils.class_weight import compute_class_weight
from torch import nn
from torch.optim import AdamW, LBFGS
from torch.utils.data import WeightedRandomSampler, DataLoader
from torchvision import transforms
from torchvision.transforms import (
    RandAugment,
)
from tqdm import tqdm
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    EarlyStoppingCallback,
    TrainingArguments,
    Trainer,
    ViTForImageClassification,
)

In [3]:
# --------------------------
# 1. Global Configurations
# --------------------------

# --- üìÇ Core Paths ---
# This is the root directory containing your original 14-class dataset structure.
BASE_DATASET_PATH = "/Users/natalyagrokh/AI/ml_expressions/img_datasets/ferckjalfaga_dataset_14_labels"
# This is the root directory where all outputs (models, logs, prepared datasets) will be saved.
OUTPUT_ROOT_DIR = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training"

# --- ‚öôÔ∏è Run Configuration ---
# default safer for daily dev runs; flip to True when you want full-corpus inference
RUN_INFERENCE = True
# default safer; run once when dataset layout changes
PREPARE_DATASETS = False

# Curation/Artifacts policy
USE_EXTERNAL_CURATIONS = True
CURATION_DIR = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V34_20251013_211825"
EXTERNAL_PATCH = os.path.join(CURATION_DIR, "patch_V35.csv")  # produced by the curation nb

# --- Smoke test mode (no training) ---
# Set to True to load the latest model and test a few images.
# Set to False to run the full training pipeline.
SMOKE_TEST_ONLY = True

# Will be dynamically found if left as None.
# Or, you can hardcode a path like "/path/to/V35_..." to test a specific version.
SMOKE_CHECKPOINT_PATH = None

# Finds the most recent V* model directory based on modification time.
# VERSION_TAG is the folder name the script just created for this run, e.g. "V35_20251014_161418"
# Ensure VERSION_TAG is defined where you compose SAVE_DIR / OUTPUT_ROOT_DIR.
# Example: VERSION_TAG = os.path.basename(SAVE_DIR)
def find_latest_checkpoint(root_dir, current_run_basename=None):
    """
    Return the path to the most recent *completed* run by semantic version + timestamp,
    excluding the current run directory. Ignores folders that don't contain model artifacts.
    Folder name pattern: V<version>_<YYYYMMDD>_<HHMMSS>  (e.g., V34_20251013_211825)
    """
    candidates = []
    pat = re.compile(r"^V(\d+)_(\d{8}_\d{6})$")  # V<num>_YYYYMMDD_HHMMSS

    for d in os.listdir(root_dir):
        full = os.path.join(root_dir, d)
        if not (os.path.isdir(full) and d.startswith("V")):
            continue
        if current_run_basename and d == current_run_basename:
            continue

        m = pat.match(d)
        if not m:
            continue

        ver = int(m.group(1))
        ts  = m.group(2)  # sortable string

        # Treat as "completed" only if it contains known artifacts
        has_model = any(
            os.path.isdir(os.path.join(full, p))
            for p in (
                "emotion_classifier_model",
                "relevance_filter_model",
                "stage_2_emotion_model_training",
            )
        )
        if not has_model:
            continue

        candidates.append((ver, ts, full))

    if not candidates:
        return None

    # Sort by (version, timestamp) descending: highest V, then latest time
    candidates.sort(key=lambda t: (t[0], t[1]), reverse=True)
    return candidates[0][2]

# --- ü§ñ Model Configuration ---
# The pretrained Vision Transformer model from Hugging Face to be used as a base.
BASE_MODEL_NAME = "google/vit-base-patch16-224-in21k"

# --- üè∑Ô∏è Dataset & Label Definitions ---
# These lists define the structure for the hierarchical pipeline.
# All folders listed here will be grouped into the 'relevant' class for Stage 1
# and used for training the final 11-class classifier in Stage 2.
RELEVANT_CLASSES = [
    'anger', 'contempt', 'disgust', 'fear', 'happiness',
    'neutral', 'questioning', 'sadness', 'surprise',
    'neutral_speech', 'speech_action'
]
# **IMPORTANT**: Since 'unknown' is a subfolder of 'hard_case', we only need to
# list 'hard_case' here. The script will find all images inside it recursively.
IRRELEVANT_CLASSES = ['hard_case']

# Mappings for the Stage 2 (11-class Emotion) model
id2label_s2 = dict(enumerate(RELEVANT_CLASSES))
label2id_s2 = {v: k for k, v in id2label_s2.items()}

# Weakest-label targeting 
WEAKEST_LABEL = "sadness"   # <‚Äî change ONLY if a different label is sub-0.80
WEAK_BOOST   = 1.8          # LOWERED to 1.8 to lessen oversampling effect
SKIP_ERASE_WEAK = True      # leave True to protect fine cues; set False if you want occlusion

# Mappings for the Stage 1 (binary Relevance) model
id2label_s1 = {0: 'irrelevant', 1: 'relevant'}
label2id_s1 = {v: k for k, v in id2label_s1.items()}

# single source of truth for review gating
REVIEW_CONF_THRESHOLD = 0.85  

# --- üñºÔ∏è File Handling ---
# Defines valid image extensions and provides a function to check them.
VALID_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff")
def is_valid_image(filename):
    return filename.lower().endswith(VALID_EXTENSIONS) and not filename.startswith("._")

# --- üî¢ Versioning and Output Directory Setup ---
# Automatically determines the next version number (e.g., V31) and creates a timestamped output folder.
def get_next_version(base_dir):
    all_entries = glob.glob(os.path.join(base_dir, "V*_*"))
    existing = [os.path.basename(d) for d in all_entries if os.path.isdir(d)]
    versions = [
        int(d[1:].split("_")[0]) for d in existing
        if d.startswith("V") and "_" in d and d[1:].split("_")[0].isdigit()
    ]
    next_version = max(versions, default=0) + 1
    return f"V{next_version}"

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
VERSION = get_next_version(OUTPUT_ROOT_DIR)
VERSION_TAG = VERSION + "_" + timestamp
SAVE_DIR = os.path.join(OUTPUT_ROOT_DIR, VERSION_TAG)
os.makedirs(SAVE_DIR, exist_ok=True)

# Dynamically find the latest checkpoint to train from
# Resolve checkpoint path (AFTER you define VERSION_TAG and OUTPUT_ROOT_DIR)
latest_checkpoint = find_latest_checkpoint(OUTPUT_ROOT_DIR, current_run_basename=VERSION_TAG)
if latest_checkpoint:
    PRETRAINED_CHECKPOINT_PATH = latest_checkpoint
    print(f"‚úÖ Dynamically loading latest checkpoint: {os.path.basename(PRETRAINED_CHECKPOINT_PATH)}")
else:
    PRETRAINED_CHECKPOINT_PATH = BASE_MODEL_NAME
    print("‚ö†Ô∏è No previous checkpoint found ‚Äî falling back to base model.")


print(f"üìÅ Output directory created: {SAVE_DIR}")

‚úÖ Dynamically loading latest checkpoint: V38_20251021_123355
üìÅ Output directory created: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V39_20251022_093948


In [4]:
# ----------------------------------------------------
# 2. Hierarchical Dataset Preparation
# ----------------------------------------------------
# This function organizes the original multi-class dataset into two separate
# folder structures required for the two-stage training process. It recursively
# searches through subdirectories (no matter how deep) and is smart enough to
# skip non-image files.
def prepare_hierarchical_datasets(base_path, output_path):
    
    stage1_path = os.path.join(output_path, "stage_1_relevance_dataset")
    stage2_path = os.path.join(output_path, "stage_2_emotion_dataset")

    print(f"üóÇÔ∏è Preparing hierarchical datasets at: {output_path}")

    # --- Create Stage 1 Dataset (Relevance Filter) ---
    print("\n--- Creating Stage 1 Dataset ---")
    irrelevant_dest = os.path.join(stage1_path, "0_irrelevant")
    relevant_dest = os.path.join(stage1_path, "1_relevant")
    os.makedirs(irrelevant_dest, exist_ok=True)
    os.makedirs(relevant_dest, exist_ok=True)

    # Copy irrelevant files recursively
    print("Processing 'irrelevant' classes...")
    for class_name in IRRELEVANT_CLASSES:
        src_dir = Path(os.path.join(base_path, class_name))
        if src_dir.is_dir():
            print(f"  Recursively copying from '{class_name}'...")
            # Here, rglob('*') finds every file in every sub-folder.
            for file_path in src_dir.rglob('*'):
                if file_path.is_file() and is_valid_image(file_path.name):
                    shutil.copy(file_path, irrelevant_dest)
        else:
            print(f"  ‚ö†Ô∏è Warning: Source directory not found for '{class_name}'")

    # Copy relevant files recursively
    print("Processing 'relevant' classes...")
    for class_name in RELEVANT_CLASSES:
        src_dir = Path(os.path.join(base_path, class_name))
        if src_dir.is_dir():
            print(f"  Recursively copying from '{class_name}'...")
            for file_path in src_dir.rglob('*'):
                if file_path.is_file() and is_valid_image(file_path.name):
                    shutil.copy(file_path, relevant_dest)
        else:
            print(f"  ‚ö†Ô∏è Warning: Source directory not found for '{class_name}'")

    # --- Create Stage 2 Dataset (Emotion Classifier) ---
    print("\n--- Creating Stage 2 Dataset ---")
    for class_name in RELEVANT_CLASSES:
        src_dir = Path(os.path.join(base_path, class_name))
        dest_dir = os.path.join(stage2_path, class_name)

        # Ensure destination is clean before copying
        if os.path.exists(dest_dir):
            shutil.rmtree(dest_dir)
        os.makedirs(dest_dir)

        if src_dir.is_dir():
            print(f"  Copying '{class_name}' to Stage 2 directory...")
            for file_path in src_dir.rglob('*'):
                 if file_path.is_file() and is_valid_image(file_path.name):
                    shutil.copy(file_path, dest_dir)
        else:
            print(f"  ‚ö†Ô∏è Warning: Source directory not found for '{class_name}'")

    print("\n‚úÖ Hierarchical dataset preparation complete.")
    return stage1_path, stage2_path

In [5]:
# -----------------------------------------------
# 3. Utility Functions & Custom Classes
# -----------------------------------------------

# --- Part A: Smoke Test ---

#  Normalizes model config maps so we always get:
      # id2label: Dict[int, str]
      # label2id: Dict[str, int]
def _normalize_label_maps_from_config(cfg):
    id2label = {int(k): str(v) for k, v in getattr(cfg, "id2label", {}).items()}
    label2id = {str(k): int(v) for k, v in getattr(cfg, "label2id", {}).items()}
    if not id2label and label2id:
        id2label = {vi: k for k, vi in label2id.items()}
    if not label2id and id2label:
        label2id = {v: k for k, v in id2label.items()}
    return id2label, label2id

# Loads the processor, S1, and S2 models from a completed training run folder.
def _load_exports_for_smoke(checkpoint_path: str, device: torch.device):
    
    s2_dir = os.path.join(checkpoint_path, "emotion_classifier_model")
    s1_dir = os.path.join(checkpoint_path, "relevance_filter_model")

    if not os.path.isdir(s1_dir) or not os.path.isdir(s2_dir):
        raise FileNotFoundError(f"Valid S1 or S2 model not found in: {checkpoint_path}")

    processor = AutoImageProcessor.from_pretrained(s2_dir)
    model_s1 = ViTForImageClassification.from_pretrained(s1_dir).to(device).eval()
    model_s2 = ViTForImageClassification.from_pretrained(s2_dir).to(device).eval()

    # Make label maps globally available for the inference function
    globals()["id2label_s2"], globals()["label2id_s2"] = _normalize_label_maps_from_config(model_s2.config)
    globals()["id2label_s1"], globals()["label2id_s1"] = _normalize_label_maps_from_config(model_s1.config)

    return model_s1, model_s2, processor
    

# --- Part A: Data Augmentation ---

# üì¶ Applies augmentations and processes images on-the-fly for each batch.
# This is a more robust approach than pre-processing the entire dataset.
class DataCollatorWithAugmentation:
    def __init__(self,
                 processor,
                 augment_dict=None,
                 base_augment=None,
                 # --- : tensor-level erasing controls (applied after processor) ---
                 random_erasing_prob: float = 0.10,
                 random_erasing_scale = (0.02, 0.08),
                 skip_erasing_label_ids=None):
        
        """
        Args:
            processor: HF image processor that yields pixel_value tensors
            augment_dict: dict[int label_id -> PIL transform], class-specific
            base_augment: fallback PIL transform when class-specific not found
            random_erasing_prob: probability for applying tensor-level RandomErasing
            random_erasing_scale: area range for erasing region
            skip_erasing_label_ids: iterable of label ids to skip erasing for
        """
        self.processor = processor
        self.augment_dict = augment_dict or {}
        # Baseline augmentation for majority classes.
        self.base_augment = base_augment or T.Compose([T.Resize((224, 224))])

        # --- : tensor-level RandomErasing (applied AFTER processor) ---
        # Keep None to disable; expects CHW tensors in [0,1]
        self.random_erasing = (
            T.RandomErasing(p=random_erasing_prob, scale=random_erasing_scale, value="random")
            if random_erasing_prob and random_erasing_prob > 0.0 else None
        )
                
        # --- : define tensor <-> PIL helpers used in __call__ ---
        self.to_tensor = T.ToTensor()
        self.to_pil = T.ToPILImage()
        
        # Labels to skip erasing for (can be overridden when constructing the collator)
        self.skip_erasing_label_ids = set(skip_erasing_label_ids or [])
        
    def __call__(self, features):
        processed_images = []
        for x in features:
            label = x["label"]
            rgb_image = x["image"].convert("RGB")

            # 1) apply class-specific PIL pipeline if present; else base PIL pipeline
            pil_aug = self.augment_dict.get(label, self.base_augment)

            img = pil_aug(rgb_image)

            # ‚¨áÔ∏è INSERT THE  LINES HERE
            # --- Tensor-level RandomErasing ---
            img_t = self.to_tensor(img)                 # PIL ‚Üí Tensor [C,H,W]
            if self.random_erasing is not None and label not in self.skip_erasing_label_ids:
                img_t = self.random_erasing(img_t)      # RandomErasing on tensor
            img = self.to_pil(img_t)  
        
            processed_images.append(img)

        batch = self.processor(images=processed_images, return_tensors="pt")
        batch["labels"] = torch.tensor([x["label"] for x in features], dtype=torch.long)
        return batch

# --- normalize any image-like object to 3-channel RGB (PIL) ---
def _ensure_rgb(img):
    # If already PIL, force RGB mode
    if isinstance(img, Image.Image):
        return img.convert("RGB")
    # Else coerce to array and expand grayscale to 3 channels
    arr = np.array(img)
    if arr.ndim == 2:
        arr = np.stack([arr, arr, arr], axis=-1)
    return Image.fromarray(arr.astype(np.uint8))


# --- Part B: Model & Training Components ---

# üèãÔ∏è Defines a custom Trainer that can use either a targeted loss function or class weights.
class CustomLossTrainer(Trainer):
    def __init__(self, *args, loss_fct=None, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fct = loss_fct
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        
        if self.loss_fct:
            # Stage 2 uses the custom targeted smoothing loss
            loss = self.loss_fct(logits, labels)
        else:
            # Stage 1 uses standard CrossEntropyLoss with class weights (all on CPU)
            loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
            loss = loss_fct(logits, labels)
            
        return (loss, outputs) if return_outputs else loss


# üîÑ Implements Cross-Entropy Loss with *Targeted* Label Smoothing.
# Smoothing is turned OFF for specified classes to encourage confident predictions. This is used for Stage 2.
class TargetedSmoothedCrossEntropyLoss(nn.Module):
    def __init__(self, smoothing=0.05, target_class_names=None, label2id_map=None, focal_gamma=None):
        super().__init__()
        self.smoothing = smoothing
        self.focal_gamma = focal_gamma  #  (None disables focal scaling)
        if target_class_names and label2id_map:
            self.target_class_ids = [label2id_map[name] for name in target_class_names]
        else:
            self.target_class_ids = []

    def forward(self, logits, target):
        num_classes = logits.size(1)
        with torch.no_grad():
            smooth_labels = torch.full_like(logits, self.smoothing / (num_classes - 1))
            smooth_labels.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
            if self.target_class_ids:
                target_mask = torch.isin(target, torch.tensor(self.target_class_ids, device=target.device))
                if target_mask.any():
                    sharp_labels = F.one_hot(target[target_mask], num_classes=num_classes).float()
                    smooth_labels[target_mask] = sharp_labels

        log_probs = F.log_softmax(logits, dim=1)
        ce_per_sample = -(smooth_labels * log_probs).sum(dim=1)

        # : optional focal scaling
        if self.focal_gamma is not None and self.focal_gamma > 0:
            with torch.no_grad():
                probs = torch.softmax(logits, dim=1)
                pt = (probs * smooth_labels).sum(dim=1).clamp_min(1e-6)
            ce_per_sample = ((1 - pt) ** self.focal_gamma) * ce_per_sample

        return ce_per_sample.mean()

# ------------------------------------------------------------------------------
# Stage 1 loss function: focal-modulated cross-entropy (relevant-only)
#   - We keep class weights for imbalance handling.
#   - We add focal modulation ONLY when the ground truth is "relevant"
#     to emphasize difficult positives without exploding FP on easy negatives.
# ------------------------------------------------------------------------------
class RelevantFocalCrossEntropy(torch.nn.Module):
    def __init__(self, class_weights: torch.Tensor, gamma: float = 2.0, relevant_id: int = 1):
        """
        Args:
            class_weights: Tensor of per-class weights (size 2 for S1)
            gamma: focal exponent (higher -> more emphasis on hard examples)
            relevant_id: integer id for the 'relevant' class
        """
        super().__init__()
        self.ce = torch.nn.CrossEntropyLoss(weight=class_weights, reduction="none")
        self.gamma = gamma
        self.relevant_id = relevant_id

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Computes cross-entropy per-sample, then applies focal scaling only
        for samples whose target == 'relevant'. Non-relevant samples keep vanilla CE.
        """
        # base cross-entropy (per-sample)
        ce = self.ce(logits, targets)  # shape: [B]

        # compute p_t = softmax(logits)[range(B), targets]
        with torch.no_grad():
            probs = torch.softmax(logits, dim=-1)
            p_t = probs[torch.arange(probs.size(0)), targets]  # [B]

        # mask: 1 for relevant targets, 0 otherwise
        mask = (targets == self.relevant_id).float()

        # focal factor: (1 - p_t)^gamma for relevant samples; 1.0 for others
        focal = (1.0 - p_t).pow(self.gamma) * mask + (1.0 - mask)

        # mean reduced loss
        return (focal * ce).mean()


# --- Part C: Metrics & Evaluation ---

# üìä Computes metrics and generates a confusion matrix plot for each evaluation step.
def compute_metrics_with_confusion(
    eval_pred,
    label_names,
    stage_name="Stage2",
    s2_temperature: float = 1.0,
):
    logits, labels = eval_pred  # logits: np.ndarray, labels: np.ndarray

    # ---- Stage-2 temperature (calibrated probabilities) ----
    if stage_name.lower().startswith("stage2") and (s2_temperature is not None) and (s2_temperature != 1.0):
        logits = logits / max(1e-6, float(s2_temperature))

    # softmax ‚Üí probs, preds
    probs = torch.softmax(torch.from_numpy(logits), dim=-1).numpy()
    preds = probs.argmax(axis=-1)

    print(f"\nüìà Classification Report for {stage_name}:")
    report = classification_report(labels, preds, target_names=label_names, output_dict=True, zero_division=0)
    print(classification_report(labels, preds, target_names=label_names, zero_division=0))

    # Save raw eval tensors for post-hoc analysis
    np.save(os.path.join(SAVE_DIR, f"logits_eval_{stage_name}_{VERSION}.npy"), logits)
    np.save(os.path.join(SAVE_DIR, f"labels_eval_{stage_name}_{VERSION}.npy"), labels)

    # Per-class metrics for CSV
    f1s        = [report[name]["f1-score"]   for name in label_names]
    recalls    = [report[name]["recall"]     for name in label_names]
    precisions = [report[name]["precision"]  for name in label_names]

    # Entropy per class
    softmax_probs = F.softmax(torch.from_numpy(logits), dim=-1)
    entropies     = -torch.sum(softmax_probs * torch.log(softmax_probs + 1e-12), dim=-1)
    entropy_per_class = []
    labels_np = np.asarray(labels)
    for idx, class_name in enumerate(label_names):
        mask = (labels_np == idx)
        if mask.any():
            class_entropy = entropies[mask].mean().item()
            entropy_per_class.append((class_name, class_entropy))
        else:
            entropy_per_class.append((class_name, 0.0))
    entropy_dict = dict(entropy_per_class)

    # CSV logging (append)
    epoch_metrics_path = os.path.join(SAVE_DIR, f"per_class_metrics_{stage_name}.csv")
    active_trainer = trainer_s1 if stage_name == "Stage1" else trainer_s2
    epoch = getattr(active_trainer.state, "epoch", None)

    df_row = pd.DataFrame({
        "epoch": [epoch],
        **{f"f1_{n}": [f] for n, f in zip(label_names, f1s)},
        **{f"recall_{n}": [r] for n, r in zip(label_names, recalls)},
        **{f"precision_{n}": [p] for n, p in zip(label_names, precisions)},
        **{f"entropy_{n}": [entropy_dict[n]] for n in label_names},
    })
    if os.path.exists(epoch_metrics_path):
        df_row.to_csv(epoch_metrics_path, mode="a", header=False, index=False)
    else:
        df_row.to_csv(epoch_metrics_path, mode="w", header=True, index=False)

    # Confusion matrix figure
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=label_names, yticklabels=label_names)
    plt.xlabel("Predicted"); plt.ylabel("True")
    plt.title(f"Confusion Matrix - {stage_name}")
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, f"confusion_matrix_{stage_name}_{VERSION}.png"))
    plt.close()

    # Top confused pairs
    confusion_pairs = [((label_names[i], label_names[j]), cm[i][j])
                       for i in range(len(label_names)) for j in range(len(label_names))
                       if i != j and cm[i][j] > 0]
    top_confusions = sorted(confusion_pairs, key=lambda x: x[1], reverse=True)[:3]
    if top_confusions:
        print("\nTop 3 confused class pairs:")
        for (true_label, pred_label), count in top_confusions:
            print(f"  - {true_label} ‚Üí {pred_label}: {count} instances")

    avg_entropy = entropies.mean().item()
    print(f"\nüß† Avg prediction entropy: {avg_entropy:.4f}")

    sorted_entropy = sorted(entropy_per_class, key=lambda x: x[1], reverse=True)
    if sorted_entropy:
        print("\nüîç Class entropies (sorted):")
        for class_name, entropy in sorted_entropy:
            print(f"  - {class_name}: entropy = {entropy:.4f}")

    return {"accuracy": float((preds == labels).mean())}


# ------------------------------------------------------------------------------
# Stage 1: Temperature scaling + threshold (œÑ) sweep
#   - Fit a single scalar T on eval logits (minimize NLL) to calibrate probabilities.
#   - Sweep œÑ in [0.30, 0.55] to pick the value that maximizes F1(relevant).
#   - Persist T and œÑ for hierarchical inference.
# ------------------------------------------------------------------------------
def fit_temperature(model, eval_ds, processor, device):
    """
    Fits a single temperature scalar T by minimizing NLL on eval set.
    Returns:
        float: learned temperature T (>= ~1e-3)
    """
    model.eval()
    logits_list, labels_list = [], []
    with torch.no_grad():
        #Normalize every eval image to 3-channel RGB in fit_temperature
        for ex in eval_ds:
            img, lab = ex["image"], int(ex["label"])
        
            # --- Ensure 3-channel RGB for the processor ---
            # If PIL: convert directly; if numpy/other: coerce to array and expand gray to 3-channels
            if isinstance(img, Image.Image):
                img = img.convert("RGB")
            else:
                arr = np.array(img)
                if arr.ndim == 2:                      # grayscale -> stack to RGB
                    arr = np.stack([arr, arr, arr], axis=-1)
                img = Image.fromarray(arr.astype(np.uint8))  # ensure PIL RGB
        
            inputs = processor(images=img, return_tensors="pt").to(device)
            logits = model(**inputs).logits
            logits_list.append(logits.cpu())
            labels_list.append(lab)

    logits = torch.cat(logits_list, dim=0)  # [N, 2]
    labels = torch.tensor(labels_list)

    T = torch.nn.Parameter(torch.ones(1))
    opt = torch.optim.LBFGS([T], lr=0.1, max_iter=50)
    ce = torch.nn.CrossEntropyLoss()

    def _closure():
        """
        LBFGS closure for temperature scaling:
        Scales logits by 1/T, computes CE loss, backprops to adjust T.
        """
        opt.zero_grad()
        scaled = logits / T.clamp(min=1e-3)
        loss = ce(scaled, labels)
        loss.backward()
        return loss

    opt.step(_closure)
    return float(T.data.item())

def sweep_tau(model, eval_ds, processor, device, T=1.0):
    """
    Sweep œÑ on P(relevant) over [0.28, 0.55] (0.01 steps) to maximize F1(relevant).
    Keeps œÑ near the historically stable 0.30, but allows a slight reduction when it
    meaningfully lifts F1 on eval.
    """
    import numpy as np
    model.eval()
    y_true, y_prob = [], []

    with torch.no_grad():
        for ex in eval_ds:
            img, lab = ex["image"], int(ex["label"])
            img = _ensure_rgb(img)  # your helper
            inputs = processor(images=img, return_tensors="pt").to(device)
            logits = model(**inputs).logits / max(T, 1e-3)
            prob_rel = torch.softmax(logits, dim=-1)[0, label2id_s1['relevant']].item()
            y_true.append(lab == label2id_s1['relevant'])
            y_prob.append(prob_rel)

    y_true = np.asarray(y_true, dtype=bool)
    y_prob = np.asarray(y_prob, dtype=float)

    pos_rate = float(y_true.mean()) if len(y_true) > 0 else 0.0
    print(f"‚ÑπÔ∏è S1 calib eval prevalence (relevant rate): {pos_rate:.3f}")

    best = {"tau": None, "f1": -1.0, "prec": None, "rec": None}
    for tau in np.round(np.arange(0.28, 0.55 + 1e-9, 0.01), 2):
        pred = (y_prob >= tau)
        tp = ((pred == 1) & (y_true == 1)).sum()
        fp = ((pred == 1) & (y_true == 0)).sum()
        fn = ((pred == 0) & (y_true == 1)).sum()
        prec = tp / (tp + fp) if (tp + fp) else 0.0
        rec  = tp / (tp + fn) if (tp + fn) else 0.0
        f1   = 2*prec*rec/(prec+rec) if (prec+rec) else 0.0
        if f1 > best["f1"]:
            best = {"tau": float(tau),
                    "f1": round(float(f1), 3),
                    "prec": round(float(prec), 3),
                    "rec": round(float(rec), 3)}
    return best
    

# --- Part D: Model Saving ---

# Minimal file router for ad-hoc runs from the training script.
def _route_image_to_fs(src_path: str, out_dir: Path, review_dir: Path, decision: dict):
    
    from pathlib import Path
    import shutil, os

    if decision.get("route_reason") == "thresholds":
        dest = review_dir / "review_lowconf"
    elif decision.get("final_label") == "irrelevant":
        dest = review_dir / "irrelevant_or_lowS1"
    else:
        dest = out_dir / decision["final_label"]

    os.makedirs(dest, exist_ok=True)
    shutil.copy2(src_path, dest / Path(src_path).name)


# üíæ Saves the model and its associated processor to a specified directory.
def save_model_and_processor(model, processor, save_dir, model_name):
    print(f"üíæ Saving {model_name} and processor to: {save_dir}")
    model_path = os.path.join(save_dir, model_name)
    os.makedirs(model_path, exist_ok=True)
    model = model.to("cpu")
    processor.save_pretrained(model_path)
    model.save_pretrained(model_path, safe_serialization=True)
    print(f"‚úÖ {model_name} saved successfully.")


# --- Part E: Post-Training Analysis ---
# ==========================================================================
#   POST-TRAINING ANALYSIS UTILITIES (OFFLINE / OPTIONAL)
#   - Qualitative error bucketing (QE)
#   - Attention rollout (XAI) for S1 inspection
#   - Ablation helpers
# ==========================================================================

def check_deployment_readiness(metrics_csv_path, f1_threshold=0.80):
    """Analyzes the final metrics CSV to check for production readiness."""
    print("\n" + "="*60)
    print("  DEPLOYMENT READINESS CHECK")
    print("="*60)
    
    if not os.path.exists(metrics_csv_path):
        print(f"‚ö†Ô∏è Metrics file not found at: {metrics_csv_path}")
        return

    metrics_df = pd.read_csv(metrics_csv_path)
    last_epoch_metrics = metrics_df.iloc[-1]
    
    label_names = [col.replace("f1_", "") for col in metrics_df.columns if col.startswith("f1_")]
    
    print(f"Threshold: F1-Score >= {f1_threshold}\n")
    
    issues_found = False
    for label in label_names:
        f1_score = last_epoch_metrics.get(f"f1_{label}", 0)
        if f1_score < f1_threshold:
            print(f"  - ‚ùå {label:<15} | F1-Score: {f1_score:.2f} (Below Threshold)")
            issues_found = True
        else:
            print(f"  - ‚úÖ {label:<15} | F1-Score: {f1_score:.2f}")
            
    if issues_found:
        print("\n Model is NOT ready for production.")
    else:
        print("\n Model meets the minimum F1-score threshold for all classes.")

# --- Qualitative Error Bucketing (Stage 1) ---
# Scans an inference CSV and tags each row with simple visual heuristics:
# blur/shadow/occlusion/low-res. Outputs a QE report CSV for targeting data fixes.
def variance_of_laplacian(gray):
    return cv2.Laplacian(gray, cv2.CV_64F).var()

def is_dark(img_pil, thresh=40):
    stat = ImageStat.Stat(img_pil.convert("L"))
    return stat.mean[0] < thresh

def qualitative_buckets_s1(inference_csv, out_csv):
    import pandas as pd
    df = pd.read_csv(inference_csv)
    # consider only S1 mistakes if you logged them; otherwise filter low conf or S2 mismatches
    rows = []
    for _, r in df.iterrows():
        path = r['filepath']
        if not os.path.exists(path): continue
        img = Image.open(path).convert("RGB")
        arr = np.array(img)
        gray = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY)
        blur = variance_of_laplacian(gray) < 60         # motion blur proxy
        dark = is_dark(img, thresh=45)                  # shadows proxy
        lowres = min(img.size) < 80
        # Cheap occlusion proxy: large random erasing candidate on face area would help, but without faces we use entropy
        ent = cv2.calcHist([gray],[0],None,[256],[0,256]).flatten()
        ent = -np.sum((ent/ent.sum()+1e-9)*np.log2(ent/ent.sum()+1e-9))
        occl = ent < 4.5                                 # low entropy proxy
        rows.append([path, r.get('true_label','?'), r.get('predicted_label','?'), r.get('confidence',np.nan),
                     int(blur), int(dark), int(occl), int(lowres)])
    with open(out_csv, "w", line="") as f:
        w = csv.writer(f)
        w.writerow(["filepath","true","pred","conf","blur","shadow","occlusion","lowres"])
        w.writerows(rows)
    return out_csv

# --- Ablation summary utility for Stage 1 ---
# Summarizes precision/recall/F1 for S1 given (T, tau).
def summarize_s1(eval_ds, model, processor, device, T: float, tau: float):
    import numpy as np
    y_true, y_prob = [], []
    model.eval()
    with torch.no_grad():
        for ex in eval_ds:
            img, lab = ex["image"], int(ex["label"])
    
            # Normalize to 3-channel RGB to avoid ndim==2 errors
            img = _ensure_rgb(img)
    
            logits = model(**processor(images=img, return_tensors="pt").to(device)).logits
            logits = logits / max(T, 1e-3)
            p = torch.softmax(logits, dim=-1)[0, label2id_s1['relevant']].item()
            y_true.append(lab == label2id_s1['relevant'])
            y_prob.append(p)

    y_true = np.array(y_true, bool); y_prob = np.array(y_prob, float)
    pred = (y_prob >= tau)
    tp = ((pred==1)&(y_true==1)).sum(); fp=((pred==1)&(y_true==0)).sum(); fn=((pred==0)&(y_true==1)).sum()
    prec = tp/(tp+fp) if tp+fp>0 else 0.0
    rec  = tp/(tp+fn) if tp+fn>0 else 0.0
    f1   = 2*prec*rec/(prec+rec) if prec+rec>0 else 0.0
    return {"precision":round(prec,3), "recall":round(rec,3), "f1":round(f1,3), "tau":tau, "T":T}


# --- Attention Rollout heatmaps for ViT (offline) ---
def vit_attention_rollout(model, inputs, discard_ratio=0.9):
    # returns a [H,W] mask normalized 0..1; you can overlay it
    # (Implementation omitted for brevity; use a standard attention-rollout snippet for ViT)
    pass

In [6]:
# --------------------------
# 4. Main Training Script
# --------------------------

def main(device):
    # Make trainer objects accessible to metrics function
    global trainer_s1, trainer_s2
    
    # --- Sanity Check for Checkpoint Path ---
    if not os.path.exists(PRETRAINED_CHECKPOINT_PATH):
        raise FileNotFoundError(f"Fatal: Pretrained checkpoint not found at {PRETRAINED_CHECKPOINT_PATH}")

    # --- Define specific model paths from the latest checkpoint ---
    s1_checkpoint_path = os.path.join(PRETRAINED_CHECKPOINT_PATH, "relevance_filter_model")
    s2_checkpoint_path = os.path.join(PRETRAINED_CHECKPOINT_PATH, "emotion_classifier_model")

    # The device is now passed in, so the local definition is removed.
    print(f"\nüñ•Ô∏è Using device: {device}")

    # --- Step 0: Prepare Datasets ---
    # This function copies files into the required two-stage structure.
    # It only needs to be run once.
    prepared_data_path = os.path.join(OUTPUT_ROOT_DIR, "prepared_datasets")
    if PREPARE_DATASETS:
        stage1_dataset_path, stage2_dataset_path = prepare_hierarchical_datasets(BASE_DATASET_PATH, prepared_data_path)
    else:
        stage1_dataset_path = os.path.join(prepared_data_path, "stage_1_relevance_dataset")
        stage2_dataset_path = os.path.join(prepared_data_path, "stage_2_emotion_dataset")
        print("‚úÖ Skipping dataset preparation, using existing directories.")
    
    # # --- Set hardware device ---
    # # commented out due to present mps and pytorch incompatibilities
    # device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    # print(f"\nüñ•Ô∏è Using device: {device}")

    # ==========================================================================
    #   STAGE 1: TRAIN RELEVANCE FILTER (BINARY CLASSIFIER)
    # ==========================================================================
    print("\n" + "="*60)
    print("  STAGE 1: TRAINING RELEVANCE FILTER (BINARY CLASSIFIER)")
    print("="*60)

    # --- Load Stage 1 data ---
    stage1_output_dir = os.path.join(SAVE_DIR, "stage_1_relevance_model_training")
    dataset_s1 = load_dataset("imagefolder", data_dir=stage1_dataset_path, split='train').train_test_split(test_size=0.2, seed=42)
    train_dataset_s1 = dataset_s1["train"]
    eval_dataset_s1 = dataset_s1["test"]
    print(f"Stage 1: {len(train_dataset_s1)} training samples, {len(eval_dataset_s1)} validation samples.")

    # --- Configure Stage 1 model ---
    # We load the base processor once.
    processor = AutoImageProcessor.from_pretrained(BASE_MODEL_NAME)
    # Load the pretrained checkpoint but replace the final layer (classifier head)
    # for our binary (2-label) task.
    model_s1 = ViTForImageClassification.from_pretrained(
        s1_checkpoint_path, # <-- Use the specific path for the Stage 1 model
        num_labels=2,
        label2id=label2id_s1,
        id2label=id2label_s1,
        ignore_mismatched_sizes=True
    ).to(device)

    # --- Handle Extreme Class Imbalance in Stage 1 with Class Weights ---
    # This is critical because the 'irrelevant' class is much larger than the 'relevant' class.
    class_weights_s1 = compute_class_weight('balanced', classes=np.unique(train_dataset_s1['label']), y=train_dataset_s1['label'])
    class_weights_s1 = torch.tensor(class_weights_s1, dtype=torch.float).to(device)
    print(f"‚öñÔ∏è Stage 1 Class Weights: {class_weights_s1}")

    # --- Define Early Stopping ---
    # Stops training if validation loss doesn't improve for 2 consecutive epochs
    early_stop_callback = EarlyStoppingCallback(
        early_stopping_patience=2,
        early_stopping_threshold=0.001
    )
    
    # --- Set up Stage 1 Trainer ---
    training_args_s1 = TrainingArguments(
        output_dir=stage1_output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        use_cpu=True,
        per_device_train_batch_size=8,      # Halved from 16 to reduce memory
        per_device_eval_batch_size=8,       # Also reduce eval batch size
        gradient_accumulation_steps=2,      # Compensate for the smaller batch size
        num_train_epochs=5,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        logging_dir=os.path.join(stage1_output_dir, "logs"),
        logging_strategy="steps",
        logging_steps=50,
        remove_unused_columns=False,
    )

    # --- Set up Stage 1 Trainer ---
    # The complex discriminative learning rate and layer freezing strategy in 
        # V31 caused a severe performance drop. This change reverts Stage 1 to 
        # V30's simpler and more effective approach of using a single, uniform 
        # learning rate for the entire model, which is managed by the Hugging 
        # Face Trainer's default optimizer.
    training_args_s1.learning_rate = 3e-5 # Set learning rate directly
    
    loss_fct_s1 = RelevantFocalCrossEntropy(
        class_weights=class_weights_s1,   # <-- we KEEP and USE class_weights here
        gamma=2.0,
        relevant_id=label2id_s1['relevant']
    )

    strong_pos_aug = T.Compose([
        T.RandomResizedCrop(224, scale=(0.8, 1.0)),
        T.RandomHorizontalFlip(),
        T.ColorJitter(0.2, 0.2, 0.2, 0.05),
        T.RandomPerspective(distortion_scale=0.05, p=0.2),
        # NEW: small pose/tilt tolerance to reduce false "irrelevant" on near-frontal faces
        T.RandomAffine(degrees=6, translate=(0.03, 0.03), scale=(0.97, 1.03)),
    ])


    # Map label id -> transform. 1 == relevant
    augment_map_s1 = { label2id_s1['relevant']: strong_pos_aug }

    # Use the flexible CustomLossTrainer, passing the class weights to it.
    # Apply stronger augmentation ONLY to the "relevant" class to expand coverage
        # near the decision boundary (lighting, small occlusions, slight perspective).
        # Keep "irrelevant" mild as before to avoid over-creating near-face artifacts.
    trainer_s1 = CustomLossTrainer(
        model=model_s1,
        args=training_args_s1,
        train_dataset=train_dataset_s1,
        eval_dataset=eval_dataset_s1,
        compute_metrics=partial(compute_metrics_with_confusion, 
                                label_names=list(id2label_s1.values()), 
                                stage_name="Stage1"),
        data_collator=DataCollatorWithAugmentation(
            processor=processor,
            augment_dict=augment_map_s1,                # your existing class‚ÜíPIL map
            random_erasing_prob=0.10,                   # enable erasing
            random_erasing_scale=(0.02, 0.08),
            skip_erasing_label_ids=[]                   # or [label2id_s1['relevant']] to skip
        ),
        loss_fct=loss_fct_s1,         # <-- : custom loss uses class_weights + focal on relevant
        callbacks=[early_stop_callback]
    )

    # --- Train Stage 1 model ---
    print("üöÄ Starting Stage 1 training...")
    start_time_s1 = time.time() # Record start time
    trainer_s1.train()
    end_time_s1 = time.time()   # Record end time
    
    # Calculate and print the duration
    duration_s1 = end_time_s1 - start_time_s1
    print(f"‚åõ Stage 1 training took: {time.strftime('%H:%M:%S', time.gmtime(duration_s1))}")
    save_model_and_processor(trainer_s1.model, processor, SAVE_DIR, model_name="relevance_filter_model")
    print("\n‚úÖ Stage 1 Training Complete.")
    
    # Ensure the S1 return uses the trained model instance
    model_s1 = trainer_s1.model
 
    # ------------------------------------------------------------------------------
    # Stage 1: Temperature scaling + threshold (œÑ) sweep
    #   - Fit a single scalar T on eval logits (minimize NLL) to calibrate probabilities.
    #   - Sweep œÑ in [0.30, 0.55] to pick the value that maximizes F1(relevant).
    #   - Persist T and œÑ for hierarchical inference.
    # ------------------------------------------------------------------------------

    print("\nüß™ Calibrating Stage 1...")
    T_s1 = fit_temperature(trainer_s1.model, eval_dataset_s1, processor, device)
    best_s1 = sweep_tau(trainer_s1.model, eval_dataset_s1, processor, device, T=T_s1)
    print(f"‚úÖ S1 calibration done: T={T_s1:.3f} | best œÑ={best_s1['tau']} | F1={best_s1['f1']} (P={best_s1['prec']}, R={best_s1['rec']})")

    # ---- Fail-fast sanity for S1 calibration ----
    if not isinstance(best_s1, dict) or "tau" not in best_s1:
        raise RuntimeError("S1 calibration failed: best_s1 missing 'tau' key.")
    if not (0.0 <= float(best_s1["tau"]) <= 1.0):
        raise RuntimeError(f"S1 calibration produced invalid tau: {best_s1['tau']}")
    
    if not isinstance(T_s1, (float, int)) or not (0.1 <= float(T_s1) <= 100.0):
        raise RuntimeError(f"S1 temperature T looks suspicious: {T_s1}")

    
    # Persist calibration for inference
    calib_out = os.path.join(SAVE_DIR, "stage1_calibration.json")
    with open(calib_out, "w") as f:
        json_mod.dump({"T": float(T_s1), "tau": float(best_s1["tau"])}, f)
    print(f"‚úÖ Wrote S1 calibration to {calib_out}")


    # ==========================================================================
    #   STAGE 2: TRAIN EMOTION CLASSIFIER (11-CLASS)
    # ==========================================================================
    print("\n" + "="*60)
    print(f"  STAGE 2: TRAINING EMOTION CLASSIFIER ({len(RELEVANT_CLASSES)}-CLASS)")
    print("="*60)

    # --- Load Stage 2 data ---
    stage2_output_dir = os.path.join(SAVE_DIR, "stage_2_emotion_model_training")
    dataset_s2 = load_dataset("imagefolder", data_dir=stage2_dataset_path, split='train').train_test_split(test_size=0.2, seed=42)
    train_dataset_s2 = dataset_s2["train"]
    eval_dataset_s2 = dataset_s2["test"]
    print(f"Stage 2: {len(train_dataset_s2)} training samples, {len(eval_dataset_s2)} validation samples.")
    print("Stage 2 Label Distribution (Train):", Counter(train_dataset_s2['label']))

    # --- Optional: inject curated patch into TRAIN ONLY (no eval leak) ---
    if USE_EXTERNAL_CURATIONS and os.path.exists(EXTERNAL_PATCH):
        from datasets import Dataset, Features, ClassLabel, Image as DatasetsImage
        
        patch_df = pd.read_csv(EXTERNAL_PATCH)
        patch_df = patch_df[patch_df["label"].isin(RELEVANT_CLASSES)]
    
        # Define the features for the patch, REUSING the ClassLabel from the main dataset
        patch_features = Features({
            'image': DatasetsImage(),
            'label': train_dataset_s2.features['label']  # This is the key fix
        })
    
        def _open_img(p):
            try:
                return Image.open(p).convert("RGB")
            except Exception:
                return None
                
        patch_hf = Dataset.from_dict({
            "image": [ _open_img(p) for p in patch_df["filepath"] ],
            "label": [ label2id_s2[l] for l in patch_df["label"] ],
        }, features=patch_features).filter(lambda ex: ex["image"] is not None)
    
        # Concatenation will now succeed because the features match
        train_dataset_s2 = concatenate_datasets([train_dataset_s2, patch_hf]).shuffle(seed=42)
        print(f"üìå Injected curated patch into TRAIN: +{len(patch_hf)} samples")

    # --- Configure Stage 2 model ---
    # Load the pretrained checkpoint again, this time with a classifier head for our 11 emotion classes.
    model_s2 = ViTForImageClassification.from_pretrained(
        s2_checkpoint_path, # <-- Use the specific path for the Stage 2 model
        num_labels=len(RELEVANT_CLASSES),
        label2id=label2id_s2,
        id2label=id2label_s2,
        ignore_mismatched_sizes=True
    ).to(device)

    # --- Define Augmentation and Loss for Stage 2 ---
    # Apply stronger augmentation to the minority classes to help the model learn them better.
    minority_aug = T.Compose([
        RandAugment(num_ops=2, magnitude=11),  
        T.RandomResizedCrop(224, scale=(0.7, 1.0)),
        T.ColorJitter(0.3, 0.3, 0.3, 0.1),
    ])
    minority_classes_s2 = [label2id_s2[n] for n in ['disgust','questioning','contempt','fear']]
    minority_augment_map_s2 = {lid: minority_aug for lid in minority_classes_s2}
    
    # very mild, targeted aug ONLY for the weakest classes
    mild_aug = T.Compose([
        T.RandomResizedCrop(224, scale=(0.95, 1.0)),
        T.RandomHorizontalFlip(),
        T.ColorJitter(0.05, 0.05, 0.05, 0.02),
        T.RandomAffine(degrees=3, translate=(0.02, 0.02), scale=(0.98, 1.02)),
    ])

    # targeted mild augmentation for fragile classes
    #     - Keep 'sadness' and 'speech_action' on very mild pipeline (no RandAug)
    #     - Extend to 'neutral_speech' to preserve subtle mouth/phoneme cues
    targeted_mild_classes = [
        label2id_s2['sadness'],
        label2id_s2['speech_action'],
    ]
    targeted_mild_map_s2 = {label_id: mild_aug for label_id in targeted_mild_classes}

    # MERGE: single mapping passed to the collator (class id -> transform)
    augment_dict = {**minority_augment_map_s2, **targeted_mild_map_s2}

    # --- Section E: Tiny loss tweak for the weakest label (minimal & safe) ----
    loss_fct_s2 = TargetedSmoothedCrossEntropyLoss(
        smoothing=0.05,                      # keep global smoothing
        target_class_names=[WEAKEST_LABEL],  # sharpen ONLY the weak class
        label2id_map=label2id_s2,
        focal_gamma=1.6                      # mild focal emphasis
    )
    # --------------------------------------------------------------------------

    early_stop_callback = EarlyStoppingCallback(
        early_stopping_patience=2,
        early_stopping_threshold=0.001  # tiny but non-zero improvement required
    )

    # --- Set up Stage 2 Trainer ---
    # Adding weight decay, cosine scheduler + warmup, grad accumulation improves stability 
        # (especially on CPU/small batch) without altering your high-level flow.
    training_args_s2 = TrainingArguments(
        output_dir=stage2_output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        use_cpu=True,
        per_device_train_batch_size=4,      # Halved from 8 to reduce memory
        per_device_eval_batch_size=4,       # Also reduce eval batch size
        gradient_accumulation_steps=4,      # Increased from 2 to compensate
        num_train_epochs=6,
        load_best_model_at_end=True,
        metric_for_best_model="eval_accuracy",
        greater_is_better=True,
        logging_dir=os.path.join(stage2_output_dir, "logs"),
        logging_strategy="epoch",
        dataloader_num_workers=0,
        overwrite_output_dir=True,
        remove_unused_columns=False,
        learning_rate=4e-5,
        weight_decay=0.05,                         
        lr_scheduler_type="cosine",                
        warmup_ratio=0.10,
    )

    # --- Set up Stage 2 Trainer ---
    # As with Stage 1, the complex fine-tuning strategy implemented in V31 failed. 
        # This change reverts the Stage 2 training process to V30's more effective 
        # uniform learning rate strategy to restore model performance.
    training_args_s2.learning_rate = 4e-5 # Set learning rate directly

    # skip erasing for fragile classes: sadness and neutral_speech
    # NEW added speech_action
    fragile_ids = [
        label2id_s2['sadness'],
        label2id_s2['speech_action'],
        label2id_s2['neutral_speech']
    ]

    # ensure weakest label is included once (idempotent)
    weak_id = label2id_s2[WEAKEST_LABEL]
    if SKIP_ERASE_WEAK and weak_id not in fragile_ids:
        fragile_ids.append(weak_id)
    
    # Single collator instance used by the trainer
    data_collator = DataCollatorWithAugmentation(
        processor=processor,
        augment_dict=augment_dict,           # your merged S2 map
        random_erasing_prob=0.10,
        random_erasing_scale=(0.02, 0.08),
        skip_erasing_label_ids=fragile_ids
    )

    print("\n" + "="*60)
    print("  STAGE 2: TRAIN EMOTION CLASSIFIER (11-CLASS)")
    print("="*60)
    
    # Hard guard: refuse to start S2 if S1 artifacts missing / not readable
    required_s1 = [
        os.path.join(SAVE_DIR, "relevance_filter_model"),
        os.path.join(SAVE_DIR, "stage1_calibration.json")
    ]
    for p in required_s1:
        if not os.path.exists(p):
            raise FileNotFoundError(f"Missing required S1 artifact before S2: {p}")
    
    # Optional: timing checkpoints to avoid silent 10h stalls
    from time import perf_counter as _t
    _t0_s2 = _t()

    # --- Step A: Define the Sampler (using the final training data) ---
    labels_np     = np.array(train_dataset_s2["label"])
    num_classes_s2 = len(label2id_s2)
    class_counts  = np.bincount(labels_np, minlength=num_classes_s2)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    weak_id = label2id_s2[WEAKEST_LABEL]
    class_weights[weak_id] *= WEAK_BOOST
    sample_weights = class_weights[labels_np]
    sampler = WeightedRandomSampler(
        weights=torch.as_tensor(sample_weights, dtype=torch.float),
        num_samples=len(sample_weights),
        replacement=True
    )

    # --- Step B: Define the Trainer and Data Collator ---
    # (data_collator was already defined earlier, this just creates the trainer)
    trainer_s2 = CustomLossTrainer(
        model=model_s2,
        args=training_args_s2,
        train_dataset=train_dataset_s2,
        eval_dataset=eval_dataset_s2,
        data_collator=data_collator,
        compute_metrics=partial(compute_metrics_with_confusion, 
                                label_names=RELEVANT_CLASSES, 
                                stage_name="Stage2"),
        loss_fct=loss_fct_s2,
        callbacks=[early_stop_callback],
    )

    # --- Step C: Define and Override the Trainer's Data Loader ---
    # This function is now defined LAST, so it can safely see both `sampler` and `trainer_s2`.
    def _custom_train_loader():
        return DataLoader(
            train_dataset_s2,
            batch_size=training_args_s2.per_device_train_batch_size,
            sampler=sampler,                # Use the sampler defined in Step A
            collate_fn=trainer_s2.data_collator, # Use the collator from the trainer in Step B
            num_workers=0,
            pin_memory=False
        )
    trainer_s2.get_train_dataloader = _custom_train_loader

    # --- Train Stage 2 model ---
    print("üöÄ Starting Stage 2 training...")
    start_time_s2 = time.time()
    trainer_s2.train()
    end_time_s2 = time.time()
    
    _t1_s2 = _t()
    print(f"‚åõ Stage 2 training took: {(_t1_s2 - _t0_s2)/3600:.2f} hours")
    
    # Calculate and print the duration
    duration_s2 = end_time_s2 - start_time_s2
    print(f"‚åõ Stage 2 training took: {time.strftime('%H:%M:%S', time.gmtime(duration_s2))}")
    save_model_and_processor(trainer_s2.model, processor, SAVE_DIR, model_name="emotion_classifier_model")
    print("\n‚úÖ Stage 2 Training Complete.")

    # --- Calibrate Stage 2 (scalar temperature on eval) ---
    print("\nüß™ Calibrating Stage 2 (scalar T on eval set)...")
    
    pred_eval = trainer_s2.predict(eval_dataset_s2)
    # The calibration function expects numpy arrays on the CPU
    logits_s2_numpy = pred_eval.predictions
    labels_s2_numpy = pred_eval.label_ids
    
    # Call the existing function to find the optimal temperature
    T_s2 = apply_temperature_scaling(logits_s2_numpy, labels_s2_numpy)
    
    s2_calib_path = os.path.join(SAVE_DIR, "emotion_classifier_model", "stage2_calibration.json")
    os.makedirs(os.path.dirname(s2_calib_path), exist_ok=True)
    with open(s2_calib_path, "w") as f:
        # Use the aliased import 'json_mod' here as well
        json_mod.dump({
            "T": float(max(1e-6, T_s2)),
            "val_size": int(labels_s2_numpy.size),
            "notes": "Scalar temperature via NLL on eval; seed=42"
        }, f, indent=2)
    
    print(f"‚úÖ S2 calibration done: T={T_s2:.3f} ‚Üí {s2_calib_path}")
    
    # (Optional) Re-run final eval with calibrated temperature for more accurate on-screen metrics
    trainer_s2.compute_metrics = partial(
        compute_metrics_with_confusion,
        label_names=RELEVANT_CLASSES,
        stage_name="Stage2_Calibrated",
        s2_temperature=float(T_s2),
    )
    print("\nüìä Re-running evaluation with calibrated temperature:")
    _ = trainer_s2.evaluate(eval_dataset_s2)
    
    print("\nüéâ Hierarchical Training Pipeline Finished Successfully.")
    
    return model_s1, trainer_s2.model, processor

In [7]:
# ----------------------------------
# 5. Hierarchical Inference
# ----------------------------------
# This function defines the two-step prediction pipeline for  images.
# It first checks for relevance (Stage 1) and then classifies the emotion (Stage 2).

def hierarchical_predict(image_paths, model_s1, model_s2, processor, device, batch_size=32):
    results = []

    @dataclass
    class _Thresh:
        base_conf: float = 0.65
        entropy_max: float = 1.60
        minority_classes: tuple = ("sadness", "speech_action")
        minority_conf: float = 0.90
    
    thr_cfg = _Thresh()

    for i in tqdm(range(0, len(image_paths), batch_size), desc="üî¨ Running Hierarchical Inference"):
        batch_paths = image_paths[i:i+batch_size]
        images = []
        valid_paths = []
        for path in batch_paths:
            try:
                img = Image.open(path).convert("RGB")
                images.append(img)
                valid_paths.append(path)
            except Exception:
                continue

        if not images:
            continue

        inputs = processor(images=images, return_tensors="pt").to(device)
        
        calib_path = os.path.join(SAVE_DIR, "stage1_calibration.json")
        T_s1, tau = 1.0, 0.30
        
        try:
            with open(calib_path, "r") as f:
                _c = json_mod.load(f)
            T_s1 = float(_c["T"])
            tau  = float(_c["tau"])
        except FileNotFoundError:
            print(f"‚ö†Ô∏è Missing {calib_path}; using defaults T={T_s1}, œÑ={tau}.")
        except (KeyError, Exception) as e:
            print(f"‚ö†Ô∏è Warning: Could not read {calib_path} ({e!s}); using defaults.")

        with torch.no_grad():
            logits_s1 = model_s1(**inputs).logits / max(T_s1, 1e-3)
            probs_s1 = F.softmax(logits_s1, dim=-1)
        
        relevant_mask = (probs_s1[:, label2id_s1['relevant']] >= tau)
        dev = logits_s1.device
        preds_s1 = torch.where(
            relevant_mask,
            torch.tensor(label2id_s1['relevant'], device=dev, dtype=torch.long),
            torch.tensor(label2id_s1['irrelevant'], device=dev, dtype=torch.long)
        )
        
        if relevant_mask.any():
            relevant_inputs = {k: v[relevant_mask] for k, v in inputs.items()}
            with torch.no_grad():
                logits_s2 = model_s2(**relevant_inputs).logits
                probs_s2 = F.softmax(logits_s2, dim=-1)
                confs_s2, preds_s2 = torch.max(probs_s2, dim=-1)
            _eps = 1e-12
            entropies_s2 = (-probs_s2 * torch.log(probs_s2 + _eps)).sum(dim=1)

        s2_idx = 0
        for j in range(len(valid_paths)):
            original_prediction = None
            if relevant_mask[j]:
                label_idx   = preds_s2[s2_idx].item()
                label_name  = id2label_s2[label_idx]
                original_prediction = label_name
                confidence  = float(confs_s2[s2_idx].item())
                entropy_val = float(entropies_s2[s2_idx].item())
                s2_idx += 1

                thr = thr_cfg.minority_conf if label_name in thr_cfg.minority_classes else thr_cfg.base_conf
                if (confidence < thr) or (entropy_val > thr_cfg.entropy_max):
                    final_label  = "review_lowconf"
                    route_reason = "thresholds"
                else:
                    final_label  = label_name
                    route_reason = "passed"
            else:
                final_label  = "irrelevant"
                confidence   = float(torch.softmax(logits_s1[j], dim=-1)[preds_s1[j]].item())
                entropy_val  = float('nan')
                route_reason = "stage1_gate"

            results.append({
                "image_path": valid_paths[j],
                "prediction": final_label,
                "top1_label": original_prediction, # The model's raw guess
                "confidence": confidence,
                "route_reason": route_reason,
                "entropy": entropy_val
            })
    return results

In [8]:
# ==============================================================================
# 6. Post-Training Analysis, Review, and Curation
# ==============================================================================

def run_post_training_analysis(model_s1, model_s2, processor, device, base_dataset_path, save_dir, version):
    """
    Runs a full inference pass and generates logs for review, curation, and analysis.
    Combines logic from old sections 15 and 16.
    """
    import pandas as pd   # ensure pd is local; prevents UnboundLocalError in notebooks
    
    print("\n" + "="*60)
    print("  RUNNING POST-TRAINING ANALYSIS & CURATION WORKFLOW")
    print("="*60)

    # --- Part A: Run Hierarchical Inference on the Entire Dataset ---
    all_image_paths = [str(p) for p in Path(base_dataset_path).rglob("*") if is_valid_image(p.name)]
    print(f"Found {len(all_image_paths)} images to process for inference.")
    
    predictions = hierarchical_predict(all_image_paths, model_s1, model_s2, processor, device)
    df = pd.DataFrame(predictions)
    
    # Derive true label from path for analysis
    df['true_label'] = df['image_path'].apply(lambda p: Path(p).parent.name)

    # Save the full log
    full_log_path = os.path.join(save_dir, f"{version}_full_inference_log.csv")
    df.to_csv(full_log_path, index=False)
    print(f"\n‚úÖ Full inference log saved to: {full_log_path}")

    # at top of the function (after building df)
    GENERATE_TRAINING_SHORTLISTS = False   # training script should not rebuild these
    GENERATE_MINING_PAIRS       = False    # keep mining in the curation notebook

    if GENERATE_TRAINING_SHORTLISTS:
        # ... (your existing shortlist + curated_additions code)
        pass
    else:
        print("‚ÑπÔ∏è Skipping shortlist/curated_additions creation here (use curation notebook artifacts).")
    
    if GENERATE_MINING_PAIRS:
        # ... (your existing hard-negative mining code)
        pass
    else:
        print("‚ÑπÔ∏è Skipping hard-negative mining here (handled in curation notebook).")

    # --- Part B: Identify and Organize Images for Manual Review ---
    # Tag images with low confidence as "REVIEW"
    review_threshold = REVIEW_CONF_THRESHOLD
    review_df = df[df['confidence'] < review_threshold]
    
    review_sort_dir = os.path.join(save_dir, "review_candidates_by_predicted_class")
    os.makedirs(review_sort_dir, exist_ok=True)
    
    print(f"\nFound {len(review_df)} images below {review_threshold} confidence for review.")
    for _, row in tqdm(review_df.iterrows(), total=len(review_df), desc="Sorting review images"):
        dest_dir = os.path.join(review_sort_dir, row['prediction'])
        os.makedirs(dest_dir, exist_ok=True)
        shutil.copy(row['image_path'], dest_dir)
    print(f"üìÇ Sorted review images into folders at: {review_sort_dir}")

    # --- : Generate shortlist and curated patch CSVs for THIS run ---
    #     - Shortlist: low-confidence items in focus classes (for targeted manual review)
    #     - Curated patch: template CSV for corrected labels to be fed back into VNext
    focus_classes = ['sadness','speech_action','neutral','neutral_speech','happiness']
    
    # Defensive: ensure the expected columns exist
    has_pred = 'prediction' in df.columns or 'predicted_label' in df.columns
    pred_col = 'prediction' if 'prediction' in df.columns else ('predicted_label' if 'predicted_label' in df.columns else None)
    if pred_col is not None:
        # Sort by confidence ascending (uncertain first)
        df_focus = df[df[pred_col].isin(focus_classes)].copy()
        if 'confidence' in df_focus.columns:
            df_focus = df_focus.sort_values('confidence', ascending=True)
    
        short_csv = os.path.join(save_dir, f"curation_shortlist_{version}.csv")
        patch_csv  = os.path.join(save_dir, f"curated_additions_{version}.csv")
    
        # Write shortlist with a stable set of columns
        keep_cols = [c for c in ['image_path','filepath','true_label',pred_col,'confidence'] if c in df_focus.columns]
        df_focus[keep_cols].to_csv(short_csv, index=False)
        print(f"‚úÖ Shortlist written: {short_csv}")
    
        # Create empty curated patch template
        src_path_col = 'image_path' if 'image_path' in df_focus.columns else 'filepath'
        patch_df = pd.DataFrame({
            "filepath": df_focus[src_path_col],
            "correct_label": "",
            "notes": ""
        })
        patch_df.to_csv(patch_csv, index=False)
        print(f"‚úÖ Curated patch template written: {patch_csv}")
    else:
        print("‚ÑπÔ∏è Skipped shortlist/patch CSVs: no predicted label column found in full log.")

    # --- : Merge this run's shortlist/patch with V32 to create canonical merged artifacts ---
    def _merge_csvs(csv_list, key_cols, out_csv):
        import pandas as pd
        import os
    
        # Normalize common column name variants so we can dedupe safely
        def _normalize_cols(df: pd.DataFrame) -> pd.DataFrame:
            colmap = {}
            # path columns
            if "image_path" not in df.columns:
                if "filepath" in df.columns:
                    colmap["filepath"] = "image_path"
                elif "path" in df.columns:
                    colmap["path"] = "image_path"
            # predicted label columns
            if "predicted_label" not in df.columns:
                if "prediction" in df.columns:
                    colmap["prediction"] = "predicted_label"
                elif "predicted" in df.columns:
                    colmap["predicted"] = "predicted_label"
            return df.rename(columns=colmap)
    
        frames = []
        for p in csv_list:
            if os.path.exists(p):
                try:
                    df = pd.read_csv(p)
                    df = _normalize_cols(df)
                    frames.append(df)
                except Exception:
                    pass
    
        if not frames:
            return
    
        merged = pd.concat(frames, ignore_index=True)
    
        # Keep only keys that actually exist after normalization
        available_keys = [k for k in key_cols if k in merged.columns]
        if not available_keys:
            print(f"‚ÑπÔ∏è Skipped merge for {out_csv}: none of the key columns {key_cols} exist in merged data.")
            return
    
        merged = merged.drop_duplicates(subset=available_keys, keep="first")
        merged.to_csv(out_csv, index=False)
        print(f"‚úÖ Merged: {out_csv} ({len(merged)} rows)")

    
    # Paths for this run (already defined above)
    short_csv = os.path.join(save_dir, f"curation_shortlist_{version}.csv")
    patch_csv  = os.path.join(save_dir, f"curated_additions_{version}.csv")
    
    # V32 paths (if present)
    v32_short = os.path.join(save_dir, "curation_shortlist_V32.csv")
    v32_patch = os.path.join(save_dir, "curated_additions_V32.csv")
    
    # Canonical merged outputs
    short_merged = os.path.join(save_dir, "curation_shortlist_merged.csv")
    patch_merged = os.path.join(save_dir, "curated_additions_merged.csv")
    
    # Merge (shortlist merges on [filepath, predicted_label]; patch merges on [filepath])
    if pred_col is not None:
        # Figure out the filepath column available
        avail_path_cols = ['image_path','filepath']
        path_col = next((c for c in avail_path_cols if c in df.columns), None)
    
        if path_col is not None:
            _merge_csvs([v32_short, short_csv], key_cols=[path_col, pred_col], out_csv=short_merged)
            _merge_csvs([v32_patch, patch_csv], key_cols=[path_col], out_csv=patch_merged)
        else:
            print("‚ÑπÔ∏è Skipped merge: no filepath column present in full log.")
    else:
        print("‚ÑπÔ∏è Skipped merge: no predicted label column present in full log.")


    # --- Part C: Mine for "Hard Negative" Confusion Pairs ---
    MINING_HARD_NEGATIVES = True

    if MINING_HARD_NEGATIVES:
        print("\n‚õèÔ∏è  Mining for hard negative confusion pairs...")
        
        # FIX: First, filter the DataFrame to only include true relevant images
        # This prevents images from 'hard_case' that slipped past S1 from being mined.
        relevant_df = df[df['true_label'].isin(RELEVANT_CLASSES)].copy()
        print(f"   - Analyzing confusions within {len(relevant_df)} truly relevant images.")

        col_true = 'true_label'
        col_pred = 'top1_label' # Use the raw model prediction before thresholding
        
        if col_pred not in relevant_df.columns:
             raise RuntimeError(f"Could not find '{col_pred}' column. Ensure hierarchical_predict is logging it.")

        # FIX: Add the speech_action pairs to the mining list
        confusion_pairs_to_mine = [
            ('contempt', 'questioning'),
            ('contempt', 'neutral'),
            ('fear', 'surprise'),
            ('neutral_speech', 'speech_action'),
            ('sadness', 'speech_action')
        ]

        for c1, c2 in confusion_pairs_to_mine:
            # The mask now operates on the pre-filtered 'relevant_df'
            mask = ((relevant_df[col_true] == c1) & (relevant_df[col_pred] == c2)) | \
                   ((relevant_df[col_true] == c2) & (relevant_df[col_pred] == c1))
            hard_negatives = relevant_df.loc[mask]

            if not hard_negatives.empty:
                out_path = os.path.join(save_dir, f"hard_negatives_{c1}_vs_{c2}.csv")
                hard_negatives.to_csv(out_path, index=False)
                print(f"  - Found {len(hard_negatives)} hard negatives for ({c1} ‚Üî {c2}). Saved: {out_path}")
            else:
                print(f"  - No hard negatives found for ({c1} ‚Üî {c2}).")
    else:
        print("‚è© Hard-negative mining disabled.")
    

    # --- Part D: Generate Final End-to-End Performance Report ---
    print("\n" + "="*60)
    print("  END-TO-END PIPELINE PERFORMANCE REPORT (S1+S2)")
    print("="*60)
    
    # For a clean report, we'll consider 'review_lowconf' as a misclassification
    # and filter the dataframe to only include the original, known labels.
    report_df = df[df['true_label'].isin(RELEVANT_CLASSES + IRRELEVANT_CLASSES)].copy()
    
    # We also need to map the S1 'irrelevant' ground truth to its own class
    report_df.loc[report_df['true_label'].isin(IRRELEVANT_CLASSES), 'true_label'] = 'irrelevant'
    
    # Get all labels that appear in either the true or predicted columns for the report
    all_labels = sorted(list(set(report_df['true_label'].unique()) | set(report_df['prediction'].unique())))
    
    # Generate and print the final report
    end_to_end_report = classification_report(
        report_df['true_label'], 
        report_df['prediction'], 
        labels=all_labels,
        zero_division=0
    )
    
    print("This report reflects the true performance of the entire two-stage system.")
    print("It accounts for errors made by both the Stage 1 filter and the Stage 2 classifier.\n")
    print(end_to_end_report)

In [9]:
# ==============================================================================
# 7. Model Calibration
# ==============================================================================

def apply_temperature_scaling(logits, labels):
    """Finds the optimal temperature for calibrating model confidence."""
    logits_tensor = torch.tensor(logits, dtype=torch.float32)
    labels_tensor = torch.tensor(labels, dtype=torch.long)

    class TemperatureScaler(nn.Module):
        def __init__(self):
            super().__init__()
            self.temperature = nn.Parameter(torch.ones(1) * 1.5)

        def forward(self, logits):
            return logits / self.temperature

    model = TemperatureScaler()
    optimizer = LBFGS([model.temperature], lr=0.01, max_iter=50)

    def eval_fn():
        optimizer.zero_grad()
        loss = F.cross_entropy(model(logits_tensor), labels_tensor)
        loss.backward()
        return loss

    optimizer.step(eval_fn)
    return model.temperature.item()

def plot_reliability_diagram(logits, labels, temperature, save_dir, version, stage_name):
    """Visualizes model calibration before and after temperature scaling."""
    logits = torch.from_numpy(logits)
    labels = torch.from_numpy(labels)
    
    # Calculate before
    probs_before = F.softmax(logits, dim=1)
    confs_before, _ = torch.max(probs_before, 1)
    
    # Calculate after
    probs_after = F.softmax(logits / temperature, dim=1)
    confs_after, _ = torch.max(probs_after, 1)

    # Plotting logic remains the same...
    # (For brevity, the detailed plotting code from your old script goes here)
    print(f"üìä Reliability diagram generation logic would go here.")

In [10]:
# ==============================================================================
# 8. Hierarchical Model Ensembling
# ==============================================================================

def hierarchical_ensemble_predict(image_path, processor, s1_models, s2_models, device):
    """Performs an ensembled prediction using multiple hierarchical models."""
    try:
        image = Image.open(image_path).convert("RGB")
        inputs = processor(images=image, return_tensors="pt").to(device)

    except Exception:
        return None, None

    # --- Stage 1 Ensemble (Majority Vote) ---
    s1_votes = []
    with torch.no_grad():
        for model in s1_models:
            logits = model(**inputs).logits
            pred = torch.argmax(logits, dim=-1).item()
            s1_votes.append(pred)
    
    # Decide relevance based on majority vote (1 = relevant)
    is_relevant = Counter(s1_votes).most_common(1)[0][0] == label2id_s1['relevant']

    if not is_relevant:
        return "irrelevant", None

    # --- Stage 2 Ensemble (Average Probabilities) ---
    s2_probs = []
    with torch.no_grad():
        for model in s2_models:
            logits = model(**inputs).logits
            probs = F.softmax(logits, dim=-1)
            s2_probs.append(probs)
            
    # Average the probabilities across all models
    avg_probs = torch.mean(torch.stack(s2_probs), dim=0)
    confidence, pred_idx = torch.max(avg_probs, dim=-1)
    
    final_prediction = id2label_s2[pred_idx.item()]
    final_confidence = confidence.item()
    
    return final_prediction, final_confidence

In [11]:
# ==============================================================================
# 9. Script Execution Entry Point (with Integrated Smoke Test and Full Analysis)
# ==============================================================================
if __name__ == "__main__":
    device = torch.device("cpu")

    # --- Configuration for Pre-flight Smoke Test ---
    RUN_SMOKE_TEST = True

    if RUN_SMOKE_TEST:
        print("="*60)
        print("üß™ RUNNING PRE-FLIGHT SMOKE TEST...")
        print("="*60)
        try:
            checkpoint_path = find_latest_checkpoint(OUTPUT_ROOT_DIR)
            if not checkpoint_path or not os.path.isdir(checkpoint_path):
                raise FileNotFoundError(
                    "No previous model found for smoke test. "
                    "To train from scratch, set RUN_SMOKE_TEST = False."
                )

            print(f"   - Loading latest model from: {os.path.basename(checkpoint_path)}")
            model_s1, model_s2, processor = _load_exports_for_smoke(checkpoint_path, device)
            
            image_paths = [str(p) for p in Path(BASE_DATASET_PATH).rglob("*") if is_valid_image(p.name)]
            if not image_paths:
                raise FileNotFoundError(f"No images found in {BASE_DATASET_PATH} for smoke test.")
            
            test_images = random.sample(image_paths, min(len(image_paths), 5))
            print(f"   - Running inference on {len(test_images)} random images...")

            predictions = hierarchical_predict(
                image_paths=test_images, model_s1=model_s1, model_s2=model_s2,
                processor=processor, device=device, batch_size=4
            )

            if len(predictions) != len(test_images):
                raise RuntimeError(f"Inference failed. Expected {len(test_images)} results, got {len(predictions)}.")

            for result in predictions:
                print(f"     - OK: '{os.path.basename(result['image_path'])}' -> '{result['prediction']}'")

        except Exception as e:
            print(f"\n‚ùå SMOKE TEST FAILED: {e}")
            print("   - Halting script. Please resolve the issue or set RUN_SMOKE_TEST = False.")
            raise

        print("\n‚úÖ Smoke test passed successfully.")

    # --- Step 1: Execute Training Pipeline ---
    print("\n" + "="*60)
    print("üöÄ PROCEEDING TO FULL TRAINING PIPELINE...")
    print("="*60)
    model_s1, model_s2, processor = main(device)

    # --- Step 2: Run Post-Training Analysis & Curation ---
    if RUN_INFERENCE:
        run_post_training_analysis(model_s1, model_s2, processor, device, BASE_DATASET_PATH, SAVE_DIR, VERSION)

    # --- Step 3: Run Final Model Checks ---
    stage2_metrics_path = os.path.join(SAVE_DIR, "per_class_metrics_Stage2.csv")
    check_deployment_readiness(stage2_metrics_path, f1_threshold=0.80)

    # --- Step 4: Calibrate the Stage 2 Model (Restored) ---
    logits_s2_path = os.path.join(SAVE_DIR, f"logits_eval_Stage2_{VERSION}.npy")
    labels_s2_path = os.path.join(SAVE_DIR, f"labels_eval_Stage2_{VERSION}.npy")
    
    if os.path.exists(logits_s2_path) and os.path.exists(labels_s2_path):
        print("\n" + "="*60)
        print("  CALIBRATING STAGE 2 MODEL")
        print("="*60)
        logits_s2 = np.load(logits_s2_path)
        labels_s2 = np.load(labels_s2_path)
        
        optimal_temp = apply_temperature_scaling(logits_s2, labels_s2)
        print(f"‚úÖ Optimal temperature for Stage 2 model: {optimal_temp:.4f}")
    else:
        print("‚ö†Ô∏è Skipping Stage 2 calibration: log files not found.")

    # --- Step 5: Run Ensemble Analysis (Restored) ---
    # Note: This uses a hardcoded path to a previous version for comparison.
    v_prev_path = find_latest_checkpoint(OUTPUT_ROOT_DIR, current_run_basename=VERSION_TAG)
    
    if v_prev_path:
        print("\n" + "="*60)
        print(f"  RUNNING ENSEMBLE ANALYSIS (Current: {VERSION_TAG} vs. Previous: {os.path.basename(v_prev_path)})")
        print("="*60)
        
        s1_model_prev = AutoModelForImageClassification.from_pretrained(
            os.path.join(v_prev_path, "relevance_filter_model")
        ).to(device).eval()
        s2_model_prev = AutoModelForImageClassification.from_pretrained(
            os.path.join(v_prev_path, "emotion_classifier_model")
        ).to(device).eval()
        
        s1_models_ensemble = [model_s1, s1_model_prev]
        s2_models_ensemble = [model_s2, s2_model_prev]

        # Find a random image to test the ensemble
        all_images = [str(p) for p in Path(BASE_DATASET_PATH).rglob("*") if is_valid_image(p.name)]
        if all_images:
            example_image_path = random.choice(all_images)
            prediction, confidence = hierarchical_ensemble_predict(
                example_image_path, processor, s1_models_ensemble, s2_models_ensemble, device
            )
            if confidence is not None:
                print(f"Ensemble prediction for '{os.path.basename(example_image_path)}': {prediction} (Confidence: {confidence:.2f})")
            else:
                print(f"Ensemble prediction for '{os.path.basename(example_image_path)}': {prediction} (Confidence: N/A)")
    else:
        print("\n‚ÑπÔ∏è Skipping ensemble analysis: no previous model version found to compare against.")

üß™ RUNNING PRE-FLIGHT SMOKE TEST...
   - Loading latest model from: V38_20251021_123355
   - Running inference on 5 random images...


üî¨ Running Hierarchical Inference:   0%|                  | 0/2 [00:00<?, ?it/s]

‚ö†Ô∏è Missing /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V39_20251022_093948/stage1_calibration.json; using defaults T=1.0, œÑ=0.3.


üî¨ Running Hierarchical Inference: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:00<00:00,  5.17it/s]


‚ö†Ô∏è Missing /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V39_20251022_093948/stage1_calibration.json; using defaults T=1.0, œÑ=0.3.
     - OK: 'Training_15203635.jpg' -> 'irrelevant'
     - OK: 'PublicTest_96387819.jpg' -> 'irrelevant'
     - OK: 'Tom_Welch_0001.jpg' -> 'irrelevant'
     - OK: 'Colin_Powell_0179.jpg' -> 'irrelevant'
     - OK: 'Atal_Bihari_Vajpayee_0009.jpg_face1.jpg' -> 'irrelevant'

‚úÖ Smoke test passed successfully.

üöÄ PROCEEDING TO FULL TRAINING PIPELINE...

üñ•Ô∏è Using device: cpu
‚úÖ Skipping dataset preparation, using existing directories.

  STAGE 1: TRAINING RELEVANCE FILTER (BINARY CLASSIFIER)


Resolving data files:   0%|          | 0/26881 [00:00<?, ?it/s]

Stage 1: 21504 training samples, 5377 validation samples.




‚öñÔ∏è Stage 1 Class Weights: tensor([0.6492, 2.1761])
üöÄ Starting Stage 1 training...


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0003,0.050759,0.992003
2,0.0026,0.043872,0.987725
3,0.0189,0.019275,0.990701
4,0.0052,0.03542,0.990887
5,0.0,0.029885,0.984564



üìà Classification Report for Stage1:
              precision    recall  f1-score   support

  irrelevant       0.99      1.00      0.99      4132
    relevant       0.99      0.98      0.98      1245

    accuracy                           0.99      5377
   macro avg       0.99      0.99      0.99      5377
weighted avg       0.99      0.99      0.99      5377


Top 3 confused class pairs:
  - relevant ‚Üí irrelevant: 30 instances
  - irrelevant ‚Üí relevant: 13 instances

üß† Avg prediction entropy: 0.0241

üîç Class entropies (sorted):
  - relevant: entropy = 0.0926
  - irrelevant: entropy = 0.0035

üìà Classification Report for Stage1:
              precision    recall  f1-score   support

  irrelevant       1.00      0.99      0.99      4132
    relevant       0.96      0.99      0.97      1245

    accuracy                           0.99      5377
   macro avg       0.98      0.99      0.98      5377
weighted avg       0.99      0.99      0.99      5377


Top 3 confused clas

Resolving data files:   0%|          | 0/6175 [00:00<?, ?it/s]

Stage 2: 4940 training samples, 1235 validation samples.
Stage 2 Label Distribution (Train): Counter({9: 1608, 4: 651, 8: 554, 5: 530, 0: 388, 6: 382, 1: 251, 3: 240, 10: 135, 7: 101, 2: 100})

  STAGE 2: TRAIN EMOTION CLASSIFIER (11-CLASS)
üöÄ Starting Stage 2 training...


Epoch,Training Loss,Validation Loss,Accuracy
0,0.0252,0.237258,0.917409
1,0.0312,0.203355,0.910931
2,0.0256,0.206216,0.917409



üìà Classification Report for Stage2:
                precision    recall  f1-score   support

         anger       0.94      0.87      0.90        85
      contempt       0.80      0.68      0.74        60
       disgust       0.86      0.69      0.77        26
          fear       0.87      0.96      0.91        71
     happiness       0.91      1.00      0.95       167
       neutral       0.97      0.99      0.98       135
   questioning       0.81      0.95      0.87        92
       sadness       0.91      0.78      0.84        40
      surprise       0.98      0.95      0.96       147
neutral_speech       0.96      0.90      0.93       381
 speech_action       0.69      1.00      0.82        31

      accuracy                           0.92      1235
     macro avg       0.88      0.89      0.88      1235
  weighted avg       0.92      0.92      0.92      1235


Top 3 confused class pairs:
  - neutral_speech ‚Üí happiness: 14 instances
  - contempt ‚Üí questioning: 10 instance


üìà Classification Report for Stage2:
                precision    recall  f1-score   support

         anger       0.94      0.86      0.90        85
      contempt       0.85      0.73      0.79        60
       disgust       0.78      0.69      0.73        26
          fear       0.90      0.93      0.92        71
     happiness       0.91      1.00      0.95       167
       neutral       0.97      0.98      0.97       135
   questioning       0.84      0.91      0.88        92
       sadness       0.90      0.65      0.75        40
      surprise       0.97      0.95      0.96       147
neutral_speech       0.93      0.90      0.92       381
 speech_action       0.65      0.97      0.78        31

      accuracy                           0.91      1235
     macro avg       0.88      0.87      0.87      1235
  weighted avg       0.91      0.91      0.91      1235


Top 3 confused class pairs:
  - neutral_speech ‚Üí happiness: 14 instances
  - neutral_speech ‚Üí speech_action: 10 


üìà Classification Report for Stage2_Calibrated:
                precision    recall  f1-score   support

         anger       0.94      0.88      0.91        85
      contempt       0.80      0.72      0.75        60
       disgust       0.86      0.69      0.77        26
          fear       0.92      0.96      0.94        71
     happiness       0.91      1.00      0.95       167
       neutral       0.96      0.98      0.97       135
   questioning       0.83      0.96      0.89        92
       sadness       0.91      0.78      0.84        40
      surprise       0.97      0.93      0.95       147
neutral_speech       0.96      0.90      0.93       381
 speech_action       0.68      0.97      0.80        31

      accuracy                           0.92      1235
     macro avg       0.88      0.89      0.88      1235
  weighted avg       0.92      0.92      0.92      1235


Top 3 confused class pairs:
  - neutral_speech ‚Üí happiness: 14 instances
  - contempt ‚Üí questioning: 

üî¨ Running Hierarchical Inference: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 837/837 [23:22<00:00,  1.68s/it]



‚úÖ Full inference log saved to: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V39_20251022_093948/V39_full_inference_log.csv
‚ÑπÔ∏è Skipping shortlist/curated_additions creation here (use curation notebook artifacts).
‚ÑπÔ∏è Skipping hard-negative mining here (handled in curation notebook).

Found 3375 images below 0.85 confidence for review.


Sorting review images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3375/3375 [00:01<00:00, 2797.53it/s]


üìÇ Sorted review images into folders at: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V39_20251022_093948/review_candidates_by_predicted_class
‚úÖ Shortlist written: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V39_20251022_093948/curation_shortlist_V39.csv
‚úÖ Curated patch template written: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V39_20251022_093948/curated_additions_V39.csv
‚úÖ Merged: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V39_20251022_093948/curation_shortlist_merged.csv (1364 rows)
‚úÖ Merged: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V39_20251022_093948/curated_additions_merged.csv (1364 rows)

‚õèÔ∏è  Mining for hard negative confusion pairs...
   - Analyzing confusions within 6174 truly relevant images.
  - No hard negatives found for (contempt ‚Üî questioning).
  - Found 2 hard negatives for (contempt ‚Üî neutral). Saved: /Users/natalyagrokh/AI/ml_expr