In [None]:
#@title ## Cell 1: Setup, Dependencies & Downloads
#@markdown Run this cell to install necessary libraries and download models from Hugging Face.
#@markdown This might take a few minutes, especially the model downloads.

!pip install Pillow torch transformers numpy tqdm accelerate huggingface-hub --quiet

import os
import glob
import json
import time
import logging
from typing import List, Dict, Any, Tuple, Optional
import gc
import sys
from pathlib import Path
import io
import base64
import html as html_parser # Renamed to avoid conflict with IPython.display.HTML
import random

from huggingface_hub import snapshot_download, hf_hub_download
from IPython.display import HTML, display, clear_output

# --- Configure Paths ---
# Base directory for downloaded models and assets
HF_REPO_ID = "laion/Empathic-Insight-Face-Large"
LOCAL_REPO_PATH = Path("/content/Empathic-Insight-Face-Large")
MODEL_DIR_NAME = "." # Models are in the root of the downloaded repo
NEUTRAL_STATS_FILENAME = "neutral_stats_cache-_human-binary-big-mlps.json"
DEMO_IMAGES_DIR = Path("/content/demo_images")

# Create directories
LOCAL_REPO_PATH.mkdir(parents=True, exist_ok=True)
DEMO_IMAGES_DIR.mkdir(parents=True, exist_ok=True)

print(f"Downloading Hugging Face repository '{HF_REPO_ID}' to '{LOCAL_REPO_PATH}'...")
try:
    snapshot_download(
        repo_id=HF_REPO_ID,
        local_dir=LOCAL_REPO_PATH,
        local_dir_use_symlinks=False, # Important for Colab to copy files
        # Explicitly allow .pth and .json files, and any other necessary file types
        allow_patterns=["*.pth", "*.json", "*.jpeg", "*.jpg", "*.png", "*.md", ".gitattributes"],
        # You can also ignore LFS files if they are too large and you don't need them
        # For this model, .pth files are LFS, so we need them.
    )
    print("Repository download complete.")
except Exception as e:
    print(f"Error downloading repository: {e}")
    # Handle error, e.g., by raising it or exiting

# Set global paths based on downloads
MODEL_DIR = LOCAL_REPO_PATH / MODEL_DIR_NAME
NEUTRAL_STATS_CACHE_FILE = LOCAL_REPO_PATH / NEUTRAL_STATS_FILENAME

# This folder is used by get_neutral_statistics if the cache is invalid.
# Since we are relying on the provided cache, and the neutral .npy files are not in the HF repo,
# we'll set this to a placeholder. The script should prioritize the cache.
NEUTRAL_EMBEDDINGS_FOLDER = Path("/content/dummy_neutral_embeddings_folder") # Placeholder
if not NEUTRAL_EMBEDDINGS_FOLDER.exists():
    NEUTRAL_EMBEDDINGS_FOLDER.mkdir(parents=True, exist_ok=True)


print(f"\nDownloading demo images to '{DEMO_IMAGES_DIR}'...")
demo_image_files = ["1.jpeg", "c1.jpeg"] # From the HF repo
for img_file in demo_image_files:
    try:
        hf_hub_download(
            repo_id=HF_REPO_ID,
            filename=img_file,
            local_dir=DEMO_IMAGES_DIR,
            local_dir_use_symlinks=False
        )
        print(f"Downloaded {img_file}")
    except Exception as e:
        print(f"Could not download demo image {img_file}: {e}")

print("\n--- Paths Configuration ---")
print(f"Model Directory: {MODEL_DIR}")
print(f"Neutral Stats Cache: {NEUTRAL_STATS_CACHE_FILE}")
print(f"Demo Images Directory: {DEMO_IMAGES_DIR}")
print(f"Using placeholder for Neutral Embeddings Folder (cache should be used): {NEUTRAL_EMBEDDINGS_FOLDER}")

# Check if critical files exist
if not MODEL_DIR.is_dir():
    print(f"FATAL ERROR: Model directory not found: {MODEL_DIR}")
elif not NEUTRAL_STATS_CACHE_FILE.exists():
    print(f"FATAL ERROR: Neutral stats cache file not found: {NEUTRAL_STATS_CACHE_FILE}")
else:
    print("\nSetup cell completed. Critical files seem to be in place.")

# Configure basic logging for the scripts
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)-7s] %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)] # Output to Colab cell
)
ImageFile = None # Placeholder, will be imported from PIL later

In [None]:
#@title ## Cell 2: Global Definitions, Model & Stats Loading
#@markdown This cell defines the MLP model, helper functions, and loads all necessary models (SigLIP, emotion classifiers) and statistics into memory. This can take a moment.

# --- Dependency Imports (deferred to ensure pip install finishes) ---
from PIL import Image, ImageFile, UnidentifiedImageError
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoProcessor
import numpy as np
from tqdm.notebook import tqdm # Use notebook version for better Colab display

ImageFile.LOAD_TRUNCATED_IMAGES = True # From your scripts

# --- Script Parameters (from your scripts, kept for consistency) ---
SIGLIP_MODEL_CKPT: str = "google/siglip2-so400m-patch16-384"
EMBEDDING_DIM: int = 1152
MLP_HIDDEN_LAYERS: List[int] = [1024, 512, 256]
MLP_DROPOUT_RATES: List[float] = [0.2, 0.2, 0.2] # Used during training, model.eval() handles it for inference
MLP_OUTPUT_SIZE: int = 1

# Visualization parameters (from HTML script)
IMAGE_DISPLAY_WIDTH = 250 # Adjusted for Colab display
JPEG_QUALITY = 70
EXCLUDE_FROM_SOFTMAX_KEYWORDS: List[str] = ["valence", "arousal", "dominance", "vulnerability"]

