# Multimodal Analysis — Where Does CLIP Help vs Hurt?

This notebook compares **ASR-only** (α=0) against **multimodal rescoring** (α>0)
at the individual sample level, then aggregates to answer the key questions:

- In what **percentage of cases** does multimodal rescoring improve WER?
- Is the benefit concentrated in **dysarthric** speech, **healthy** speech, or both?
- Which **specific samples** see the biggest improvement or degradation?
- Does the system meet the **success criterion**: multimodal helps ≥ 50% of test cases?

## Why this matters

CLIP rescoring doesn't always help. If the ASR already got the right answer, CLIP
can potentially push a worse hypothesis to the top. Understanding *when* and *why*
multimodal helps is critical for deciding whether to deploy it and at what α.

## 0) Imports and prerequisite code

This notebook is **self-contained** — it includes the full `MultimodalASR` pipeline
inline so it can run independently.

In [None]:
import re
import json
from collections import defaultdict
from pathlib import Path

import numpy as np
import torch
import soundfile as sf
from PIL import Image
from jiwer import wer as compute_wer, cer as compute_cer
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    CLIPModel,
    CLIPProcessor,
)

# ── Prerequisite code (defined in clip_embeddings & multimodal_asr notebooks) ──
# Included inline so this notebook is self-contained and runnable on its own.

DEFAULT_CLIP_MODEL = "openai/clip-vit-base-patch32"
DEFAULT_WHISPER_MODEL = "openai/whisper-small"
DEFAULT_CACHE_PATH = Path("cache/clip_image_embeddings.npz")

def _pick_device():
    if torch.cuda.is_available():
        return "cuda"
    if torch.backends.mps.is_available():
        return "mps"
    return "cpu"

def load_clip(model_name=DEFAULT_CLIP_MODEL, device=None):
    if device is None:
        device = _pick_device()
    processor = CLIPProcessor.from_pretrained(model_name)
    model = CLIPModel.from_pretrained(model_name).to(device)
    model.eval()
    return model, processor, device

def _get_image_features(model, pixel_values):
    vision_out = model.vision_model(pixel_values=pixel_values)
    features = model.visual_projection(vision_out.pooler_output)
    return features / features.norm(dim=-1, keepdim=True)

def _get_text_features(model, input_ids, attention_mask):
    text_out = model.text_model(input_ids=input_ids, attention_mask=attention_mask)
    features = model.text_projection(text_out.pooler_output)
    return features / features.norm(dim=-1, keepdim=True)

