In [None]:
# ============================================
# 0. INSTALLS & IMPORTS
# ============================================
!pip install -q "transformers>=4.46.0" "mistral_common[audio]>=1.8.6" "accelerate>=0.34.0" pandas einops

from google.colab import drive
drive.mount('/content/drive')

import os
import glob
import re

import pandas as pd
import numpy as np
from tqdm.auto import tqdm

import torch
from transformers import VoxtralForConditionalGeneration, AutoProcessor

import matplotlib.pyplot as plt
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    classification_report,
    confusion_matrix,
    precision_recall_fscore_support,
)

# ============================================
# 1. BUILD DATAFRAME (UNCHANGED)
# ============================================
ROOT_DIR = "/content/drive/MyDrive/adsp/downloads/ESD/Emotion Speech Dataset"
EN_SPEAKERS = [f"{i:04d}" for i in range(11, 21)]  # 0011–0020 are English

rows = []

for spk in EN_SPEAKERS:
    txt_path = os.path.join(ROOT_DIR, spk, f"{spk}.txt")
    if not os.path.exists(txt_path):
        print(f"[WARN] transcript missing: {txt_path}")
        continue

    with open(txt_path, encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue

            parts = line.split()
            utt_id = parts[0]          # e.g. 0011_000001
            emo    = parts[-1]         # Angry / Happy / Neutral / Sad / Surprise
            sent   = " ".join(parts[1:-1])

            wav_path = os.path.join(ROOT_DIR, spk, emo, f"{utt_id}.wav")
            if not os.path.exists(wav_path):
                cands = glob.glob(
                    os.path.join(ROOT_DIR, spk, "**", f"{utt_id}.wav"),
                    recursive=True
                )
                if not cands:
                    print(f"[MISS] audio for {utt_id}")
                    continue
                wav_path = cands[0]

            rows.append({
                "speaker": spk,
                "utt_id": utt_id,
                "audio_path": wav_path,
                "transcript": sent,
                "emotion": emo,
            })

df = pd.DataFrame(rows)
print("Total English utterances:", len(df))
display(df.head())


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Total English utterances: 17500


Unnamed: 0,speaker,utt_id,audio_path,transcript,emotion
0,11,0011_000001,/content/drive/MyDrive/adsp/downloads/ESD/Emot...,"The nine the eggs, I keep.",Neutral
1,11,0011_000002,/content/drive/MyDrive/adsp/downloads/ESD/Emot...,"I did go, and made many prisoners.",Neutral
2,11,0011_000003,/content/drive/MyDrive/adsp/downloads/ESD/Emot...,That I owe my thanks to you.,Neutral
3,11,0011_000004,/content/drive/MyDrive/adsp/downloads/ESD/Emot...,They went up to the dark mass job had pointed ...,Neutral
4,11,0011_000005,/content/drive/MyDrive/adsp/downloads/ESD/Emot...,Clear than clear water!,Neutral


In [None]:
# ============================================
# 2. LOAD MODEL (SAME, BUT KEEP IT CLEAN)
# ============================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

REPO_ID = "mistralai/Voxtral-Mini-3B-2507"

processor = AutoProcessor.from_pretrained(REPO_ID, trust_remote_code=True)
model = VoxtralForConditionalGeneration.from_pretrained(
    REPO_ID,
    torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
    device_map=device,
    trust_remote_code=True,
)
model.eval()
print("Model loaded.")

LABELS = ["Angry", "Happy", "Neutral", "Sad", "Surprise"]



Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model loaded.


In [None]:
# ============================================
# 3. LABEL NORMALIZATION (UNCHANGED)
# ============================================
def normalize_label(text: str) -> str:
    text = text.lower()
    for lab in LABELS:
        if lab.lower() in text:
            return lab
    words = re.findall(r"[a-zA-Z]+", text)
    for w in words:
        for lab in LABELS:
            if w == lab.lower():
                return lab
    return "Unknown"


In [None]:
# ============================================
# 4. SINGLE-SAMPLE PREDICTOR (FOR DEBUGGING)
#    (unchanged semantics)
# ============================================
def voxtral_predict(audio_path: str,
                    transcript: str = "",
                    use_text: bool = False,
                    max_new_tokens: int = 3) -> str:
    """
    Zero-shot emotion prediction for a single sample.
    use_text = False -> audio only
    use_text = True  -> audio + transcript
    """
    file_url = "file://" + os.path.abspath(audio_path)

    content = [
        {
            "type": "audio_url",
            "audio_url": file_url,
        }
    ]

    if use_text and transcript.strip():
        content.append({"type": "text", "text": f"Transcript: {transcript}"})

    content.append({
        "type": "text",
        "text": (
            "Classify the speaker's emotion. "
            "Answer with ONLY one word from [Angry, Happy, Neutral, Sad, Surprise]."
        ),
    })

    conversation = [{"role": "user", "content": content}]

    inputs = processor.apply_chat_template(
        conversation,
        tokenize=True,
        return_tensors="pt",
    )

    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
        )

    new_tokens = outputs[:, inputs["input_ids"].shape[1]:]
    decoded = processor.batch_decode(new_tokens, skip_special_tokens=True)[0]

    return normalize_label(decoded)


