In [None]:
# @title Videos with Top 3 Emotion Predictions (Code to make such Videos below in Cell 5)
# Cell 0: YouTube Video Previews

# --- How to Use This Cell ---
# 1. YouTube Video Links or IDs:
#    - Modify the `youtube_video_sources` list below.
#    - You can provide full YouTube video URLs (e.g., "https://www.youtube.com/watch?v=dQw4w9WgXcQ")
#    - Or just the YouTube Video ID (e.g., "dQw4w9WgXcQ")
#    - Or "youtu.be" short links (e.g., "https://youtu.be/dQw4w9WgXcQ")
#    - You can also include timestamps in the URL (e.g., "https://youtu.be/BUnfuiwE_IM?t=90"),
#      though the thumbnail will still be for the video itself, not the specific timestamp.
#
# 2. Thumbnail Quality (Optional):
#    - `THUMBNAIL_QUALITY` can be one of:
#      - "default": Standard quality (120x90)
#      - "mqdefault": Medium quality (320x180)
#      - "hqdefault": High quality (480x360)
#      - "sddefault": Standard definition (640x480) - may not exist for all videos
#      - "maxresdefault": Maximum resolution (1280x720 or 1920x1080) - may not exist for all videos
#
# 3. Run the Cell:
#    - Execute this cell to display the video thumbnails.
# --- End of How to Use ---

import re
from IPython.display import HTML, display
from typing import Optional, List # <<<<<<<<<<<< ADDED IMPORT HERE

# --- Configuration ---
youtube_video_sources: List[str] = [ # Changed to List[str] for consistency
    "https://youtu.be/TsTVKCmqHhk",
    "https://www.youtube.com/watch?v=sErqFgL4vA8",
    "BUnfuiwE_IM", # Just the ID
    "https://youtu.be/BUnfuiwE_IM?t=90" # With timestamp
    "https://www.youtube.com/watch?v=dDrmjcUq8W4?t=74"
]

THUMBNAIL_QUALITY: str = "hqdefault" # Options: default, mqdefault, hqdefault, sddefault, maxresdefault
# THUMBNAILS_PER_ROW is implicitly handled by flexbox wrap

# --- Helper Function to Extract Video ID ---
def extract_youtube_id(url_or_id: str) -> Optional[str]:
    """
    Extracts the YouTube video ID from various URL formats or if an ID is given directly.
    """
    if not url_or_id:
        return None

    # Check if it's likely already an ID (11 characters, no typical URL parts)
    if re.fullmatch(r"[a-zA-Z0-9_-]{11}", url_or_id):
        return url_or_id

    # Regex patterns for different YouTube URL formats
    patterns = [
        r"(?:https?:\/\/)?(?:www\.)?youtube\.com\/watch\?v=([a-zA-Z0-9_-]{11})",  # Standard watch URL
        r"(?:https?:\/\/)?youtu\.be\/([a-zA-Z0-9_-]{11})",  # Shortened youtu.be URL
        r"(?:https?:\/\/)?(?:www\.)?youtube\.com\/embed\/([a-zA-Z0-9_-]{11})", # Embed URL
        r"(?:https?:\/\/)?(?:www\.)?youtube\.com\/v\/([a-zA-Z0-9_-]{11})", # /v/ URL
        r"(?:https?:\/\/)?(?:www\.)?youtube\.com\/shorts\/([a-zA-Z0-9_-]{11})" # Shorts URL
    ]

    for pattern in patterns:
        match = re.search(pattern, url_or_id)
        if match:
            return match.group(1)

    print(f"Warning: Could not extract a valid YouTube ID from '{url_or_id}'")
    return None

# --- Generate HTML for Thumbnails ---
def generate_thumbnail_html(video_sources: List[str], quality: str) -> str: # Removed per_row as it's handled by flex
    if not video_sources:
        return "<p>No video sources provided.</p>"

    html_parts = [
        "<div style='display: flex; flex-wrap: wrap; justify-content: flex-start; gap: 15px;'>"
    ]

    valid_ids_count = 0
    for source in video_sources:
        video_id = extract_youtube_id(source)
        if video_id:
            valid_ids_count +=1
            thumbnail_url = f"https://img.youtube.com/vi/{video_id}/{quality}.jpg"
            video_watch_url = f"https://www.youtube.com/watch?v={video_id}"

            # Dynamic width based on quality for better layout
            width_style = "max-width: 300px;" # default max
            if quality == "default": width_style = "width: 120px;"
            elif quality == "mqdefault": width_style = "width: 320px;"
            elif quality == "hqdefault": width_style = "width: 480px; max-width: 480px;"
            elif quality == "sddefault": width_style = "width: 640px; max-width: 640px;"
            elif quality == "maxresdefault": width_style = "width: 100%; max-width: 720px;" # Maxres can be large

            item_style = f"flex: 0 1 auto; margin-bottom: 15px; text-align: center; {width_style}"

            html_parts.append(f"""
            <div style='{item_style}'>
                <a href='{video_watch_url}' target='_blank' title='Watch video {video_id}'>
                    <img src='{thumbnail_url}' alt='YouTube Thumbnail for {video_id}' style='width: 100%; height: auto; border: 1px solid #ccc; border-radius: 4px; display: block;'>
                </a>
                <p style='font-size: 0.8em; margin-top: 5px; word-break: break-all;'>ID: {video_id}</p>
            </div>
            """)

    if valid_ids_count == 0:
        html_parts.append("<p>No valid YouTube video IDs could be extracted from the provided sources.</p>")

    html_parts.append("</div>")
    return "".join(html_parts)

# --- Display the HTML ---
if not youtube_video_sources:
    print("The 'youtube_video_sources' list is empty. Please add YouTube video URLs or IDs.")
    html_output = "<p>Please configure the <code>youtube_video_sources</code> list in this cell.</p>"
else:
    print(f"Generating YouTube thumbnails for {len(youtube_video_sources)} source(s)...")
    html_output = generate_thumbnail_html(youtube_video_sources, THUMBNAIL_QUALITY)

display(HTML(html_output))

In [None]:
# Cell 1: Setup and Dependencies
!pip install transformers torch torchaudio torchvision librosa huggingface_hub numpy pydub ipython --quiet
print("Dependencies installed.")

import os
import sys
import json
import gc
import logging
import time
from pathlib import Path
from typing import List, Dict, Tuple, Any, Optional, Set
import base64
import io
import shutil # For cleaning up downloaded models if needed

import torch
import torch.nn as nn
import numpy as np
import librosa
import librosa.display # For plotting waveform in HTML
import matplotlib.pyplot as plt # For plotting waveform

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download, hf_hub_download
from IPython.display import HTML, display, Audio as IPythonAudio

try:
    from pydub import AudioSegment
    from pydub.exceptions import CouldntDecodeError, CouldntEncodeError
    PYDUB_AVAILABLE = True
except ImportError:
    PYDUB_AVAILABLE = False
    print("WARNING: pydub library not found. Audio player in HTML report (Cell 3) will not have playable audio. Install with: !pip install pydub")
except Exception as e:
    PYDUB_AVAILABLE = False
    print(f"WARNING: Error initializing pydub (likely ffmpeg/avconv issue): {e}. Audio player in HTML report (Cell 3) will not have playable audio.")

# Setup basic logging
def setup_notebook_logging(log_level=logging.INFO):
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
        handler.close()
    logging.basicConfig(
        level=log_level,
        format='%(asctime)s [%(levelname)-7s] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[logging.StreamHandler(sys.stdout)]
    )
setup_notebook_logging()

# --- Global Configuration Toggles (User can modify these in Cell 2) ---
# These will be properly defined and used in Cell 2.
# This is just a placeholder comment.

In [None]:
# Cell 2: Configuration, Demo File Download, Model Definitions, and MLP Path Discovery
# This demo is not speed optimized and therefore pretty slow.
# Atm it is loading every MLP model at a time,
# because the RAM and the VRAM of the GPU here in Google Colab are not big enough to keep all.

import os
import sys
import json
import gc
import logging
import time
from pathlib import Path
from typing import List, Dict, Tuple, Any, Optional
import shutil
import requests # For downloading demo files

import torch
import torch.nn as nn
import numpy as np
import librosa
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download, hf_hub_download

# --- User-Modifiable Configuration ---
# For MLP model execution
USE_CPU_OFFLOADING_FOR_MLPS = True
USE_HALF_PRECISION_FOR_MLPS = True
USE_TORCH_COMPILE_FOR_MLPS = True

WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small"
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

# --- Paths for Batch Processing ---
# Create these folders in your Colab environment if they don't exist
# !mkdir -p /content/batch_audio_input /content/batch_annotations_output_html /content/batch_annotations_output_json
BATCH_INPUT_AUDIO_FOLDER = Path("/content/batch_audio_input")
BATCH_OUTPUT_HTML_REPORT_FILE = Path("/content/batch_annotations_output_html/batch_emotion_report.html")
BATCH_OUTPUT_JSON_FOLDER = Path("/content/batch_annotations_output_json")

