In [1]:
# ==============================================================================
#  V34  -  Verified full training + inference pipeline
#  Summary: merged V32 artifact structure, added RandomErasing + calibration RGB normalization
# ==============================================================================

# V33 to V34 changes:
    # overview: Elevate Stage 1 to deployment readiness while preserving Stage 2 gains.
    #   Adds relevant-only focal CE, asymmetric augmentation, calibration + œÑ tuning,
    #   and restores V32 diagnostic artifacts (review folders, mining, shortlist,
    #   curated patch). Introduces optional QE/XAI utilities and ablation scaffolding.

    # section #4 (Stage 1 loss):
    #   - Added RelevantFocalCrossEntropy (Œ≥=2.0) that applies focal scaling only for
    #     the 'relevant' class (uses existing class_weights_s1 internally).
    #   - Rationale: Increase recall on difficult positives without inflating easy FPs.
    #   - Expected Impact: F1(relevant) ‚Üë, Recall(relevant) ‚Üë, Calibration ‚Üë.

    # section #4 (Stage 1 augmentation):
    #   - Applied stronger augmentation ONLY to 'relevant' via augment_map_s1.
    #   - Rationale: Expand boundary coverage for lighting/occlusion/perspective.
    #   - Expected Impact: Recall(relevant) ‚Üë 5‚Äì10pp, stable precision.

    # section #4 (Stage 1 calibration):
    #   - Implemented temperature scaling (fit on eval) + œÑ sweep (0.30‚Äì0.55).
    #   - Persisted stage1_calibration.json {T, œÑ} for hierarchical inference use.
    #   - Expected Impact: Better-calibrated probabilities; optimal œÑ selection -> F1 ‚Üë.
    # ‚Äúhierarchical_predict(...) now uses stage1_calibration.json (T and œÑ) 
        # to gate relevance during inference.‚Äù

    # section #4 (Collator & Aug fix):
    #   - Removed PIL-level RandomErasing from strong_pos_aug; applied RandomErasing
    #     only at the tensor stage inside DataCollatorWithAugmentation to prevent
    #     AttributeError on PIL images and keep occlusion realism.
    #   - Defined collator helpers (ToTensor/ToPILImage) and fixed attribute mismatch:
    #     now uses `self.random_erasing` (configurable) instead of undefined
    #     `self.post_tensor_erase`.
    #   - Expected Impact: Stable training with correct augmentation; no PIL/tensor
    #     shape errors; consistent occlusion regularization.

    # section #4 (Calibration/Eval robustness ‚Äì update):
    #   - Added _ensure_rgb() and applied it inside both S1 calibration loops
    #     to convert all eval images to 3-channel RGB (expand grayscale arrays),
    #     preventing ViTImageProcessor error on 2-D images.
    #   - Harmonized division by temperature to logits / max(T, 1e-3) for numeric stability.
    #   - Expected Impact: Stable S1 calibration/œÑ-sweep on heterogeneous eval data.

    # section #5 (Stage 2 micro-tweaks):
    #   - Lowered focal_gamma from 1.5 ‚Üí 1.2 for ['sadness','speech_action'].
    #   - (Optional) Mild aug added for 'neutral_speech' only.
    #   - Expected Impact: +1‚Äì2pp macro-F1; improved calibration on fragile classes.

    # section #5 (Inference safety):
    #   - In hierarchical_predict, ensured device/dtype-safe `torch.where` for S1
    #     gating (fill tensors now created on `logits_s1.device`).
    #   - Expected Impact: Avoids device mismatch when switching CPU‚ÜîGPU in future runs.

    # section #6 - Aug pipeline wiring:
        #   ‚ÄúWired Stage-2 augment_dict to include neutral_speech mild aug and 
        # pass the merged augment_dict to the collator.‚Äù

    # section #7 (Artifacts restoration & merge):
    #   - Restored V32-style artifacts for V34:
    #       * V34_full_inference_log.csv
    #       * review_candidates_by_predicted_class/ (CONF_THRESHOLD=0.85)
    #       * hard_negatives_*_vs_*.csv (pair mining)
    #       * curation_shortlist_V34.csv + curated_additions_V34.csv
    #   - Merged overlapping files with V32:
    #       * curation_shortlist_merged.csv
    #       * curated_additions_merged.csv
    #   - Rationale: Full diagnostics parity + reproducibility across versions.
    #   - Mining pairs now prefer THIS run‚Äôs full log first (fall back to prior runs only if missing).
    #   - Added curation_shortlist_{VXX}.csv and curated_additions_{VXX}.csv emission in post-training flow.
    #   - Added canonical merges:
    #       * curation_shortlist_merged.csv (V32 + current)
    #       * curated_additions_merged.csv (V32 + current)
    #   - Rationale: Full parity with V32 artifacts while keeping continuity across versions.
    #   - Expected Impact: Faster, targeted curation and stable longitudinal review.


    # section #8 (Optional analysis utilities):
    #   - Added QE (qualitative error bucketing) helper.
    #   - Added S1 ablation summary helper for clean experiment tracking.
    #   - Expected Impact: Data collection guided by dominant failure modes; faster iteration.

    # section #10 - Operational defaults:
    #   Set RUN_INFERENCE=True 

    # Follow-Up Actions:
    #   - Run V34a/b/c ablations (loss-only; aug-only; curated-only), then V34d with best combo.
    #   - Confirm S1 deployment gate: F1(relevant) ‚â• 0.80, Recall ‚â• 0.85, Precision ‚â• 0.70.
    #   - Rebuild merged artifacts and refresh curation stream for next cycle.
# ==============================================================================

In [2]:
# --------------------------
# 0. Imports
# --------------------------
# WORKAROUND for PyTorch MPS bug
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# Standard Library Imports
import datasets
import csv
import gc
import glob
import multiprocessing as mp
import os
import torch
import random
import re
import shutil
import subprocess
import sys
import time
import json

# Third-Party Imports
import accelerate
import dill
import face_recognition
import matplotlib.pyplot as plt
import numpy as np, cv2
import pandas as pd
import seaborn as sns
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import transformers

# From Imports
from collections import Counter
from datasets import ClassLabel, Dataset, Features, Image as DatasetsImage, concatenate_datasets, load_dataset
from datetime import datetime
from functools import partial
from imagehash import phash, hex_to_hash
from io import BytesIO
from pathlib import Path
from PIL import Image, ImageOps, ImageStat, ExifTags, UnidentifiedImageError
from sklearn.metrics import classification_report, confusion_matrix, log_loss
from sklearn.utils.class_weight import compute_class_weight
from torch import nn
from torch.optim import AdamW, LBFGS
from torchvision import transforms
from torchvision.transforms import (
    RandAugment,
)
from tqdm import tqdm
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    EarlyStoppingCallback,
    TrainingArguments,
    Trainer,
    ViTForImageClassification,
)

In [3]:
# --------------------------
# 1. Global Configurations
# --------------------------

# --- üìÇ Core Paths ---
# This is the root directory containing your original 14-class dataset structure.
BASE_DATASET_PATH = "/Users/natalyagrokh/AI/ml_expressions/img_datasets/ferckjalfaga_dataset_14_labels"
# This is the root directory where all outputs (models, logs, prepared datasets) will be saved.
OUTPUT_ROOT_DIR = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training"

# --- ‚öôÔ∏è Run Configuration ---
# default safer for daily dev runs; flip to True when you want full-corpus inference
RUN_INFERENCE = True
# default safer; run once when dataset layout changes
PREPARE_DATASETS = False

# Finds the most recent V* model directory based on modification time.
def find_latest_checkpoint(root_dir):
    all_run_dirs = [
        os.path.join(root_dir, d)
        for d in os.listdir(root_dir)
        if d.startswith("V") and os.path.isdir(os.path.join(root_dir, d))
    ]
    if not all_run_dirs:
        return None

    # Sort directories by modification time, newest first
    sorted_dirs = sorted(all_run_dirs, key=os.path.getmtime, reverse=True)

    # The newest directory is the current run's empty folder.
    # We need the second newest, which is the latest *completed* run.
    if len(sorted_dirs) > 1:
        return sorted_dirs[1] # <-- Return the second item in the list
    else:
        # If there's only one (or zero), no previous checkpoint exists
        return None

# --- ü§ñ Model Configuration ---
# The pretrained Vision Transformer model from Hugging Face to be used as a base.
BASE_MODEL_NAME = "google/vit-base-patch16-224-in21k"

# Dynamically find the latest checkpoint to train from
latest_checkpoint = find_latest_checkpoint(OUTPUT_ROOT_DIR)

if latest_checkpoint:
    PRETRAINED_CHECKPOINT_PATH = latest_checkpoint
    print(f"‚úÖ Dynamically loading latest checkpoint: {os.path.basename(PRETRAINED_CHECKPOINT_PATH)}")
else:
    # If no checkpoint is found, fall back to the base model from Hugging Face
    PRETRAINED_CHECKPOINT_PATH = BASE_MODEL_NAME
    print(f"‚ö†Ô∏è No local checkpoint found. Starting from base model: {BASE_MODEL_NAME}")
    
# --- üè∑Ô∏è Dataset & Label Definitions ---
# These lists define the structure for the hierarchical pipeline.
# All folders listed here will be grouped into the 'relevant' class for Stage 1
# and used for training the final 11-class classifier in Stage 2.
RELEVANT_CLASSES = [
    'anger', 'contempt', 'disgust', 'fear', 'happiness',
    'neutral', 'questioning', 'sadness', 'surprise',
    'neutral_speech', 'speech_action'
]
# **IMPORTANT**: Since 'unknown' is a subfolder of 'hard_case', we only need to
# list 'hard_case' here. The script will find all images inside it recursively.
IRRELEVANT_CLASSES = ['hard_case']