SUPPORTED_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif'}
MODEL_EXTENSION = '.pth'
NPY_EXTENSION = '.npy' # For neutral embeddings, if cache fails

# --- MODEL KEY TO JSON KEY MAPPING (from infer_save_emotions_json.py) ---
MODEL_KEY_TO_JSON_KEY_MAP: Dict[str, str] = {
    "model_elation_best": "Positive High-Energy Emotions|Elation",
    "model_amusement_best": "Positive High-Energy Emotions|Amusement",
    "model_pleasure_ecstasy_best": "Positive High-Energy Emotions|Pleasure/Ecstasy",
    "model_anger_best": "Negative High-Energy Emotions|Anger",
    "model_fear_best": "Negative High-Energy Emotions|Fear",
    "model_distress_best": "Negative High-Energy Emotions|Distress",
    "model_fatigue_exhaustion_best": "Physical and Exhaustive States|Fatigue/Exhaustion",
    "model_helplessness_best": "Negative Low-Energy Emotions|Helplessness",
    "model_astonishment_surprise_best": "Positive High-Energy Emotions|Astonishment/Surprise",
    "model_hope_enthusiasm_optimism_best": "Positive High-Energy Emotions|Hope/Enthusiasm/Optimism",
    "model_dominance_best": "Extra Dimensions|Dominance",
    "model_concentration_best": "Cognitive States and Processes|Concentration",
    "model_impatience_and_irritability_best": "Negative High-Energy Emotions|Impatience and Irritability",
    "model_sadness_best": "Negative Low-Energy Emotions|Sadness",
    "model_emotional_numbness_best": "Negative Low-Energy Emotions|Emotional Numbness",
    "model_relief_best": "Positive Low-Energy Emotions|Relief",
    "model_triumph_best": "Positive High-Energy Emotions|Triumph",
    "model_awe_best": "Positive High-Energy Emotions|Awe",
    "model_intoxication_altered_states_of_consciousness_best": "Physical and Exhaustive States|Intoxication/Altered States of Consciousness",
    "model_jealousy_&_envy_best": "Negative Low-Energy Emotions|Jealousy & Envy",
    "model_pain_best": "Physical and Exhaustive States|Pain",
    "model_disgust_best": "Negative High-Energy Emotions|Disgust",
    "model_sourness_best": "Physical and Exhaustive States|Sourness",
    "model_valence_best": "Extra Dimensions|Valence",
    "model_embarrassment_best": "Negative Low-Energy Emotions|Embarrassment",
    "model_confusion_best": "Cognitive States and Processes|Confusion",
    "model_teasing_best": "Positive High-Energy Emotions|Teasing",
    "model_emotional_vulnerability_best": "Extra Dimensions|Emotional Vulnerability",
    "model_contentment_best": "Positive Low-Energy Emotions|Contentment",
    "model_arousal_best": "Extra Dimensions|Arousal",
    "model_contemplation_best": "Positive Low-Energy Emotions|Contemplation",
    "model_contempt_best": "Negative Low-Energy Emotions|Contempt",
    "model_pride_best": "Positive Low-Energy Emotions|Pride",
    "model_thankfulness_gratitude_best": "Positive Low-Energy Emotions|Thankfulness/Gratitude",
    "model_malevolence_malice_best": "Negative High-Energy Emotions|Malevolence/Malice",
    "model_shame_best": "Negative Low-Energy Emotions|Shame",
    "model_sexual_lust_best": "Longing & Lust|Sexual Lust",
    "model_disappointment_best": "Negative Low-Energy Emotions|Disappointment",
    "model_interest_best": "Positive High-Energy Emotions|Interest",
    "model_longing_best": "Longing & Lust|Longing",
    "model_affection_best": "Positive Low-Energy Emotions|Affection",
    "model_doubt_best": "Negative Low-Energy Emotions|Doubt",
    "model_infatuation_best": "Longing & Lust|Infatuation",
    "model_bitterness_best": "Negative Low-Energy Emotions|Bitterness"
}

# --- MLP Architecture (Identical in both scripts) ---
class MLP(nn.Module):
    def __init__(self, input_size: int = EMBEDDING_DIM,
                 hidden_layers_config: Optional[List[int]] = None, # For compatibility with viz script's init
                 dropout_rates_config: Optional[List[float]] = None, # For compatibility
                 output_size: int = MLP_OUTPUT_SIZE):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size

        # Use the fixed architecture from the scripts
        actual_hidden_layers = MLP_HIDDEN_LAYERS
        actual_dropout_rates = MLP_DROPOUT_RATES

        layers = []
        current_dim = input_size
        for i, hidden_dim in enumerate(actual_hidden_layers):
            layers.append(nn.Linear(current_dim, hidden_dim))
            layers.append(nn.ReLU())
            if i < len(actual_dropout_rates) and actual_dropout_rates[i] > 0.0:
                layers.append(nn.Dropout(actual_dropout_rates[i]))
            current_dim = hidden_dim
        layers.append(nn.Linear(current_dim, output_size))
        self.layers = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

# --- Helper Functions ---
def normalized(a: np.ndarray, axis: int = -1, order: int = 2) -> np.ndarray:
    a = np.asarray(a)
    l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
    l2[l2 == 0] = 1
    return a / np.expand_dims(l2, axis)