DEMO_AUDIO_FILES_TO_DOWNLOAD = {
    "1.mp3": "https://huggingface.co/laion/Empathic-Insight-Voice-Small/resolve/main/1.mp3",
    "2.mp3": "https://huggingface.co/laion/Empathic-Insight-Voice-Small/resolve/main/2.mp3",
    "3.mp3": "https://huggingface.co/laion/Empathic-Insight-Voice-Small/resolve/main/3.mp3",
    "4.mp3": "https://huggingface.co/laion/Empathic-Insight-Voice-Small/resolve/main/4.mp3",
}

# --- Core Model & Audio Configuration ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_SEQ_LEN: int = 1500
WHISPER_EMBED_DIM: int = 768
PROJECTION_DIM_FOR_FULL_EMBED: int = 64
MLP_HIDDEN_DIMS: List[int] = [64, 32, 16]
MLP_DROPOUTS: List[float] = [0.0, 0.1, 0.1, 0.1]

SUPPORTED_AUDIO_EXTENSIONS = ['.mp3', '.wav', '.flac', '.m4a', '.ogg', '.aac']

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]
assert len(TARGET_EMOTION_KEYS_FOR_REPORT) == 40

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

# --- MLP Model Definition ---
class FullEmbeddingMLP(nn.Module): # (Same as before)
    def __init__(self,
                 seq_len: int,
                 embed_dim: int,
                 projection_dim: int,
                 mlp_hidden_dims: List[int],
                 mlp_dropout_rates: List[float]):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError(f"Dropout rates length error. Expected {len(mlp_hidden_dims) + 1}, got {len(mlp_dropout_rates)}")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([
                nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])
            ])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Helper Functions (Adapted) ---
def download_hf_mlp_checkpoints(repo_id: str, local_dir: Path, force_redownload: bool = False) -> Path:
    if local_dir.exists() and force_redownload:
        logging.info(f"Force redownload: Removing existing directory {local_dir}")
        shutil.rmtree(local_dir)
    if not local_dir.exists() or not any(local_dir.glob("*.pth")):
        logging.info(f"Downloading MLP checkpoints from {repo_id} to {local_dir}...")
        local_dir.mkdir(parents=True, exist_ok=True)
        try:
            snapshot_download(repo_id=repo_id, local_dir=local_dir, local_dir_use_symlinks=False, allow_patterns=["*.pth"], repo_type="model")
            logging.info(f"MLP checkpoints downloaded to {local_dir}")
        except Exception as e:
            logging.error(f"Failed to download MLP checkpoints from {repo_id}: {e}", exc_info=True)
            raise RuntimeError(f"Could not download MLP checkpoints from {repo_id}.")
    else:
        logging.info(f"MLP checkpoints found in local cache: {local_dir}")
    return local_dir

def download_demo_audio_files(target_dir: Path, files_to_download: Dict[str, str]):
    target_dir.mkdir(parents=True, exist_ok=True)
    logging.info(f"Downloading demo audio files to {target_dir}...")
    for filename, url in files_to_download.items():
        filepath = target_dir / filename
        if filepath.exists():
            logging.info(f"Demo file {filename} already exists. Skipping download.")
            continue
        try:
            response = requests.get(url, stream=True)
            response.raise_for_status()
            with open(filepath, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
            logging.info(f"Downloaded {filename} to {filepath}")
        except requests.exceptions.RequestException as e:
            logging.error(f"Error downloading {filename} from {url}: {e}")
        except IOError as e:
            logging.error(f"Error writing {filename} to {filepath}: {e}")

def get_mlp_model_paths_map(mlp_checkpoints_dir: Path, filename_map: Dict[str, str]) -> Dict[str, Path]:
    all_mapped_model_paths: Dict[str, Path] = {}
    if not mlp_checkpoints_dir.is_dir():
        logging.error(f"MLP checkpoints directory not found: {mlp_checkpoints_dir.resolve()}")
        return {}
    logging.info(f"Mapping MLP model files from {mlp_checkpoints_dir.resolve()}...")
    for pth_file in mlp_checkpoints_dir.glob("model_*_best.pth"):
        try:
            filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
            if filename_part in filename_map:
                target_key = filename_map[filename_part]
                if target_key in all_mapped_model_paths:
                    logging.warning(f"Duplicate mapping for target key '{target_key}'. Overwriting.")
                all_mapped_model_paths[target_key] = pth_file
            # else:
                # logging.debug(f"Filename part '{filename_part}' not in map. Skipping {pth_file.name}")
        except IndexError:
            logging.warning(f"Could not parse filename part from {pth_file.name}. Skipping.")
    logging.info(f"Found {len(all_mapped_model_paths)} MLP model paths based on map.")
    return all_mapped_model_paths

def load_whisper_model(model_id: str, device: torch.device) -> Tuple[Optional[WhisperForConditionalGeneration], Optional[WhisperProcessor]]:
    logging.info(f"Loading Whisper model '{model_id}' to {device}...")
    try:
        processor = WhisperProcessor.from_pretrained(model_id)
        model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device)
        model.eval()
        logging.info(f"Whisper model '{model_id}' loaded to {device}.")
        return model, processor
    except Exception as e:
        logging.error(f"Error loading Whisper model '{model_id}': {e}", exc_info=True)
        return None, None

# --- Function to load a SINGLE MLP model ---
def load_single_mlp_model(
    model_path: Path,
    target_key: str, # For logging
    mlp_device: torch.device,
    use_half_cfg: bool,
    use_compile_cfg: bool
) -> Optional[nn.Module]:
    target_dtype = torch.float16 if use_half_cfg and mlp_device.type == 'cuda' else torch.float32
    compile_mode = "reduce-overhead"
    logging.debug(f"Loading MLP for '{target_key}' from {model_path} to {mlp_device} (Half: {use_half_cfg}, Compile: {use_compile_cfg})")
    try:
        model_instance = FullEmbeddingMLP(
            seq_len=WHISPER_SEQ_LEN, embed_dim=WHISPER_EMBED_DIM,
            projection_dim=PROJECTION_DIM_FOR_FULL_EMBED,
            mlp_hidden_dims=MLP_HIDDEN_DIMS,
            mlp_dropout_rates=MLP_DROPOUTS
        )
        state_dict_content = torch.load(model_path, map_location='cpu')
        actual_state_dict = state_dict_content # Assuming raw state_dict from HF

        needs_stripping = any(k.startswith("_orig_mod.") for k in actual_state_dict.keys())
        if needs_stripping:
            stripped_state_dict = {
                k[len("_orig_mod."):] if k.startswith("_orig_mod.") else k: v
                for k, v in actual_state_dict.items()
            }
            actual_state_dict = stripped_state_dict

        model_instance.load_state_dict(actual_state_dict)
        model_instance.eval()

        if mlp_device.type == 'cuda' and use_half_cfg:
            model_instance = model_instance.to(dtype=target_dtype)
        model_instance = model_instance.to(mlp_device)

        if use_compile_cfg and hasattr(torch, 'compile') and mlp_device.type == 'cuda' and torch.__version__ >= "2.0.0":
            try:
                model_instance = torch.compile(model_instance, mode=compile_mode)
            except Exception as e_compile:
                logging.warning(f"torch.compile failed for MLP '{target_key}': {e_compile}. Using uncompiled.")

        logging.debug(f"Successfully loaded MLP for '{target_key}'.")
        return model_instance
    except Exception as e:
        logging.error(f"Failed to load/prepare MLP for '{target_key}' from {model_path}: {e}", exc_info=True)
        return None

# --- Determine Devices ---
_whisper_device_type = "cuda" if torch.cuda.is_available() else "cpu"
WHISPER_DEVICE = torch.device(_whisper_device_type)
_mlp_device_type = "cpu" if USE_CPU_OFFLOADING_FOR_MLPS else _whisper_device_type
MLP_DEVICE = torch.device(_mlp_device_type)

logging.info(f"Whisper will run on: {WHISPER_DEVICE}")
logging.info(f"MLPs will run on: {MLP_DEVICE}")

# --- Preparations ---
# 1. Create output directories
BATCH_INPUT_AUDIO_FOLDER.mkdir(parents=True, exist_ok=True)
BATCH_OUTPUT_HTML_REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)
BATCH_OUTPUT_JSON_FOLDER.mkdir(parents=True, exist_ok=True)

# 2. Download demo audio files into the BATCH_INPUT_AUDIO_FOLDER
download_demo_audio_files(BATCH_INPUT_AUDIO_FOLDER, DEMO_AUDIO_FILES_TO_DOWNLOAD)

