In [None]:
# ==============================================================================
#
#   V30_Hierarchical_FER_Model
#
#   Description:
#   This script implements the "Strategy A: Hierarchical Classification" pipeline
#   as specified in the project handover document. It refactors the V29 script
#   to support a two-stage training and inference process to handle extreme
#   class imbalance and improve real-world performance.
#
#   - Stage 1: Trains a binary "Relevance Filter" to classify images as either
#     'relevant' (an emotion/speech action) or 'irrelevant' (hard cases,
#     non-faces).
#
#   - Stage 2: Trains a fine-grained 11-class classifier on *only* the
#     'relevant' images to predict the final emotion label.

#V30 changes:
    # overview: Complete refactor to a two-stage hierarchical classification 
        # pipeline to solve the core class imbalance problem.
    # section #1 - Updated global configurations to define 'RELEVANT_CLASSES' 
        #(11 emotions/actions) and 'IRRELEVANT_CLASSES' ('hard_case') to 
        #drive the new two-stage process.
    # section #2 - Added a new, robust 'prepare_hierarchical_datasets' 
        #function. This function automatically reorganizes the source 
        #dataset into two separate structures for training and is designed 
        #to recursively search all sub-folders (no matter how deep) while 
        #skipping non-image files.
    # section #3 - Replaced the single training block with a full two-stage 
        #training pipeline:
    #   - Stage 1: Trains a binary 'Relevance Filter' model. Implemented 
        #class weighting in the loss function to handle the extreme imbalance 
        #from the 'hard_case' folder.
    #   - Stage 2: Trains the final 11-class 'Emotion Classifier' model only 
        #on the 'relevant' data, isolating it from noisy examples and 
        #allowing it to focus on subtle differences.
    # section #4 - Enhanced the 'CustomLossTrainer' to be more flexible, 
        #supporting either class weights (for Stage 1) or a targeted loss 
        #function (for Stage 2).
    # section #5 - Overhauled the entire inference process. Created a new 
        #'hierarchical_predict' function that loads both models and chains 
        #them: first checking relevance, then classifying the emotion.
    # section #6 - Removed the previous data balancing/oversampling logic 
        #(using MINORITY_CAP), as the hierarchical structure is a superior 
        #and more direct method for handling the class imbalance.
    # overview: This new architecture prevents the final emotion model from 
        #being biased by ambiguous or irrelevant images, aiming for 
        #significantly better real-world performance and generalization.
#
# ==============================================================================

In [1]:
# --------------------------
# 0. Imports
# --------------------------
# WORKAROUND for PyTorch MPS bug
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# Standard Library Imports
import datasets
import csv
import gc
import glob
import multiprocessing as mp
import os
import random
import re
import shutil
import subprocess
import sys
import time

# Third-Party Imports
import accelerate
import dill
import face_recognition
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import transformers

# From Imports
from collections import Counter
from datasets import ClassLabel, Dataset, Features, Image as DatasetsImage, concatenate_datasets, load_dataset
from datetime import datetime
from functools import partial
from imagehash import phash, hex_to_hash
from io import BytesIO
from pathlib import Path
from PIL import Image, ImageOps, ExifTags, UnidentifiedImageError
from sklearn.metrics import classification_report, confusion_matrix, log_loss
from sklearn.utils.class_weight import compute_class_weight
from torch import nn
from torch.nn import functional as F
from torch.optim import AdamW, LBFGS
from torchvision import transforms
from torchvision.transforms import (
    RandAugment,
)
from tqdm import tqdm
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    EarlyStoppingCallback,
    TrainingArguments,
    Trainer,
    ViTForImageClassification,
)

In [2]:
# --------------------------
# 1. Global Configurations
# --------------------------

# --- üìÇ Core Paths ---
# This is the root directory containing your original 14-class dataset structure.
BASE_DATASET_PATH = "/Users/natalyagrokh/AI/ml_expressions/img_datasets/ferckjalfaga_dataset_14_labels"
# This is the root directory where all outputs (models, logs, prepared datasets) will be saved.
OUTPUT_ROOT_DIR = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training"