# Mappings for the Stage 2 (11-class Emotion) model
id2label_s2 = dict(enumerate(RELEVANT_CLASSES))
label2id_s2 = {v: k for k, v in id2label_s2.items()}

# Mappings for the Stage 1 (binary Relevance) model
id2label_s1 = {0: 'irrelevant', 1: 'relevant'}
label2id_s1 = {v: k for k, v in id2label_s1.items()}

# single source of truth for review gating
REVIEW_CONF_THRESHOLD = 0.85  

# --- üñºÔ∏è File Handling ---
# Defines valid image extensions and provides a function to check them.
VALID_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff")
def is_valid_image(filename):
    return filename.lower().endswith(VALID_EXTENSIONS) and not filename.startswith("._")

# --- üî¢ Versioning and Output Directory Setup ---
# Automatically determines the next version number (e.g., V31) and creates a timestamped output folder.
def get_next_version(base_dir):
    all_entries = glob.glob(os.path.join(base_dir, "V*_*"))
    existing = [os.path.basename(d) for d in all_entries if os.path.isdir(d)]
    versions = [
        int(d[1:].split("_")[0]) for d in existing
        if d.startswith("V") and "_" in d and d[1:].split("_")[0].isdigit()
    ]
    next_version = max(versions, default=0) + 1
    return f"V{next_version}"

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
VERSION = get_next_version(OUTPUT_ROOT_DIR)
VERSION_TAG = VERSION + "_" + timestamp
SAVE_DIR = os.path.join(OUTPUT_ROOT_DIR, VERSION_TAG)
os.makedirs(SAVE_DIR, exist_ok=True)
print(f"üìÅ Output directory created: {SAVE_DIR}")

‚úÖ Dynamically loading latest checkpoint: V33_20251010_122749
üìÅ Output directory created: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V34_20251013_211825


In [4]:
# ----------------------------------------------------
# 2. Hierarchical Dataset Preparation
# ----------------------------------------------------
# This function organizes the original multi-class dataset into two separate
# folder structures required for the two-stage training process. It recursively
# searches through subdirectories (no matter how deep) and is smart enough to
# skip non-image files.
def prepare_hierarchical_datasets(base_path, output_path):
    
    stage1_path = os.path.join(output_path, "stage_1_relevance_dataset")
    stage2_path = os.path.join(output_path, "stage_2_emotion_dataset")

    print(f"üóÇÔ∏è Preparing hierarchical datasets at: {output_path}")

    # --- Create Stage 1 Dataset (Relevance Filter) ---
    print("\n--- Creating Stage 1 Dataset ---")
    irrelevant_dest = os.path.join(stage1_path, "0_irrelevant")
    relevant_dest = os.path.join(stage1_path, "1_relevant")
    os.makedirs(irrelevant_dest, exist_ok=True)
    os.makedirs(relevant_dest, exist_ok=True)

    # Copy irrelevant files recursively
    print("Processing 'irrelevant' classes...")
    for class_name in IRRELEVANT_CLASSES:
        src_dir = Path(os.path.join(base_path, class_name))
        if src_dir.is_dir():
            print(f"  Recursively copying from '{class_name}'...")
            # Here, rglob('*') finds every file in every sub-folder.
            for file_path in src_dir.rglob('*'):
                if file_path.is_file() and is_valid_image(file_path.name):
                    shutil.copy(file_path, irrelevant_dest)
        else:
            print(f"  ‚ö†Ô∏è Warning: Source directory not found for '{class_name}'")

    # Copy relevant files recursively
    print("Processing 'relevant' classes...")
    for class_name in RELEVANT_CLASSES:
        src_dir = Path(os.path.join(base_path, class_name))
        if src_dir.is_dir():
            print(f"  Recursively copying from '{class_name}'...")
            for file_path in src_dir.rglob('*'):
                if file_path.is_file() and is_valid_image(file_path.name):
                    shutil.copy(file_path, relevant_dest)
        else:
            print(f"  ‚ö†Ô∏è Warning: Source directory not found for '{class_name}'")

    # --- Create Stage 2 Dataset (Emotion Classifier) ---
    print("\n--- Creating Stage 2 Dataset ---")
    for class_name in RELEVANT_CLASSES:
        src_dir = Path(os.path.join(base_path, class_name))
        dest_dir = os.path.join(stage2_path, class_name)

        # Ensure destination is clean before copying
        if os.path.exists(dest_dir):
            shutil.rmtree(dest_dir)
        os.makedirs(dest_dir)

        if src_dir.is_dir():
            print(f"  Copying '{class_name}' to Stage 2 directory...")
            for file_path in src_dir.rglob('*'):
                 if file_path.is_file() and is_valid_image(file_path.name):
                    shutil.copy(file_path, dest_dir)
        else:
            print(f"  ‚ö†Ô∏è Warning: Source directory not found for '{class_name}'")

    print("\n‚úÖ Hierarchical dataset preparation complete.")
    return stage1_path, stage2_path

In [5]:
# -----------------------------------------------
# 3. Utility Functions & Custom Classes
# -----------------------------------------------

# --- Part A: Data Augmentation ---

# üì¶ Applies augmentations and processes images on-the-fly for each batch.
# This is a more robust approach than pre-processing the entire dataset.
class DataCollatorWithAugmentation:
    def __init__(self,
                 processor,
                 augment_dict=None,
                 base_augment=None,
                 # --- NEW: tensor-level erasing controls (applied after processor) ---
                 random_erasing_prob: float = 0.10,
                 random_erasing_scale = (0.02, 0.08),
                 skip_erasing_label_ids=None):
        
        """
        Args:
            processor: HF image processor that yields pixel_value tensors
            augment_dict: dict[int label_id -> PIL transform], class-specific
            base_augment: fallback PIL transform when class-specific not found
            random_erasing_prob: probability for applying tensor-level RandomErasing
            random_erasing_scale: area range for erasing region
            skip_erasing_label_ids: iterable of label ids to skip erasing for
        """
        self.processor = processor
        self.augment_dict = augment_dict or {}
        # Baseline augmentation for majority classes.
        self.base_augment = base_augment or T.Compose([T.Resize((224, 224))])

        # --- NEW: tensor-level RandomErasing (applied AFTER processor) ---
        # Keep None to disable; expects CHW tensors in [0,1]
        self.random_erasing = (
            T.RandomErasing(p=random_erasing_prob, scale=random_erasing_scale, value="random")
            if random_erasing_prob and random_erasing_prob > 0.0 else None
        )
                
        # --- NEW: define tensor <-> PIL helpers used in __call__ ---
        self.to_tensor = T.ToTensor()
        self.to_pil = T.ToPILImage()
        
        # Labels to skip erasing for (can be overridden when constructing the collator)
        self.skip_erasing_label_ids = set()  # e.g., {label2id_s2['sadness'], ...}
        
    def __call__(self, features):
        processed_images = []
        for x in features:
            label = x["label"]
            rgb_image = x["image"].convert("RGB")

            # 1) apply class-specific PIL pipeline if present; else base PIL pipeline
            pil_aug = self.augment_dict.get(label, self.base_augment)

            img = pil_aug(rgb_image)

            # ‚¨áÔ∏è INSERT THE NEW LINES HERE
            # --- Tensor-level RandomErasing ---
            img_t = self.to_tensor(img)                 # PIL ‚Üí Tensor [C,H,W]
            if self.random_erasing is not None and label not in self.skip_erasing_label_ids:
                img_t = self.random_erasing(img_t)      # RandomErasing on tensor
            img = self.to_pil(img_t)  
        
            processed_images.append(img)

        batch = self.processor(images=processed_images, return_tensors="pt")
        batch["labels"] = torch.tensor([x["label"] for x in features], dtype=torch.long)
        return batch

# --- normalize any image-like object to 3-channel RGB (PIL) ---
def _ensure_rgb(img):
    # If already PIL, force RGB mode
    if isinstance(img, Image.Image):
        return img.convert("RGB")
    # Else coerce to array and expand grayscale to 3 channels
    arr = np.array(img)
    if arr.ndim == 2:
        arr = np.stack([arr, arr, arr], axis=-1)
    return Image.fromarray(arr.astype(np.uint8))


# --- Part B: Model & Training Components ---