# 3. Download MLP checkpoints from Hugging Face
downloaded_mlp_checkpoints_dir = download_hf_mlp_checkpoints(HF_MLP_REPO_ID, LOCAL_MLP_MODELS_DOWNLOAD_DIR)

# 4. Load Whisper Model (this stays loaded)
whisper_model_global, whisper_processor_global = load_whisper_model(WHISPER_MODEL_ID, WHISPER_DEVICE)
if not whisper_model_global or not whisper_processor_global:
    raise RuntimeError("Failed to load Whisper model. Cannot proceed.")

# 5. Get paths and map for all MLP models (these will be loaded one by one later)
# This dictionary {target_key: model_path_object} is crucial for Cells 3 & 4.
all_mlp_model_paths_dict: Dict[str, Path] = get_mlp_model_paths_map(
    downloaded_mlp_checkpoints_dir,
    FILENAME_PART_TO_TARGET_KEY_MAP
)
if not all_mlp_model_paths_dict:
    raise RuntimeError("No MLP model paths could be mapped or found. Cannot proceed.")

logging.info(f"--- Cell 2 Setup Complete. Whisper model loaded. {len(all_mlp_model_paths_dict)} MLP model paths identified. ---")
logging.info(f"Demo audio files are in: {BATCH_INPUT_AUDIO_FOLDER.resolve()}")
logging.info(f"HTML report will be saved to: {BATCH_OUTPUT_HTML_REPORT_FILE.resolve()}")
logging.info(f"JSON annotations will be saved to: {BATCH_OUTPUT_JSON_FOLDER.resolve()}")

In [None]:
# Cell 3: Batch HTML Report Generation (One MLP at a Time, with new HTML structure)

from IPython.display import HTML, display
import matplotlib.pyplot as plt
import base64 # For embedding images and audio in HTML
import io
import torch # For softmax

try:
    from pydub import AudioSegment
    PYDUB_AVAILABLE = True # Assume it was set in Cell 1 or 2 if used
except ImportError:
    PYDUB_AVAILABLE = False
    # logging.warning("pydub not available for audio embedding in HTML.") # Logging setup in Cell 1/2


# --- Helper functions (find_audio_files_in_folder, get_prediction_with_single_mlp,
# --- get_whisper_embedding_for_audio, convert_audio_to_base64_mp3_for_html,
# --- generate_waveform_plot_base64) are assumed to be the same as the previous good version.
# --- Ensure they are defined or copy them here if running this cell standalone after Cell 2.

# Re-define if necessary from previous correct version (or ensure Cell 2 definitions are accessible)
def find_audio_files_in_folder(input_dir: Path) -> List[Path]:
    audio_files = []
    # Assuming SUPPORTED_AUDIO_EXTENSIONS is defined in Cell 2
    for ext in SUPPORTED_AUDIO_EXTENSIONS:
        audio_files.extend(list(input_dir.rglob(f'*{ext}')))
    if not audio_files:
        logging.warning(f"No audio files found in {input_dir.resolve()}.")
    else:
        logging.info(f"Found {len(audio_files)} audio files in {input_dir.resolve()} for HTML report.")
    return sorted(audio_files)

@torch.no_grad()
def get_prediction_with_single_mlp(
    whisper_embedding: torch.Tensor,
    mlp_model: nn.Module,
    current_mlp_device: torch.device
) -> float:
    embedding_for_mlp = whisper_embedding.to(current_mlp_device)
    try:
        current_mlp_dtype = next(mlp_model.parameters()).dtype
        prediction_tensor = mlp_model(embedding_for_mlp.to(current_mlp_dtype))
        return prediction_tensor.item()
    except Exception as e:
        # logging.error(f"Error predicting with a single MLP: {e}", exc_info=True) # logging from Cell 1/2
        return float('nan')

@torch.no_grad()
def get_whisper_embedding_for_audio(
    audio_waveform: np.ndarray,
    loaded_whisper_model: nn.Module,
    loaded_whisper_processor: any,
    current_whisper_device: torch.device
) -> Optional[torch.Tensor]:
    try:
        # Assuming SAMPLING_RATE, WHISPER_SEQ_LEN, WHISPER_EMBED_DIM defined in Cell 2
        input_features = loaded_whisper_processor(
            audio_waveform, sampling_rate=SAMPLING_RATE, return_tensors="pt"
        ).input_features.to(current_whisper_device).to(loaded_whisper_model.dtype)

        encoder_outputs = loaded_whisper_model.get_encoder()(input_features=input_features)
        embedding = encoder_outputs.last_hidden_state

        current_seq_len = embedding.shape[1]
        if current_seq_len < WHISPER_SEQ_LEN:
            padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                                  device=current_whisper_device, dtype=embedding.dtype)
            embedding = torch.cat((embedding, padding), dim=1)
        elif current_seq_len > WHISPER_SEQ_LEN:
            embedding = embedding[:, :WHISPER_SEQ_LEN, :]
        return embedding
    except Exception as e:
        # logging.error(f"Error generating Whisper embedding: {e}", exc_info=True)
        return None

def convert_audio_to_base64_mp3_for_html(audio_path_str: str) -> Optional[str]:
    if not PYDUB_AVAILABLE: return None
    try:
        # Assuming SAMPLING_RATE defined in Cell 2
        audio = AudioSegment.from_file(audio_path_str)
        audio = audio.set_channels(1).set_frame_rate(SAMPLING_RATE)
        mp3_buffer = io.BytesIO()
        audio.export(mp3_buffer, format="mp3", bitrate="96k")
        return base64.b64encode(mp3_buffer.getvalue()).decode('utf-8')
    except Exception as e:
        # logging.warning(f"Pydub/ffmpeg error for {audio_path_str}: {e}. No audio player.")
        return None

def generate_waveform_plot_base64(waveform_data: np.ndarray, sr: int) -> str:
    try:
        plt.figure(figsize=(10, 2.5))
        librosa.display.waveshow(waveform_data, sr=sr, color='royalblue', alpha=0.7)
        plt.title("Waveform", fontsize=10)
        plt.xlabel("Time (s)", fontsize=8); plt.ylabel("Amplitude", fontsize=8)
        plt.xticks(fontsize=7); plt.yticks(fontsize=7)
        plt.tight_layout()
        img_buffer = io.BytesIO()
        plt.savefig(img_buffer, format='png', bbox_inches='tight'); plt.close()
        img_buffer.seek(0)
        return base64.b64encode(img_buffer.read()).decode('utf-8')
    except Exception as e:
        # logging.warning(f"Waveform plot failed: {e}")
        return ""
# --- End of re-definitions/imports ---