# --- ‚öôÔ∏è Run Configuration ---
# Set to True to run the hierarchical inference pipeline on the full dataset after training is complete.
RUN_INFERENCE = True
# Set to True on the first run to copy and organize files. Set to False on subsequent runs to save time.
PREPARE_DATASETS = True

# --- ü§ñ Model Configuration ---
# The pretrained Vision Transformer model from Hugging Face to be used as a base.
BASE_MODEL_NAME = "google/vit-base-patch16-224-in21k"
# Path to a previous model checkpoint to start from (e.g., your V29 model).
# This allows the new models to leverage prior learning.
PRETRAINED_CHECKPOINT_PATH = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807"

# --- üè∑Ô∏è Dataset & Label Definitions ---
# These lists define the structure for the hierarchical pipeline.
# All folders listed here will be grouped into the 'relevant' class for Stage 1
# and used for training the final 11-class classifier in Stage 2.
RELEVANT_CLASSES = [
    'anger', 'contempt', 'disgust', 'fear', 'happiness',
    'neutral', 'questioning', 'sadness', 'surprise',
    'neutral_speech', 'speech_action'
]
# **IMPORTANT**: Since 'unknown' is a subfolder of 'hard_case', we only need to
# list 'hard_case' here. The script will find all images inside it recursively.
IRRELEVANT_CLASSES = ['hard_case']

# Mappings for the Stage 2 (11-class Emotion) model
id2label_s2 = dict(enumerate(RELEVANT_CLASSES))
label2id_s2 = {v: k for k, v in id2label_s2.items()}

# Mappings for the Stage 1 (binary Relevance) model
id2label_s1 = {0: 'irrelevant', 1: 'relevant'}
label2id_s1 = {v: k for k, v in id2label_s1.items()}

# --- üñºÔ∏è File Handling ---
# Defines valid image extensions and provides a function to check them.
VALID_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff")
def is_valid_image(filename):
    return filename.lower().endswith(VALID_EXTENSIONS) and not filename.startswith("._")

# --- üî¢ Versioning and Output Directory Setup ---
# Automatically determines the next version number (e.g., V31) and creates a timestamped output folder.
def get_next_version(base_dir):
    all_entries = glob.glob(os.path.join(base_dir, "V*_*"))
    existing = [os.path.basename(d) for d in all_entries if os.path.isdir(d)]
    versions = [
        int(d[1:].split("_")[0]) for d in existing
        if d.startswith("V") and "_" in d and d[1:].split("_")[0].isdigit()
    ]
    next_version = max(versions, default=0) + 1
    return f"V{next_version}"

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
VERSION = get_next_version(OUTPUT_ROOT_DIR)
VERSION_TAG = VERSION + "_" + timestamp
SAVE_DIR = os.path.join(OUTPUT_ROOT_DIR, VERSION_TAG)
os.makedirs(SAVE_DIR, exist_ok=True)
print(f"üìÅ Output directory created: {SAVE_DIR}")

üìÅ Output directory created: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V30_20251007_075715