def encode_texts(model, processor, texts, device):
    inputs = processor(text=texts, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        features = _get_text_features(model, inputs["input_ids"].to(device), inputs["attention_mask"].to(device))
    return features.cpu().numpy()

def load_cached_embeddings(cache_path):
    return dict(np.load(str(cache_path)))

FILLER_WORDS = frozenset({
    "um", "uh", "uh-huh", "hmm", "hm", "ah", "er", "oh",
    "like", "you know", "i mean", "okay", "ok", "so", "well",
})

def normalize_transcript(text):
    text = text.strip().lower()
    if not text:
        return text
    text = re.sub(r"\b(\w+)( \1\b)+", r"\1", text)
    words = text.split()
    cleaned, skip_next = [], False
    for i, w in enumerate(words):
        if skip_next:
            skip_next = False
            continue
        bigram = f"{w} {words[i + 1]}" if i + 1 < len(words) else ""
        if bigram in FILLER_WORDS:
            skip_next = True
            continue
        if w not in FILLER_WORDS:
            cleaned.append(w)
    return re.sub(r"\s+", " ", " ".join(cleaned)).strip()

def to_caption_style(text):
    text = normalize_transcript(text)
    if not text:
        return text
    return f"an image showing {text}" if len(text.split()) <= 2 else text

def _softmax(x):
    e = np.exp(x - x.max())
    return e / e.sum()


class MultimodalASR:
    """Whisper ASR with optional CLIP visual-context rescoring."""

    def __init__(self, whisper_model_id=DEFAULT_WHISPER_MODEL, clip_model_id=DEFAULT_CLIP_MODEL,
                 cache_path=DEFAULT_CACHE_PATH, alpha=0.3, num_beams=5, device=None):
        if device is None:
            device = _pick_device()
        self.device = device
        self.alpha = alpha
        self.num_beams = num_beams
        self.whisper_processor = WhisperProcessor.from_pretrained(whisper_model_id)
        self.whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_id).to(device)
        self.whisper_model.eval()
        self.clip_model, self.clip_processor, _ = load_clip(clip_model_id, device)
        self.image_embeddings = {}
        cache_path = Path(cache_path)
        if cache_path.exists():
            self.image_embeddings = load_cached_embeddings(cache_path)
            print(f"Loaded {len(self.image_embeddings)} cached image embeddings")

    def generate_nbest(self, audio_array, sr=16000):
        input_features = self.whisper_processor(audio_array, sampling_rate=sr, return_tensors="pt").input_features.to(self.device)
        with torch.no_grad():
            outputs = self.whisper_model.generate(input_features, num_beams=self.num_beams,
                num_return_sequences=self.num_beams, return_dict_in_generate=True, output_scores=True)
        seq_scores = outputs.sequences_scores.cpu().numpy()
        hypotheses, seen = [], set()
        for i, seq in enumerate(outputs.sequences):
            text = self.whisper_processor.decode(seq, skip_special_tokens=True).strip().lower()
            if text not in seen:
                seen.add(text)
                hypotheses.append({"text": text, "score": float(seq_scores[i])})
        hypotheses.sort(key=lambda h: h["score"], reverse=True)
        return hypotheses

    def clip_similarity(self, image_embedding, texts):
        if not texts:
            return np.array([])
        text_embs = encode_texts(self.clip_model, self.clip_processor, texts, self.device)
        return np.atleast_1d((image_embedding.reshape(1, -1) @ text_embs.T).squeeze())

    def rescore(self, hypotheses, image_embedding, alpha=None, caption_style=True):
        if alpha is None:
            alpha = self.alpha
        if not hypotheses:
            return hypotheses
        texts = [h["text"] for h in hypotheses]
        clip_texts = [to_caption_style(t) if caption_style else normalize_transcript(t) for t in texts]
        clip_scores = self.clip_similarity(image_embedding, clip_texts)
        asr_probs = _softmax(np.array([h["score"] for h in hypotheses]))
        clip_probs = _softmax(clip_scores * 100.0)
        fused = (1 - alpha) * asr_probs + alpha * clip_probs
        rescored = [{"text": h["text"], "asr_score": float(asr_probs[i]),
                      "clip_score": float(clip_scores[i]), "fused_score": float(fused[i])}
                     for i, h in enumerate(hypotheses)]
        rescored.sort(key=lambda h: h["fused_score"], reverse=True)
        return rescored

    def transcribe(self, audio_array, image_id=None, image_path=None, sr=16000, alpha=None):
        hypotheses = self.generate_nbest(audio_array, sr)
        if image_id is None and image_path is None:
            return {"transcription": hypotheses[0]["text"] if hypotheses else "", "hypotheses": hypotheses, "mode": "asr_only"}
        img_emb = self._resolve_image_embedding(image_id, image_path)
        if img_emb is None:
            return {"transcription": hypotheses[0]["text"] if hypotheses else "", "hypotheses": hypotheses, "mode": "asr_only",
                    "warning": f"Image not found (id={image_id}, path={image_path})"}
        rescored = self.rescore(hypotheses, img_emb, alpha)
        return {"transcription": rescored[0]["text"] if rescored else "", "hypotheses": rescored,
                "mode": "multimodal", "alpha": alpha if alpha is not None else self.alpha}

    def _resolve_image_embedding(self, image_id, image_path):
        if image_id and image_id in self.image_embeddings:
            return self.image_embeddings[image_id]
        if image_path:
            image_path = Path(image_path)
            if not image_path.exists():
                return None
            image = Image.open(image_path).convert("RGB")
            inputs = self.clip_processor(images=image, return_tensors="pt")
            with torch.no_grad():
                emb = _get_image_features(self.clip_model, inputs["pixel_values"].to(self.device))
            return emb.cpu().numpy().squeeze()
        return None