# `get_neutral_statistics` from your infer_visualize_emotions_dynamic_softmax_meansub_v7_mlp.py
# (It's more robust in handling cache and recalculation)
def get_neutral_statistics(
    classifiers_dict: Dict[str, nn.Module], # Renamed from 'classifiers' to avoid conflict with global
    neutral_embeddings_folder_path: str, # Renamed for clarity
    cache_file_path_obj: Path, # Renamed for clarity
    device_obj: torch.device # Renamed for clarity
) -> Dict[str, Dict[str, float]]:
    neutral_stats: Dict[str, Dict[str, float]] = {}
    abs_neutral_folder = Path(neutral_embeddings_folder_path).resolve()

    if cache_file_path_obj.exists():
        try:
            with open(cache_file_path_obj, 'r', encoding='utf-8') as f:
                neutral_stats = json.load(f)
            logging.info(f"Successfully loaded neutral statistics from cache: {cache_file_path_obj}")
            if not isinstance(neutral_stats, dict) or \
               any(not isinstance(v, dict) or 'mean' not in v or 'std' not in v for v in neutral_stats.values()):
                logging.warning(f"Cache file {cache_file_path_obj} has unexpected format. Will attempt recalculation if folder exists.")
                neutral_stats = {}
            else:
                 missing_models = set(classifiers_dict.keys()) - set(neutral_stats.keys())
                 if missing_models:
                     logging.warning(f"Cache is missing stats for models: {', '.join(missing_models)}. Will attempt recalculation for all if folder exists.")
                     neutral_stats = {}
                 else:
                     logging.info("Neutral statistics loaded from cache and cover all currently loaded classifiers.")
                     return neutral_stats
        except (json.JSONDecodeError, OSError, TypeError) as e:
            logging.warning(f"Could not load or parse cache file {cache_file_path_obj}: {e}. Will attempt recalculation if folder exists.")
            neutral_stats = {}

    logging.info("--- Calculating Neutral Statistics (Cache not used or incomplete/invalid) ---")
    if not abs_neutral_folder.is_dir() or not any(abs_neutral_folder.rglob(f'*{NPY_EXTENSION}')):
        logging.error(f"Neutral embeddings folder '{abs_neutral_folder}' not found or is empty of '{NPY_EXTENSION}' files. CANNOT RECALCULATE STATS. RELYING ON CACHE IF IT LOADED PARTIALLY, OR USING ZEROES.")
        # If cache was bad AND we can't recalculate, return empty or whatever was partially loaded
        # Or, ensure all models have a default 0,0 stat if no cache and no recalculation possible
        for key in classifiers_dict:
            if key not in neutral_stats:
                 neutral_stats[key] = {"mean": 0.0, "std": 0.0}
        if not neutral_stats: # If cache completely failed and no recalc
             logging.warning("Returning zero-mean/std for all models as cache failed and recalculation is not possible.")
        return neutral_stats

    neutral_npy_files = list(abs_neutral_folder.rglob(f'*{NPY_EXTENSION}'))
    logging.info(f"Found {len(neutral_npy_files)} neutral embedding files for recalculation.")

    neutral_scores: Dict[str, List[float]] = {key: [] for key in classifiers_dict.keys()}
    npy_load_errors = 0
    npy_processed_count = 0

    for npy_path in tqdm(neutral_npy_files, desc="Processing neutral embeddings for stats", unit="file"):
        try:
            embedding_norm_arr = np.load(npy_path) # Renamed for clarity
            if embedding_norm_arr.ndim == 2 and embedding_norm_arr.shape[0] == 1: embedding_norm_arr = embedding_norm_arr.flatten()
            elif embedding_norm_arr.ndim != 1: raise ValueError(f"Unexpected embedding dimension: {embedding_norm_arr.ndim}")
            if embedding_norm_arr.shape[0] != EMBEDDING_DIM: raise ValueError(f"Embedding dim mismatch. Expected {EMBEDDING_DIM}, got {embedding_norm_arr.shape[0]}")

            embedding_tensor = torch.from_numpy(embedding_norm_arr).unsqueeze(0).to(device_obj).float()
            with torch.no_grad():
                for internal_model_key, model_instance in classifiers_dict.items():
                    output = model_instance(embedding_tensor)
                    neutral_scores[internal_model_key].append(float(output.item()))
            npy_processed_count += 1
        except Exception as e:
            logging.error(f"Error processing neutral file {npy_path.name}: {type(e).__name__} - {e}")
            npy_load_errors += 1

    logging.info(f"Processed {npy_processed_count} neutral embeddings for stats. Encountered {npy_load_errors} errors.")
    if npy_processed_count == 0 and not neutral_stats: # only error if cache was also empty
        logging.error("No neutral embeddings processed for stats and cache was empty. Cannot calc stats.")
        return {key: {"mean": 0.0, "std": 0.0} for key in classifiers_dict.keys()} # return default

    for internal_model_key, scores_list in neutral_scores.items():
        if len(scores_list) > 0:
            mean_score = float(np.mean(scores_list))
            std_dev_score = float(np.std(scores_list, ddof=1)) if len(scores_list) > 1 else 0.0
            neutral_stats[internal_model_key] = {"mean": mean_score, "std": std_dev_score}
            logging.info(f"  Recalculated for Model '{internal_model_key}': Mean={mean_score:.4f}, StdDev={std_dev_score:.4f} (from {len(scores_list)} samples)")
        elif internal_model_key not in neutral_stats: # If not in cache and no scores calculated
            neutral_stats[internal_model_key] = {"mean": 0.0, "std": 0.0}
            logging.warning(f"  Model '{internal_model_key}': No neutral scores for recalc and not in cache. Mean/Std set to 0.")
    try:
        cache_file_path_obj.parent.mkdir(parents=True, exist_ok=True)
        with open(cache_file_path_obj, 'w', encoding='utf-8') as f: json.dump(neutral_stats, f, indent=4)
        logging.info(f"Saved/Updated neutral stats to cache: {cache_file_path_obj}")
    except Exception as e: logging.error(f"Error saving neutral stats cache: {e}")
    return neutral_stats