In [3]:
# ----------------------------------------------------
# 2. Hierarchical Dataset Preparation
# ----------------------------------------------------
# This function organizes the original multi-class dataset into two separate
# folder structures required for the two-stage training process. It recursively
# searches through subdirectories (no matter how deep) and is smart enough to
# skip non-image files.
def prepare_hierarchical_datasets(base_path, output_path):
    
    stage1_path = os.path.join(output_path, "stage_1_relevance_dataset")
    stage2_path = os.path.join(output_path, "stage_2_emotion_dataset")

    print(f"üóÇÔ∏è Preparing hierarchical datasets at: {output_path}")

    # --- Create Stage 1 Dataset (Relevance Filter) ---
    print("\n--- Creating Stage 1 Dataset ---")
    irrelevant_dest = os.path.join(stage1_path, "0_irrelevant")
    relevant_dest = os.path.join(stage1_path, "1_relevant")
    os.makedirs(irrelevant_dest, exist_ok=True)
    os.makedirs(relevant_dest, exist_ok=True)

    # Copy irrelevant files recursively
    print("Processing 'irrelevant' classes...")
    for class_name in IRRELEVANT_CLASSES:
        src_dir = Path(os.path.join(base_path, class_name))
        if src_dir.is_dir():
            print(f"  Recursively copying from '{class_name}'...")
            # Here, rglob('*') finds every file in every sub-folder.
            for file_path in src_dir.rglob('*'):
                if file_path.is_file() and is_valid_image(file_path.name):
                    shutil.copy(file_path, irrelevant_dest)
        else:
            print(f"  ‚ö†Ô∏è Warning: Source directory not found for '{class_name}'")

    # Copy relevant files recursively
    print("Processing 'relevant' classes...")
    for class_name in RELEVANT_CLASSES:
        src_dir = Path(os.path.join(base_path, class_name))
        if src_dir.is_dir():
            print(f"  Recursively copying from '{class_name}'...")
            for file_path in src_dir.rglob('*'):
                if file_path.is_file() and is_valid_image(file_path.name):
                    shutil.copy(file_path, relevant_dest)
        else:
            print(f"  ‚ö†Ô∏è Warning: Source directory not found for '{class_name}'")

    # --- Create Stage 2 Dataset (Emotion Classifier) ---
    print("\n--- Creating Stage 2 Dataset ---")
    for class_name in RELEVANT_CLASSES:
        src_dir = Path(os.path.join(base_path, class_name))
        dest_dir = os.path.join(stage2_path, class_name)

        # Ensure destination is clean before copying
        if os.path.exists(dest_dir):
            shutil.rmtree(dest_dir)
        os.makedirs(dest_dir)

        if src_dir.is_dir():
            print(f"  Copying '{class_name}' to Stage 2 directory...")
            for file_path in src_dir.rglob('*'):
                 if file_path.is_file() and is_valid_image(file_path.name):
                    shutil.copy(file_path, dest_dir)
        else:
            print(f"  ‚ö†Ô∏è Warning: Source directory not found for '{class_name}'")

    print("\n‚úÖ Hierarchical dataset preparation complete.")
    return stage1_path, stage2_path

In [4]:
# -----------------------------------------------
# 3. Utility Functions & Custom Classes
# -----------------------------------------------

# --- Part A: Data Augmentation ---

# üì¶ Applies augmentations and processes images on-the-fly for each batch.
# This is a more robust approach than pre-processing the entire dataset.
class DataCollatorWithAugmentation:
    def __init__(self, processor, augment_dict):
        self.processor = processor
        self.augment_dict = augment_dict
        # Baseline augmentation for majority classes.
        self.base_augment = T.Compose([
            T.RandomResizedCrop(size=(224, 224)), # <-- Use this instead of T.Resize
            T.RandomHorizontalFlip(),
            T.RandomRotation(10),
            T.ColorJitter(brightness=0.1, contrast=0.1)
        ])
    def __call__(self, features):
        processed_images = []
        for x in features:
            label = x["label"]
            # Select the correct augmentation pipeline, default to base_augment
            aug_pipeline = self.augment_dict.get(label, self.base_augment)
            rgb_image = x["image"].convert("RGB")
            augmented_image = aug_pipeline(rgb_image)
            processed_images.append(augmented_image)

        batch = self.processor(
            images=processed_images,
            return_tensors="pt"
        )
        batch["labels"] = torch.tensor([x["label"] for x in features], dtype=torch.long)
        return batch

# --- Part B: Model & Training Components ---