## 1) `compare_asr_vs_multimodal` — per-sample comparison

For each test sample, this function:
1. Runs the pipeline **without** an image → ASR-only transcription
2. Runs the pipeline **with** an image → multimodal transcription
3. Computes WER for both against the ground-truth reference
4. Records whether multimodal **improved**, **degraded**, or **unchanged** the result

Each test sample should include an optional `speech_status` field (`"dysarthria"` or
`"healthy"`) for per-group analysis.

In [None]:
def compare_asr_vs_multimodal(
    pipeline: MultimodalASR,
    test_samples: list[dict],
    alpha: float = 0.3,
) -> dict:
    """Run ASR-only and multimodal on every sample, return full analysis."""
    per_sample: list[dict] = []

    for i, sample in enumerate(test_samples):
        ref = sample["reference"].strip().lower()
        sr = sample.get("sr", 16000)

        asr_result = pipeline.transcribe(audio_array=sample["audio_array"], sr=sr)
        asr_text = asr_result["transcription"]

        mm_result = pipeline.transcribe(
            audio_array=sample["audio_array"],
            image_id=sample.get("image_id"),
            image_path=sample.get("image_path"),
            sr=sr,
            alpha=alpha,
        )
        mm_text = mm_result["transcription"]

        asr_wer = compute_wer(ref, asr_text) if ref else 0.0
        mm_wer = compute_wer(ref, mm_text) if ref else 0.0

        per_sample.append({
            "index": i,
            "reference": ref,
            "asr_hypothesis": asr_text,
            "multimodal_hypothesis": mm_text,
            "asr_wer": asr_wer,
            "multimodal_wer": mm_wer,
            "wer_delta": mm_wer - asr_wer,
            "improved": mm_wer < asr_wer,
            "degraded": mm_wer > asr_wer,
            "unchanged": abs(mm_wer - asr_wer) < 1e-9,
            "speech_status": sample.get("speech_status", "unknown"),
        })

        if (i + 1) % 10 == 0:
            print(f"  Compared {i + 1}/{len(test_samples)} samples \u2026")

    return _aggregate(per_sample)

## 2) `_aggregate` — compute summary statistics

Takes the per-sample results and computes:
- **Overall WER/CER** for ASR-only and multimodal
- **Improvement rate** (% of samples where multimodal had lower WER)
- **Per-group breakdown** by `speech_status` (dysarthria vs healthy)
- **Top movers** — the 5 samples with the biggest improvement and degradation