# Quick sanity check on one row (OPTIONAL)
row = df.iloc[0]
print("Audio:", row.audio_path)
print("Transcript:", row.transcript)
print("Gold:", row.emotion)

print("Pred (audio only):", voxtral_predict(row.audio_path, use_text=False))
print("Pred (audio + text):", voxtral_predict(row.audio_path, row.transcript, use_text=True))


Audio: /content/drive/MyDrive/adsp/downloads/ESD/Emotion Speech Dataset/0011/Neutral/0011_000001.wav
Transcript: The nine the eggs, I keep.
Gold: Neutral
Pred (audio only): Happy
Pred (audio + text): Happy


In [None]:
# ============================================
# 5. BATCHED PREDICTION (FASTER PART)
# ============================================
def build_conversation_for_row(audio_path: str,
                               transcript: str,
                               use_text: bool) -> dict:
    """Create one-chat turn for a given row, matching the single-sample logic."""
    file_url = "file://" + os.path.abspath(audio_path)

    content = [
        {
            "type": "audio_url",
            "audio_url": file_url,
        }
    ]

    if use_text and transcript.strip():
        content.append({"type": "text", "text": f"Transcript: {transcript}"})

    content.append({
        "type": "text",
        "text": (
            "Classify the speaker's emotion. "
            "Answer with ONLY one word from [Angry, Happy, Neutral, Sad, Surprise]."
        ),
    })

    return {"role": "user", "content": content}


def voxtral_predict_batch(df_batch: pd.DataFrame,
                          use_text: bool,
                          max_new_tokens: int = 3) -> list[str]:
    """
    Batched zero-shot prediction for a slice of the dataframe.
    Returns list of labels (same order as df_batch).
    """
    conversations = []
    for _, row in df_batch.iterrows():
        conv = [build_conversation_for_row(
            audio_path=row["audio_path"],
            transcript=row["transcript"],
            use_text=use_text,
        )]
        conversations.append(conv)

    # Batched chat template
    inputs = processor.apply_chat_template(
        conversations,
        tokenize=True,
        return_tensors="pt",
    )

    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
        )

    input_ids = inputs["input_ids"]
    new_tokens = outputs[:, input_ids.shape[1]:]
    decoded_list = processor.batch_decode(new_tokens, skip_special_tokens=True)

    labels = [normalize_label(text) for text in decoded_list]
    return labels


def run_batched_predictions(df: pd.DataFrame,
                            batch_size: int = 4) -> pd.DataFrame:
    """
    Full fast prediction over df using batched Voxtral calls.
    Adds:
        pred_audio : audio-only prediction
        pred_both  : audio+transcript prediction
    """
    df_eval = df.reset_index(drop=True).copy()

    all_pa = []
    all_pb = []

    n = len(df_eval)
    n_batches = (n + batch_size - 1) // batch_size

    for start in tqdm(range(0, n, batch_size), total=n_batches):
        end = min(start + batch_size, n)
        batch = df_eval.iloc[start:end]

        # Audio only
        batch_pa = voxtral_predict_batch(batch, use_text=False)
        # Audio + transcript
        batch_pb = voxtral_predict_batch(batch, use_text=True)

        all_pa.extend(batch_pa)
        all_pb.extend(batch_pb)

    df_eval["pred_audio"] = all_pa
    df_eval["pred_both"]  = all_pb

    return df_eval

In [None]:
# 6. RUN FAST PREDICTIONS
# ============================================
# Tune batch_size based on GPU memory; 2–4 is usually safe, go higher if you can.
df_eval = run_batched_predictions(df, batch_size=8)
df_eval.head()


  0%|          | 0/1750 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# ============================================
# 7. METRICS & COMPARISON (UNCHANGED LOGIC)
# ============================================
y_true = df_eval["emotion"].values
y_pa   = df_eval["pred_audio"].values
y_pb   = df_eval["pred_both"].values