# üèãÔ∏è Defines a custom Trainer that can use either a targeted loss function or class weights.
class CustomLossTrainer(Trainer):
    def __init__(self, *args, loss_fct=None, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fct = loss_fct
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        
        if self.loss_fct:
            # Stage 2 uses the custom targeted smoothing loss
            loss = self.loss_fct(logits, labels)
        else:
            # Stage 1 uses standard CrossEntropyLoss with class weights (all on CPU)
            loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
            loss = loss_fct(logits, labels)
            
        return (loss, outputs) if return_outputs else loss


# üîÑ Implements Cross-Entropy Loss with *Targeted* Label Smoothing.
# Smoothing is turned OFF for specified classes to encourage confident predictions. This is used for Stage 2.
class TargetedSmoothedCrossEntropyLoss(nn.Module):
    def __init__(self, smoothing=0.05, target_class_names=None, label2id_map=None):
        super().__init__()
        self.smoothing = smoothing
        if target_class_names and label2id_map:
            self.target_class_ids = [label2id_map[name] for name in target_class_names]
        else:
            self.target_class_ids = []

    def forward(self, logits, target):
        num_classes = logits.size(1)
        with torch.no_grad():
            smooth_labels = torch.full_like(logits, self.smoothing / (num_classes - 1))
            smooth_labels.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)

            if self.target_class_ids:
                target_mask = torch.isin(target, torch.tensor(self.target_class_ids, device=target.device))
                if target_mask.any():
                    sharp_labels = F.one_hot(target[target_mask], num_classes=num_classes).float()
                    smooth_labels[target_mask] = sharp_labels

        log_probs = F.log_softmax(logits, dim=1)
        loss = -(smooth_labels * log_probs).sum(dim=1).mean()
        return loss

# --- Part C: Metrics & Evaluation ---

# üìä Computes metrics and generates a confusion matrix plot for each evaluation step.
def compute_metrics_with_confusion(eval_pred, label_names, stage_name=""):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)

    print(f"\nüìà Classification Report for {stage_name}:")
    report = classification_report(labels, preds, target_names=label_names, output_dict=True, zero_division=0)
    print(classification_report(labels, preds, target_names=label_names, zero_division=0))

    # Save raw logits/labels for later analysis like temperature scaling
    np.save(os.path.join(SAVE_DIR, f"logits_eval_{stage_name}_{VERSION}.npy"), logits)
    np.save(os.path.join(SAVE_DIR, f"labels_eval_{stage_name}_{VERSION}.npy"), labels)

    # Generate and save a heatmap of the confusion matrix
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=label_names, yticklabels=label_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix - {stage_name}")
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, f"confusion_matrix_{stage_name}_{VERSION}.png"))
    plt.close()

    accuracy = (preds == labels).mean()
    return {"accuracy": accuracy}

# --- Part D: Model Saving ---

# üíæ Saves the model and its associated processor to a specified directory.
def save_model_and_processor(model, processor, save_dir, model_name):
    print(f"üíæ Saving {model_name} and processor to: {save_dir}")
    model_path = os.path.join(save_dir, model_name)
    os.makedirs(model_path, exist_ok=True)
    model = model.to("cpu")
    processor.save_pretrained(model_path)
    model.save_pretrained(model_path, safe_serialization=True)
    print(f"‚úÖ {model_name} saved successfully.")

In [5]:
# --------------------------
# 4. Main Training Script
# --------------------------