In [None]:
def _aggregate(per_sample: list[dict]) -> dict:
    n = len(per_sample)
    if n == 0:
        return {"error": "No results to analyse"}

    improved = sum(r["improved"] for r in per_sample)
    degraded = sum(r["degraded"] for r in per_sample)
    unchanged = sum(r["unchanged"] for r in per_sample)

    all_refs = [r["reference"] for r in per_sample]
    all_asr = [r["asr_hypothesis"] for r in per_sample]
    all_mm = [r["multimodal_hypothesis"] for r in per_sample]

    overall_asr_wer = compute_wer(all_refs, all_asr)
    overall_mm_wer = compute_wer(all_refs, all_mm)
    overall_asr_cer = compute_cer(all_refs, all_asr)
    overall_mm_cer = compute_cer(all_refs, all_mm)

    # Group by speech status
    groups: dict[str, list[dict]] = defaultdict(list)
    for r in per_sample:
        groups[r["speech_status"]].append(r)

    group_stats: dict[str, dict] = {}
    for status, items in groups.items():
        g_refs = [r["reference"] for r in items]
        g_asr = [r["asr_hypothesis"] for r in items]
        g_mm = [r["multimodal_hypothesis"] for r in items]
        group_stats[status] = {
            "count": len(items),
            "asr_wer": compute_wer(g_refs, g_asr),
            "multimodal_wer": compute_wer(g_refs, g_mm),
            "asr_cer": compute_cer(g_refs, g_asr),
            "multimodal_cer": compute_cer(g_refs, g_mm),
            "improved": sum(r["improved"] for r in items),
            "degraded": sum(r["degraded"] for r in items),
        }

    sorted_by_delta = sorted(per_sample, key=lambda r: r["wer_delta"])
    top_improvements = sorted_by_delta[:5]
    top_degradations = [r for r in reversed(sorted_by_delta) if r["degraded"]][:5]

    return {
        "summary": {
            "total_samples": n,
            "improved": improved,
            "degraded": degraded,
            "unchanged": unchanged,
            "improvement_rate": improved / n,
            "overall_asr_wer": overall_asr_wer,
            "overall_multimodal_wer": overall_mm_wer,
            "overall_asr_cer": overall_asr_cer,
            "overall_multimodal_cer": overall_mm_cer,
            "wer_reduction": overall_asr_wer - overall_mm_wer,
            "relative_improvement": (
                (overall_asr_wer - overall_mm_wer) / overall_asr_wer
                if overall_asr_wer > 0 else 0.0
            ),
            "success_criterion_met": improved / n >= 0.5,
        },
        "group_stats": group_stats,
        "per_sample": per_sample,
        "top_improvements": top_improvements,
        "top_degradations": top_degradations,
    }

## 3) `print_analysis` — human-readable report

Prints a formatted summary including:
- Overall improvement/degradation counts
- WER comparison (ASR-only vs multimodal)
- Per-group breakdown (dysarthria vs healthy)
- Specific examples of top improvements and degradations
- Whether the success criterion (≥50% improved) is met

In [None]:
def print_analysis(analysis: dict):
    """Human-readable report."""
    s = analysis["summary"]

    print("=" * 70)
    print("MULTIMODAL RESCORING ANALYSIS")
    print("=" * 70)

    print(f"\n  Total samples:           {s['total_samples']}")
    print(f"  Improved by multimodal:  {s['improved']} ({s['improvement_rate'] * 100:.1f}%)")
    print(f"  Degraded:                {s['degraded']}")
    print(f"  Unchanged:               {s['unchanged']}")

    print(f"\n  Overall ASR WER:         {s['overall_asr_wer'] * 100:.1f}%")
    print(f"  Overall Multimodal WER:  {s['overall_multimodal_wer'] * 100:.1f}%")
    print(f"  WER reduction:           {s['wer_reduction'] * 100:.1f}% abs "
          f"({s['relative_improvement'] * 100:.1f}% rel)")

    criterion = "PASSED" if s["success_criterion_met"] else "FAILED"
    print(f"\n  Success criterion (\u226550% helped): {criterion}")

    if analysis["group_stats"]:
        print(f"\n{'  Per-Group Breakdown  ':=^70}")
        header = f"  {'Group':<14} {'N':>5} {'ASR WER':>9} {'MM WER':>9} {'Improved':>9} {'Degraded':>9}"
        print(header)
        print("  " + "-" * (len(header) - 2))
        for group, d in sorted(analysis["group_stats"].items()):
            print(
                f"  {group:<14} {d['count']:>5} "
                f"{d['asr_wer'] * 100:>8.1f}% "
                f"{d['multimodal_wer'] * 100:>8.1f}% "
                f"{d['improved']:>9} {d['degraded']:>9}"
            )

    print(f"\n{'  Top Improvements  ':=^70}")
    for r in analysis.get("top_improvements", [])[:5]:
        print(f"  [{r['index']:>4}] WER {r['asr_wer'] * 100:.0f}% \u2192 {r['multimodal_wer'] * 100:.0f}%"
              f"  ({r['speech_status']})")
        print(f"        ref: {r['reference']}")
        print(f"        asr: {r['asr_hypothesis']}")
        print(f"        mm:  {r['multimodal_hypothesis']}")

    if analysis.get("top_degradations"):
        print(f"\n{'  Top Degradations  ':=^70}")
        for r in analysis["top_degradations"][:5]:
            print(f"  [{r['index']:>4}] WER {r['asr_wer'] * 100:.0f}% \u2192 {r['multimodal_wer'] * 100:.0f}%"
                  f"  ({r['speech_status']})")
            print(f"        ref: {r['reference']}")
            print(f"        asr: {r['asr_hypothesis']}")
            print(f"        mm:  {r['multimodal_hypothesis']}")

