In [None]:
# ============================================
# 0. INSTALLS & BASIC IMPORTS
# ============================================
!pip install -q "transformers>=4.46.0" "mistral_common[audio]>=1.8.6" \
               "accelerate>=0.34.0" pandas einops plotly

from google.colab import drive
drive.mount('/content/drive')

import os
import glob
import re
import json

import pandas as pd
import numpy as np
from tqdm.auto import tqdm

import torch
from transformers import VoxtralForConditionalGeneration, AutoProcessor

import matplotlib.pyplot as plt
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    classification_report,
    confusion_matrix,
    precision_recall_fscore_support,
    balanced_accuracy_score,
    cohen_kappa_score,
    matthews_corrcoef,
)

import plotly.express as px
import plotly.graph_objects as go


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/6.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/6.5 MB[0m [31m37.1 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━[0m [32m4.2/6.5 MB[0m [31m67.3 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m6.2/6.5 MB[0m [31m56.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.5/6.5 MB[0m [31m46.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.9/40.9 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m85.4 MB/s[0m eta [36m0:00:00[0m
[?25hMounted at /content/drive


In [None]:
# ============================================
# 1. PATHS & DATASET CONFIG (EMOBOX)
# ============================================

# Root of the EmoBox repo (the folder that contains: data, docs, EmoBox, examples, ...)
EMOBOX_ROOT = "/content/drive/MyDrive/adsp/EmoBox"

# Root of your original ESD audio (CHECK THE CASE: 'ESD' not 'esd')
ROOT_DIR = "/content/drive/MyDrive/adsp/downloads/esd/Emotion Speech Dataset"

print("Using EMOBOX_ROOT:", EMOBOX_ROOT)
print("Using ROOT_DIR   :", ROOT_DIR)
assert os.path.exists(EMOBOX_ROOT), "EMOBOX_ROOT does not exist. Fix the path."
assert os.path.exists(ROOT_DIR),    "ROOT_DIR does not exist. Fix the path."

DATASET = "esd"
FOLD    = 1  # EmoBox provides fold_1..fold_5. Use 1 unless TA says otherwise.

META_DIR = os.path.join(EMOBOX_ROOT, "data", DATASET)
print("META_DIR:", META_DIR)
assert os.path.exists(META_DIR), "META_DIR does not exist. Fix the path."

TEST_JSONL = os.path.join(META_DIR, f"fold_{FOLD}", f"{DATASET}_test_fold_{FOLD}.jsonl")
print("TEST_JSONL:", TEST_JSONL)
assert os.path.exists(TEST_JSONL), "TEST_JSONL does not exist. Fix the path."

# English speakers in ESD: 0011–0020
EN_SPEAKERS = [f"{i:04d}" for i in range(11, 21)]
EN_SET = set(EN_SPEAKERS)
print("English speakers:", EN_SPEAKERS)


Using EMOBOX_ROOT: /content/drive/MyDrive/adsp/EmoBox
Using ROOT_DIR   : /content/drive/MyDrive/adsp/downloads/esd/Emotion Speech Dataset
META_DIR: /content/drive/MyDrive/adsp/EmoBox/data/esd
TEST_JSONL: /content/drive/MyDrive/adsp/EmoBox/data/esd/fold_1/esd_test_fold_1.jsonl
English speakers: ['0011', '0012', '0013', '0014', '0015', '0016', '0017', '0018', '0019', '0020']


In [None]:
import json, os

meta_path = "/content/drive/MyDrive/adsp/EmoBox/data/esd/esd.json"  # or your meta file
with open(meta_path, "r", encoding="utf-8") as f:
    meta = json.load(f)

key = "esd-0001-000001"
rel = meta[key]["wav"]
abs_from_json = os.path.join(EMOBOX_ROOT, rel)
abs_from_root = os.path.join(ROOT_DIR, "0001", "Neutral", "0001_000001.wav")

print(abs_from_json)
print(abs_from_root)


/content/drive/MyDrive/adsp/EmoBox/downloads/esd/0001/Neutral/0001_000001.wav
/content/drive/MyDrive/adsp/downloads/esd/Emotion Speech Dataset/0001/Neutral/0001_000001.wav


In [None]:
# ============================================
# 2. LABEL MAP & KEY PARSING
# ============================================

# load label_map.json from EmoBox (maps short labels -> canonical names)
label_map_path = os.path.join(META_DIR, "label_map.json")
with open(label_map_path, "r", encoding="utf-8") as f:
    ESD_LABEL_MAP = json.load(f)

print("Using label_map:", ESD_LABEL_MAP)

# Canonical label set we’ll use for Voxtral + metrics
LABELS = ["Angry", "Happy", "Neutral", "Sad", "Surprise"]
LABELS_LOWER = [x.lower() for x in LABELS]

def parse_esd_key(key: str):
    """
    EmoBox ESD keys look like: 'esd-0005-000001'
    We want: spk='0005', utt_suffix='000001'
    """
    parts = key.split("-")
    assert len(parts) == 3 and parts[0] == "esd", f"Unexpected ESD key format: {key}"
    spk = parts[1]
    utt_suffix = parts[2]
    return spk, utt_suffix

def is_english_speaker(spk: str) -> bool:
    return spk in EN_SET


Using label_map: {'Neutral': 'Neutral', 'Angry': 'Angry', 'Happy': 'Happy', 'Sad': 'Sad', 'Surprise': 'Surprise'}


In [None]:
# ============================================
# 3. PRELOAD TRANSCRIPTS FROM ORIGINAL ESD
# ============================================
# Optional but useful if you want audio+text

transcripts = {}  # (spk, utt_id) -> (sentence, emo_from_txt)

for spk in EN_SPEAKERS:
    txt_path = os.path.join(ROOT_DIR, spk, f"{spk}.txt")
    if not os.path.exists(txt_path):
        print(f"[WARN] transcript missing: {txt_path}")
        continue

    with open(txt_path, encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            utt_id = parts[0]      # e.g. 0011_000001
            emo_txt = parts[-1]
            sent   = " ".join(parts[1:-1])
            transcripts[(spk, utt_id)] = (sent, emo_txt)

print(f"Loaded transcripts for {len(transcripts)} (speaker, utt_id) pairs")


Loaded transcripts for 17500 (speaker, utt_id) pairs


In [None]:
# ============================================
# 4. BUILD DATAFRAME FROM EMOBOX TEST SPLIT
#    (ENGLISH ONLY, NO RESAMPLING)
# ============================================

rows = []

with open(TEST_JSONL, "r", encoding="utf-8") as f:
    for line in f:
        ex = json.loads(line)

        key = ex["key"]                # 'esd-0005-000001'
        # EmoBox Jsonl uses 'emo'; we also accept 'label' just in case
        raw_label = ex.get("label", ex.get("emo"))
        if raw_label is None:
            print(f"[WARN] no 'label' or 'emo' in record, skipping: {ex}")
            continue

        spk, utt_suffix = parse_esd_key(key)

        # filter out Chinese speakers (keep only English)
        if not is_english_speaker(spk):
            continue

        # normalize label via label_map
        emo = ESD_LABEL_MAP.get(raw_label, raw_label)

        # construct utterance id as in original ESD tree
        utt_id = f"{spk}_{utt_suffix}"

        # reconstruct WAV path in original ESD folder
        wav_path = os.path.join(ROOT_DIR, spk, emo, f"{utt_id}.wav")
        if not os.path.exists(wav_path):
            # paranoid fallback: search under that speaker
            cands = glob.glob(
                os.path.join(ROOT_DIR, spk, "**", f"{utt_id}.wav"),
                recursive=True
            )
            if not cands:
                print(f"[MISS] audio for key={key}, expected={wav_path}")
                continue
            wav_path = cands[0]

        # transcript if we have it
        sent, _ = transcripts.get((spk, utt_id), ("", None))

        rows.append({
            "key":        key,
            "speaker":    spk,
            "utt_id":     utt_id,
            "audio_path": wav_path,
            "transcript": sent,
            "emotion":    emo,  # ground truth (canonical from label_map)
        })

df = pd.DataFrame(rows)
print(f"\nTotal English utterances from EmoBox test (fold {FOLD}):", len(df))
print(df["emotion"].value_counts())
display(df.head())

if df.empty:
    raise RuntimeError("DataFrame is empty – check JSONL path and filters.")

# sanity: dataset labels must be subset of our LABELS
uniq = sorted(df["emotion"].unique())
print("\nUnique labels in dataset:", uniq)
assert set(uniq) <= set(LABELS), "Dataset has labels outside LABELS!"



Total English utterances from EmoBox test (fold 1): 5250
emotion
Neutral     1050
Angry       1050
Happy       1050
Sad         1050
Surprise    1050
Name: count, dtype: int64


Unnamed: 0,key,speaker,utt_id,audio_path,transcript,emotion
0,esd-0011-000001,11,0011_000001,/content/drive/MyDrive/adsp/downloads/esd/Emot...,"The nine the eggs, I keep.",Neutral
1,esd-0011-000002,11,0011_000002,/content/drive/MyDrive/adsp/downloads/esd/Emot...,"I did go, and made many prisoners.",Neutral
2,esd-0011-000003,11,0011_000003,/content/drive/MyDrive/adsp/downloads/esd/Emot...,That I owe my thanks to you.,Neutral
3,esd-0011-000004,11,0011_000004,/content/drive/MyDrive/adsp/downloads/esd/Emot...,They went up to the dark mass job had pointed ...,Neutral
4,esd-0011-000005,11,0011_000005,/content/drive/MyDrive/adsp/downloads/esd/Emot...,Clear than clear water!,Neutral



Unique labels in dataset: ['Angry', 'Happy', 'Neutral', 'Sad', 'Surprise']


In [None]:
# !pip install -q huggingface_hub

# from huggingface_hub import snapshot_download
# import os

# LOCAL_MODEL_DIR = "/content/drive/MyDrive/adsp/models/voxtral-mini-3b"  # your choice

# os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)

# snapshot_download(
#     repo_id="mistralai/Voxtral-Mini-3B-2507",
#     local_dir=LOCAL_MODEL_DIR,
#     local_dir_use_symlinks=False,   # simpler on Drive
#     resume_download=True
# )


In [None]:
import torch
from transformers import AutoProcessor, VoxtralForConditionalGeneration

# Path where you moved the model in Google Drive
LOCAL_MODEL_DIR = "/content/drive/MyDrive/adsp/models/voxtral-mini-3b"

# Decide device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

processor = AutoProcessor.from_pretrained(
    LOCAL_MODEL_DIR,
    trust_remote_code=True,
)

model = VoxtralForConditionalGeneration.from_pretrained(
    LOCAL_MODEL_DIR,
    torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
    device_map="auto" if device.type == "cuda" else {"": "cpu"},
    trust_remote_code=True,
)

model.eval()


Using device: cuda


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

VoxtralForConditionalGeneration(
  (audio_tower): VoxtralEncoder(
    (conv1): Conv1d(128, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
    (embed_positions): Embedding(1500, 1280)
    (layers): ModuleList(
      (0-31): 32 x VoxtralEncoderLayer(
        (self_attn): VoxtralAttention(
          (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
          (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (activation_fn): GELUActivation()
        (fc1): Linear(in_features=1280, out_features=5120, bias=True)
        (fc2): Linear(in_features=5120, out_features=1280, bias=True)
        (final_layer_norm): LayerNorm((1280,), ep

In [None]:
# # ============================================
# # 5. LOAD VOXTRAL MODEL (ZERO-SHOT INFERENCE ONLY)
# # ============================================
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print("Using device:", device)

# REPO_ID = "mistralai/Voxtral-Mini-3B-2507"

# processor = AutoProcessor.from_pretrained(
#     REPO_ID,
#     trust_remote_code=True,
# )

# model = VoxtralForConditionalGeneration.from_pretrained(
#     REPO_ID,
#     torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
#     device_map="auto" if device.type == "cuda" else {"": "cpu"},
#     trust_remote_code=True,
# )
# model.eval()
# print("Model loaded.")


In [None]:
# ============================================
# 6. LABEL NORMALIZATION (HARDENED, WITH SYNONYMS)
# ============================================

def normalize_label(raw_text: str | None) -> str:
    """
    Map raw Voxtral output to one of LABELS or 'Unknown'.

    Strategy:
    1) Exact match (whole output == label).
    2) Last token match (with synonyms).
    3) Whole-word regex search anywhere.
    4) Fallback: 'Unknown' (for ESD we keep it explicit).
    """
    if raw_text is None:
        return "Unknown"

    t = raw_text.strip().lower()
    if not t:
        return "Unknown"

    # 1) exact match
    for lab in LABELS_LOWER:
        if t == lab:
            return lab.capitalize()

    # tokenize
    tokens = [tok for tok in re.split(r"\W+", t) if tok]

    # 2) last token with simple synonyms
    if tokens:
        last = tokens[-1]
        # direct match
        for lab in LABELS_LOWER:
            if last == lab:
                return lab.capitalize()
        # synonyms
        if last == "angry":
            return "Angry"
        if last == "happy":
            return "Happy"
        if last == "sad":
            return "Sad"
        if last in ("surprised", "astonished"):
            return "Surprise"
        if last in ("neutral", "calm"):
            return "Neutral"

    # 3) whole-word search in text
    for lab in LABELS_LOWER:
        if re.search(rf"\b{lab}\b", t):
            return lab.capitalize()

    if re.search(r"\bangry\b", t):
        return "Angry"
    if re.search(r"\bhappy\b", t):
        return "Happy"
    if re.search(r"\bsad\b", t):
        return "Sad"
    if re.search(r"\bsurprised\b", t):
        return "Surprise"
    if re.search(r"\bneutral\b", t):
        return "Neutral"

    return "Unknown"


In [None]:
# ============================================
# 7. PROMPT BUILDER + SINGLE-SAMPLE PREDICTOR
# ============================================

def build_base_user_content(use_text: bool) -> str:
    """
    Task description + valid labels.
    """
    label_list = ", ".join(LABELS)
    base_instruction = (
        "You are an emotion classifier.\n"
        f"Possible emotions: {label_list}.\n"
        "From the given audio"
    )
    if use_text:
        base_instruction += " and its transcript"
    base_instruction += (
        ", classify the SPEAKER'S emotion.\n"
        "Answer with EXACTLY ONE word from this list: "
        "Angry, Happy, Neutral, Sad, Surprise.\n"
        "Do not add any extra words, punctuation, or explanations."
    )
    return base_instruction


def build_conversation_single(
    audio_path: str,
    transcript: str = "",
    use_text: bool = False,
):
    """
    Conversation for a SINGLE sample.
    """
    file_url = "file://" + os.path.abspath(audio_path)

    user_content = [
        {
            "type": "audio_url",
            "audio_url": file_url,
        },
        {
            "type": "text",
            "text": build_base_user_content(use_text),
        },
    ]

    if use_text and transcript.strip():
        user_content.append(
            {"type": "text", "text": f"\nTranscript:\n{transcript}"}
        )

    conversation = [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": "You are a careful and concise emotion classification assistant.",
                }
            ],
        },
        {
            "role": "user",
            "content": user_content,
        },
    ]
    return conversation


def voxtral_predict(
    audio_path: str,
    transcript: str = "",
    use_text: bool = False,
    max_new_tokens: int = 3,
) -> str:
    """
    Zero-shot emotion prediction for a single sample.
    use_text = False -> audio only
    use_text = True  -> audio + transcript
    """
    conversation = build_conversation_single(
        audio_path=audio_path,
        transcript=transcript,
        use_text=use_text,
    )

    inputs = processor.apply_chat_template(
        conversation,
        tokenize=True,
        return_tensors="pt",
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,      # sampling ON (chat mode)
            temperature=0.2,
            top_p=0.95,
        )

    new_tokens = outputs[:, inputs["input_ids"].shape[1]:]
    decoded = processor.batch_decode(new_tokens, skip_special_tokens=True)[0]
    return normalize_label(decoded)


# quick sanity check on one sample
row = df.iloc[0]
print("\n=== Sanity check on one sample (EmoBox test, English only) ===")
print("Audio:", row.audio_path)
print("Transcript:", row.transcript)
print("Gold:", row.emotion)
print("Pred (audio only):",  voxtral_predict(row.audio_path, use_text=False))
print("Pred (audio+text):",  voxtral_predict(row.audio_path, row.transcript, use_text=True))



=== Sanity check on one sample (EmoBox test, English only) ===
Audio: /content/drive/MyDrive/adsp/downloads/esd/Emotion Speech Dataset/0011/Neutral/0011_000001.wav
Transcript: The nine the eggs, I keep.
Gold: Neutral
Pred (audio only): Neutral
Pred (audio+text): Neutral


In [None]:
# ============================================
# 8. BATCHED PREDICTION HELPERS
# ============================================

def build_conversation_for_row(audio_path: str, transcript: str, use_text: bool):
    """
    Same logic as build_conversation_single, used for batching.
    """
    file_url = "file://" + os.path.abspath(audio_path)

    user_content = [
        {
            "type": "audio_url",
            "audio_url": file_url,
        },
        {
            "type": "text",
            "text": build_base_user_content(use_text),
        },
    ]

    if use_text and transcript.strip():
        user_content.append(
            {"type": "text", "text": f"\nTranscript:\n{transcript}"}
        )

    conversation = [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": "You are a careful and concise emotion classification assistant.",
                }
            ],
        },
        {
            "role": "user",
            "content": user_content,
        },
    ]
    return conversation


def voxtral_predict_batch(
    df_batch: pd.DataFrame,
    use_text: bool,
    max_new_tokens: int = 3,
):
    """
    Batched zero-shot prediction for a slice of the dataframe.
    Returns list of labels (same order as df_batch).
    """
    conversations = []
    for _, row in df_batch.iterrows():
        conv = build_conversation_for_row(
            audio_path=row["audio_path"],
            transcript=row["transcript"],
            use_text=use_text,
        )
        conversations.append(conv)

    inputs = processor.apply_chat_template(
        conversations,
        tokenize=True,
        return_tensors="pt",
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,      # sampling ON (chat mode)
            temperature=0.2,
            top_p=0.95,
        )

    input_ids = inputs["input_ids"]
    new_tokens = outputs[:, input_ids.shape[1]:]
    decoded_list = processor.batch_decode(new_tokens, skip_special_tokens=True)

    labels = [normalize_label(text) for text in decoded_list]
    return labels


def run_batched_predictions(df_in: pd.DataFrame, batch_size: int = 8) -> pd.DataFrame:
    """
    Full prediction over df_in using batched Voxtral calls.
    Adds:
        pred_audio : audio-only prediction
        pred_both  : audio+transcript prediction
    """
    df_eval = df_in.reset_index(drop=True).copy()

    all_pa = []
    all_pb = []

    n = len(df_eval)
    n_batches = (n + batch_size - 1) // batch_size

    for start in tqdm(range(0, n, batch_size), total=n_batches, desc="Voxtral batches"):
        end = min(start + batch_size, n)
        batch = df_eval.iloc[start:end]

        pa = voxtral_predict_batch(batch, use_text=False)
        pb = voxtral_predict_batch(batch, use_text=True)

        all_pa.extend(pa)
        all_pb.extend(pb)

    df_eval["pred_audio"] = all_pa
    df_eval["pred_both"]  = all_pb
    return df_eval


In [None]:
# ============================================
# 9. RUN PREDICTIONS ON FULL EMOBOX TEST (EN ONLY)
# ============================================
df_eval = run_batched_predictions(df, batch_size=8)

print("\n=== Head of df_eval ===")
display(df_eval.head())

print("\n=== Quick class distributions ===")
print("GT   :")
print(df_eval["emotion"].value_counts())
print("\nAudio preds:")
print(df_eval["pred_audio"].value_counts())
print("\nAudio+text preds:")
print(df_eval["pred_both"].value_counts())


Voxtral batches:   0%|          | 0/657 [00:00<?, ?it/s]


=== Head of df_eval ===


Unnamed: 0,key,speaker,utt_id,audio_path,transcript,emotion,pred_audio,pred_both
0,esd-0011-000001,11,0011_000001,/content/drive/MyDrive/adsp/downloads/esd/Emot...,"The nine the eggs, I keep.",Neutral,Neutral,Neutral
1,esd-0011-000002,11,0011_000002,/content/drive/MyDrive/adsp/downloads/esd/Emot...,"I did go, and made many prisoners.",Neutral,Neutral,Neutral
2,esd-0011-000003,11,0011_000003,/content/drive/MyDrive/adsp/downloads/esd/Emot...,That I owe my thanks to you.,Neutral,Neutral,Neutral
3,esd-0011-000004,11,0011_000004,/content/drive/MyDrive/adsp/downloads/esd/Emot...,They went up to the dark mass job had pointed ...,Neutral,Neutral,Neutral
4,esd-0011-000005,11,0011_000005,/content/drive/MyDrive/adsp/downloads/esd/Emot...,Clear than clear water!,Neutral,Neutral,Neutral



=== Quick class distributions ===
GT   :
emotion
Neutral     1050
Angry       1050
Happy       1050
Sad         1050
Surprise    1050
Name: count, dtype: int64

Audio preds:
pred_audio
Neutral     4404
Angry        264
Unknown      218
Sad          213
Happy        145
Surprise       6
Name: count, dtype: int64

Audio+text preds:
pred_both
Neutral     4416
Sad          319
Angry        314
Happy        182
Surprise      19
Name: count, dtype: int64


In [None]:
# ============================================
# 10. METRICS & BASIC VISUALIZATIONS
# ============================================
y_true = df_eval["emotion"].values
y_pa   = df_eval["pred_audio"].values
y_pb   = df_eval["pred_both"].values

print("Label set in ground truth:", sorted(set(y_true)))
print("Label set in audio preds :", sorted(set(y_pa)))
print("Label set in audio+text  :", sorted(set(y_pb)))

acc_audio = float(accuracy_score(y_true, y_pa))
acc_both  = float(accuracy_score(y_true, y_pb))

f1m_audio = float(f1_score(y_true, y_pa, labels=LABELS, average="macro", zero_division=0))
f1m_both  = float(f1_score(y_true, y_pb, labels=LABELS, average="macro", zero_division=0))

ba_audio = float(balanced_accuracy_score(y_true, y_pa))
ba_both  = float(balanced_accuracy_score(y_true, y_pb))

kappa_audio = float(cohen_kappa_score(y_true, y_pa))
kappa_both  = float(cohen_kappa_score(y_true, y_pb))

mcc_audio = float(matthews_corrcoef(y_true, y_pa))
mcc_both  = float(matthews_corrcoef(y_true, y_pb))

print("\n=== Global metrics ===")
print(f"Accuracy (audio only)      : {acc_audio:.3f}")
print(f"Accuracy (audio+text)      : {acc_both:.3f}")
print(f"Macro-F1 (audio only)      : {f1m_audio:.3f}")
print(f"Macro-F1 (audio+text)      : {f1m_both:.3f}")
print(f"Balanced Acc (audio only)  : {ba_audio:.3f}")
print(f"Balanced Acc (audio+text)  : {ba_both:.3f}")
print(f"Cohen kappa (audio only)   : {kappa_audio:.3f}")
print(f"Cohen kappa (audio+text)   : {kappa_both:.3f}")
print(f"MCC (audio only)           : {mcc_audio:.3f}")
print(f"MCC (audio+text)           : {mcc_both:.3f}")

# per-class precision/recall/F1
rep_audio = classification_report(
    y_true, y_pa, labels=LABELS, output_dict=True, zero_division=0
)
rep_both = classification_report(
    y_true, y_pb, labels=LABELS, output_dict=True, zero_division=0
)

perclass_audio = (
    pd.DataFrame(rep_audio)
      .T.loc[LABELS, ["precision", "recall", "f1-score"]]
      .rename(columns=lambda c: f"{c}_audio")
)
perclass_both = (
    pd.DataFrame(rep_both)
      .T.loc[LABELS, ["precision", "recall", "f1-score"]]
      .rename(columns=lambda c: f"{c}_both")
)

perclass = perclass_audio.join(perclass_both)
display(perclass.round(3))


Label set in ground truth: ['Angry', 'Happy', 'Neutral', 'Sad', 'Surprise']
Label set in audio preds : ['Angry', 'Happy', 'Neutral', 'Sad', 'Surprise', 'Unknown']
Label set in audio+text  : ['Angry', 'Happy', 'Neutral', 'Sad', 'Surprise']

=== Global metrics ===
Accuracy (audio only)      : 0.195
Accuracy (audio+text)      : 0.209
Macro-F1 (audio only)      : 0.113
Macro-F1 (audio+text)      : 0.128
Balanced Acc (audio only)  : 0.195
Balanced Acc (audio+text)  : 0.209
Cohen kappa (audio only)   : 0.004
Cohen kappa (audio+text)   : 0.011
MCC (audio only)           : 0.007
MCC (audio+text)           : 0.019




Unnamed: 0,precision_audio,recall_audio,f1-score_audio,precision_both,recall_both,f1-score_both
Angry,0.246,0.062,0.099,0.232,0.07,0.107
Happy,0.2,0.028,0.049,0.209,0.036,0.062
Neutral,0.197,0.828,0.319,0.202,0.848,0.326
Sad,0.258,0.052,0.087,0.257,0.078,0.12
Surprise,0.833,0.005,0.009,0.737,0.013,0.026


In [None]:
# ============================================
# 11. PER-EMOTION COMPARISON & ΔF1
# ============================================

# Long format for Plotly bar chart
f1_long = (
    perclass.reset_index()
    .melt(id_vars="index",
          value_vars=["f1-score_audio", "f1-score_both"],
          var_name="system", value_name="F1")
)
f1_long["system"] = f1_long["system"].map(
    {"f1-score_audio": "Audio only", "f1-score_both": "Audio + text"}
)

fig_f1 = px.bar(
    f1_long,
    x="index",
    y="F1",
    color="system",
    barmode="group",
    title="F1 per emotion: audio vs audio+text (EmoBox ESD test, English only)",
    labels={"index": "Emotion"},
)
fig_f1.update_layout(template="plotly_white")
fig_f1.update_yaxes(range=[0, 1])
fig_f1.show()

# ΔF1 (audio+text − audio) per emotion
perclass["ΔF1"] = perclass["f1-score_both"] - perclass["f1-score_audio"]
perclass["ΔRecall"] = perclass["recall_both"] - perclass["recall_audio"]
display(perclass[["f1-score_audio", "f1-score_both", "ΔF1", "ΔRecall"]].round(3))

fig_df1 = px.scatter(
    perclass.reset_index(),
    x="index",
    y="ΔF1",
    text="index",
    title="ΔF1 (audio+text − audio) per emotion",
    labels={"index": "Emotion", "ΔF1": "Delta F1"},
)
fig_df1.add_hline(y=0, line_dash="dash")
fig_df1.update_traces(textposition="top center")
fig_df1.update_layout(template="plotly_white")
fig_df1.show()


Unnamed: 0,f1-score_audio,f1-score_both,ΔF1,ΔRecall
Angry,0.099,0.107,0.008,0.008
Happy,0.049,0.062,0.013,0.009
Neutral,0.319,0.326,0.007,0.02
Sad,0.087,0.12,0.033,0.026
Surprise,0.009,0.026,0.017,0.009


In [None]:
# ============================================
# 12. CONFUSION MATRICES
# ============================================
cm_audio = confusion_matrix(y_true, y_pa, labels=LABELS)
cm_both  = confusion_matrix(y_true, y_pb, labels=LABELS)

print("\nConfusion matrix – audio only (counts)")
display(pd.DataFrame(cm_audio, index=LABELS, columns=LABELS))

print("\nConfusion matrix – audio + text (counts)")
display(pd.DataFrame(cm_both, index=LABELS, columns=LABELS))

cm_audio_pct = cm_audio.astype(float) / cm_audio.sum(axis=1, keepdims=True)
cm_both_pct  = cm_both.astype(float)  / cm_both.sum(axis=1, keepdims=True)

fig_cm_audio = px.imshow(
    cm_audio_pct,
    x=LABELS,
    y=LABELS,
    color_continuous_scale="Blues",
    labels=dict(x="Predicted", y="True", color="Proportion"),
    text_auto=".2f",
)
fig_cm_audio.update_layout(
    title="Confusion matrix (audio only) – row-normalized",
    template="plotly_white",
)
fig_cm_audio.show()

fig_cm_both = px.imshow(
    cm_both_pct,
    x=LABELS,
    y=LABELS,
    color_continuous_scale="Greens",
    labels=dict(x="Predicted", y="True", color="Proportion"),
    text_auto=".2f",
)
fig_cm_both.update_layout(
    title="Confusion matrix (audio + text) – row-normalized",
    template="plotly_white",
)
fig_cm_both.show()



Confusion matrix – audio only (counts)


Unnamed: 0,Angry,Happy,Neutral,Sad,Surprise
Angry,65,30,874,43,1
Happy,47,29,878,39,0
Neutral,51,30,869,44,0
Sad,55,31,862,55,0
Surprise,46,25,921,32,5



Confusion matrix – audio + text (counts)


Unnamed: 0,Angry,Happy,Neutral,Sad,Surprise
Angry,73,34,873,67,3
Happy,54,38,899,59,0
Neutral,59,37,890,63,1
Sad,68,38,861,82,1
Surprise,60,35,893,48,14


In [None]:
# ============================================
# 13. PER-SPEAKER ACCURACY
# ============================================

def speaker_accuracy(df_local, pred_col):
    return (
        df_local.assign(correct=lambda d: d["emotion"] == d[pred_col])
                .groupby("speaker")["correct"]
                .mean()
                .rename(f"acc_{pred_col}")
    )

acc_sp_audio = speaker_accuracy(df_eval, "pred_audio")
acc_sp_both  = speaker_accuracy(df_eval, "pred_both")

sp_acc = pd.concat([acc_sp_audio, acc_sp_both], axis=1)
sp_acc["Δacc"] = sp_acc["acc_pred_both"] - sp_acc["acc_pred_audio"]
display(sp_acc.sort_values("Δacc").round(3))

fig_sp = px.bar(
    sp_acc.reset_index(),
    x="speaker",
    y="Δacc",
    title="ΔAccuracy per speaker (audio+text − audio)",
    labels={"speaker": "Speaker ID", "Δacc": "Delta accuracy"},
)
fig_sp.add_hline(y=0, line_dash="dash")
fig_sp.update_layout(template="plotly_white")
fig_sp.show()


Unnamed: 0_level_0,acc_pred_audio,acc_pred_both,Δacc
speaker,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
16,0.194,0.208,0.014
18,0.2,0.214,0.014
11,0.19,0.205,0.014


In [None]:
# ============================================
# 14. DOMINANT CONFUSIONS PER EMOTION
# ============================================

def dominant_confusions(cm, labels):
    rows = []
    for i, true_lab in enumerate(labels):
        row = cm[i].copy()
        row[i] = 0  # ignore correct predictions
        if row.sum() == 0:
            rows.append({"true": true_lab, "confused_with": None, "count": 0})
        else:
            j = row.argmax()
            rows.append({
                "true": true_lab,
                "confused_with": labels[j],
                "count": int(row[j]),
            })
    return pd.DataFrame(rows)

dom_audio = dominant_confusions(cm_audio, LABELS).rename(
    columns={"confused_with": "confused_with_audio",
             "count": "count_audio"}
)
dom_both = dominant_confusions(cm_both, LABELS).rename(
    columns={"confused_with": "confused_with_both",
             "count": "count_both"}
)

dom = dom_audio.merge(dom_both, on="true")
display(dom)

fig_err = px.bar(
    dom.melt(id_vars="true",
             value_vars=["count_audio", "count_both"],
             var_name="system", value_name="count"),
    x="true",
    y="count",
    color="system",
    barmode="group",
    title="Most frequent misclassification count per emotion",
    labels={"true": "Emotion"},
)
fig_err.update_layout(template="plotly_white")
fig_err.show()


Unnamed: 0,true,confused_with_audio,count_audio,confused_with_both,count_both
0,Angry,Neutral,874,Neutral,873
1,Happy,Neutral,878,Neutral,899
2,Neutral,Angry,51,Sad,63
3,Sad,Neutral,862,Neutral,861
4,Surprise,Neutral,921,Neutral,893


In [None]:
# ============================================
# 15. UNKNOWN PREDICTIONS & CLASS BALANCE + NEUTRAL BIAS
# ============================================
unknown_audio = (df_eval["pred_audio"] == "Unknown").sum()
unknown_both  = (df_eval["pred_both"]  == "Unknown").sum()

print(f"Unknown predictions – audio only : {unknown_audio}")
print(f"Unknown predictions – audio+text : {unknown_both}")

emo_counts = df_eval["emotion"].value_counts().reindex(LABELS).fillna(0)

fig_dist = px.bar(
    emo_counts,
    x=emo_counts.index,
    y=emo_counts.values,
    labels={"x": "Emotion", "y": "Count"},
    title="Class distribution in EmoBox ESD test (English only)",
    template="plotly_white",
)
fig_dist.show()

# neutral-bias analysis (like in IEMOCAP)
ser_pa = pd.Series(df_eval["pred_audio"])
ser_pb = pd.Series(df_eval["pred_both"])

neu_a = (ser_pa == "Neutral").mean()
neu_b = (ser_pb == "Neutral").mean()

print(f"\nNeutral prediction rate – audio only : {neu_a:.3f}")
print(f"Neutral prediction rate – audio+text : {neu_b:.3f}")


Unknown predictions – audio only : 218
Unknown predictions – audio+text : 0



Neutral prediction rate – audio only : 0.839
Neutral prediction rate – audio+text : 0.841


In [None]:
# ============================================
# 16. ERROR TABLE FOR MANUAL ANALYSIS
# ============================================
errors = (
    df_eval
    .loc[df_eval["emotion"] != df_eval["pred_audio"],
         ["speaker", "utt_id", "emotion", "pred_audio", "pred_both", "transcript", "audio_path"]]
    .reset_index(drop=True)
)

print(f"Total errors (audio only): {len(errors)}")
display(errors.head(20))


Total errors (audio only): 4227


Unnamed: 0,speaker,utt_id,emotion,pred_audio,pred_both,transcript,audio_path
0,11,0011_000007,Neutral,Angry,Angry,I'm as bad as I can be.,/content/drive/MyDrive/adsp/downloads/esd/Emot...
1,11,0011_000014,Neutral,Sad,Sad,Annie please please don't hurt me!,/content/drive/MyDrive/adsp/downloads/esd/Emot...
2,11,0011_000015,Neutral,Sad,Sad,Poor Tom now is dead!,/content/drive/MyDrive/adsp/downloads/esd/Emot...
3,11,0011_000018,Neutral,Sad,Sad,Then sadly it is much farther.,/content/drive/MyDrive/adsp/downloads/esd/Emot...
4,11,0011_000031,Neutral,Happy,Happy,All smile were real and the happier the more s...,/content/drive/MyDrive/adsp/downloads/esd/Emot...
5,11,0011_000034,Neutral,Angry,Angry,A divine wrath made her blue eyes awful.,/content/drive/MyDrive/adsp/downloads/esd/Emot...
6,11,0011_000036,Neutral,Angry,Angry,How I hate this foul pool!,/content/drive/MyDrive/adsp/downloads/esd/Emot...
7,11,0011_000037,Neutral,Happy,Happy,I think it'll encourage me.,/content/drive/MyDrive/adsp/downloads/esd/Emot...
8,11,0011_000047,Neutral,Unknown,Neutral,please excuse me.,/content/drive/MyDrive/adsp/downloads/esd/Emot...
9,11,0011_000048,Neutral,Unknown,Neutral,"You are not a runaway,who are you?",/content/drive/MyDrive/adsp/downloads/esd/Emot...


In [None]:
# ============================================
# 17. AGREEMENT / DISAGREEMENT REGIMES
# ============================================
df_ag = df_eval.copy()
df_ag["correct_audio"] = (df_ag["emotion"] == df_ag["pred_audio"])
df_ag["correct_both"]  = (df_ag["emotion"] == df_ag["pred_both"])

def categorize(row):
    if row["correct_audio"] and row["correct_both"]:
        return "both_correct"
    elif row["correct_audio"] and not row["correct_both"]:
        return "only_audio_correct"
    elif not row["correct_audio"] and row["correct_both"]:
        return "only_A+T_correct"
    else:
        return "both_wrong"

df_ag["regime"] = df_ag.apply(categorize, axis=1)

regime_counts = df_ag["regime"].value_counts().rename_axis("regime").reset_index(name="count")
regime_counts["pct"] = regime_counts["count"] / len(df_ag)
display(regime_counts)

fig_reg = px.bar(
    regime_counts,
    x="regime",
    y="pct",
    title="Agreement / Disagreement regimes",
    labels={"pct": "Proportion of samples"},
)
fig_reg.update_layout(template="plotly_white")
fig_reg.show()


Unnamed: 0,regime,count,pct
0,both_wrong,4110,0.782857
1,both_correct,980,0.186667
2,only_A+T_correct,117,0.022286
3,only_audio_correct,43,0.00819


In [None]:
# ============================================
# 18. CONDITIONING BY TRANSCRIPT LENGTH (BUGFIXED)
# ============================================
if "transcript" in df_eval.columns:
    df_tx = df_eval.copy()
    df_tx["text_len"] = df_tx["transcript"].fillna("").str.split().str.len()

    q1, q2 = df_tx["text_len"].quantile([0.33, 0.66])

    def txt_bucket(n):
        if n <= q1:
            return "short_text"
        elif n <= q2:
            return "medium_text"
        else:
            return "long_text"

    df_tx["txt_bucket"] = df_tx["text_len"].apply(txt_bucket)

    rows = []
    for b, sub in df_tx.groupby("txt_bucket"):
        y_b_true = sub["emotion"].values
        y_pa_b   = sub["pred_audio"].values
        y_pb_b   = sub["pred_both"].values

        acc_a = accuracy_score(y_b_true, y_pa_b)
        acc_b = accuracy_score(y_b_true, y_pb_b)
        f1_a  = f1_score(y_b_true, y_pa_b, labels=LABELS, average="macro", zero_division=0)
        f1_b  = f1_score(y_b_true, y_pb_b, labels=LABELS, average="macro", zero_division=0)

        rows.append({
            "bucket": b,
            "n_samples": len(sub),
            "acc_audio": acc_a,
            "acc_both":  acc_b,
            "f1_audio":  f1_a,
            "f1_both":   f1_b,
            "Δacc": acc_b - acc_a,
            "ΔF1":  f1_b - f1_a,
        })

    txt_metrics = pd.DataFrame(rows).sort_values("bucket")
    display(txt_metrics.round(3))

    fig_txt = px.bar(
        txt_metrics,
        x="bucket",
        y="ΔF1",
        title="ΔMacro-F1 per transcript-length bucket (audio+text − audio)",
        labels={"ΔF1": "Delta Macro-F1"},
    )
    fig_txt.add_hline(y=0, line_dash="dash")
    fig_txt.update_layout(template="plotly_white")
    fig_txt.show()
else:
    print("No 'transcript' column found; skipping transcript-length-based analysis.")


Unnamed: 0,bucket,n_samples,acc_audio,acc_both,f1_audio,f1_both,Δacc,ΔF1
0,long_text,1202,0.199,0.206,0.111,0.122,0.007,0.012
1,medium_text,2113,0.195,0.212,0.114,0.131,0.017,0.017
2,short_text,1935,0.192,0.207,0.109,0.125,0.016,0.016


In [None]:
# ============================================
# 19. SUMMARY ROW FOR CROSS-DATASET COMPARISON
# ============================================
DATASET_NAME = "ESD_EmoBox_test_fold1_EN"

def global_scores_esd(y_true, y_pred):
    return {
        "Accuracy":     accuracy_score(y_true, y_pred),
        "Macro F1":     f1_score(y_true, y_pred, labels=LABELS, average="macro", zero_division=0),
        "Balanced Acc": balanced_accuracy_score(y_true, y_pred),
    }

def summary_row(dataset_name: str, modality: str, y_true, y_pred) -> dict:
    ser_pred = pd.Series(y_pred)
    neutral_frac = (ser_pred == "Neutral").mean()
    gs = global_scores_esd(y_true, y_pred)
    return {
        "dataset":      dataset_name,
        "modality":     modality,
        "accuracy":     gs["Accuracy"],
        "macro_f1":     gs["Macro F1"],
        "balanced_acc": gs["Balanced Acc"],
        "neutral_frac": neutral_frac,
    }

summary_rows = [
    summary_row(DATASET_NAME, "audio_only",     y_true, y_pa),
    summary_row(DATASET_NAME, "audio_plus_text", y_true, y_pb),
]

df_summary_esd = pd.DataFrame(summary_rows)
print("\n=== SUMMARY (for cross-dataset comparison) – ESD ===")
display(df_summary_esd.round(4))



=== SUMMARY (for cross-dataset comparison) – ESD ===



y_pred contains classes not in y_true



Unnamed: 0,dataset,modality,accuracy,macro_f1,balanced_acc,neutral_frac
0,ESD_EmoBox_test_fold1_EN,audio_only,0.1949,0.1125,0.1949,0.8389
1,ESD_EmoBox_test_fold1_EN,audio_plus_text,0.209,0.1281,0.209,0.8411