def main():
    # --- Step 0: Prepare Datasets ---
    # This function copies files into the required two-stage structure.
    # It only needs to be run once.
    prepared_data_path = os.path.join(OUTPUT_ROOT_DIR, "prepared_datasets")
    if PREPARE_DATASETS:
        stage1_dataset_path, stage2_dataset_path = prepare_hierarchical_datasets(BASE_DATASET_PATH, prepared_data_path)
    else:
        stage1_dataset_path = os.path.join(prepared_data_path, "stage_1_relevance_dataset")
        stage2_dataset_path = os.path.join(prepared_data_path, "stage_2_emotion_dataset")
        print("‚úÖ Skipping dataset preparation, using existing directories.")

    # --- Set hardware device ---
    # WORKAROUND: Forcing CPU to bypass the persistent MPS backend bug.
    device = torch.device("cpu")
    print(f"\nüñ•Ô∏è Using device: {device} (Forced to bypass MPS bug)")
    
    # # --- Set hardware device ---
    # # commented out due to present mps and pytorch incompatibilities
    # device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    # print(f"\nüñ•Ô∏è Using device: {device}")

    # ==========================================================================
    #   STAGE 1: TRAIN RELEVANCE FILTER (BINARY CLASSIFIER)
    # ==========================================================================
    print("\n" + "="*60)
    print("  STAGE 1: TRAINING RELEVANCE FILTER (BINARY CLASSIFIER)")
    print("="*60)

    # --- Load Stage 1 data ---
    stage1_output_dir = os.path.join(SAVE_DIR, "stage_1_relevance_model_training")
    dataset_s1 = load_dataset("imagefolder", data_dir=stage1_dataset_path, split='train').train_test_split(test_size=0.2, seed=42)
    train_dataset_s1 = dataset_s1["train"]
    eval_dataset_s1 = dataset_s1["test"]
    print(f"Stage 1: {len(train_dataset_s1)} training samples, {len(eval_dataset_s1)} validation samples.")

    # --- Configure Stage 1 model ---
    # We load the base processor once.
    processor = AutoImageProcessor.from_pretrained(BASE_MODEL_NAME)
    # Load the pretrained checkpoint but replace the final layer (classifier head)
    # for our binary (2-label) task.
    model_s1 = ViTForImageClassification.from_pretrained(
        PRETRAINED_CHECKPOINT_PATH,
        num_labels=2,
        label2id=label2id_s1,
        id2label=id2label_s1,
        ignore_mismatched_sizes=True
    ).to(device)

    # --- Handle Extreme Class Imbalance in Stage 1 with Class Weights ---
    # This is critical because the 'irrelevant' class is much larger than the 'relevant' class.
    class_weights_s1 = compute_class_weight('balanced', classes=np.unique(train_dataset_s1['label']), y=train_dataset_s1['label'])
    class_weights_s1 = torch.tensor(class_weights_s1, dtype=torch.float).to(device)
    print(f"‚öñÔ∏è Stage 1 Class Weights: {class_weights_s1}")
    
    # --- Set up Stage 1 Trainer ---
    training_args_s1 = TrainingArguments(
        output_dir=stage1_output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        use_cpu=True,
        learning_rate=3e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=2,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        logging_dir=os.path.join(stage1_output_dir, "logs"),
        logging_strategy="steps",
        logging_steps=50,
        remove_unused_columns=False,
    )

    # Use the flexible CustomLossTrainer, passing the class weights to it.
    trainer_s1 = CustomLossTrainer(
        model=model_s1,
        args=training_args_s1,
        train_dataset=train_dataset_s1,
        eval_dataset=eval_dataset_s1,
        compute_metrics=partial(compute_metrics_with_confusion, label_names=list(id2label_s1.values()), stage_name="Stage1"),
        data_collator=DataCollatorWithAugmentation(processor=processor, augment_dict={}), # Use base augmentation for all
        class_weights=class_weights_s1 # Pass weights to the trainer
    )

    # --- Train Stage 1 model ---
    print("üöÄ Starting Stage 1 training...")
    start_time_s1 = time.time() # Record start time
    trainer_s1.train()
    end_time_s1 = time.time()   # Record end time
    
    # Calculate and print the duration
    duration_s1 = end_time_s1 - start_time_s1
    print(f"‚åõ Stage 1 training took: {time.strftime('%H:%M:%S', time.gmtime(duration_s1))}")
    save_model_and_processor(trainer_s1.model, processor, SAVE_DIR, model_name="relevance_filter_model")
    print("\n‚úÖ Stage 1 Training Complete.")

    # ==========================================================================
    #   STAGE 2: TRAIN EMOTION CLASSIFIER (11-CLASS)
    # ==========================================================================
    print("\n" + "="*60)
    print(f"  STAGE 2: TRAINING EMOTION CLASSIFIER ({len(RELEVANT_CLASSES)}-CLASS)")
    print("="*60)

    # --- Load Stage 2 data ---
    stage2_output_dir = os.path.join(SAVE_DIR, "stage_2_emotion_model_training")
    dataset_s2 = load_dataset("imagefolder", data_dir=stage2_dataset_path, split='train').train_test_split(test_size=0.2, seed=42)
    train_dataset_s2 = dataset_s2["train"]
    eval_dataset_s2 = dataset_s2["test"]
    print(f"Stage 2: {len(train_dataset_s2)} training samples, {len(eval_dataset_s2)} validation samples.")
    print("Stage 2 Label Distribution (Train):", Counter(sorted(train_dataset_s2['label'])))

    # --- Configure Stage 2 model ---
    # Load the pretrained checkpoint again, this time with a classifier head for our 11 emotion classes.
    model_s2 = ViTForImageClassification.from_pretrained(
        PRETRAINED_CHECKPOINT_PATH,
        num_labels=len(RELEVANT_CLASSES),
        label2id=label2id_s2,
        id2label=id2label_s2,
        ignore_mismatched_sizes=True
    ).to(device)

    # --- Define Augmentation and Loss for Stage 2 ---
    # Apply stronger augmentation to the minority classes to help the model learn them better.
    minority_aug = T.Compose([
        RandAugment(num_ops=2, magnitude=9),
        T.RandomResizedCrop(224, scale=(0.7, 1.0)),
        T.ColorJitter(0.3, 0.3, 0.3, 0.1),
    ])
    minority_classes_s2 = [label2id_s2[name] for name in ['disgust', 'questioning', 'contempt', 'fear']]
    minority_augment_map_s2 = {label_id: minority_aug for label_id in minority_classes_s2}

    # Use the custom loss function to turn off label smoothing for historically difficult classes.
    loss_fct_s2 = TargetedSmoothedCrossEntropyLoss(
        smoothing=0.05,
        target_class_names=['contempt', 'disgust'],
        label2id_map=label2id_s2
    )

    # --- Set up Stage 2 Trainer ---
    training_args_s2 = TrainingArguments(
        output_dir=stage2_output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        use_cpu=True, 
        learning_rate=4e-5,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=5,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        logging_dir=os.path.join(stage2_output_dir, "logs"),
        logging_strategy="epoch",
        remove_unused_columns=False,
    )

    # Use the CustomLossTrainer again, this time passing the targeted loss function.
    trainer_s2 = CustomLossTrainer(
        model=model_s2,
        args=training_args_s2,
        train_dataset=train_dataset_s2,
        eval_dataset=eval_dataset_s2,
        compute_metrics=partial(compute_metrics_with_confusion, label_names=RELEVANT_CLASSES, stage_name="Stage2"),
        data_collator=DataCollatorWithAugmentation(processor=processor, augment_dict=minority_augment_map_s2),
        loss_fct=loss_fct_s2 # Pass custom loss function
    )

    # --- Train Stage 2 model ---
    print("üöÄ Starting Stage 2 training...")
    start_time_s2 = time.time() # Record start time
    trainer_s2.train()
    end_time_s2 = time.time()   # Record end time
    
    # Calculate and print the duration
    duration_s2 = end_time_s2 - start_time_s2
    print(f"‚åõ Stage 2 training took: {time.strftime('%H:%M:%S', time.gmtime(duration_s2))}")
    save_model_and_processor(trainer_s2.model, processor, SAVE_DIR, model_name="emotion_classifier_model")
    print("\n‚úÖ Stage 2 Training Complete.")
    print("\nüéâ Hierarchical Training Pipeline Finished Successfully.")