def generate_batch_html_report_updated(
    all_files_scores: Dict[Path, Dict[str, float]], # {filepath: {dim_key: raw_score, ...}}
    target_emotion_keys_list: List[str], # The 40 primary emotion keys
    all_attribute_keys_list: List[str], # All other keys considered attributes
    output_html_path: Path
):
    # HTML Explanation Section (as provided by user)
    explanation_html = """
        <div class="explanation-section">
            <h3>Interpretation of Scores</h3>
            <p>The models predict raw scores. For the 40 Emotional Categories, these raw scores are also used to calculate a normalized <strong>Softmax Probability</strong>, indicating the relative likelihood of each emotion.
            Higher raw scores (shown in parentheses for emotions, or directly for attributes) generally suggest a stronger presence or intensity, aligning with the original annotation scales used during training.</p>

            <h4>Emotional Categories (40)</h4>
            <p><em>Original Annotation Scale: 0 (Not present at all) to 4 (Extremely present).</em><br>
            The table below displays: <strong>Softmax Probability</strong> (Raw Model Score)</p>

            <h4>Attribute Dimensions</h4>
            <p><em>Original Annotation Scale: Varies per dimension (detailed below).</em><br>
            The table for these dimensions displays: Raw Model Score</p>

            <div class="dimension-details-section">
                <h4>Details for Attribute Dimensions:</h4>
                <p><strong>Valence:</strong> <em>Range: -3 (Ext. Negative) to +3 (Ext. Positive).</em> 0=Neutral.</p>
                <p><strong>Arousal:</strong> <em>Range: 0 (Very Calm) to 4 (Very Excited).</em> 2=Neutral.</p>
                <p><strong>Submissive vs. Dominant:</strong> <em>Range: -3 (Ext. Submissive) to +3 (Ext. Dominant).</em> 0=Neutral.</p>
                <p><strong>Age:</strong> <em>Range: 0 (Infant/Toddler) to 6 (Very Old).</em> (e.g., 2=Teenager, 4=Adult).</p>
                <p><strong>Gender:</strong> <em>Range: -2 (Very Masculine) to +2 (Very Feminine).</em> 0=Neutral/Unsure.</p>
                <p><strong>Serious vs. Humorous:</strong> <em>Range: 0 (Very Serious) to 4 (Very Humorous).</em> 2=Neutral.</p>
                <p><strong>Vulnerable vs. Emotionally Detached:</strong> <em>Range: 0 (Very Vulnerable) to 4 (Very Detached).</em> 2=Neutral.</p>
                <p><strong>Confident vs. Hesitant:</strong> <em>Range: 0 (Very Confident) to 4 (Very Hesitant).</em> 2=Neutral.</p>
                <p><strong>Warm vs. Cold:</strong> <em>Range: -2 (Very Cold) to +2 (Very Warm).</em> 0=Neutral.</p>
                <p><strong>Monotone vs. Expressive:</strong> <em>Range: 0 (Very Monotone) to 4 (Very Expressive).</em> 2=Neutral.</p>
                <p><strong>High-Pitched vs. Low-Pitched:</strong> <em>Range: 0 (Very High-Pitched) to 4 (Very Low-Pitched).</em> 2=Neutral.</p>
                <p><strong>Soft vs. Harsh:</strong> <em>Range: -2 (Very Harsh) to +2 (Very Soft).</em> 0=Neutral.</p>
                <p><strong>Authenticity:</strong> <em>Range: 0 (Very Artificial) to 4 (Very Genuine).</em> 2=Neutral.</p>
                <p><strong>Recording Quality:</strong> <em>Range: 0 (Very Low) to 4 (Very High).</em> 2=Decent.</p>
                <p><strong>Background Noise:</strong> <em>Range: 0 (No Noise) to 3 (Intense Noise).</em></p>
            </div>
        </div>
    """

    html_content = [f"""
<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"><title>Audio Inference Report</title>
<style>
    body {{ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; margin: 15px; background-color: #f0f2f5; color: #333; font-size: 14px; }}
    .report-container {{ max-width: 1000px; margin: auto; }}
    .audio-item {{ background-color: #fff; border: 1px solid #e0e0e0; margin-bottom: 20px; padding: 15px; border-radius: 8px; box-shadow: 0 2px 5px rgba(0,0,0,0.07); }}
    h1 {{ text-align: center; color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; margin-bottom: 10px; }} /* Reduced margin-bottom */
    h2 {{ margin-top: 0; color: #34495e; font-size: 1.3em; border-bottom: 1px dashed #bdc3c7; padding-bottom: 6px; margin-bottom: 12px; }}
    h3 {{ color: #555; font-size: 1.1em; margin-top:15px; margin-bottom:8px;}}
    h4 {{ color: #666; font-size: 1.0em; margin-top:10px; margin-bottom:5px;}}
    table {{ width: 100%; border-collapse: collapse; margin-bottom: 15px; font-size: 0.9em;}}
    th, td {{ border: 1px solid #ddd; padding: 7px; text-align: left; }}
    th {{ background-color: #ecf0f1; font-weight: 600; }}
    .top1 {{ background-color: #ffdddd !important; }} .top2 {{ background-color: #ffe8cc !important; }} .top3 {{ background-color: #ddffdd !important; }}
    audio {{ width: 100%; margin-top: 8px; margin-bottom: 8px; }}
    .waveform-img {{ width:100%; max-width:500px; margin-bottom:10px; border:1px solid #eee; border-radius:4px; display:block; margin-left:auto; margin-right:auto;}}
    .explanation-section {{ background-color: #e9ecef; padding: 15px; border-radius: 5px; margin-bottom: 25px; border: 1px solid #ced4da;}}
    .explanation-section p {{font-size: 0.95em; line-height: 1.5;}}
    .dimension-details-section p {{ margin-bottom: 3px; font-size:0.9em; }}
    .dimension-details-section strong {{ color: #34495e; }}
</style></head><body><div class="report-container"><h1>Audio Inference Report</h1>
{explanation_html}
"""] # Insert explanation here

    if not all_files_scores:
        html_content.append("<p>No audio files were processed or no results to display.</p>")

    for audio_file_path, raw_predictions in all_files_scores.items():
        audio_file_name = audio_file_path.name
        html_content.append(f"<div class='audio-item'><h2>File: {audio_file_name}</h2>")

        try:
            wf_data, wf_sr = librosa.load(str(audio_file_path), sr=SAMPLING_RATE, mono=True)
            waveform_b64 = generate_waveform_plot_base64(wf_data, wf_sr)
            if waveform_b64:
                html_content.append(f"<img src='data:image/png;base64,{waveform_b64}' alt='Waveform' class='waveform-img'/>")
            del wf_data
        except Exception: pass

        base64_audio_mp3 = convert_audio_to_base64_mp3_for_html(str(audio_file_path))
        if base64_audio_mp3:
            html_content.append(f"<audio controls src='data:audio/mp3;base64,{base64_audio_mp3}'></audio>")
        else:
            html_content.append("<p><i>Audio player not available.</i></p>")

        # Separate raw scores and calculate softmax for emotions
        emotion_raw_scores_dict = {k: raw_predictions.get(k, float('nan')) for k in target_emotion_keys_list}
        attribute_raw_scores_dict = {k: raw_predictions.get(k, float('nan')) for k in all_attribute_keys_list if k in raw_predictions}

        # Softmax calculation
        emotion_scores_for_softmax = [emotion_raw_scores_dict.get(k, -float('inf')) for k in target_emotion_keys_list] # Use -inf for missing, to avoid NaN in softmax
        with torch.no_grad():
            softmax_probs_tensor = torch.softmax(torch.tensor(emotion_scores_for_softmax, dtype=torch.float32), dim=0)
        softmax_probs_dict = {k: softmax_probs_tensor[i].item() for i, k in enumerate(target_emotion_keys_list)}

        # Sort by raw score for highlighting
        sorted_emotions_by_raw_score = sorted(emotion_raw_scores_dict.items(), key=lambda item: item[1], reverse=True)
        top_raw_score_keys = [item[0] for item in sorted_emotions_by_raw_score[:3]]


        html_content.append("<h3>Emotional Categories (40)</h3><table>")
        # Display in the original predefined order
        emotions_in_table_order = target_emotion_keys_list[:]
        num_rows_target = (len(emotions_in_table_order) + 1) // 2
        for i in range(num_rows_target):
            html_content.append("<tr>")
            for col in range(2):
                idx = i + col * num_rows_target
                if idx < len(emotions_in_table_order):
                    key = emotions_in_table_order[idx]
                    softmax_val = softmax_probs_dict.get(key, float('nan'))
                    raw_val = emotion_raw_scores_dict.get(key, float('nan'))

                    css_class = ""
                    if key == (top_raw_score_keys[0] if len(top_raw_score_keys)>0 else None): css_class = "top1"
                    elif key == (top_raw_score_keys[1] if len(top_raw_score_keys)>1 else None): css_class = "top2"
                    elif key == (top_raw_score_keys[2] if len(top_raw_score_keys)>2 else None): css_class = "top3"

                    html_content.append(f"<td class='{css_class}'>{key}</td><td class='{css_class}'>{softmax_val:.4f} ({raw_val:.4f})</td>")
                else:
                    html_content.append("<td></td><td></td>")
            html_content.append("</tr>")
        html_content.append("</table>")

        if attribute_raw_scores_dict:
            html_content.append("<h3>Attribute Dimensions</h3><table>")
            sorted_additional = sorted(attribute_raw_scores_dict.items())
            num_rows_additional = (len(sorted_additional) + 1) // 2
            for i in range(num_rows_additional):
                html_content.append("<tr>")
                for col in range(2):
                    idx = i + col * num_rows_additional
                    if idx < len(sorted_additional):
                        key, score = sorted_additional[idx]
                        html_content.append(f"<td>{key}</td><td>{score:.4f}</td>")
                    else:
                        html_content.append("<td></td><td></td>")
                html_content.append("</tr>")
            html_content.append("</table>")
        html_content.append("</div>")

    html_content.append("</div></body></html>")
    final_html = "".join(html_content)

    try:
        with open(output_html_path, "w", encoding="utf-8") as f:
            f.write(final_html)
        logging.info(f"Batch HTML report saved to: {output_html_path.resolve()}")
    except Exception as e:
        logging.error(f"Error writing HTML report: {e}", exc_info=True)

    return final_html


# --- Main execution for Cell 3 ---
logging.info("--- Starting Cell 3: Batch HTML Report Generation (Updated) ---")
audio_files_for_html_report = find_audio_files_in_folder(BATCH_INPUT_AUDIO_FOLDER) # BATCH_INPUT_AUDIO_FOLDER from Cell 2
aggregated_results_for_html: Dict[Path, Dict[str, float]] = {fp: {} for fp in audio_files_for_html_report}

if not audio_files_for_html_report:
    logging.warning("No audio files in input folder for HTML report. Skipping.")
    display(HTML("<p><b>No audio files found in input folder. HTML report generation skipped.</b></p>"))