print("Label set in ground truth:", sorted(set(y_true)))
print("Label set in audio preds :", sorted(set(y_pa)))
print("Label set in audio+text  :", sorted(set(y_pb)))

acc_audio = accuracy_score(y_true, y_pa)
acc_both  = accuracy_score(y_true, y_pb)

f1m_audio = f1_score(y_true, y_pa, labels=LABELS, average="macro")
f1m_both  = f1_score(y_true, y_pb, labels=LABELS, average="macro")

print("\n=== Global metrics ===")
print(f"Accuracy (audio only) : {acc_audio: .3f}")
print(f"Accuracy (audio+text) : {acc_both:  .3f}")
print(f"Macro-F1 (audio only) : {f1m_audio: .3f}")
print(f"Macro-F1 (audio+text) : {f1m_both:  .3f}")

print("\n=== Audio only – classification report ===")
print(classification_report(y_true, y_pa, labels=LABELS))

print("\n=== Audio + transcript – classification report ===")
print(classification_report(y_true, y_pb, labels=LABELS))

cm_audio = confusion_matrix(y_true, y_pa, labels=LABELS)
cm_both  = confusion_matrix(y_true, y_pb, labels=LABELS)

print("\nConfusion matrix – audio only (rows = true, cols = pred):")
print(pd.DataFrame(cm_audio, index=LABELS, columns=LABELS))

print("\nConfusion matrix – audio + text (rows = true, cols = pred):")
print(pd.DataFrame(cm_both, index=LABELS, columns=LABELS))

In [None]:
# ============================================
# 8. PLOTS (UNCHANGED, BUT NOW ON df_eval)
# ============================================
systems = ["Audio only", "Audio + text"]
accs = [acc_audio, acc_both]
f1s  = [f1m_audio, f1m_both]

x = np.arange(len(systems))
width = 0.35

plt.figure(figsize=(6,4))
plt.bar(x - width/2, accs, width, label="Accuracy")
plt.bar(x + width/2, f1s,  width, label="Macro-F1")
plt.xticks(x, systems, rotation=0)
plt.ylim(0, 1.0)
plt.ylabel("Score")
plt.title("Voxtral zero-shot on ESD (English)")
plt.legend()
plt.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.show()

_, _, f1_audio_per_class, _ = precision_recall_fscore_support(
    y_true, y_pa, labels=LABELS, average=None
)
_, _, f1_both_per_class, _ = precision_recall_fscore_support(
    y_true, y_pb, labels=LABELS, average=None
)

x = np.arange(len(LABELS))
width = 0.35

plt.figure(figsize=(8,4))
plt.bar(x - width/2, f1_audio_per_class, width, label="Audio only")
plt.bar(x + width/2, f1_both_per_class,  width, label="Audio + text")
plt.xticks(x, LABELS, rotation=45)
plt.ylim(0, 1.0)
plt.ylabel("F1-score")
plt.title("Per-class F1: audio vs audio+transcript")
plt.legend()
plt.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# ============================================
# EXTRA EVALUATION + PLOTLY VISUALIZATIONS
# ============================================

!pip install -q plotly

import plotly.express as px
import plotly.graph_objects as go
from sklearn.metrics import precision_recall_fscore_support

# 1) PER-EMOTION METRICS TABLE
# ----------------------------

prec_a, rec_a, f1_a, sup = precision_recall_fscore_support(
    y_true, y_pa, labels=LABELS, zero_division=0
)
prec_b, rec_b, f1_b, _ = precision_recall_fscore_support(
    y_true, y_pb, labels=LABELS, zero_division=0
)

metrics_df = pd.DataFrame({
    "emotion": LABELS,
    "precision_audio": prec_a,
    "recall_audio":    rec_a,
    "f1_audio":        f1_a,
    "precision_both":  prec_b,
    "recall_both":     rec_b,
    "f1_both":         f1_b,
    "support":         sup,
})
display(metrics_df)


In [None]:
# 2) PER-EMOTION F1 COMPARISON (GROUPED BAR)
# ------------------------------------------

fig_f1_bar = go.Figure()

fig_f1_bar.add_bar(
    x=metrics_df["emotion"],
    y=metrics_df["f1_audio"],
    name="Audio only",
)
fig_f1_bar.add_bar(
    x=metrics_df["emotion"],
    y=metrics_df["f1_both"],
    name="Audio + text",
)

fig_f1_bar.update_layout(
    barmode="group",
    title="Per-emotion F1: audio vs audio+text",
    xaxis_title="Emotion",
    yaxis_title="F1-score",
    yaxis=dict(range=[0, 1]),
    template="plotly_white",
)
fig_f1_bar.show()