In [6]:
# ----------------------------------
# 5. Hierarchical Inference
# ----------------------------------
# This function defines the two-step prediction pipeline for new images.
# It first checks for relevance (Stage 1) and then classifies the emotion (Stage 2).
def hierarchical_predict(image_paths, model_s1, model_s2, processor, device, batch_size=32):
    results = []
    for i in tqdm(range(0, len(image_paths), batch_size), desc="üî¨ Running Hierarchical Inference"):
        batch_paths = image_paths[i:i+batch_size]
        images = []
        valid_paths = []
        for path in batch_paths:
            try:
                img = Image.open(path).convert("RGB")
                images.append(img)
                valid_paths.append(path)
            except Exception:
                continue

        if not images:
            continue

        inputs = processor(images=images, return_tensors="pt").to(device)

        # --- Stage 1 Prediction: Is the image relevant? ---
        with torch.no_grad():
            logits_s1 = model_s1(**inputs).logits
            preds_s1 = torch.argmax(logits_s1, dim=-1)

        # Create a mask of images that were classified as 'relevant'
        relevant_mask = (preds_s1 == label2id_s1['relevant'])

        # --- Stage 2 Prediction (only on relevant images) ---
        if relevant_mask.any():
            # Filter the input tensors to only include the relevant images
            relevant_inputs = {k: v[relevant_mask] for k, v in inputs.items()}

            with torch.no_grad():
                logits_s2 = model_s2(**relevant_inputs).logits
                probs_s2 = F.softmax(logits_s2, dim=-1)
                confs_s2, preds_s2 = torch.max(probs_s2, dim=-1)

        # --- Aggregate Results ---
        # Loop through the original batch and assign the correct prediction
        s2_idx = 0
        for j in range(len(valid_paths)):
            if relevant_mask[j]:
                # If relevant, get the prediction from the Stage 2 model
                pred_label = id2label_s2[preds_s2[s2_idx].item()]
                confidence = confs_s2[s2_idx].item()
                s2_idx += 1
            else:
                # If not relevant, label it and stop
                pred_label = "irrelevant"
                confidence = torch.softmax(logits_s1[j], dim=-1)[preds_s1[j]].item()

            results.append({
                "image_path": valid_paths[j],
                "prediction": pred_label,
                "confidence": confidence
            })
    return results