else:
    if not whisper_model_global or not whisper_processor_global: # from Cell 2
         raise RuntimeError("Whisper model not available for Cell 3.")
    if not all_mlp_model_paths_dict: # from Cell 2
         raise RuntimeError("MLP model paths not available for Cell 3.")

    total_mlps_to_process = len(all_mlp_model_paths_dict)
    mlp_processed_count = 0

    for mlp_target_key, mlp_model_path in all_mlp_model_paths_dict.items(): # all_mlp_model_paths_dict from Cell 2
        mlp_processed_count += 1
        logging.info(f"Processing MLP {mlp_processed_count}/{total_mlps_to_process}: '{mlp_target_key}' for HTML report")

        # MLP_DEVICE, USE_HALF_PRECISION_FOR_MLPS, USE_TORCH_COMPILE_FOR_MLPS from Cell 2
        current_mlp_model = load_single_mlp_model( # load_single_mlp_model from Cell 2
            mlp_model_path, mlp_target_key, MLP_DEVICE,
            USE_HALF_PRECISION_FOR_MLPS, USE_TORCH_COMPILE_FOR_MLPS
        )
        if not current_mlp_model:
            logging.error(f"Skipping MLP '{mlp_target_key}' due to loading error.")
            for audio_f_path in audio_files_for_html_report:
                 aggregated_results_for_html[audio_f_path][mlp_target_key] = float('nan')
            continue

        for audio_file_path in audio_files_for_html_report:
            try:
                # SAMPLING_RATE, MAX_AUDIO_SECONDS from Cell 2
                waveform, sr = librosa.load(str(audio_file_path), sr=SAMPLING_RATE, mono=True)
                max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
                if len(waveform) > max_samples: waveform = waveform[:max_samples]

                # WHISPER_DEVICE from Cell 2
                whisper_embedding = get_whisper_embedding_for_audio(
                    waveform, whisper_model_global, whisper_processor_global, WHISPER_DEVICE
                )
                del waveform; gc.collect()

                if whisper_embedding is not None:
                    prediction = get_prediction_with_single_mlp(
                        whisper_embedding, current_mlp_model, MLP_DEVICE
                    )
                    aggregated_results_for_html[audio_file_path][mlp_target_key] = prediction
                    del whisper_embedding;
                else:
                    aggregated_results_for_html[audio_file_path][mlp_target_key] = float('nan')
            except Exception as e_audio:
                logging.error(f"Error processing audio {audio_file_path.name} for MLP {mlp_target_key}: {e_audio}")
                aggregated_results_for_html[audio_file_path][mlp_target_key] = float('nan')

            if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

        del current_mlp_model
        gc.collect()
        if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Derive attribute keys (all keys from MLP models that are NOT in the 40 emotion list)
    # all_mlp_model_paths_dict.keys() gives all predictable dimension keys
    # TARGET_EMOTION_KEYS_FOR_REPORT is defined in Cell 2
    all_predictable_keys = set(all_mlp_model_paths_dict.keys())
    emotion_keys_set = set(TARGET_EMOTION_KEYS_FOR_REPORT)
    attribute_keys_derived = sorted(list(all_predictable_keys - emotion_keys_set))

    # BATCH_OUTPUT_HTML_REPORT_FILE from Cell 2
    final_html_output = generate_batch_html_report_updated(
        aggregated_results_for_html,
        TARGET_EMOTION_KEYS_FOR_REPORT, # from Cell 2
        attribute_keys_derived,
        BATCH_OUTPUT_HTML_REPORT_FILE
    )
    display(HTML(final_html_output))
    logging.info("--- Cell 3: Batch HTML Report Generation (Updated) Finished ---")

In [None]:
# Cell 4: Batch JSON Annotation (One MLP at a Time)

def batch_annotate_to_json(
    input_audio_folder: Path,
    output_json_folder: Path,
    mlp_model_paths_map: Dict[str, Path], # {target_key: model_path}
    global_whisper_model: WhisperForConditionalGeneration,
    global_whisper_processor: WhisperProcessor
):
    audio_files_for_json = find_audio_files_in_folder(input_audio_folder)
    output_json_folder.mkdir(parents=True, exist_ok=True)

    if not audio_files_for_json:
        logging.warning("No audio files in input folder for JSON annotation. Skipping.")
        return

    logging.info(f"Starting JSON batch annotation for {len(audio_files_for_json)} files...")
    total_files = len(audio_files_for_json)

    for i, audio_file_path in enumerate(audio_files_for_json):
        logging.info(f"--- JSON Processing file {i+1}/{total_files}: {audio_file_path.name} ---")
        current_file_all_scores: Dict[str, float] = {}

        # Get Whisper embedding for this audio file (once)
        # logging.debug(f"  Getting Whisper embedding for JSON: {audio_file_path.name}")
        try:
            waveform, sr = librosa.load(str(audio_file_path), sr=SAMPLING_RATE, mono=True)
            max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
            if len(waveform) > max_samples: waveform = waveform[:max_samples]

            audio_whisper_embedding = get_whisper_embedding_for_audio(
                waveform, global_whisper_model, global_whisper_processor, WHISPER_DEVICE
            )
            del waveform; gc.collect()
        except Exception as e_emb:
            logging.error(f"Could not get embedding for {audio_file_path.name}: {e_emb}. Skipping this file for JSON.")
            if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()
            continue # Skip to next audio file

        if audio_whisper_embedding is None:
            logging.warning(f"Whisper embedding failed for {audio_file_path.name}. Skipping this file for JSON.")
            if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()
            continue


        # Inner loop: Iterate through MLPs for the current audio file
        total_mlps_to_process_for_file = len(mlp_model_paths_map)
        mlp_processed_count_for_file = 0
        for mlp_target_key, mlp_model_path in mlp_model_paths_map.items():
            mlp_processed_count_for_file +=1
            # logging.debug(f"    MLP {mlp_processed_count_for_file}/{total_mlps_to_process_for_file} ('{mlp_target_key}') for {audio_file_path.name}")

            current_mlp_model = load_single_mlp_model(
                mlp_model_path, mlp_target_key, MLP_DEVICE,
                USE_HALF_PRECISION_FOR_MLPS, USE_TORCH_COMPILE_FOR_MLPS
            )
            if not current_mlp_model:
                logging.error(f"Skipping MLP '{mlp_target_key}' for file '{audio_file_path.name}' due to loading error.")
                current_file_all_scores[mlp_target_key] = float('nan')
                continue

            prediction = get_prediction_with_single_mlp(
                audio_whisper_embedding, current_mlp_model, MLP_DEVICE # Pass the cached embedding
            )
            current_file_all_scores[mlp_target_key] = prediction

            del current_mlp_model # Unload current MLP
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()

        del audio_whisper_embedding # Free embedding for this audio file
        if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()


        # Save all collected scores for the current audio file to JSON
        relative_path_from_input = audio_file_path.relative_to(input_audio_folder)
        output_json_subfolder = output_json_folder / relative_path_from_input.parent
        output_json_subfolder.mkdir(parents=True, exist_ok=True)

        output_json_filename = audio_file_path.stem + ".json"
        output_json_path = output_json_subfolder / output_json_filename

        try:
            with open(output_json_path, 'w') as f:
                json.dump(current_file_all_scores, f, indent=2)
            logging.info(f"Saved annotations to {output_json_path}")
        except IOError as e:
            logging.error(f"Error writing JSON for {audio_file_path.name}: {e}")
        except TypeError as e: # Handle potential non-serializable if NaNs are not handled by json
             logging.error(f"TypeError saving JSON for {audio_file_path.name} (likely NaN issue): {e}")
             # Try again by converting NaNs to null or string
             cleaned_scores = {k: (None if np.isnan(v) else v) for k,v in current_file_all_scores.items()}
             try:
                 with open(output_json_path, 'w') as f:
                     json.dump(cleaned_scores, f, indent=2)
                 logging.info(f"Saved annotations (with NaN->null) to {output_json_path}")
             except Exception as e2:
                 logging.error(f"Still failed to save JSON for {audio_file_path.name} after NaN handling: {e2}")


    logging.info(f"--- JSON Batch Annotation Finished for {len(audio_files_for_json)} files. ---")


# --- Main execution for Cell 4 ---
logging.info("--- Starting Cell 4: Batch JSON Annotation ---")

# Using whisper_model_global, whisper_processor_global, and all_mlp_model_paths_dict from Cell 2
if not whisper_model_global or not whisper_processor_global:
     raise RuntimeError("Whisper model not available for Cell 4.")
if not all_mlp_model_paths_dict:
     raise RuntimeError("MLP model paths not available for Cell 4.")

batch_annotate_to_json(
    BATCH_INPUT_AUDIO_FOLDER,
    BATCH_OUTPUT_JSON_FOLDER,
    all_mlp_model_paths_dict, # From Cell 2
    whisper_model_global,     # From Cell 2
    whisper_processor_global  # From Cell 2
)
display(HTML(f"<p><b>JSON batch annotation complete. Check logs and output folder: '{BATCH_OUTPUT_JSON_FOLDER.resolve()}'</b></p>"))
logging.info("--- Cell 4: Batch JSON Annotation Finished ---")