# --- Global Model Variables ---
device: Optional[torch.device] = None
siglip_model: Optional[AutoModel] = None
siglip_processor: Optional[AutoProcessor] = None
# Renamed to avoid conflict with local 'classifiers' in get_neutral_statistics
emotion_classifiers: Dict[str, nn.Module] = {}
neutral_classifier_stats: Dict[str, Dict[str, float]] = {}
siglip_device: Optional[torch.device] = None # For SigLIP if using accelerate

def load_all_models_and_stats():
    global device, siglip_model, siglip_processor, emotion_classifiers, neutral_classifier_stats, siglip_device

    logging.info("--- Initializing Device ---")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device} for MLPs and general processing.")

    logging.info("\n--- Loading SigLIP Model ---")
    try:
        ACCELERATE_AVAILABLE = False
        try:
            import accelerate
            ACCELERATE_AVAILABLE = True
        except ImportError:
            logging.info("Accelerate library not found. SigLIP will load to single device.")

        if ACCELERATE_AVAILABLE:
            try:
                siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_CKPT, device_map="auto").eval()
                siglip_device = next(siglip_model.parameters()).device # Get actual device
                logging.info(f"SigLIP loaded with device_map='auto' to {siglip_device}.")
            except Exception as e_accel:
                logging.warning(f"device_map='auto' failed ({e_accel}). Loading SigLIP to {device}.")
                siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_CKPT).to(device).eval()
                siglip_device = device
        else:
            siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_CKPT).to(device).eval()
            siglip_device = device

        if siglip_model is None: raise RuntimeError("SigLIP model loading failed.")
        siglip_processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_CKPT)
        logging.info("SigLIP base model and processor loaded successfully.")
    except Exception as e:
        logging.error(f"Fatal Error loading SigLIP model: {e}", exc_info=True)
        # sys.exit(1) # In Colab, just print and let user decide

    logging.info("\n--- Loading Emotion Classifier Models ---")
    if not MODEL_DIR.is_dir():
        logging.error(f"Model directory '{MODEL_DIR}' not found. Cannot load classifiers.")
        return

    found_model_files = [f for f in os.listdir(MODEL_DIR) if f.lower().endswith(MODEL_EXTENSION) and (MODEL_DIR / f).is_file()]
    if not found_model_files:
        logging.error(f"No '{MODEL_EXTENSION}' files found in {MODEL_DIR}. Cannot load classifiers.")
        return

    loaded_count = 0
    for model_filename in tqdm(found_model_files, desc="Loading MLP Classifiers"):
        model_path = MODEL_DIR / model_filename
        internal_model_key = model_path.stem # e.g., "model_elation_best"
        try:
            model = MLP().to(device) # Architecture is fixed
            model.load_state_dict(torch.load(model_path, map_location=device))
            model.eval()
            emotion_classifiers[internal_model_key] = model
            # logging.debug(f"  Loaded classifier: {internal_model_key}") # Too verbose for Colab
            loaded_count += 1
        except Exception as e:
            logging.error(f"  Error loading classifier '{internal_model_key}': {e}")

    if loaded_count == 0:
        logging.error("No emotion classifiers loaded successfully.")
        return
    logging.info(f"Loaded {loaded_count}/{len(found_model_files)} emotion classifiers.")

    logging.info("\n--- Loading/Calculating Neutral Statistics ---")
    if not emotion_classifiers:
        logging.warning("No classifiers loaded, skipping neutral stats.")
    elif not NEUTRAL_STATS_CACHE_FILE.exists():
        logging.warning(f"Neutral stats cache file '{NEUTRAL_STATS_CACHE_FILE}' not found. Attempting recalculation if neutral embeddings folder is valid, otherwise using zeros.")
        # Call get_neutral_statistics even if cache is missing; it will try to recalc or return defaults
        neutral_classifier_stats = get_neutral_statistics(
            emotion_classifiers, str(NEUTRAL_EMBEDDINGS_FOLDER), NEUTRAL_STATS_CACHE_FILE, device
        )
    else:
        neutral_classifier_stats = get_neutral_statistics(
            emotion_classifiers, str(NEUTRAL_EMBEDDINGS_FOLDER), NEUTRAL_STATS_CACHE_FILE, device
        )

    if not neutral_classifier_stats and emotion_classifiers: # If it's still empty after trying
        logging.warning("Failed to load or calculate neutral stats. Scores will use raw values or zero means.")
        neutral_classifier_stats = {key: {"mean": 0.0, "std": 0.0} for key in emotion_classifiers.keys()}

    logging.info("Global model and stats loading complete.")

# --- Execute Loading ---
load_all_models_and_stats()

In [None]:
#@title ## Cell 3: Generate HTML Visualization for Image Folder
#@markdown 1. Create a folder in Colab (e.g., `/content/my_images`).
#@markdown 2. Upload your images to this folder (or upload a ZIP and unzip it there).
#@markdown 3. Enter the **full path** to your image folder below.
#@markdown 4. Click the Run button for this cell.
#@markdown The HTML report will be displayed below the cell.

IMAGE_FOLDER_FOR_HTML = "/content/demo_images" #@param {type:"string"}
MAX_SAMPLES_FOR_HTML = 50 #@param {type:"integer"}