In [7]:
# ----------------------------------
# 6. Script Execution Entry Point
# ----------------------------------
if __name__ == "__main__":
    # --- Execute Training ---
    main()

    # --- Execute Inference (if enabled) ---
    if RUN_INFERENCE:
        print("\n" + "="*60)
        print("  EXECUTING HIERARCHICAL INFERENCE ON FULL DATASET")
        print("="*60)

        # --- Load Both Trained Models ---
        print("Reloading trained models for inference...")
        # WORKAROUND: Forcing CPU to bypass the persistent MPS backend bug.
        device_inf = torch.device("cpu")

        # # commented out due to present mps and pytorch incompatibilities
        # print("Reloading trained models for inference...")
        # device_inf = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
        model_s1_inf = AutoModelForImageClassification.from_pretrained(os.path.join(SAVE_DIR, "relevance_filter_model")).to(device_inf).eval()
        model_s2_inf = AutoModelForImageClassification.from_pretrained(os.path.join(SAVE_DIR, "emotion_classifier_model")).to(device_inf).eval()
        processor_inf = AutoImageProcessor.from_pretrained(os.path.join(SAVE_DIR, "relevance_filter_model"))
        print("‚úÖ Models loaded.")

        # --- Run Inference on the entire original dataset to test the pipeline ---
        all_image_paths = [str(p) for p in Path(BASE_DATASET_PATH).rglob("*") if is_valid_image(p.name)]
        print(f"Found {len(all_image_paths)} images to process for inference.")

        predictions = hierarchical_predict(all_image_paths, model_s1_inf, model_s2_inf, processor_inf, device_inf)

        # --- Save results to CSV for analysis ---
        df_preds = pd.DataFrame(predictions)
        output_csv_path = os.path.join(SAVE_DIR, f"{VERSION}_hierarchical_predictions.csv")
        df_preds.to_csv(output_csv_path, index=False)
        print(f"\n‚úÖ Inference complete. Results saved to: {output_csv_path}")
        print("\nPrediction distribution:")
        print(df_preds['prediction'].value_counts())

üóÇÔ∏è Preparing hierarchical datasets at: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/prepared_datasets

--- Creating Stage 1 Dataset ---
Processing 'irrelevant' classes...
  Recursively copying from 'hard_case'...
Processing 'relevant' classes...
  Recursively copying from 'anger'...
  Recursively copying from 'contempt'...
  Recursively copying from 'disgust'...
  Recursively copying from 'fear'...
  Recursively copying from 'happiness'...
  Recursively copying from 'neutral'...
  Recursively copying from 'questioning'...
  Recursively copying from 'sadness'...
  Recursively copying from 'surprise'...
  Recursively copying from 'neutral_speech'...
  Recursively copying from 'speech_action'...