In [None]:
# Cell 5: Standalone Video Processing Script with Emotion Subtitles

# --- How to Use This Script ---
# 1. Dependencies:
#    This script requires several Python libraries. If you haven't installed them in your
#    current environment, uncomment and run the following pip install command (or install them manually):
#    !pip install transformers torch torchvision torchaudio moviepy librosa numpy --quiet
#
# 2. Configure Paths (IMPORTANT!):
#    - HARDCODED_INPUT_FOLDERS: List of paths to folders containing your input video files.
#      Example: ["/content/my_videos_folder_1", "/content/another_video_collection"]
#    - HARDCODED_OUTPUT_FOLDER: Path where processed videos and SRT files will be saved.
#      Example: "/content/processed_emotion_videos"
#    - HARDCODED_MLP_MODELS_DIR: ABSOLUTE PATH to the directory where the
#      'Empathic-Insight-Voice-Small' MLP model checkpoints (*.pth files) are stored.
#      This is the folder you downloaded from laion/Empathic-Insight-Voice-Small,
#      e.g., the 'empathic_insight_voice_small_models_downloaded' folder from previous cells,
#      or if you downloaded them elsewhere, point to that full path.
#      Example: "/content/empathic_insight_voice_small_models_downloaded" OR
#               "/mnt/nvme/empathic-insights-voice-small/" (as in your example)
#
# 3. FFmpeg (Optional but Recommended for Burned-in Subtitles):
#    - For burning subtitles directly into the video, FFmpeg is required.
#    - If FFmpeg is in your system's PATH, HARDCODED_FFMPEG_PATH can be `None`.
#    - Otherwise, provide the full path to the ffmpeg executable.
#      Example: HARDCODED_FFMPEG_PATH = "/usr/bin/ffmpeg"
#    - If FFmpeg is not found, SRT files will still be generated, but videos with
#      burned-in subtitles will not.
#
# 4. Run the Script:
#    - After configuring, execute this cell.
#
# 5. Output:
#    - For each input video, an SRT file with emotion predictions per chunk will be created.
#    - If FFmpeg is available, a new MP4 video file with these emotions burned in as
#      subtitles will also be created.
#    - All output files will be in the HARDCODED_OUTPUT_FOLDER.
#      Filenames will be like: <InputFolderName>_<VideoBaseName>_<Suffix>.srt/.mp4
# --- End of How to Use ---

# --- Start of Python Script ---
# (The script content from your "subtitle burn-in script" example,
#  adapted for this cell and requirements, will go here)

# Ensure we are in a clean state for this script's execution context
_video_script_initialized = False
if '_video_script_initialized' not in globals() or not _video_script_initialized:
    # Python standard library imports
    import os
    import gc
    import logging
    import time
    from pathlib import Path
    from typing import List, Dict, Tuple, Any, Optional
    import sys
    import shutil
    import subprocess # For FFmpeg

    # Third-party library imports
    import numpy as np
    import torch
    import torch.nn as nn
    from moviepy.editor import VideoFileClip
    import librosa
    from transformers import WhisperProcessor, WhisperForConditionalGeneration

    _video_script_initialized = True
    print("Video script environment initialized/re-initialized.")


# --- Configuration Block (Edit these values directly) ---
HARDCODED_INPUT_FOLDERS: List[str] = ["/content/sample_input_videos"] # << USER: EDIT THIS
HARDCODED_OUTPUT_FOLDER: str = "/content/processed_emotion_videos"    # << USER: EDIT THIS
HARDCODED_MLP_MODELS_DIR: str = "/content/empathic_insight_voice_small_models_downloaded" # << USER: EDIT THIS (e.g. result of Cell 2)

HARDCODED_OUTPUT_SUFFIX: str = "emotion_subs"
HARDCODED_CHUNK_DURATION_S: float = 2.5  # Duration of each audio chunk for emotion analysis
HARDCODED_FFMPEG_PATH: Optional[str] = None # e.g., "/usr/bin/ffmpeg" or None to auto-detect
HARDCODED_TEMP_AUDIO_BASE_NAME: str = "temp_video_audio_"

# --- General Configuration ---
DEFAULT_PROCESSING_SAMPLING_RATE: int = 16000 # For audio extraction
DEFAULT_TEXT_TOP_N_EMOTIONS: int = 3 # How many top emotions to show in SRT
VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.avi', '.mov', '.webm', '.flv', '.wmv'}

# --- Parameters for the Emotion Prediction MLP part (Must match Empathic-Insight-Voice-Small) ---
USE_CPU_OFFLOADING_FOR_MLPS: bool = True
USE_HALF_PRECISION_FOR_MLPS: bool = True
USE_TORCH_COMPILE_FOR_MLPS: bool = True # Requires PyTorch 2.0+
EMOTION_WHISPER_MODEL_ID: str = "mkrausio/EmoWhisper-AnS-Small-v0.1"
MLP_SAMPLING_RATE: int = 16000 # For MLP model input
WHISPER_SEQ_LEN: int = 1500
WHISPER_EMBED_DIM: int = 768
PROJECTION_DIM_FOR_FULL_EMBED: int = 64
MLP_HIDDEN_DIMS: List[int] = [64, 32, 16]
MLP_DROPOUTS: List[float] = [0.0, 0.1, 0.1, 0.1]

TARGET_EMOTION_KEYS: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]
assert len(TARGET_EMOTION_KEYS) == 40

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability", "Infatuation": "Infatuation",
    "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice", "Monotone_vs._Expressive": "Monotone_vs._Expressive",
    "Pain": "Pain", "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence", "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}
# --- End of Configuration Block ---

# --- Global Variables for Models (initialized once per script run) ---
_video_emotion_whisper_model: Optional[WhisperForConditionalGeneration] = None
_video_emotion_whisper_processor: Optional[WhisperProcessor] = None
_video_mlp_models_ensemble: Optional[Dict[str, nn.Module]] = None
_video_emotion_whisper_device: Optional[torch.device] = None
_video_mlp_device: Optional[torch.device] = None
_video_MLP_MODELS_BASE_DIR_CONFIG: Optional[str] = None

def setup_video_script_logging(log_level=logging.INFO):
    """Sets up basic logging for this script."""
    # Clear existing handlers for this specific logger context if needed,
    # but usually basicConfig handles it if run once.
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
        handler.close()
    logging.basicConfig(
        level=log_level,
        format='%(asctime)s [VIDEO_SCRIPT] [%(levelname)-7s] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[logging.StreamHandler(sys.stdout)],
        force=True # Ensure it takes over in notebook env
    )

class VideoFullEmbeddingMLP(nn.Module): # Identical to MLP class in previous cells
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

def get_video_mlp_model_paths_map(mlp_models_dir_str, filename_map):
    mlp_models_dir = Path(mlp_models_dir_str)
    mapped_paths: Dict[str, Path] = {}
    if not mlp_models_dir.is_dir():
        logging.error(f"[VIDEO_SCRIPT] MLP models directory not found: {mlp_models_dir.resolve()}")
        return {}
    for fname_part, target_key in filename_map.items():
        chkpt_path = mlp_models_dir / f"model_{fname_part}_best.pth"
        if chkpt_path.is_file():
            if target_key in mapped_paths:
                 logging.warning(f"[VIDEO_SCRIPT] Duplicate mapping for '{target_key}'. Overwriting.")
            mapped_paths[target_key] = chkpt_path
    if not mapped_paths:
        logging.warning(f"[VIDEO_SCRIPT] No MLP models mapped from {mlp_models_dir_str}.")
    return mapped_paths