In [None]:
# 3) PER-EMOTION F1 SCATTER (WHICH EMOTIONS BENEFIT FROM TEXT?)
# --------------------------------------------------------------

fig_f1_scatter = go.Figure()

fig_f1_scatter.add_trace(go.Scatter(
    x=metrics_df["f1_audio"],
    y=metrics_df["f1_both"],
    mode="markers+text",
    text=metrics_df["emotion"],
    textposition="top center",
    name="Emotions",
))

# y = x reference line
min_f1 = float(min(metrics_df["f1_audio"].min(), metrics_df["f1_both"].min()))
max_f1 = float(max(metrics_df["f1_audio"].max(), metrics_df["f1_both"].max()))

fig_f1_scatter.add_trace(go.Scatter(
    x=[min_f1, max_f1],
    y=[min_f1, max_f1],
    mode="lines",
    line=dict(dash="dash"),
    name="y = x (no change)",
))

fig_f1_scatter.update_layout(
    title="F1(audio) vs F1(audio+text) per emotion",
    xaxis_title="F1 (audio only)",
    yaxis_title="F1 (audio + text)",
    xaxis=dict(range=[0, 1]),
    yaxis=dict(range=[0, 1]),
    template="plotly_white",
)
fig_f1_scatter.show()

In [None]:
# 4) CONFUSION MATRICES (HEATMAPS, PERCENTAGES)
# ---------------------------------------------

# Convert counts to percentages per true label for interpretability
cm_audio_pct = cm_audio.astype(float) / cm_audio.sum(axis=1, keepdims=True)
cm_both_pct  = cm_both.astype(float) / cm_both.sum(axis=1, keepdims=True)

fig_cm_audio = px.imshow(
    cm_audio_pct,
    x=LABELS,
    y=LABELS,
    color_continuous_scale="Blues",
    labels=dict(x="Predicted", y="True", color="Proportion"),
    text_auto=".2f",
)
fig_cm_audio.update_layout(
    title="Confusion matrix (audio only) – row-normalized",
    template="plotly_white",
)
fig_cm_audio.show()

fig_cm_both = px.imshow(
    cm_both_pct,
    x=LABELS,
    y=LABELS,
    color_continuous_scale="Greens",
    labels=dict(x="Predicted", y="True", color="Proportion"),
    text_auto=".2f",
)
fig_cm_both.update_layout(
    title="Confusion matrix (audio + text) – row-normalized",
    template="plotly_white",
)
fig_cm_both.show()


In [None]:
# 5) PER-SPEAKER ACCURACY (HOW CONSISTENT ACROSS SPEAKERS?)
# ----------------------------------------------------------

def speaker_accuracy(df_eval, pred_col):
    return (
        df_eval
        .assign(correct=lambda d: d["emotion"] == d[pred_col])
        .groupby("speaker")["correct"]
        .mean()
        .reset_index()
        .rename(columns={"correct": f"acc_{pred_col}"})
    )

acc_spk_audio = speaker_accuracy(df_eval, "pred_audio")
acc_spk_both  = speaker_accuracy(df_eval, "pred_both")

acc_spk = acc_spk_audio.merge(acc_spk_both, on="speaker")

fig_spk = go.Figure()
fig_spk.add_bar(
    x=acc_spk["speaker"],
    y=acc_spk["acc_pred_audio"],
    name="Audio only",
)
fig_spk.add_bar(
    x=acc_spk["speaker"],
    y=acc_spk["acc_pred_both"],
    name="Audio + text",
)

fig_spk.update_layout(
    barmode="group",
    title="Per-speaker accuracy: audio vs audio+text",
    xaxis_title="Speaker ID",
    yaxis_title="Accuracy",
    yaxis=dict(range=[0, 1]),
    template="plotly_white",
)
fig_spk.show()


In [None]:
# 6) UNKNOWN PREDICTIONS & CLASS BALANCE
# --------------------------------------

unknown_audio = (df_eval["pred_audio"] == "Unknown").sum()
unknown_both  = (df_eval["pred_both"]  == "Unknown").sum()

print(f"Unknown predictions (audio only): {unknown_audio}")
print(f"Unknown predictions (audio+text): {unknown_both}")

# Class distribution in ground truth
emo_counts = df_eval["emotion"].value_counts().reindex(LABELS).fillna(0)

fig_dist = px.bar(
    emo_counts,
    x=emo_counts.index,
    y=emo_counts.values,
    labels={"x": "Emotion", "y": "Count"},
    title="Class distribution in ESD-English subset",
    template="plotly_white",
)
fig_dist.show()