--- Creating Stage 2 Dataset ---
  Copying 'anger' to Stage 2 directory...
  Copying 'contempt' to Stage 2 directory...
  Copying 'disgust' to Stage 2 directory...
  Copying 'fear' to Stage 2 directory...
  Copying 'happiness' to Stage 2 directory...
  Copying 'neutral'

Resolving data files:   0%|          | 0/26881 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Stage 1: 21504 training samples, 5377 validation samples.


Some weights of ViTForImageClassification were not initialized from the model checkpoint at /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([10]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([10, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


‚öñÔ∏è Stage 1 Class Weights: tensor([0.6492, 2.1761])
üöÄ Starting Stage 1 training...


Epoch,Training Loss,Validation Loss,Accuracy
1,0.5408,0.494553,0.743537
2,0.4731,0.479724,0.789102



üìà Classification Report for Stage1:
              precision    recall  f1-score   support

  irrelevant       0.91      0.74      0.82      4132
    relevant       0.47      0.77      0.58      1245

    accuracy                           0.74      5377
   macro avg       0.69      0.75      0.70      5377
weighted avg       0.81      0.74      0.76      5377


üìà Classification Report for Stage1:
              precision    recall  f1-score   support

  irrelevant       0.90      0.82      0.86      4132
    relevant       0.53      0.70      0.61      1245

    accuracy                           0.79      5377
   macro avg       0.72      0.76      0.73      5377
weighted avg       0.82      0.79      0.80      5377

‚åõ Stage 1 training took: 02:14:38
üíæ Saving relevance_filter_model and processor to: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V30_20251007_075715
‚úÖ relevance_filter_model saved successfully.

‚úÖ Stage 1 Training Complete.

  STAGE 2:

Resolving data files:   0%|          | 0/6175 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Stage 2: 4940 training samples, 1235 validation samples.
Stage 2 Label Distribution (Train): Counter({9: 1608, 4: 651, 8: 554, 5: 530, 0: 388, 6: 382, 1: 251, 3: 240, 10: 135, 7: 101, 2: 100})


Some weights of ViTForImageClassification were not initialized from the model checkpoint at /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([10]) in the checkpoint and torch.Size([11]) in the model instantiated
- classifier.weight: found shape torch.Size([10, 768]) in the checkpoint and torch.Size([11, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


üöÄ Starting Stage 2 training...


Epoch,Training Loss,Validation Loss,Accuracy
1,1.0742,0.814775,0.808097
2,0.7787,0.846999,0.804858
3,0.7275,0.820377,0.79919
4,0.6739,0.763455,0.825911
5,0.6249,0.760402,0.832389



üìà Classification Report for Stage2:
                precision    recall  f1-score   support

         anger       0.83      0.75      0.79        85
      contempt       0.89      0.85      0.87        60
       disgust       0.91      0.77      0.83        26
          fear       0.93      0.92      0.92        71
     happiness       0.90      0.78      0.84       167
       neutral       0.70      0.79      0.74       135
   questioning       0.78      0.82      0.80        92
       sadness       0.67      0.05      0.09        40
      surprise       0.77      0.84      0.81       147
neutral_speech       0.79      0.91      0.85       381
 speech_action       0.81      0.42      0.55        31

      accuracy                           0.81      1235
     macro avg       0.82      0.72      0.74      1235
  weighted avg       0.81      0.81      0.80      1235


üìà Classification Report for Stage2:
                precision    recall  f1-score   support

         anger      

üî¨ Running Hierarchical Inference: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 841/841 [30:33<00:00,  2.18s/it]


‚úÖ Inference complete. Results saved to: /Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V30_20251007_075715/V30_hierarchical_predictions.csv

Prediction distribution:
prediction
irrelevant        16820
neutral_speech     2455
neutral            1504
contempt           1374
anger              1240
surprise           1167
questioning         880
fear                756
happiness           268
disgust             227
speech_action       112
sadness              99
Name: count, dtype: int64