def load_video_emotion_models_once():
    global _video_emotion_whisper_model, _video_emotion_whisper_processor, _video_mlp_models_ensemble
    global _video_emotion_whisper_device, _video_mlp_device, _video_MLP_MODELS_BASE_DIR_CONFIG

    if _video_emotion_whisper_model is not None: return # Already loaded

    _video_MLP_MODELS_BASE_DIR_CONFIG = HARDCODED_MLP_MODELS_DIR # Use hardcoded value
    if not _video_MLP_MODELS_BASE_DIR_CONFIG or not Path(_video_MLP_MODELS_BASE_DIR_CONFIG).is_dir():
        logging.critical(f"[VIDEO_SCRIPT] MLP Models Directory '{_video_MLP_MODELS_BASE_DIR_CONFIG}' is invalid. Cannot load models.")
        raise FileNotFoundError(f"Invalid MLP Models Directory: {_video_MLP_MODELS_BASE_DIR_CONFIG}")

    logging.info("[VIDEO_SCRIPT] Loading emotion prediction models for video script...")
    _video_emotion_whisper_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _video_mlp_device = torch.device("cpu") if USE_CPU_OFFLOADING_FOR_MLPS else _video_emotion_whisper_device

    logging.info(f"[VIDEO_SCRIPT] Emotion Whisper on: {_video_emotion_whisper_device}, MLPs on: {_video_mlp_device}")

    try:
        _video_emotion_whisper_processor = WhisperProcessor.from_pretrained(EMOTION_WHISPER_MODEL_ID)
        _video_emotion_whisper_model = WhisperForConditionalGeneration.from_pretrained(EMOTION_WHISPER_MODEL_ID)
        _video_emotion_whisper_model = _video_emotion_whisper_model.to(_video_emotion_whisper_device).eval()
    except Exception as e:
        logging.error(f"[VIDEO_SCRIPT] Error loading Emotion Whisper model: {e}", exc_info=True); raise

    all_model_paths = get_video_mlp_model_paths_map(_video_MLP_MODELS_BASE_DIR_CONFIG, FILENAME_PART_TO_TARGET_KEY_MAP)
    if not all_model_paths:
        raise ValueError("[VIDEO_SCRIPT] No MLP model paths found/mapped. Emotion prediction cannot proceed.")

    _video_mlp_models_ensemble = {}
    target_dtype = torch.float16 if USE_HALF_PRECISION_FOR_MLPS and _video_mlp_device.type == 'cuda' else torch.float32

    for target_key, model_path in all_model_paths.items():
        try:
            model_instance = VideoFullEmbeddingMLP(
                WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
                MLP_HIDDEN_DIMS, MLP_DROPOUTS
            )
            state_dict = torch.load(model_path, map_location='cpu')
            actual_state_dict = state_dict # Assuming raw state_dict from HF
            if any(k.startswith("_orig_mod.") for k in actual_state_dict.keys()): # Handle torch.compile
                actual_state_dict = {k[len("_orig_mod."):] if k.startswith("_orig_mod.") else k: v for k, v in actual_state_dict.items()}

            model_instance.load_state_dict(actual_state_dict)
            model_instance.eval()
            if _video_mlp_device.type == 'cuda' and USE_HALF_PRECISION_FOR_MLPS:
                model_instance = model_instance.to(dtype=target_dtype)
            model_instance = model_instance.to(_video_mlp_device)
            if USE_TORCH_COMPILE_FOR_MLPS and hasattr(torch, 'compile') and _video_mlp_device.type == 'cuda' and torch.__version__ >= "2.0.0":
                model_instance = torch.compile(model_instance, mode="reduce-overhead")
            _video_mlp_models_ensemble[target_key] = model_instance
        except Exception as e:
            logging.error(f"[VIDEO_SCRIPT] Failed to load MLP '{target_key}': {e}", exc_info=True)
    logging.info(f"[VIDEO_SCRIPT] Loaded {len(_video_mlp_models_ensemble)} MLP models.")


@torch.no_grad()
def predict_emotions_for_video_waveform(audio_waveform: np.ndarray) -> Optional[Dict[str, float]]:
    if _video_emotion_whisper_model is None:
        logging.error("[VIDEO_SCRIPT] Video emotion models not loaded. Call load_video_emotion_models_once().")
        return None
    try:
        input_features = _video_emotion_whisper_processor(
            audio_waveform, sampling_rate=MLP_SAMPLING_RATE, return_tensors="pt"
        ).input_features.to(_video_emotion_whisper_device).to(_video_emotion_whisper_model.dtype)

        embedding = _video_emotion_whisper_model.get_encoder()(input_features=input_features).last_hidden_state
        current_seq_len = embedding.shape[1]
        if current_seq_len < WHISPER_SEQ_LEN:
            padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                                  device=_video_emotion_whisper_device, dtype=embedding.dtype)
            embedding = torch.cat((embedding, padding), dim=1)
        elif current_seq_len > WHISPER_SEQ_LEN:
            embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    except Exception as e:
        logging.error(f"[VIDEO_SCRIPT] Error generating Whisper embedding for video: {e}", exc_info=True); return None

    predictions: Dict[str, float] = {}
    embedding_for_mlps = embedding.to(_video_mlp_device)
    del embedding;
    if _video_emotion_whisper_device.type == 'cuda' and _video_emotion_whisper_device != _video_mlp_device: torch.cuda.empty_cache()

    for key, mlp in _video_mlp_models_ensemble.items():
        try:
            dtype = next(mlp.parameters()).dtype
            pred_tensor = mlp(embedding_for_mlps.to(dtype))
            predictions[key] = pred_tensor.item()
        except Exception as e:
            logging.error(f"[VIDEO_SCRIPT] Error predicting with MLP '{key}': {e}"); predictions[key] = float('nan')

    del embedding_for_mlps;
    if _video_mlp_device.type == 'cuda': torch.cuda.empty_cache();
    gc.collect()
    return predictions

def get_video_top_n_emotions(predictions, emotion_keys, top_n):
    relevant = {k: predictions.get(k, -float('inf')) for k in emotion_keys if k in predictions}
    if not relevant: return [("N/A", 0.0)]
    return sorted(relevant.items(), key=lambda item: item[1], reverse=True)[:top_n]

def extract_audio_from_video_moviepy(video_path, audio_output_path, target_sr):
    logging.info(f"[VIDEO_SCRIPT] Extracting audio: '{video_path}' -> '{audio_output_path}' @ {target_sr}Hz")
    clip = None
    try:
        clip = VideoFileClip(str(video_path))
        if clip.audio is None: logging.error(f"[VIDEO_SCRIPT] No audio in '{video_path}'."); return False
        clip.audio.write_audiofile(str(audio_output_path), fps=target_sr, codec='pcm_s16le', logger=None)
        return True
    except Exception as e:
        logging.error(f"[VIDEO_SCRIPT] Audio extraction error for '{video_path}': {e}", exc_info=True); return False
    finally:
        if clip: clip.close();
        if clip and clip.audio: clip.audio.close()


def format_srt_time(seconds: float) -> str:
    millis = int(round((seconds - int(seconds)) * 1000))
    s, m, h = int(seconds), 0, 0
    if s >= 60: m = s // 60; s %= 60
    if m >= 60: h = m // 60; m %= 60
    return f"{h:02d}:{m:02d}:{s:02d},{millis:03d}"

def write_video_srt_file(filepath: Path, entries: List[Dict[str, Any]]):
    logging.info(f"[VIDEO_SCRIPT] Writing {len(entries)} SRT entries to: {filepath}")
    try:
        with open(filepath, 'w', encoding='utf-8') as f:
            for i, entry in enumerate(entries):
                f.write(f"{i + 1}\n{entry['start_str']} --> {entry['end_str']}\n{entry['text']}\n\n")
    except IOError as e:
        logging.error(f"[VIDEO_SCRIPT] Failed to write SRT {filepath}: {e}", exc_info=True)


def process_single_video_for_srt(video_path, srt_path, temp_audio_path, chunk_dur, proc_sr, top_n_disp):
    if not extract_audio_from_video_moviepy(video_path, temp_audio_path, proc_sr):
        logging.error(f"[VIDEO_SCRIPT] Audio extraction failed for {video_path.name}. Skipping."); return False
    if not temp_audio_path.is_file() or temp_audio_path.stat().st_size == 0:
        logging.error(f"[VIDEO_SCRIPT] Temp audio {temp_audio_path} invalid. Skipping {video_path.name}."); return False

    try:
        total_dur = librosa.get_duration(path=str(temp_audio_path))
    except Exception as e:
        logging.error(f"[VIDEO_SCRIPT] Cannot get duration of {temp_audio_path}: {e}. Skipping."); return False

    num_chunks = int(np.floor(total_dur / chunk_dur)) if total_dur >= chunk_dur else (1 if total_dur > 0.1 else 0)
    if num_chunks == 0:
        logging.info(f"[VIDEO_SCRIPT] No chunks for {video_path.name}. SRT will be empty."); write_video_srt_file(srt_path, []); return True

    srt_entries = []
    for i in range(num_chunks):
        start_s, end_s = i * chunk_dur, (i + 1) * chunk_dur
        end_s = min(end_s, total_dur) # Ensure end_s doesn't exceed total_dur for the last chunk
        actual_chunk_dur = end_s - start_s
        if actual_chunk_dur < 0.1 : continue # Skip tiny remainder

        logging.info(f"[VIDEO_SCRIPT] Chunk {i+1}/{num_chunks} for {video_path.name}: {start_s:.2f}s-{end_s:.2f}s")
        try:
            # MLP_SAMPLING_RATE is used for loading the segment for MLPs
            segment_wf, _ = librosa.load(str(temp_audio_path), sr=MLP_SAMPLING_RATE, mono=True, offset=start_s, duration=actual_chunk_dur)
        except Exception as e:
            logging.error(f"[VIDEO_SCRIPT] Error loading chunk for {video_path.name}: {e}"); continue
        if len(segment_wf) == 0: logging.warning(f"[VIDEO_SCRIPT] Empty waveform for chunk in {video_path.name}"); continue

        raw_preds = predict_emotions_for_video_waveform(segment_wf)
        emo_text = "No predictions"
        if raw_preds:
            top_emos = get_video_top_n_emotions(raw_preds, TARGET_EMOTION_KEYS, top_n_disp)
            emo_text = "\n".join([f"{e}: {s:.2f}" for e, s in top_emos if e != "N/A"]) or "N/A"

        srt_entries.append({'start_str': format_srt_time(start_s), 'end_str': format_srt_time(end_s), 'text': emo_text})
        del segment_wf, raw_preds; gc.collect();
        if torch.cuda.is_available(): torch.cuda.empty_cache()

    write_video_srt_file(srt_path, srt_entries)
    if temp_audio_path.exists():
        try: os.remove(temp_audio_path)
        except OSError as e: logging.warning(f"[VIDEO_SCRIPT] Could not delete temp audio {temp_audio_path}: {e}")
    return True