# --- Visualization Helper Function (from infer_visualize_emotions_dynamic_softmax_meansub_v7_mlp.py) ---
def process_image_for_html(image_path: Path) -> str:
    if not image_path or not image_path.exists():
         logging.error(f"Cannot process image for HTML: Original image path not found or invalid: {image_path}")
         return f"Error: Image not found at {html_parser.escape(str(image_path))}"
    try:
        with Image.open(image_path) as img:
            width, height = img.size
            if width <= 0 or height <= 0:
                raise ValueError(f"Image has invalid dimensions ({width}x{height})")

            aspect_ratio = height / width
            new_height = int(IMAGE_DISPLAY_WIDTH * aspect_ratio)
            if new_height <= 0: new_height = 1

            img = img.resize((IMAGE_DISPLAY_WIDTH, new_height), Image.Resampling.LANCZOS)
            if img.mode in ['P', 'RGBA', 'LA']:
                 background = Image.new("RGB", img.size, (255, 255, 255))
                 try:
                     if 'A' in img.mode:
                        mask = img.split()[-1]
                        background.paste(img, (0, 0), mask)
                     else:
                         rgb_img = img.convert('RGB')
                         background.paste(rgb_img, (0,0))
                     img = background
                 except Exception: # Fallback
                     img = img.convert('RGB')
            elif img.mode != 'RGB':
                img = img.convert('RGB')

            buffer = io.BytesIO()
            img.save(buffer, format="JPEG", quality=JPEG_QUALITY)
            img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
            return f"data:image/jpeg;base64,{img_base64}"
    except (UnidentifiedImageError, ValueError, SyntaxError) as e:
        logging.error(f"Cannot identify or process image file for HTML: {image_path}. Error: {type(e).__name__} - {e}")
        return f"Error: Cannot process image: {html_parser.escape(image_path.name)} ({type(e).__name__})"
    except Exception as e:
        logging.error(f"Unexpected error processing image {image_path.name} for HTML: {e}", exc_info=True)
        return f"Error processing image {html_parser.escape(image_path.name)}: {html_parser.escape(str(e))}"