## 4) `save_analysis` — persist results to JSON

Saves the full analysis (including per-sample details) to a JSON file. Handles
NumPy types that `json.dump` doesn't support natively.

In [None]:
def save_analysis(analysis: dict, output_path: Path | str):
    """Write analysis dict to JSON (handles numpy types)."""
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    class _Enc(json.JSONEncoder):
        def default(self, o):
            if isinstance(o, (np.integer,)):
                return int(o)
            if isinstance(o, (np.floating,)):
                return float(o)
            if isinstance(o, np.ndarray):
                return o.tolist()
            return super().default(o)

    with open(output_path, "w") as f:
        json.dump(analysis, f, indent=2, cls=_Enc)
    print(f"Analysis saved \u2192 {output_path}")

## 5) Prepare test data and run analysis

Like the fusion tuning notebook, this requires **paired image-audio test data**.
Each sample needs `audio_array`, `reference`, `image_id`, and optionally
`speech_status`.

> Replace the placeholder below with your actual test set.

In [None]:
# Placeholder — replace with your actual paired test data
test_samples = []

# Example:
# import soundfile as sf
# pairs = [
#     ("../audio/torgo/processed/test/sample_00000.wav", "img_001.png",
#      "a cat sleeping on a windowsill", "dysarthria"),
# ]
# for audio_file, image_id, reference, status in pairs:
#     audio_array, sr = sf.read(audio_file)
#     test_samples.append({
#         "audio_array": audio_array, "sr": sr,
#         "reference": reference, "image_id": image_id,
#         "speech_status": status,
#     })

In [None]:
if test_samples:
    pipeline = MultimodalASR(alpha=0.3, num_beams=5)
    analysis = compare_asr_vs_multimodal(pipeline, test_samples, alpha=0.3)
    print_analysis(analysis)
    save_analysis(analysis, Path("cache/multimodal_analysis.json"))
else:
    print("No test samples provided. Populate test_samples in the cell above to run analysis.")

## 6) Interpreting the report

### Success criterion

The project's success criterion is: **multimodal helps ≥ 50% of test cases**.
"Helps" means the multimodal WER for that sample is strictly lower than the
ASR-only WER.

### If the criterion fails

The deliverable specifies fallback strategies:

1. **Improve transcript normalization** — the filler-word list and caption-wrapping
   heuristics in `transcript_normalization.py` may need domain-specific tuning
2. **Fine-tune CLIP** — CLIP's text encoder was trained on web captions, not
   conversational speech. Fine-tuning on (image, therapeutic-description) pairs
   could improve alignment
3. **Condition on confidence** — only apply CLIP rescoring when Whisper's top-beam
   confidence is below a threshold (i.e., the model is uncertain)

### What the group breakdown reveals

- If CLIP helps **dysarthric** speech more than healthy speech, multimodal rescoring
  is successfully compensating for ASR weakness on atypical speech
- If CLIP helps **healthy** speech more, it may be doing vocabulary disambiguation
  rather than acoustic disambiguation — still valuable but different
- If it **hurts** a particular group, α may need per-group tuning

## Summary

This notebook provides the tools to answer:

| Question | Metric |
|----------|--------|
| Does multimodal help overall? | WER reduction (abs & relative) |
| How often does it help? | Improvement rate (% of samples) |
| Does it meet the success bar? | ≥ 50% improvement rate |
| Where does it help most? | Per-group breakdown (dysarthria vs healthy) |
| What are the failure modes? | Top degradation examples |

Together with the fusion tuning results, this analysis determines whether and how
to deploy CLIP rescoring in the final ADI/O system.