def find_ffmpeg(user_path=None):
    candidates = [user_path] if user_path else []
    candidates.extend(["ffmpeg", "ffmpeg.exe"])
    for cmd in candidates:
        try:
            if subprocess.run([cmd, "-version"], capture_output=True, timeout=5, check=True).returncode == 0:
                logging.info(f"[VIDEO_SCRIPT] Found FFmpeg: {cmd}"); return cmd
        except Exception: pass
    logging.error("[VIDEO_SCRIPT] FFmpeg not found. Burn-in skipped."); return None

def burn_subs_ffmpeg(vid_in, srt_in, vid_out, ffmpeg_exe):
    if not srt_in.is_file(): logging.error(f"[VIDEO_SCRIPT] SRT {srt_in} not found for burn-in."); return False
    logging.info(f"[VIDEO_SCRIPT] Burning subs: '{srt_in.name}' into '{vid_in.name}' -> '{vid_out.name}'")

    # FFmpeg requires careful path escaping, especially on Windows
    srt_path_ffmpeg = str(srt_in.resolve()).replace('\\', '/')
    if sys.platform == "win32": # Double escape colons for Windows drive letters
        drive, tail = os.path.splitdrive(srt_path_ffmpeg)
        if drive: srt_path_ffmpeg = drive.replace(":", "\\\\:") + tail

    cmd = [
        ffmpeg_exe, "-y", "-i", str(vid_in),
        "-vf", f"subtitles='{srt_path_ffmpeg}':force_style='FontName=Arial,Fontsize=20,PrimaryColour=&HFFFFFF&,BorderStyle=1,Outline=1,OutlineColour=&H000000&,Shadow=0.5,Alignment=2'", # White text, black outline, bottom center
        "-c:v", "libx264", "-preset", "fast", "-crf", "23",
        "-c:a", "aac", "-b:a", "128k", str(vid_out)
    ]
    logging.debug(f"[VIDEO_SCRIPT] FFmpeg cmd: {' '.join(cmd)}")
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, check=False) # check=False to inspect stderr
        if result.returncode != 0:
            logging.error(f"[VIDEO_SCRIPT] FFmpeg error for {vid_in.name}:\n{result.stderr}")
            return False
        return True
    except Exception as e:
        logging.error(f"[VIDEO_SCRIPT] FFmpeg execution error for {vid_in.name}: {e}", exc_info=True); return False


def video_script_main_flow():
    setup_video_script_logging()
    logging.info("[VIDEO_SCRIPT] --- Starting Video Processing Script ---")

    # Create output folder if it doesn't exist
    output_dir = Path(HARDCODED_OUTPUT_FOLDER)
    try:
        output_dir.mkdir(parents=True, exist_ok=True)
    except OSError as e:
        logging.critical(f"[VIDEO_SCRIPT] Cannot create output directory {output_dir}: {e}. Aborting.")
        return

    try:
        load_video_emotion_models_once()
    except Exception as e:
        logging.critical(f"[VIDEO_SCRIPT] Failed to load emotion models: {e}. Aborting.", exc_info=True)
        return

    ffmpeg_executable = find_ffmpeg(HARDCODED_FFMPEG_PATH)

    all_videos = []
    for folder_str in HARDCODED_INPUT_FOLDERS:
        in_folder = Path(folder_str)
        if not in_folder.is_dir():
            logging.warning(f"[VIDEO_SCRIPT] Input folder {in_folder} not found. Skipping.")
            continue
        for item in in_folder.rglob('*'): # rglob for recursive search
            if item.is_file() and item.suffix.lower() in VIDEO_EXTENSIONS:
                all_videos.append(item)

    if not all_videos:
        logging.info("[VIDEO_SCRIPT] No video files found in specified input folders. Exiting.")
        return

    logging.info(f"[VIDEO_SCRIPT] Found {len(all_videos)} videos to process.")

    for idx, video_file in enumerate(all_videos):
        logging.info(f"[VIDEO_SCRIPT] Processing video {idx+1}/{len(all_videos)}: {video_file.name}")
        start_vid_time = time.time()

        # Create unique names to avoid collisions if multiple input folders have same video names
        input_parent_dir_name = video_file.parent.name if video_file.parent.name != "." else "root_input"
        base_name_for_output = f"{input_parent_dir_name}_{video_file.stem}"

        srt_file = output_dir / f"{base_name_for_output}_{HARDCODED_OUTPUT_SUFFIX}.srt"
        temp_audio_file = output_dir / f"{HARDCODED_TEMP_AUDIO_BASE_NAME}{base_name_for_output}.wav" # Unique temp audio

        srt_ok = process_single_video_for_srt(
            video_file, srt_file, temp_audio_file,
            HARDCODED_CHUNK_DURATION_S, DEFAULT_PROCESSING_SAMPLING_RATE, DEFAULT_TEXT_TOP_N_EMOTIONS
        )

        if srt_ok and ffmpeg_executable and srt_file.is_file() and srt_file.stat().st_size > 0:
            output_video_file = output_dir / f"{base_name_for_output}_{HARDCODED_OUTPUT_SUFFIX}.mp4"
            burn_subs_ffmpeg(video_file, srt_file, output_video_file, ffmpeg_executable)
        elif not ffmpeg_executable:
            logging.info(f"[VIDEO_SCRIPT] SRT file for {video_file.name} is at {srt_file}. FFmpeg burn-in skipped.")
        elif not srt_ok or not srt_file.is_file() or srt_file.stat().st_size == 0:
            logging.warning(f"[VIDEO_SCRIPT] SRT generation failed or empty for {video_file.name}. Burn-in skipped.")


        logging.info(f"[VIDEO_SCRIPT] Finished {video_file.name} in {time.time() - start_vid_time:.2f}s")
        gc.collect() # Clean up per video
        if torch.cuda.is_available(): torch.cuda.empty_cache()

    logging.info("[VIDEO_SCRIPT] --- Video Processing Script Finished ---")


if __name__ == "__main__": # This check is good practice but might not be strictly necessary in a Colab cell
    # Create dummy input folder and a dummy MLP model dir for testing if they don't exist
    # This is just for making the cell runnable standalone without prior setup from other cells.
    # In a real workflow, HARDCODED_MLP_MODELS_DIR should point to the actual downloaded models.

    # Dummy MLP model directory setup for testing this cell standalone
    example_mlp_dir = Path(HARDCODED_MLP_MODELS_DIR)
    if not example_mlp_dir.exists():
        print(f"[INFO_FOR_TESTING] MLP model dir '{example_mlp_dir}' not found. Please set HARDCODED_MLP_MODELS_DIR correctly.")
        print("[INFO_FOR_TESTING] For a quick test, you can create a dummy file like:")
        print(f"  !mkdir -p {example_mlp_dir}")
        print(f"  !touch {example_mlp_dir}/model_Amusement_best.pth")
        # This dummy file won't work for actual predictions but allows the script to run further.

    # Dummy input video directory setup
    example_input_video_dir = Path(HARDCODED_INPUT_FOLDERS[0] if HARDCODED_INPUT_FOLDERS else "./dummy_videos")
    if not example_input_video_dir.exists():
        example_input_video_dir.mkdir(parents=True, exist_ok=True)
        print(f"[INFO_FOR_TESTING] Created dummy input video directory: {example_input_video_dir}")
        print(f"[INFO_FOR_TESTING] Please place some .mp4 (or other supported) videos in it to test.")
        # You might want to download a short sample video for testing:
        # !wget -q -O {example_input_video_dir}/sample_video.mp4 http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4
        # print(f"[INFO_FOR_TESTING] Downloaded a sample video to {example_input_video_dir}/sample_video.mp4 for testing.")


    # Execute the main flow
    video_script_main_flow()