def generate_html_report(
    input_image_folder_str: str,
    max_samples_html: int,
    output_html_filename: str = "emotion_visualization.html" # Local temp file
):
    global device, siglip_model, siglip_processor, emotion_classifiers, neutral_classifier_stats, siglip_device # Use globals

    if not all([device, siglip_model, siglip_processor, emotion_classifiers]): # neutral_classifier_stats can be empty
        logging.error("Models not loaded. Please run Cell 2 first.")
        return None

    abs_image_folder = Path(input_image_folder_str).resolve()
    if not abs_image_folder.is_dir():
        logging.error(f"Image folder not found: {abs_image_folder}")
        return None

    logging.info(f"\n--- Finding Target Images in {abs_image_folder} ---")
    image_files_paths = []
    for ext in SUPPORTED_EXTENSIONS:
        image_files_paths.extend(abs_image_folder.rglob(f'*{ext}'))
        image_files_paths.extend(abs_image_folder.rglob(f'*{ext.upper()}')) # case-insensitive
    unique_image_files_paths = sorted(list(set(image_files_paths))) # Path objects
    total_images_found = len(unique_image_files_paths)
    logging.info(f"Found {total_images_found} unique image files.")
    if not unique_image_files_paths:
        logging.warning("No images found in the specified folder.")
        return None

    inference_results: List[Dict[str, Any]] = []
    processed_count, error_count = 0, 0
    logging.info("\n--- Running Inference for HTML Report ---")
    classifier_keys = list(emotion_classifiers.keys())
    exclude_keywords_lower = [kw.lower() for kw in EXCLUDE_FROM_SOFTMAX_KEYWORDS]

    with torch.no_grad():
        for img_path_obj in tqdm(unique_image_files_paths, desc="Inferring for HTML", unit="image"):
            pil_image = None
            try:
                pil_image = Image.open(img_path_obj).convert("RGB")
                inputs = siglip_processor(images=[pil_image], return_tensors="pt", padding="max_length", truncation=True).to(siglip_device)
                image_features = siglip_model.get_image_features(**inputs)
                embedding_norm = normalized(image_features.cpu().numpy())
                embedding_tensor = torch.from_numpy(embedding_norm).to(device).float()

                if embedding_tensor.shape != (1, EMBEDDING_DIM):
                    logging.warning(f"Embedding shape mismatch for '{img_path_obj.name}'. Skip."); error_count+=1; continue

                predictions: Dict[str, Dict[str, Optional[float]]] = {}
                models_for_softmax = []
                mean_subtracted_scores_list = []

                for model_key in classifier_keys:
                    model = emotion_classifiers[model_key]
                    raw_score = float(model(embedding_tensor).item())
                    mean_subtracted_score = raw_score # Default if no stats
                    model_neutral_stats = neutral_classifier_stats.get(model_key)
                    if model_neutral_stats and 'mean' in model_neutral_stats:
                        mean_subtracted_score = raw_score - model_neutral_stats['mean']

                    predictions[model_key] = {"raw": raw_score, "mean_subtracted": mean_subtracted_score, "softmax": None}
                    is_excluded = any(keyword in model_key.lower() for keyword in exclude_keywords_lower)
                    # Add to softmax list ONLY IF NOT EXCLUDED AND mean_subtracted_score is a valid number
                    if not is_excluded and isinstance(mean_subtracted_score, (int, float)):
                        models_for_softmax.append(model_key)
                        mean_subtracted_scores_list.append(mean_subtracted_score)

                if models_for_softmax: # Check if list is not empty
                    scores_tensor = torch.tensor(mean_subtracted_scores_list, dtype=torch.float32, device=device)
                    softmax_scores_tensor = F.softmax(scores_tensor, dim=0)
                    for i, model_key in enumerate(models_for_softmax):
                        predictions[model_key]["softmax"] = float(softmax_scores_tensor[i].item())

                inference_results.append({"image_path": img_path_obj, "predictions": predictions})
                processed_count += 1
            except Exception as e:
                logging.error(f"Error processing {img_path_obj.name} for HTML: {e}", exc_info=False)
                error_count += 1
            finally:
                if pil_image: pil_image.close()
                if (processed_count + error_count) % 50 == 0: # Run gc periodically
                    gc.collect()
                    if torch.cuda.is_available(): torch.cuda.empty_cache()

    logging.info(f"--- HTML Inference Complete. Processed: {processed_count}, Errors: {error_count} ---")
    if not inference_results:
        logging.warning("No successful inference results to visualize.")
        return None

    samples_to_render = inference_results
    if len(inference_results) > max_samples_html:
        logging.info(f"Sampling {max_samples_html} out of {len(inference_results)} for HTML.")
        samples_to_render = random.sample(inference_results, max_samples_html)
    try:
        # Sort by the filename part of the image path
        samples_to_render.sort(key=lambda s: s['image_path'].name if s.get('image_path') else "")
    except Exception: pass # ignore sort error if path is weird

    logging.info(f"Generating HTML content for {len(samples_to_render)} samples...")
    # HTML Structure (condensed for brevity, similar to your script)
    html_start = f"""<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8">
    <title>Emotion Prediction Visualization</title><style>
    body {{ font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif; margin: 20px; background-color: #f9fafb; color: #1f2937; }}
    .sample {{ background-color: #ffffff; border: 1px solid #d1d5db; margin-bottom: 25px; padding: 15px; border-radius: 8px; box-shadow: 0 1px 3px 0 rgba(0,0,0,0.1), 0 1px 2px -1px rgba(0,0,0,0.1); display: flex; gap: 20px; align-items: flex-start;}}
    .image-container {{ flex-shrink: 0; }}
    .image-container img {{ max-width: {IMAGE_DISPLAY_WIDTH}px; height: auto; border: 1px solid #e5e7eb; border-radius: 6px; display:block;}}
    .image-container p {{ font-size: 0.8em; color: #6b7280; word-break: break-all; margin-top: 5px;}}
    .predictions-container {{ flex-grow: 1; min-width: 300px; }}
    .dimension-grid {{ display: grid; grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)); gap: 6px 10px; margin-top: 8px;}}
    .dimension-item {{ background-color: #f9fafb; padding: 5px 10px; border-radius: 4px; border: 1px solid #e5e7eb; font-size:0.85rem; display: flex; justify-content: space-between; align-items: center;}}
    .dimension-name {{ color: #374151; padding-right: 8px; }}
    .dimension-score {{ font-weight: 500; color: #111827; background-color: #e5e7eb; padding: 2px 6px; border-radius: 4px; font-family: Menlo, Monaco, Consolas, monospace; min-width: 140px; text-align:right;}}
    .mean-sub-score-paren {{ color: #4b5563; font-size: 0.95em; margin-left: 4px;}}
    .raw-score {{ color: #6b7280; font-size: 0.9em; margin-left: 4px;}}
    .highlight-top1 {{ border-left: 4px solid #DC2626; background-color: #FEE2E2 !important; }}
    .highlight-top1 .dimension-name {{ color: #991B1B; font-weight: 600; }}
    .highlight-top2 {{ border-left: 4px solid #D97706; background-color: #FEF3C7 !important; }}
    .highlight-top2 .dimension-name {{ color: #92400E; font-weight: 600; }}
    .highlight-top3 {{ border-left: 4px solid #059669; background-color: #D1FAE5 !important; }}
    .highlight-top3 .dimension-name {{ color: #047857; font-weight: 500; }}
    .score-na {{ color: #6b7280; font-style: italic; }}
    .error {{ color: #991b1b; background-color: #fee2e2; border: 1px solid #fecaca; padding: 8px; border-radius: 4px; }}
    h1 {{ text-align:center; color: #111827; border-bottom: 2px solid #e5e7eb; padding-bottom:10px; margin-bottom: 20px;}}
    h2, h3 {{ color: #1f2937; margin-top:0; }}
    h3 {{ font-size: 1.1em; margin-bottom: 10px;}}
    </style></head><body><h1>Emotion Prediction Visualization</h1>
    <p>Displaying {len(samples_to_render)} samples (out of {processed_count} successfully inferred). Scores: Softmax (Mean-Subtracted) (Raw). Highlighting: Top 3 Softmax.</p><hr>"""
    html_end = "</body></html>"
    html_body_parts = []

    for i, sample_data in enumerate(tqdm(samples_to_render, desc="Generating HTML parts")):
        image_path = sample_data.get('image_path')
        predictions = sample_data.get('predictions', {})
        img_display_name = image_path.name if image_path else "Unknown"

        img_data_uri_or_error = process_image_for_html(image_path)

        html_body_parts.append(f'<div class="sample">')
        html_body_parts.append('<div class="image-container">')
        if img_data_uri_or_error.startswith('data:image'):
            html_body_parts.append(f'<img src="{img_data_uri_or_error}" alt="{html_parser.escape(img_display_name)}">')
        else:
            html_body_parts.append(f'<div class="error">{img_data_uri_or_error}</div>')

        display_path_str = str(image_path.relative_to(Path("/content")) ) if image_path and image_path.is_relative_to(Path("/content")) else str(image_path)
        html_body_parts.append(f'<p>./{html_parser.escape(display_path_str)}</p></div>') # Display relative path from /content

        html_body_parts.append('<div class="predictions-container">')
        html_body_parts.append(f'<h3>{i+1}. {html_parser.escape(img_display_name)}</h3><div class="dimension-grid">')

        valid_softmax_scores = []
        for dim_key, score_data in predictions.items():
            softmax_val = score_data.get('softmax')
            if isinstance(softmax_val, (int, float)): # Check if it's a number
                valid_softmax_scores.append((dim_key, softmax_val))
        valid_softmax_scores.sort(key=lambda item: item[1], reverse=True)
        top_ranks = {model_key: rank for rank, (model_key, _) in enumerate(valid_softmax_scores[:3], 1)}

        sorted_dimension_keys = sorted(predictions.keys())
        for dim_key in sorted_dimension_keys:
            score_data = predictions.get(dim_key, {})
            raw_s = score_data.get('raw')
            mean_sub_s = score_data.get('mean_subtracted')
            softmax_s = score_data.get('softmax')
            is_excluded = any(keyword in dim_key.lower() for keyword in exclude_keywords_lower)

            score_str_parts = []
            na_class = ""

            # *** CORRECTED F-STRING LOGIC START ***
            if not is_excluded and isinstance(softmax_s, (int, float)):
                score_str_parts.append(f"{softmax_s:.3f}")
                mean_sub_text = f"{mean_sub_s:.2f}" if isinstance(mean_sub_s, (int,float)) else "N/A"
                score_str_parts.append(f'<span class="mean-sub-score-paren">({mean_sub_text})</span>')
                raw_text = f"{raw_s:.1f}" if isinstance(raw_s, (int,float)) else "N/A"
                score_str_parts.append(f'<span class="raw-score">({raw_text})</span>')
            else:
                if isinstance(mean_sub_s, (int, float)):
                    score_str_parts.append(f"{mean_sub_s:.2f}")
                else:
                    score_str_parts.append("N/A"); na_class = "score-na"
                raw_text = f"{raw_s:.1f}" if isinstance(raw_s, (int,float)) else "N/A"
                score_str_parts.append(f'<span class="raw-score">({raw_text})</span>')
            # *** CORRECTED F-STRING LOGIC END ***

            final_score_str = " ".join(score_str_parts)
            rank = top_ranks.get(dim_key)
            hl_class = ""
            if rank == 1: hl_class = "highlight-top1"
            elif rank == 2: hl_class = "highlight-top2"
            elif rank == 3: hl_class = "highlight-top3"

            html_body_parts.append(f'<div class="dimension-item {hl_class}"><span class="dimension-name">{html_parser.escape(dim_key)}</span> <span class="dimension-score {na_class}">{final_score_str}</span></div>')

        html_body_parts.append('</div></div></div>') # End grid, predictions-container, sample

    full_html = html_start + "\n".join(html_body_parts) + html_end
    # Save to a temporary file (optional, could just return string)
    # with open(output_html_filename, 'w', encoding='utf-8') as f:
    #     f.write(full_html)
    # logging.info(f"HTML report saved to {output_html_filename}")
    return full_html