# üèãÔ∏è Defines a custom Trainer that can use either a targeted loss function or class weights.
class CustomLossTrainer(Trainer):
    def __init__(self, *args, loss_fct=None, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fct = loss_fct
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        
        if self.loss_fct:
            # Stage 2 uses the custom targeted smoothing loss
            loss = self.loss_fct(logits, labels)
        else:
            # Stage 1 uses standard CrossEntropyLoss with class weights (all on CPU)
            loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
            loss = loss_fct(logits, labels)
            
        return (loss, outputs) if return_outputs else loss


# üîÑ Implements Cross-Entropy Loss with *Targeted* Label Smoothing.
# Smoothing is turned OFF for specified classes to encourage confident predictions. This is used for Stage 2.
class TargetedSmoothedCrossEntropyLoss(nn.Module):
    def __init__(self, smoothing=0.05, target_class_names=None, label2id_map=None, focal_gamma=None):
        super().__init__()
        self.smoothing = smoothing
        self.focal_gamma = focal_gamma  # NEW (None disables focal scaling)
        if target_class_names and label2id_map:
            self.target_class_ids = [label2id_map[name] for name in target_class_names]
        else:
            self.target_class_ids = []

    def forward(self, logits, target):
        num_classes = logits.size(1)
        with torch.no_grad():
            smooth_labels = torch.full_like(logits, self.smoothing / (num_classes - 1))
            smooth_labels.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
            if self.target_class_ids:
                target_mask = torch.isin(target, torch.tensor(self.target_class_ids, device=target.device))
                if target_mask.any():
                    sharp_labels = F.one_hot(target[target_mask], num_classes=num_classes).float()
                    smooth_labels[target_mask] = sharp_labels

        log_probs = F.log_softmax(logits, dim=1)
        ce_per_sample = -(smooth_labels * log_probs).sum(dim=1)

        # NEW: optional focal scaling
        if self.focal_gamma is not None and self.focal_gamma > 0:
            with torch.no_grad():
                probs = torch.softmax(logits, dim=1)
                pt = (probs * smooth_labels).sum(dim=1).clamp_min(1e-6)
            ce_per_sample = ((1 - pt) ** self.focal_gamma) * ce_per_sample

        return ce_per_sample.mean()

# ------------------------------------------------------------------------------
# Stage 1 loss function: focal-modulated cross-entropy (relevant-only)
#   - We keep class weights for imbalance handling.
#   - We add focal modulation ONLY when the ground truth is "relevant"
#     to emphasize difficult positives without exploding FP on easy negatives.
# ------------------------------------------------------------------------------
class RelevantFocalCrossEntropy(torch.nn.Module):
    def __init__(self, class_weights: torch.Tensor, gamma: float = 2.0, relevant_id: int = 1):
        """
        Args:
            class_weights: Tensor of per-class weights (size 2 for S1)
            gamma: focal exponent (higher -> more emphasis on hard examples)
            relevant_id: integer id for the 'relevant' class
        """
        super().__init__()
        self.ce = torch.nn.CrossEntropyLoss(weight=class_weights, reduction="none")
        self.gamma = gamma
        self.relevant_id = relevant_id

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Computes cross-entropy per-sample, then applies focal scaling only
        for samples whose target == 'relevant'. Non-relevant samples keep vanilla CE.
        """
        # base cross-entropy (per-sample)
        ce = self.ce(logits, targets)  # shape: [B]

        # compute p_t = softmax(logits)[range(B), targets]
        with torch.no_grad():
            probs = torch.softmax(logits, dim=-1)
            p_t = probs[torch.arange(probs.size(0)), targets]  # [B]

        # mask: 1 for relevant targets, 0 otherwise
        mask = (targets == self.relevant_id).float()

        # focal factor: (1 - p_t)^gamma for relevant samples; 1.0 for others
        focal = (1.0 - p_t).pow(self.gamma) * mask + (1.0 - mask)

        # mean reduced loss
        return (focal * ce).mean()


# --- Part C: Metrics & Evaluation ---

# üìä Computes metrics and generates a confusion matrix plot for each evaluation step.
def compute_metrics_with_confusion(eval_pred, label_names, stage_name=""):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)

    print(f"\nüìà Classification Report for {stage_name}:")
    report = classification_report(labels, preds, target_names=label_names, output_dict=True, zero_division=0)
    print(classification_report(labels, preds, target_names=label_names, zero_division=0))

    # Save raw logits/labels for later analysis like temperature scaling
    np.save(os.path.join(SAVE_DIR, f"logits_eval_{stage_name}_{VERSION}.npy"), logits)
    np.save(os.path.join(SAVE_DIR, f"labels_eval_{stage_name}_{VERSION}.npy"), labels)

    # --- Re-integrated from V28 ---
    # Save per-class F1/precision/recall/entropy to CSV (append per epoch)
    f1s = [report[name]["f1-score"] for name in label_names]
    recalls = [report[name]["recall"] for name in label_names]
    precisions = [report[name]["precision"] for name in label_names]

    # Entropy per class (sorted by entropy)
    softmax_probs = F.softmax(torch.tensor(logits), dim=-1)
    entropies = -torch.sum(softmax_probs * torch.log(softmax_probs + 1e-12), dim=-1)
    entropy_per_class = []
    for idx, class_name in enumerate(label_names):
        mask = (np.array(labels) == idx)
        if mask.any():
            class_entropy = entropies[mask].mean().item()
            entropy_per_class.append((class_name, class_entropy))
        else:
            entropy_per_class.append((class_name, 0.0))
    
    # Create a dictionary for entropies in the correct order for the CSV
    entropy_dict = dict(entropy_per_class)

    # CSV logging
    epoch_metrics_path = os.path.join(SAVE_DIR, f"per_class_metrics_{stage_name}.csv")
    # Access the trainer instance through its global-like availability during compute_metrics call
    active_trainer = trainer_s1 if stage_name == "Stage1" else trainer_s2
    epoch = getattr(active_trainer.state, "epoch", None)

    df_row = pd.DataFrame({
        "epoch": [epoch],
        **{f"f1_{n}": [f] for n, f in zip(label_names, f1s)},
        **{f"recall_{n}": [r] for n, r in zip(label_names, recalls)},
        **{f"precision_{n}": [p] for n, p in zip(label_names, precisions)},
        **{f"entropy_{n}": [entropy_dict[n]] for n in label_names}
    })
    
    if os.path.exists(epoch_metrics_path):
        df_row.to_csv(epoch_metrics_path, mode="a", header=False, index=False)
    else:
        df_row.to_csv(epoch_metrics_path, mode="w", header=True, index=False)
    # --- End Re-integration ---

    # Generate and save a heatmap of the confusion matrix
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=label_names, yticklabels=label_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix - {stage_name}")
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, f"confusion_matrix_{stage_name}_{VERSION}.png"))
    plt.close()

    # --- Re-integrated from V28 ---
    # Top confused pairs
    confusion_pairs = [
        ((label_names[i], label_names[j]), cm[i][j])
        for i in range(len(label_names))
        for j in range(len(label_names)) if i != j and cm[i][j] > 0
    ]
    top_confusions = sorted(confusion_pairs, key=lambda x: x[1], reverse=True)[:3]
    if top_confusions:
        print("\nTop 3 confused class pairs:")
        for (true_label, pred_label), count in top_confusions:
            print(f"  - {true_label} ‚Üí {pred_label}: {count} instances")

    # Compute and print entropy metrics
    avg_entropy = entropies.mean().item()
    print(f"\nüß† Avg prediction entropy: {avg_entropy:.4f}")

    sorted_entropy = sorted(entropy_per_class, key=lambda x: x[1], reverse=True)
    if sorted_entropy:
        print("\nüîç Class entropies (sorted):")
        for class_name, entropy in sorted_entropy:
            print(f"  - {class_name}: entropy = {entropy:.4f}")
    # --- End Re-integration ---
    
    accuracy = (preds == labels).mean()
    return {"accuracy": accuracy}


# ------------------------------------------------------------------------------
# Stage 1: Temperature scaling + threshold (œÑ) sweep
#   - Fit a single scalar T on eval logits (minimize NLL) to calibrate probabilities.
#   - Sweep œÑ in [0.30, 0.55] to pick the value that maximizes F1(relevant).
#   - Persist T and œÑ for hierarchical inference.
# ------------------------------------------------------------------------------
def fit_temperature(model, eval_ds, processor, device):
    """
    Fits a single temperature scalar T by minimizing NLL on eval set.
    Returns:
        float: learned temperature T (>= ~1e-3)
    """
    model.eval()
    logits_list, labels_list = [], []
    with torch.no_grad():
        #Normalize every eval image to 3-channel RGB in fit_temperature
        for ex in eval_ds:
            img, lab = ex["image"], int(ex["label"])
        
            # --- Ensure 3-channel RGB for the processor ---
            # If PIL: convert directly; if numpy/other: coerce to array and expand gray to 3-channels
            if isinstance(img, Image.Image):
                img = img.convert("RGB")
            else:
                arr = np.array(img)
                if arr.ndim == 2:                      # grayscale -> stack to RGB
                    arr = np.stack([arr, arr, arr], axis=-1)
                img = Image.fromarray(arr.astype(np.uint8))  # ensure PIL RGB
        
            inputs = processor(images=img, return_tensors="pt").to(device)
            logits = model(**inputs).logits
            logits_list.append(logits.cpu())
            labels_list.append(lab)

    logits = torch.cat(logits_list, dim=0)  # [N, 2]
    labels = torch.tensor(labels_list)

    T = torch.nn.Parameter(torch.ones(1))
    opt = torch.optim.LBFGS([T], lr=0.1, max_iter=50)
    ce = torch.nn.CrossEntropyLoss()

    def _closure():
        """
        LBFGS closure for temperature scaling:
        Scales logits by 1/T, computes CE loss, backprops to adjust T.
        """
        opt.zero_grad()
        scaled = logits / T.clamp(min=1e-3)
        loss = ce(scaled, labels)
        loss.backward()
        return loss

    opt.step(_closure)
    return float(T.data.item())

def sweep_tau(model, eval_ds, processor, device, T=1.0):
    """
    Sweeps œÑ (threshold on P(relevant)) over [0.30, 0.55] to maximize F1(relevant).
    Returns:
        dict: {'tau', 'f1', 'prec', 'rec'} with 3-decimal rounding for logging.
    """
    import numpy as np
    model.eval()
    y_true, y_prob = [], []
    with torch.no_grad():
        for ex in eval_ds:
            img, lab = ex["image"], int(ex["label"])
    
            # Normalize to 3-channel RGB to avoid ndim==2 errors
            img = _ensure_rgb(img)
    
            inputs = processor(images=img, return_tensors="pt").to(device)
            logits = model(**inputs).logits / max(T, 1e-3)  # robustness vs tiny T
            prob_rel = torch.softmax(logits, dim=-1)[0, label2id_s1['relevant']].item()
            y_true.append(lab == label2id_s1['relevant'])
            y_prob.append(prob_rel)
    
    y_true = np.array(y_true, dtype=bool)
    y_prob = np.array(y_prob, dtype=float)


    best = {"tau": None, "f1": -1.0, "prec": None, "rec": None}
    for tau in np.linspace(0.30, 0.55, 26):
        pred = (y_prob >= tau)
        tp = ((pred == 1) & (y_true == 1)).sum()
        fp = ((pred == 1) & (y_true == 0)).sum()
        fn = ((pred == 0) & (y_true == 1)).sum()
        prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1   = 2*prec*rec/(prec+rec) if (prec+rec) > 0 else 0.0
        if f1 > best["f1"]:
            best = {"tau": round(float(tau), 3),
                    "f1": round(float(f1), 3),
                    "prec": round(float(prec), 3),
                    "rec": round(float(rec), 3)}
    return best
    

# --- Part D: Model Saving ---

# üíæ Saves the model and its associated processor to a specified directory.
def save_model_and_processor(model, processor, save_dir, model_name):
    print(f"üíæ Saving {model_name} and processor to: {save_dir}")
    model_path = os.path.join(save_dir, model_name)
    os.makedirs(model_path, exist_ok=True)
    model = model.to("cpu")
    processor.save_pretrained(model_path)
    model.save_pretrained(model_path, safe_serialization=True)
    print(f"‚úÖ {model_name} saved successfully.")


# --- Part E: Post-Training Analysis ---
# ==========================================================================
#   POST-TRAINING ANALYSIS UTILITIES (OFFLINE / OPTIONAL)
#   - Qualitative error bucketing (QE)
#   - Attention rollout (XAI) for S1 inspection
#   - Ablation helpers
# ==========================================================================

def check_deployment_readiness(metrics_csv_path, f1_threshold=0.80):
    """Analyzes the final metrics CSV to check for production readiness."""
    print("\n" + "="*60)
    print("  DEPLOYMENT READINESS CHECK")
    print("="*60)
    
    if not os.path.exists(metrics_csv_path):
        print(f"‚ö†Ô∏è Metrics file not found at: {metrics_csv_path}")
        return

    metrics_df = pd.read_csv(metrics_csv_path)
    last_epoch_metrics = metrics_df.iloc[-1]
    
    label_names = [col.replace("f1_", "") for col in metrics_df.columns if col.startswith("f1_")]
    
    print(f"Threshold: F1-Score >= {f1_threshold}\n")
    
    issues_found = False
    for label in label_names:
        f1_score = last_epoch_metrics.get(f"f1_{label}", 0)
        if f1_score < f1_threshold:
            print(f"  - ‚ùå {label:<15} | F1-Score: {f1_score:.2f} (Below Threshold)")
            issues_found = True
        else:
            print(f"  - ‚úÖ {label:<15} | F1-Score: {f1_score:.2f}")
            
    if issues_found:
        print("\n Model is NOT ready for production.")
    else:
        print("\n Model meets the minimum F1-score threshold for all classes.")

# --- Qualitative Error Bucketing (Stage 1) ---
# Scans an inference CSV and tags each row with simple visual heuristics:
# blur/shadow/occlusion/low-res. Outputs a QE report CSV for targeting data fixes.
def variance_of_laplacian(gray):
    return cv2.Laplacian(gray, cv2.CV_64F).var()

def is_dark(img_pil, thresh=40):
    stat = ImageStat.Stat(img_pil.convert("L"))
    return stat.mean[0] < thresh

def qualitative_buckets_s1(inference_csv, out_csv):
    import pandas as pd
    df = pd.read_csv(inference_csv)
    # consider only S1 mistakes if you logged them; otherwise filter low conf or S2 mismatches
    rows = []
    for _, r in df.iterrows():
        path = r['filepath']
        if not os.path.exists(path): continue
        img = Image.open(path).convert("RGB")
        arr = np.array(img)
        gray = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY)
        blur = variance_of_laplacian(gray) < 60         # motion blur proxy
        dark = is_dark(img, thresh=45)                  # shadows proxy
        lowres = min(img.size) < 80
        # Cheap occlusion proxy: large random erasing candidate on face area would help, but without faces we use entropy
        ent = cv2.calcHist([gray],[0],None,[256],[0,256]).flatten()
        ent = -np.sum((ent/ent.sum()+1e-9)*np.log2(ent/ent.sum()+1e-9))
        occl = ent < 4.5                                 # low entropy proxy
        rows.append([path, r.get('true_label','?'), r.get('predicted_label','?'), r.get('confidence',np.nan),
                     int(blur), int(dark), int(occl), int(lowres)])
    with open(out_csv, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["filepath","true","pred","conf","blur","shadow","occlusion","lowres"])
        w.writerows(rows)
    return out_csv

# --- Ablation summary utility for Stage 1 ---
# Summarizes precision/recall/F1 for S1 given (T, tau).
def summarize_s1(eval_ds, model, processor, device, T: float, tau: float):
    import numpy as np
    y_true, y_prob = [], []
    model.eval()
    with torch.no_grad():
        for ex in eval_ds:
            img, lab = ex["image"], int(ex["label"])
    
            # Normalize to 3-channel RGB to avoid ndim==2 errors
            img = _ensure_rgb(img)
    
            logits = model(**processor(images=img, return_tensors="pt").to(device)).logits
            logits = logits / max(T, 1e-3)
            p = torch.softmax(logits, dim=-1)[0, label2id_s1['relevant']].item()
            y_true.append(lab == label2id_s1['relevant'])
            y_prob.append(p)

    y_true = np.array(y_true, bool); y_prob = np.array(y_prob, float)
    pred = (y_prob >= tau)
    tp = ((pred==1)&(y_true==1)).sum(); fp=((pred==1)&(y_true==0)).sum(); fn=((pred==0)&(y_true==1)).sum()
    prec = tp/(tp+fp) if tp+fp>0 else 0.0
    rec  = tp/(tp+fn) if tp+fn>0 else 0.0
    f1   = 2*prec*rec/(prec+rec) if prec+rec>0 else 0.0
    return {"precision":round(prec,3), "recall":round(rec,3), "f1":round(f1,3), "tau":tau, "T":T}


# --- Attention Rollout heatmaps for ViT (offline) ---
def vit_attention_rollout(model, inputs, discard_ratio=0.9):
    # returns a [H,W] mask normalized 0..1; you can overlay it
    # (Implementation omitted for brevity; use a standard attention-rollout snippet for ViT)
    pass

In [6]:
# --------------------------
# 4. Main Training Script
# --------------------------

def main(device):
    # Make trainer objects accessible to metrics function
    global trainer_s1, trainer_s2
    
    # --- Sanity Check for Checkpoint Path ---
    if not os.path.exists(PRETRAINED_CHECKPOINT_PATH):
        raise FileNotFoundError(f"Fatal: Pretrained checkpoint not found at {PRETRAINED_CHECKPOINT_PATH}")

    # --- Define specific model paths from the latest checkpoint ---
    s1_checkpoint_path = os.path.join(PRETRAINED_CHECKPOINT_PATH, "relevance_filter_model")
    s2_checkpoint_path = os.path.join(PRETRAINED_CHECKPOINT_PATH, "emotion_classifier_model")

    # The device is now passed in, so the local definition is removed.
    print(f"\nüñ•Ô∏è Using device: {device}")

    # --- Step 0: Prepare Datasets ---
    # This function copies files into the required two-stage structure.
    # It only needs to be run once.
    prepared_data_path = os.path.join(OUTPUT_ROOT_DIR, "prepared_datasets")
    if PREPARE_DATASETS:
        stage1_dataset_path, stage2_dataset_path = prepare_hierarchical_datasets(BASE_DATASET_PATH, prepared_data_path)
    else:
        stage1_dataset_path = os.path.join(prepared_data_path, "stage_1_relevance_dataset")
        stage2_dataset_path = os.path.join(prepared_data_path, "stage_2_emotion_dataset")
        print("‚úÖ Skipping dataset preparation, using existing directories.")
    
    # # --- Set hardware device ---
    # # commented out due to present mps and pytorch incompatibilities
    # device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    # print(f"\nüñ•Ô∏è Using device: {device}")

    # ==========================================================================
    #   STAGE 1: TRAIN RELEVANCE FILTER (BINARY CLASSIFIER)
    # ==========================================================================
    print("\n" + "="*60)
    print("  STAGE 1: TRAINING RELEVANCE FILTER (BINARY CLASSIFIER)")
    print("="*60)

    # --- Load Stage 1 data ---
    stage1_output_dir = os.path.join(SAVE_DIR, "stage_1_relevance_model_training")
    dataset_s1 = load_dataset("imagefolder", data_dir=stage1_dataset_path, split='train').train_test_split(test_size=0.2, seed=42)
    train_dataset_s1 = dataset_s1["train"]
    eval_dataset_s1 = dataset_s1["test"]
    print(f"Stage 1: {len(train_dataset_s1)} training samples, {len(eval_dataset_s1)} validation samples.")

    # --- Configure Stage 1 model ---
    # We load the base processor once.
    processor = AutoImageProcessor.from_pretrained(BASE_MODEL_NAME)
    # Load the pretrained checkpoint but replace the final layer (classifier head)
    # for our binary (2-label) task.
    model_s1 = ViTForImageClassification.from_pretrained(
        s1_checkpoint_path, # <-- Use the specific path for the Stage 1 model
        num_labels=2,
        label2id=label2id_s1,
        id2label=id2label_s1,
        ignore_mismatched_sizes=True
    ).to(device)

    # --- Handle Extreme Class Imbalance in Stage 1 with Class Weights ---
    # This is critical because the 'irrelevant' class is much larger than the 'relevant' class.
    class_weights_s1 = compute_class_weight('balanced', classes=np.unique(train_dataset_s1['label']), y=train_dataset_s1['label'])
    class_weights_s1 = torch.tensor(class_weights_s1, dtype=torch.float).to(device)
    print(f"‚öñÔ∏è Stage 1 Class Weights: {class_weights_s1}")

    # --- Define Early Stopping ---
    # Stops training if validation loss doesn't improve for 2 consecutive epochs
    early_stop_callback = EarlyStoppingCallback(
        early_stopping_patience=2,
        early_stopping_threshold=0.001
    )
    
    # --- Set up Stage 1 Trainer ---
    training_args_s1 = TrainingArguments(
        output_dir=stage1_output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        use_cpu=True,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=5,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        logging_dir=os.path.join(stage1_output_dir, "logs"),
        logging_strategy="steps",
        logging_steps=50,
        remove_unused_columns=False,
    )

    # --- Set up Stage 1 Trainer ---
    # The complex discriminative learning rate and layer freezing strategy in 
        # V31 caused a severe performance drop. This change reverts Stage 1 to 
        # V30's simpler and more effective approach of using a single, uniform 
        # learning rate for the entire model, which is managed by the Hugging 
        # Face Trainer's default optimizer.
    training_args_s1.learning_rate = 3e-5 # Set learning rate directly
    
    loss_fct_s1 = RelevantFocalCrossEntropy(
        class_weights=class_weights_s1,   # <-- we KEEP and USE class_weights here
        gamma=2.0,
        relevant_id=label2id_s1['relevant']
    )

    strong_pos_aug = T.Compose([
        T.RandomResizedCrop(224, scale=(0.8, 1.0)),     # encourage scale & crop robustness
        T.RandomHorizontalFlip(),
        T.ColorJitter(0.2, 0.2, 0.2, 0.05),             # moderate color/lighting jitter
        T.RandomPerspective(distortion_scale=0.05, p=0.2),
    ])

    # Map label id -> transform. 1 == relevant
    augment_map_s1 = { label2id_s1['relevant']: strong_pos_aug }

    # Use the flexible CustomLossTrainer, passing the class weights to it.
    # Apply stronger augmentation ONLY to the "relevant" class to expand coverage
        # near the decision boundary (lighting, small occlusions, slight perspective).
        # Keep "irrelevant" mild as before to avoid over-creating near-face artifacts.
    trainer_s1 = CustomLossTrainer(
        model=model_s1,
        args=training_args_s1,
        train_dataset=train_dataset_s1,
        eval_dataset=eval_dataset_s1,
        compute_metrics=partial(compute_metrics_with_confusion, 
                                label_names=list(id2label_s1.values()), 
                                stage_name="Stage1"),
        data_collator=DataCollatorWithAugmentation(
            processor=processor,
            augment_dict=augment_map_s1,                # your existing class‚ÜíPIL map
            random_erasing_prob=0.10,                   # enable erasing
            random_erasing_scale=(0.02, 0.08),
            skip_erasing_label_ids=[]                   # or [label2id_s1['relevant']] to skip
        ),
        loss_fct=loss_fct_s1,         # <-- NEW: custom loss uses class_weights + focal on relevant
        callbacks=[early_stop_callback]
    )

    # --- Train Stage 1 model ---
    print("üöÄ Starting Stage 1 training...")
    start_time_s1 = time.time() # Record start time
    trainer_s1.train()
    end_time_s1 = time.time()   # Record end time
    
    # Calculate and print the duration
    duration_s1 = end_time_s1 - start_time_s1
    print(f"‚åõ Stage 1 training took: {time.strftime('%H:%M:%S', time.gmtime(duration_s1))}")
    save_model_and_processor(trainer_s1.model, processor, SAVE_DIR, model_name="relevance_filter_model")
    print("\n‚úÖ Stage 1 Training Complete.")

    
    # ------------------------------------------------------------------------------
    # Stage 1: Temperature scaling + threshold (œÑ) sweep
    #   - Fit a single scalar T on eval logits (minimize NLL) to calibrate probabilities.
    #   - Sweep œÑ in [0.30, 0.55] to pick the value that maximizes F1(relevant).
    #   - Persist T and œÑ for hierarchical inference.
    # ------------------------------------------------------------------------------

    print("\nüß™ Calibrating Stage 1...")
    T_s1 = fit_temperature(trainer_s1.model, eval_dataset_s1, processor, device)
    best_s1 = sweep_tau(trainer_s1.model, eval_dataset_s1, processor, device, T=T_s1)
    print(f"‚úÖ S1 calibration done: T={T_s1:.3f} | best œÑ={best_s1['tau']} | F1={best_s1['f1']} (P={best_s1['prec']}, R={best_s1['rec']})")
    
    # Persist calibration for inference
    with open(os.path.join(SAVE_DIR, "stage1_calibration.json"), "w") as f:
        json.dump({"T": T_s1, "tau": best_s1["tau"]}, f)


    # ==========================================================================
    #   STAGE 2: TRAIN EMOTION CLASSIFIER (11-CLASS)
    # ==========================================================================
    print("\n" + "="*60)
    print(f"  STAGE 2: TRAINING EMOTION CLASSIFIER ({len(RELEVANT_CLASSES)}-CLASS)")
    print("="*60)

    # --- Load Stage 2 data ---
    stage2_output_dir = os.path.join(SAVE_DIR, "stage_2_emotion_model_training")
    dataset_s2 = load_dataset("imagefolder", data_dir=stage2_dataset_path, split='train').train_test_split(test_size=0.2, seed=42)
    train_dataset_s2 = dataset_s2["train"]
    eval_dataset_s2 = dataset_s2["test"]
    print(f"Stage 2: {len(train_dataset_s2)} training samples, {len(eval_dataset_s2)} validation samples.")
    print("Stage 2 Label Distribution (Train):", Counter(train_dataset_s2['label']))


    # --- Configure Stage 2 model ---
    # Load the pretrained checkpoint again, this time with a classifier head for our 11 emotion classes.
    model_s2 = ViTForImageClassification.from_pretrained(
        s2_checkpoint_path, # <-- Use the specific path for the Stage 2 model
        num_labels=len(RELEVANT_CLASSES),
        label2id=label2id_s2,
        id2label=id2label_s2,
        ignore_mismatched_sizes=True
    ).to(device)

    # --- Define Augmentation and Loss for Stage 2 ---
    # Apply stronger augmentation to the minority classes to help the model learn them better.
    minority_aug = T.Compose([
        RandAugment(num_ops=2, magnitude=11),  # was 9
        T.RandomResizedCrop(224, scale=(0.7, 1.0)),
        T.ColorJitter(0.3, 0.3, 0.3, 0.1),
    ])

    # The addition of 'sadness' and 'speech_action' to the heavy augmentation pipeline in V31 
        # was counterproductive, causing the F1-scores for these classes to collapse. 
        # This change reverts the list to the V30 definition, removing the aggressive 
        # augmentation from the classes it harmed.
    minority_classes_s2 = [label2id_s2[name] for name in ['disgust', 'questioning', 'contempt', 'fear']]
    minority_augment_map_s2 = {label_id: minority_aug for label_id in minority_classes_s2}

    # NEW: very mild, targeted aug ONLY for the weakest classes (no RandAugment)
    mild_aug = T.Compose([
        T.RandomResizedCrop(224, scale=(0.9, 1.0)),
        T.RandomHorizontalFlip(),
        T.ColorJitter(0.05, 0.05, 0.05, 0.02),
        T.RandomPerspective(distortion_scale=0.05, p=0.3),
    ])

    # --- NEW: targeted mild augmentation for fragile classes
    #     - Keep 'sadness' and 'speech_action' on very mild pipeline (no RandAug)
    #     - Extend to 'neutral_speech' to preserve subtle mouth/phoneme cues
    targeted_mild_classes = [
        label2id_s2['sadness'],
        label2id_s2['speech_action'],
        label2id_s2['neutral_speech']   # <-- NEW
    ]
    targeted_mild_map_s2 = {label_id: mild_aug for label_id in targeted_mild_classes}

    # MERGE: single mapping passed to the collator (class id -> transform)
    augment_dict = {**minority_augment_map_s2, **targeted_mild_map_s2}

    # Use the custom loss function to turn off label smoothing for historically difficult classes.
        # Turn OFF smoothing for the hardest classes (sharper targets) and apply mild focal emphasis
        # Stage 2 loss: slightly softer focal gamma for fragile classes
        # Reduces over-focus; improves probability calibration a bit.
    loss_fct_s2 = TargetedSmoothedCrossEntropyLoss(
        smoothing=0.05,
        target_class_names=['sadness', 'speech_action'],
        label2id_map=label2id_s2,
        focal_gamma=1.2   # <-- slightly softer focal
    )

    # --- Set up Stage 2 Trainer ---
    # Adding weight decay, cosine scheduler + warmup, grad accumulation improves stability 
        # (especially on CPU/small batch) without altering your high-level flow.
    training_args_s2 = TrainingArguments(
        output_dir=stage2_output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        use_cpu=True,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=6,                       # +1 epoch for minorities
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        logging_dir=os.path.join(stage2_output_dir, "logs"),
        logging_strategy="epoch",
        remove_unused_columns=False,
        weight_decay=0.05,                        # NEW
        lr_scheduler_type="cosine",               # NEW
        warmup_ratio=0.10,                        # NEW
        gradient_accumulation_steps=2,            # NEW
    )

    # --- Set up Stage 2 Trainer ---
    # As with Stage 1, the complex fine-tuning strategy implemented in V31 failed. 
        # This change reverts the Stage 2 training process to V30's more effective 
        # uniform learning rate strategy to restore model performance.
    training_args_s2.learning_rate = 4e-5 # Set learning rate directly

    # skip erasing for fragile classes like sadness and neutral_speech
    fragile_ids = [
        label2id_s2['sadness'],
        label2id_s2['neutral_speech']
    ]
    # Use the CustomLossTrainer again, passing the targeted loss function.
    trainer_s2 = CustomLossTrainer(
        model=model_s2,
        args=training_args_s2,
        train_dataset=train_dataset_s2,
        eval_dataset=eval_dataset_s2,
        compute_metrics=partial(compute_metrics_with_confusion, 
                                label_names=RELEVANT_CLASSES, 
                                stage_name="Stage2"),
        data_collator=DataCollatorWithAugmentation(
            processor=processor,
            augment_dict=augment_dict,                  # your merged S2 map
            random_erasing_prob=0.10,
            random_erasing_scale=(0.02, 0.08),
            skip_erasing_label_ids=fragile_ids          # <-- skip erasing for fragile classes
        ),
        loss_fct=loss_fct_s2, # Pass custom loss function
        callbacks=[early_stop_callback] # Keep early stopping
    )

    # --- Train Stage 2 model ---
    print("üöÄ Starting Stage 2 training...")
    start_time_s2 = time.time() # Record start time
    trainer_s2.train()
    end_time_s2 = time.time()   # Record end time
    
    # Calculate and print the duration
    duration_s2 = end_time_s2 - start_time_s2
    print(f"‚åõ Stage 2 training took: {time.strftime('%H:%M:%S', time.gmtime(duration_s2))}")
    save_model_and_processor(trainer_s2.model, processor, SAVE_DIR, model_name="emotion_classifier_model")
    print("\n‚úÖ Stage 2 Training Complete.")
    print("\nüéâ Hierarchical Training Pipeline Finished Successfully.")

    
    # Return the trained models and processor to be used by analysis functions
    return trainer_s1.model, trainer_s2.model, processor

In [7]:
# ----------------------------------
# 5. Hierarchical Inference
# ----------------------------------
# This function defines the two-step prediction pipeline for new images.
# It first checks for relevance (Stage 1) and then classifies the emotion (Stage 2).

def hierarchical_predict(image_paths, model_s1, model_s2, processor, device, batch_size=32):
    results = []
    for i in tqdm(range(0, len(image_paths), batch_size), desc="üî¨ Running Hierarchical Inference"):
        batch_paths = image_paths[i:i+batch_size]
        images = []
        valid_paths = []
        for path in batch_paths:
            try:
                img = Image.open(path).convert("RGB")
                images.append(img)
                valid_paths.append(path)
            except Exception:
                continue

        if not images:
            continue

        inputs = processor(images=images, return_tensors="pt").to(device)
        
        # --- Stage 1 Prediction: 
        # apply learned T and œÑ; fall back safely if file missing)
        # --- Apply Stage-1 temperature scaling + threshold from saved calibration ---
        calib_path = os.path.join(SAVE_DIR, "stage1_calibration.json")
        T_s1, tau = 1.0, 0.45  # safe defaults
        if os.path.exists(calib_path):
            try:
                with open(calib_path, "r") as f:
                    _c = json.load(f)
                    T_s1 = float(_c.get("T", 1.0))
                    tau  = float(_c.get("tau", 0.45))
            except Exception:
                pass
        
        with torch.no_grad():
            logits_s1 = model_s1(**inputs).logits / max(T_s1, 1e-3)  # temperature scaling
            probs_s1 = F.softmax(logits_s1, dim=-1)
        
        # Create a mask of images that were classified as 'relevant'
            # Gate on calibrated œÑ
        relevant_mask = (probs_s1[:, label2id_s1['relevant']] >= tau)
        dev = logits_s1.device
        preds_s1 = torch.where(
            relevant_mask,
            torch.tensor(label2id_s1['relevant'], device=dev, dtype=torch.long),
            torch.tensor(label2id_s1['irrelevant'], device=dev, dtype=torch.long)
        )
        
        # --- Stage 2 Prediction (only on relevant images) ---
        if relevant_mask.any():
            # Filter the input tensors to only include the relevant images
            relevant_inputs = {k: v[relevant_mask] for k, v in inputs.items()}

            with torch.no_grad():
                logits_s2 = model_s2(**relevant_inputs).logits
                probs_s2 = F.softmax(logits_s2, dim=-1)
                confs_s2, preds_s2 = torch.max(probs_s2, dim=-1)

        # --- Aggregate Results ---
        # Loop through the original batch and assign the correct prediction
        s2_idx = 0
        for j in range(len(valid_paths)):
            if relevant_mask[j]:
                # If relevant, get the prediction from the Stage 2 model
                pred_label = id2label_s2[preds_s2[s2_idx].item()]
                confidence = confs_s2[s2_idx].item()
                s2_idx += 1
            else:
                # If not relevant, label it and stop
                pred_label = "irrelevant"
                confidence = torch.softmax(logits_s1[j], dim=-1)[preds_s1[j]].item()

            results.append({
                "image_path": valid_paths[j],
                "prediction": pred_label,
                "confidence": confidence
            })
    return results

In [16]:
# ==============================================================================
# 6. Post-Training Analysis, Review, and Curation
# ==============================================================================

def run_post_training_analysis(model_s1, model_s2, processor, device, base_dataset_path, save_dir, version):
    """
    Runs a full inference pass and generates logs for review, curation, and analysis.
    Combines logic from old sections 15 and 16.
    """
    import pandas as pd   # ensure pd is local; prevents UnboundLocalError in notebooks
    
    print("\n" + "="*60)
    print("  RUNNING POST-TRAINING ANALYSIS & CURATION WORKFLOW")
    print("="*60)

    # --- Part A: Run Hierarchical Inference on the Entire Dataset ---
    all_image_paths = [str(p) for p in Path(base_dataset_path).rglob("*") if is_valid_image(p.name)]
    print(f"Found {len(all_image_paths)} images to process for inference.")
    
    predictions = hierarchical_predict(all_image_paths, model_s1, model_s2, processor, device)
    df = pd.DataFrame(predictions)
    
    # Derive true label from path for analysis
    df['true_label'] = df['image_path'].apply(lambda p: Path(p).parent.name)

    # Save the full log
    full_log_path = os.path.join(save_dir, f"{version}_full_inference_log.csv")
    df.to_csv(full_log_path, index=False)
    print(f"\n‚úÖ Full inference log saved to: {full_log_path}")

    # --- Part B: Identify and Organize Images for Manual Review ---
    # Tag images with low confidence as "REVIEW"
    review_threshold = REVIEW_CONF_THRESHOLD
    review_df = df[df['confidence'] < review_threshold]
    
    review_sort_dir = os.path.join(save_dir, "review_candidates_by_predicted_class")
    os.makedirs(review_sort_dir, exist_ok=True)
    
    print(f"\nFound {len(review_df)} images below {review_threshold} confidence for review.")
    for _, row in tqdm(review_df.iterrows(), total=len(review_df), desc="Sorting review images"):
        dest_dir = os.path.join(review_sort_dir, row['prediction'])
        os.makedirs(dest_dir, exist_ok=True)
        shutil.copy(row['image_path'], dest_dir)
    print(f"üìÇ Sorted review images into folders at: {review_sort_dir}")

    # --- NEW: Generate shortlist and curated patch CSVs for THIS run ---
    #     - Shortlist: low-confidence items in focus classes (for targeted manual review)
    #     - Curated patch: template CSV for corrected labels to be fed back into VNext
    focus_classes = ['sadness','speech_action','neutral','neutral_speech','happiness']
    
    # Defensive: ensure the expected columns exist
    has_pred = 'prediction' in df.columns or 'predicted_label' in df.columns
    pred_col = 'prediction' if 'prediction' in df.columns else ('predicted_label' if 'predicted_label' in df.columns else None)
    if pred_col is not None:
        # Sort by confidence ascending (uncertain first)
        df_focus = df[df[pred_col].isin(focus_classes)].copy()
        if 'confidence' in df_focus.columns:
            df_focus = df_focus.sort_values('confidence', ascending=True)
    
        short_csv = os.path.join(save_dir, f"curation_shortlist_{version}.csv")
        patch_csv  = os.path.join(save_dir, f"curated_additions_{version}.csv")
    
        # Write shortlist with a stable set of columns
        keep_cols = [c for c in ['image_path','filepath','true_label',pred_col,'confidence'] if c in df_focus.columns]
        df_focus[keep_cols].to_csv(short_csv, index=False)
        print(f"‚úÖ Shortlist written: {short_csv}")
    
        # Create empty curated patch template
        src_path_col = 'image_path' if 'image_path' in df_focus.columns else 'filepath'
        patch_df = pd.DataFrame({
            "filepath": df_focus[src_path_col],
            "correct_label": "",
            "notes": ""
        })
        patch_df.to_csv(patch_csv, index=False)
        print(f"‚úÖ Curated patch template written: {patch_csv}")
    else:
        print("‚ÑπÔ∏è Skipped shortlist/patch CSVs: no predicted label column found in full log.")

    # --- NEW: Merge this run's shortlist/patch with V32 to create canonical merged artifacts ---
    def _merge_csvs(csv_list, key_cols, out_csv):
        import pandas as pd
        import os
    
        # Normalize common column name variants so we can dedupe safely
        def _normalize_cols(df: pd.DataFrame) -> pd.DataFrame:
            colmap = {}
            # path columns
            if "image_path" not in df.columns:
                if "filepath" in df.columns:
                    colmap["filepath"] = "image_path"
                elif "path" in df.columns:
                    colmap["path"] = "image_path"
            # predicted label columns
            if "predicted_label" not in df.columns:
                if "prediction" in df.columns:
                    colmap["prediction"] = "predicted_label"
                elif "predicted" in df.columns:
                    colmap["predicted"] = "predicted_label"
            return df.rename(columns=colmap)
    
        frames = []
        for p in csv_list:
            if os.path.exists(p):
                try:
                    df = pd.read_csv(p)
                    df = _normalize_cols(df)
                    frames.append(df)
                except Exception:
                    pass
    
        if not frames:
            return
    
        merged = pd.concat(frames, ignore_index=True)
    
        # Keep only keys that actually exist after normalization
        available_keys = [k for k in key_cols if k in merged.columns]
        if not available_keys:
            print(f"‚ÑπÔ∏è Skipped merge for {out_csv}: none of the key columns {key_cols} exist in merged data.")
            return
    
        merged = merged.drop_duplicates(subset=available_keys, keep="first")
        merged.to_csv(out_csv, index=False)
        print(f"‚úÖ Merged: {out_csv} ({len(merged)} rows)")

    
    # Paths for this run (already defined above)
    short_csv = os.path.join(save_dir, f"curation_shortlist_{version}.csv")
    patch_csv  = os.path.join(save_dir, f"curated_additions_{version}.csv")
    
    # V32 paths (if present)
    v32_short = os.path.join(save_dir, "curation_shortlist_V32.csv")
    v32_patch = os.path.join(save_dir, "curated_additions_V32.csv")
    
    # Canonical merged outputs
    short_merged = os.path.join(save_dir, "curation_shortlist_merged.csv")
    patch_merged = os.path.join(save_dir, "curated_additions_merged.csv")
    
    # Merge (shortlist merges on [filepath, predicted_label]; patch merges on [filepath])
    if pred_col is not None:
        # Figure out the filepath column available
        avail_path_cols = ['image_path','filepath']
        path_col = next((c for c in avail_path_cols if c in df.columns), None)
    
        if path_col is not None:
            _merge_csvs([v32_short, short_csv], key_cols=[path_col, pred_col], out_csv=short_merged)
            _merge_csvs([v32_patch, patch_csv], key_cols=[path_col], out_csv=patch_merged)
        else:
            print("‚ÑπÔ∏è Skipped merge: no filepath column present in full log.")
    else:
        print("‚ÑπÔ∏è Skipped merge: no predicted label column present in full log.")


    # --- Part C: Mine for "Hard Negative" Confusion Pairs (toggleable & robust) ---
    MINING_HARD_NEGATIVES = True  # ‚Üê set False for deployment runs

    if MINING_HARD_NEGATIVES:
        import pandas as pd
         
        # Prefer the freshly generated full log from THIS run; fallback to prior runs only if missing.
        inference_log_path = full_log_path
        if not os.path.exists(inference_log_path):
            v33_log = os.path.join(SAVE_DIR, "V33_full_inference_log.csv")
            v32_log = os.path.join(SAVE_DIR, "V32_full_inference_log.csv")
            inference_log_path = v33_log if os.path.exists(v33_log) else (v32_log if os.path.exists(v32_log) else None)

    
        if not os.path.exists(inference_log_path):
            print("‚è© Skipping hard-negative mining: no full inference log found.")
        else:
            print("\n‚õèÔ∏è  Mining for hard negative confusion pairs...")
            print(f"   using: {inference_log_path}")
            df = pd.read_csv(inference_log_path)
    
            # Normalize column names between runs (V32 used 'prediction', V33 uses 'predicted_label')
            cols = {c.lower(): c for c in df.columns}
            col_true = cols.get("true_label", "true_label")
            col_pred = cols.get("predicted_label") or cols.get("prediction")
            if col_pred is None:
                raise RuntimeError(f"Could not find predicted label column in {df.columns.tolist()}")
    
            # (Optional) keep a stable sort by confidence descending if available
            col_conf = cols.get("confidence")
            if col_conf:
                df = df.sort_values(col_conf, ascending=False)
    
            # Which pairs to mine
            confusion_pairs_to_mine = [
                ('contempt', 'questioning'),
                ('contempt', 'neutral'),
                ('fear', 'surprise')
            ]
    
            # Save to the current run folder
            save_dir = SAVE_DIR
    
            for c1, c2 in confusion_pairs_to_mine:
                mask = ((df[col_true] == c1) & (df[col_pred] == c2)) | \
                       ((df[col_true] == c2) & (df[col_pred] == c1))
                hard_negatives = df.loc[mask]
    
                if not hard_negatives.empty:
                    out_path = os.path.join(save_dir, f"hard_negatives_{c1}_vs_{c2}.csv")
                    hard_negatives.to_csv(out_path, index=False)
                    print(f"  - Found {len(hard_negatives)} hard negatives for ({c1} ‚Üî {c2}). Saved: {out_path}")
                else:
                    print(f"  - No hard negatives found for ({c1} ‚Üî {c2}).")
    else:
        print("‚è© Hard-negative mining disabled (set MINING_HARD_NEGATIVES=True to enable).")

In [9]:
# ==============================================================================
# 7. Model Calibration
# ==============================================================================

def apply_temperature_scaling(logits, labels):
    """Finds the optimal temperature for calibrating model confidence."""
    logits_tensor = torch.tensor(logits, dtype=torch.float32)
    labels_tensor = torch.tensor(labels, dtype=torch.long)

    class TemperatureScaler(nn.Module):
        def __init__(self):
            super().__init__()
            self.temperature = nn.Parameter(torch.ones(1) * 1.5)

        def forward(self, logits):
            return logits / self.temperature

    model = TemperatureScaler()
    optimizer = LBFGS([model.temperature], lr=0.01, max_iter=50)

    def eval_fn():
        optimizer.zero_grad()
        loss = F.cross_entropy(model(logits_tensor), labels_tensor)
        loss.backward()
        return loss

    optimizer.step(eval_fn)
    return model.temperature.item()

def plot_reliability_diagram(logits, labels, temperature, save_dir, version, stage_name):
    """Visualizes model calibration before and after temperature scaling."""
    logits = torch.from_numpy(logits)
    labels = torch.from_numpy(labels)
    
    # Calculate before
    probs_before = F.softmax(logits, dim=1)
    confs_before, _ = torch.max(probs_before, 1)
    
    # Calculate after
    probs_after = F.softmax(logits / temperature, dim=1)
    confs_after, _ = torch.max(probs_after, 1)

    # Plotting logic remains the same...
    # (For brevity, the detailed plotting code from your old script goes here)
    print(f"üìä Reliability diagram generation logic would go here.")

In [10]:
# ==============================================================================
# 8. Hierarchical Model Ensembling
# ==============================================================================

def hierarchical_ensemble_predict(image_path, processor, s1_models, s2_models, device):
    """Performs an ensembled prediction using multiple hierarchical models."""
    try:
        image = Image.open(image_path).convert("RGB")
        inputs = processor(images=image, return_tensors="pt").to(device)

    except Exception:
        return None, None

    # --- Stage 1 Ensemble (Majority Vote) ---
    s1_votes = []
    with torch.no_grad():
        for model in s1_models:
            logits = model(**inputs).logits
            pred = torch.argmax(logits, dim=-1).item()
            s1_votes.append(pred)
    
    # Decide relevance based on majority vote (1 = relevant)
    is_relevant = Counter(s1_votes).most_common(1)[0][0] == label2id_s1['relevant']

    if not is_relevant:
        return "irrelevant", None

    # --- Stage 2 Ensemble (Average Probabilities) ---
    s2_probs = []
    with torch.no_grad():
        for model in s2_models:
            logits = model(**inputs).logits
            probs = F.softmax(logits, dim=-1)
            s2_probs.append(probs)
            
    # Average the probabilities across all models
    avg_probs = torch.mean(torch.stack(s2_probs), dim=0)
    confidence, pred_idx = torch.max(avg_probs, dim=-1)
    
    final_prediction = id2label_s2[pred_idx.item()]
    final_confidence = confidence.item()
    
    return final_prediction, final_confidence

In [11]:
# # --- DEBUG sanity: verify that `main` is defined and callable ---
# try:
#     import inspect
#     print("[DEBUG] __name__:", __name__)
#     print("[DEBUG] has `main` in globals():", 'main' in globals())
#     if 'main' in globals():
#         print("[DEBUG] `main` callable:", callable(main))
#         print("[DEBUG] `main` defined at line:", inspect.getsourcelines(main)[1])
# except Exception as _e:
#     print("[DEBUG] sanity check error:", repr(_e))


# from transformers import AutoImageProcessor
# from datasets import load_dataset
# import os

# # 1) Recreate processor as in Stage-1
# processor = AutoImageProcessor.from_pretrained(BASE_MODEL_NAME)

# # 2) Point to prepared Stage-1 dataset (your script uses this when PREPARE_DATASETS=False)
# prepared_data_path = os.path.join(OUTPUT_ROOT_DIR, "prepared_datasets")
# stage1_dataset_path = os.path.join(prepared_data_path, "stage_1_relevance_dataset")

# ds_s1 = load_dataset("imagefolder", data_dir=stage1_dataset_path, split="train")
# sample = [ds_s1[i] for i in range(4)]

# # 3) Minimal Stage-1 augment map (or use your full augment_map_s1 if it‚Äôs available)
# try:
#     _ = augment_map_s1  # see if your map exists
# except NameError:
#     augment_map_s1 = {}  # fall back to empty (base_augment only)

# # 4) Collator and test
# coll = DataCollatorWithAugmentation(
#     processor=processor,
#     augment_dict=augment_map_s1,
#     random_erasing_prob=0.10,
#     random_erasing_scale=(0.02, 0.08),
#     skip_erasing_label_ids=[]
# )

# batch = coll(sample)
# print("pixel_values shape:", batch["pixel_values"].shape)  # (4, 3, 224, 224)
# print("labels shape:", batch["labels"].shape)              # (4,)

# from transformers import ViTForImageClassification, AutoImageProcessor
# from datasets import load_dataset
# import os, torch

# device = torch.device("cpu")

# # Paths mirror your script‚Äôs logic
# prepared_data_path = os.path.join(OUTPUT_ROOT_DIR, "prepared_datasets")
# stage1_dataset_path = os.path.join(prepared_data_path, "stage_1_relevance_dataset")
# s1_checkpoint_path = os.path.join(PRETRAINED_CHECKPOINT_PATH, "relevance_filter_model")

# # Recreate processor & model as Stage 1 does
# processor = AutoImageProcessor.from_pretrained(BASE_MODEL_NAME)
# model_s1 = ViTForImageClassification.from_pretrained(
#     s1_checkpoint_path,
#     num_labels=2,
#     label2id=label2id_s1,
#     id2label=id2label_s1,
#     ignore_mismatched_sizes=True
# ).to(device).eval()

# # Eval split
# dataset_s1 = load_dataset("imagefolder", data_dir=stage1_dataset_path, split="train").train_test_split(test_size=0.2, seed=42)
# eval_dataset_s1 = dataset_s1["test"]

# print("\nüß™ Calibrating Stage 1 (standalone)‚Ä¶")
# T_hat = fit_temperature(model_s1, eval_dataset_s1, processor, device)
# _ = sweep_tau(model_s1, eval_dataset_s1, processor, device, T=T_hat)
# print("‚úÖ S1 calibration helpers completed.")

In [12]:
# ==============================================================================
# 9. Script Execution Entry Point
# ==============================================================================
if __name__ == "__main__":

    # Define the device once for the entire script run.
    device = torch.device("cpu")
    
    # --- Step 1: Execute Training Pipeline ---
    # The main function now returns the trained models and processor
    model_s1, model_s2, processor = main(device)
    
    # --- Step 2: Run Post-Training Analysis & Curation ---
    if RUN_INFERENCE:
        # This function runs the full inference pass and generates logs for review.
        # It uses the in-memory models returned from main().
        run_post_training_analysis(model_s1, model_s2, processor, device, BASE_DATASET_PATH, SAVE_DIR, VERSION)
    
    # --- Step 3: Run Final Model Checks ---
    # Check if the model is ready for "deployment" based on F1 scores
    stage2_metrics_path = os.path.join(SAVE_DIR, "per_class_metrics_Stage2.csv")
    check_deployment_readiness(stage2_metrics_path, f1_threshold=0.80)
    
    # --- Step 4: Calibrate the Stage 2 Model ---
    logits_s2_path = os.path.join(SAVE_DIR, f"logits_eval_Stage2_{VERSION}.npy")
    labels_s2_path = os.path.join(SAVE_DIR, f"labels_eval_Stage2_{VERSION}.npy")
    
    if os.path.exists(logits_s2_path) and os.path.exists(labels_s2_path):
        print("\n" + "="*60)
        print("  CALIBRATING STAGE 2 MODEL")
        print("="*60)
        logits_s2 = np.load(logits_s2_path)
        labels_s2 = np.load(labels_s2_path)
        
        optimal_temp = apply_temperature_scaling(logits_s2, labels_s2)
        print(f"‚úÖ Optimal temperature for Stage 2 model: {optimal_temp:.4f}")
        # plot_reliability_diagram(logits_s2, labels_s2, optimal_temp, SAVE_DIR, VERSION, "Stage2")
    else:
        print("‚ö†Ô∏è Skipping calibration, logits/labels files for Stage 2 not found.")

    # COME BACK LATER TO MAKE DYNAMIC AND AUTOMATED LOADING OF PATH
    # --- Step 5: (Hypothetical) Run Ensemble Analysis ---
    # Use the saved V32 artifacts as the "previous" models for ensembling
    v_prev_path = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V32_20251008_115114"
    
    if os.path.exists(v_prev_path):
        print("\n" + "="*60)
        print("  RUNNING HIERARCHICAL ENSEMBLE ANALYSIS (current + V32)")
        print("="*60)
        
        # Load the older V32 models for the ensemble
        s1_model_prev = AutoModelForImageClassification.from_pretrained(
            os.path.join(v_prev_path, "relevance_filter_model")
        ).to(device).eval()
        s2_model_prev = AutoModelForImageClassification.from_pretrained(
            os.path.join(v_prev_path, "emotion_classifier_model")
        ).to(device).eval()
        
        # Use the in-memory models from THIS run (e.g., V33 when you launch it)
        # Assumes you have model_s1 and model_s2 already defined in memory
        s1_models_ensemble = [model_s1, s1_model_prev]
        s2_models_ensemble = [model_s2, s2_model_prev]

        # NEW: auto-pick a real image from ANY non-empty predicted-class folder
        review_root = os.path.join(v_prev_path, "review_candidates_by_predicted_class")
        example_image_path = None
        if os.path.isdir(review_root):
            for cls in os.listdir(review_root):
                cls_dir = os.path.join(review_root, cls)
                if os.path.isdir(cls_dir):
                    imgs = [f for f in os.listdir(cls_dir) if f.lower().endswith((".jpg",".jpeg",".png",".tif",".tiff"))]
                    if imgs:
                        example_image_path = os.path.join(cls_dir, imgs[0])
                        break
    
        if example_image_path and os.path.exists(example_image_path):
            prediction, confidence = hierarchical_ensemble_predict(
                example_image_path, processor, s1_models_ensemble, s2_models_ensemble, device
            )
            print(f"Ensemble prediction for {Path(example_image_path).name}: {prediction} (Confidence: {confidence:.2f})")
        else:
            print("‚ÑπÔ∏è Skipping ensemble demo: no example image found under 'review_candidates_by_predicted_class'.")


üñ•Ô∏è Using device: cpu
‚úÖ Skipping dataset preparation, using existing directories.

  STAGE 1: TRAINING RELEVANCE FILTER (BINARY CLASSIFIER)


Resolving data files:   0%|          | 0/26881 [00:00<?, ?it/s]

Stage 1: 21504 training samples, 5377 validation samples.




‚öñÔ∏è Stage 1 Class Weights: tensor([0.6492, 2.1761])
üöÄ Starting Stage 1 training...


Epoch,Training Loss,Validation Loss,Accuracy
1,0.2138,0.230733,0.930816
2,0.0737,0.226137,0.953134
3,0.117,0.117083,0.96392
4,0.039,0.115449,0.966524
5,0.0589,0.119504,0.962247



üìà Classification Report for Stage1:
              precision    recall  f1-score   support

  irrelevant       0.93      0.99      0.96      4132
    relevant       0.94      0.75      0.83      1245

    accuracy                           0.93      5377
   macro avg       0.94      0.87      0.89      5377
weighted avg       0.93      0.93      0.93      5377


Top 3 confused class pairs:
  - relevant ‚Üí irrelevant: 316 instances
  - irrelevant ‚Üí relevant: 56 instances

üß† Avg prediction entropy: 0.1981

üîç Class entropies (sorted):
  - relevant: entropy = 0.4571
  - irrelevant: entropy = 0.1201

üìà Classification Report for Stage1:
              precision    recall  f1-score   support

  irrelevant       0.95      0.99      0.97      4132
    relevant       0.96      0.84      0.89      1245

    accuracy                           0.95      5377
   macro avg       0.95      0.91      0.93      5377
weighted avg       0.95      0.95      0.95      5377


Top 3 confused cla

Resolving data files:   0%|          | 0/6175 [00:00<?, ?it/s]

Stage 2: 4940 training samples, 1235 validation samples.
Stage 2 Label Distribution (Train): Counter({9: 1608, 4: 651, 8: 554, 5: 530, 0: 388, 6: 382, 1: 251, 3: 240, 10: 135, 7: 101, 2: 100})
üöÄ Starting Stage 2 training...


Epoch,Training Loss,Validation Loss,Accuracy
1,0.2097,0.262585,0.889069
2,0.1339,0.243608,0.897166
3,0.1025,0.232779,0.897166
4,0.0792,0.234519,0.902834
5,0.0558,0.22718,0.910931
6,0.0522,0.210232,0.918219



üìà Classification Report for Stage2:
                precision    recall  f1-score   support

         anger       0.81      1.00      0.89        85
      contempt       0.83      0.80      0.81        60
       disgust       1.00      0.69      0.82        26
          fear       0.88      0.93      0.90        71
     happiness       0.96      0.94      0.95       167
       neutral       0.96      0.96      0.96       135
   questioning       0.76      0.84      0.80        92
       sadness       0.76      0.47      0.58        40
      surprise       0.98      0.97      0.98       147
neutral_speech       0.90      0.86      0.88       381
 speech_action       0.63      0.84      0.72        31

      accuracy                           0.89      1235
     macro avg       0.86      0.85      0.85      1235
  weighted avg       0.89      0.89      0.89      1235


Top 3 confused class pairs:
  - sadness ‚Üí neutral_speech: 14 instances
  - neutral_speech ‚Üí questioning: 13 inst

üî¨ Running Hierarchical Inference: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 841/841 [24:47<00:00,  1.77s/it]


UnboundLocalError: local variable 'pd' referenced before assignment