# --- Run HTML Generation ---
if not Path(IMAGE_FOLDER_FOR_HTML).is_dir():
    print(f"ERROR: The specified IMAGE_FOLDER_FOR_HTML '{IMAGE_FOLDER_FOR_HTML}' does not exist or is not a directory.")
    print("Please create the folder, upload images, and re-run this cell with the correct path.")
else:
    print(f"Attempting to generate HTML report for images in: {IMAGE_FOLDER_FOR_HTML}")
    html_content = generate_html_report(IMAGE_FOLDER_FOR_HTML, MAX_SAMPLES_FOR_HTML)
    if html_content:
        clear_output(wait=True) # Clear previous output before displaying new HTML
        display(HTML(html_content))
        print(f"HTML report generated for images in '{IMAGE_FOLDER_FOR_HTML}'. Displayed above.")
    else:
        print(f"Failed to generate HTML report for '{IMAGE_FOLDER_FOR_HTML}'. Check logs above for errors.")

In [None]:
#@title ## Cell 4: Generate JSON Annotations for Image Folder
#@markdown 1. Ensure your images are in a folder in Colab (e.g., `/content/my_images` or `/content/demo_images`).
#@markdown 2. Enter the **full path** to this image folder below.
#@markdown 3. Click the Run button for this cell.
#@markdown `.json` annotation files will be saved in the same folder as the images.

IMAGE_FOLDER_FOR_JSON = "/content/demo_images" #@param {type:"string"}
OUTPUT_JSON_SUBFOLDER_NAME = "HQ_emotion_inference_results_json_colab_demo" # Subfolder within input folder for JSONs

def generate_json_annotations(
    input_image_folder_str: str,
    output_json_subfolder_name: str
):
    global device, siglip_model, siglip_processor, emotion_classifiers, neutral_classifier_stats, MODEL_KEY_TO_JSON_KEY_MAP, siglip_device # Use globals

    if not all([device, siglip_model, siglip_processor, emotion_classifiers, MODEL_KEY_TO_JSON_KEY_MAP]):
        logging.error("Models or mappings not loaded. Please run Cell 2 first.")
        return

    abs_image_folder = Path(input_image_folder_str).resolve()
    if not abs_image_folder.is_dir():
        logging.error(f"Image folder not found: {abs_image_folder}")
        return

    # Create a subfolder for JSON outputs within the image folder
    # This matches the behavior of your original JSON script (outputting to a specific folder)
    # but keeps it relative to the input for Colab ease.
    abs_output_json_folder = abs_image_folder / output_json_subfolder_name
    abs_output_json_folder.mkdir(parents=True, exist_ok=True)
    logging.info(f"JSON annotations will be saved to: {abs_output_json_folder}")


    logging.info(f"\n--- Finding Target Images in {abs_image_folder} for JSON ---")
    image_files_paths = []
    for ext in SUPPORTED_EXTENSIONS:
        image_files_paths.extend(abs_image_folder.rglob(f'*{ext}')) # rglob from the input folder
        image_files_paths.extend(abs_image_folder.rglob(f'*{ext.upper()}'))
    unique_image_files_paths = sorted(list(set(p for p in image_files_paths if p.is_file() and p.parent == abs_image_folder))) # Only top-level images in the folder
    total_images_found = len(unique_image_files_paths)
    logging.info(f"Found {total_images_found} unique image files in the top-level of the specified folder.")
    if not unique_image_files_paths:
        logging.warning("No top-level images found in the specified folder for JSON annotation.")
        return

    processed_count, error_count, json_saved_count = 0, 0, 0
    logging.info("\n--- Running Inference and Saving JSON ---")
    internal_classifier_keys = list(emotion_classifiers.keys())

    with torch.no_grad():
        for img_path_obj in tqdm(unique_image_files_paths, desc="Inferring/Saving JSON", unit="image"):
            pil_image = None
            try:
                pil_image = Image.open(img_path_obj).convert("RGB")
                inputs = siglip_processor(images=[pil_image], return_tensors="pt", padding="max_length", truncation=True).to(siglip_device)
                image_features = siglip_model.get_image_features(**inputs)
                embedding_norm = normalized(image_features.cpu().numpy())
                embedding_tensor = torch.from_numpy(embedding_norm).to(device).float()

                if embedding_tensor.shape != (1, EMBEDDING_DIM):
                    logging.warning(f"Embedding shape mismatch for '{img_path_obj.name}'. Skip."); error_count+=1; continue

                predictions_output_dict: Dict[str, float] = {}
                for internal_model_key in internal_classifier_keys:
                    model_instance = emotion_classifiers[internal_model_key]
                    raw_score = float(model_instance(embedding_tensor).item())
                    final_score = raw_score
                    model_neutral_stats_dict = neutral_classifier_stats.get(internal_model_key) # Renamed for clarity
                    if model_neutral_stats_dict and 'mean' in model_neutral_stats_dict:
                        final_score = raw_score - model_neutral_stats_dict['mean']

                    json_output_key = MODEL_KEY_TO_JSON_KEY_MAP.get(internal_model_key)
                    if json_output_key:
                        predictions_output_dict[json_output_key] = final_score
                    else:
                        logging.warning(f"No JSON mapping for '{internal_model_key}'. Score excluded for {img_path_obj.name}.")

                # "key" in output JSON: [image_folder_name, model_config_name, image_filename]
                # For Colab, model_config_name is less relevant if we use one set of models.
                # Let's use the HF repo ID as a stand-in for model_config_name.
                json_key_list = [abs_image_folder.name, HF_REPO_ID.split('/')[-1], img_path_obj.name]
                output_data = {"key": json_key_list, "value": predictions_output_dict}

                json_filename = img_path_obj.stem + ".json"
                # Save JSON in the designated subfolder
                output_json_path = abs_output_json_folder / json_filename

                try:
                    with open(output_json_path, 'w', encoding='utf-8') as f_json:
                        json.dump(output_data, f_json, indent=2)
                    json_saved_count +=1
                except Exception as e_json:
                    logging.error(f"Error saving JSON for {img_path_obj.name}: {e_json}"); error_count +=1
                processed_count += 1

            except Exception as e:
                logging.error(f"Error processing {img_path_obj.name} for JSON: {type(e).__name__} - {e}", exc_info=False)
                error_count += 1
            finally:
                if pil_image: pil_image.close()
                if processed_count > 0 and processed_count % 100 == 0:
                    gc.collect()
                    if torch.cuda.is_available(): torch.cuda.empty_cache()

    logging.info(f"--- JSON Inference/Save Complete ---")
    logging.info(f"Processed: {processed_count}/{total_images_found} images.")
    logging.info(f"Saved: {json_saved_count} JSON files to '{abs_output_json_folder}'.")
    if error_count > 0: logging.warning(f"Errors for: {error_count} images/ops during JSON generation.")


# --- Run JSON Generation ---
if not Path(IMAGE_FOLDER_FOR_JSON).is_dir():
    print(f"ERROR: The specified IMAGE_FOLDER_FOR_JSON '{IMAGE_FOLDER_FOR_JSON}' does not exist or is not a directory.")
    print("Please create the folder, upload images, and re-run this cell with the correct path.")
else:
    print(f"Attempting to generate JSON annotations for images in: {IMAGE_FOLDER_FOR_JSON}")
    print(f"JSON files will be saved into a subfolder named '{OUTPUT_JSON_SUBFOLDER_NAME}' inside it.")
    generate_json_annotations(IMAGE_FOLDER_FOR_JSON, OUTPUT_JSON_SUBFOLDER_NAME)
    print(f"\nJSON annotation generation finished for '{IMAGE_FOLDER_FOR_JSON}'.")
    print(f"Check the folder '{Path(IMAGE_FOLDER_FOR_JSON) / OUTPUT_JSON_SUBFOLDER_NAME}' for the output files.")
    print("\nExample of how to list generated JSONs (run in a new Colab cell):")
    print(f"!ls -l {Path(IMAGE_FOLDER_FOR_JSON) / OUTPUT_JSON_SUBFOLDER_NAME}/*.json")
    print("\nExample of how to view a JSON file (run in a new Colab cell, replace 'your_image_name.json'):")
    print(f"!cat {Path(IMAGE_FOLDER_FOR_JSON) / OUTPUT_JSON_SUBFOLDER_NAME}/1.json") # Example with demo image