# Multimodal ASR — Whisper + CLIP Rescoring Pipeline

This notebook implements the **core multimodal transcription pipeline**:

1. **Whisper beam search** generates *n*-best transcription hypotheses with log-prob scores
2. **CLIP** computes cosine similarity between the visual context (image) and each text hypothesis
3. **Fusion** combines both signals: `score = (1-α)·ASR_prob + α·CLIP_prob`
4. **Re-rank** and return the best hypothesis

## Why multimodal rescoring?

In the ADI/O therapeutic setting, a child is shown an image and asked to describe it.
When the speech is dysarthric, Whisper's 1-best transcription may be wrong — but the
correct answer is often *somewhere* in the beam. CLIP can identify which hypothesis
best matches the image content, resolving ambiguities that pure audio cannot.

For example, if the image shows a **bear**, Whisper might confuse "bear" with "bare".
CLIP's image-text similarity strongly favours "bear" when a bear is visible.

## 0) Imports and configuration

This notebook is **self-contained** — all CLIP helper functions, transcript normalization,
and the main pipeline class are defined inline.

If you haven't already, run the `clip_embeddings` notebook first to generate the
cached image embeddings in `cache/`.

In [None]:
import re
import json
from pathlib import Path

import numpy as np
import torch
import soundfile as sf
from PIL import Image
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    CLIPModel,
    CLIPProcessor,
)

DEFAULT_CLIP_MODEL = "openai/clip-vit-base-patch32"
DEFAULT_WHISPER_MODEL = "openai/whisper-small"
DEFAULT_CACHE_PATH = Path("cache/clip_image_embeddings.npz")

# ── CLIP helpers (from clip_embeddings notebook) ─────────────────────

def _pick_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    if torch.backends.mps.is_available():
        return "mps"
    return "cpu"

def load_clip(model_name=DEFAULT_CLIP_MODEL, device=None):
    if device is None:
        device = _pick_device()
    processor = CLIPProcessor.from_pretrained(model_name)
    model = CLIPModel.from_pretrained(model_name).to(device)
    model.eval()
    return model, processor, device

def _get_image_features(model, pixel_values):
    vision_out = model.vision_model(pixel_values=pixel_values)
    features = model.visual_projection(vision_out.pooler_output)
    return features / features.norm(dim=-1, keepdim=True)

def _get_text_features(model, input_ids, attention_mask):
    text_out = model.text_model(input_ids=input_ids, attention_mask=attention_mask)
    features = model.text_projection(text_out.pooler_output)
    return features / features.norm(dim=-1, keepdim=True)

def encode_texts(model, processor, texts, device):
    inputs = processor(text=texts, return_tensors="pt", padding=True, truncation=True)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    with torch.no_grad():
        features = _get_text_features(model, input_ids, attention_mask)
    return features.cpu().numpy()

def load_cached_embeddings(cache_path):
    data = np.load(str(cache_path))
    return dict(data)

# ── Transcript normalization (from transcript_normalization notebook) ─

FILLER_WORDS = frozenset({
    "um", "uh", "uh-huh", "hmm", "hm", "ah", "er", "oh",
    "like", "you know", "i mean", "okay", "ok", "so", "well",
})

def normalize_transcript(text):
    text = text.strip().lower()
    if not text:
        return text
    text = re.sub(r"\b(\w+)( \1\b)+", r"\1", text)
    words = text.split()
    cleaned = []
    skip_next = False
    for i, w in enumerate(words):
        if skip_next:
            skip_next = False
            continue
        bigram = f"{w} {words[i + 1]}" if i + 1 < len(words) else ""
        if bigram in FILLER_WORDS:
            skip_next = True
            continue
        if w not in FILLER_WORDS:
            cleaned.append(w)
    return re.sub(r"\s+", " ", " ".join(cleaned)).strip()

def to_caption_style(text):
    text = normalize_transcript(text)
    if not text:
        return text
    if len(text.split()) <= 2:
        return f"an image showing {text}"
    return text

## 1) Softmax helper

Both ASR log-probabilities and CLIP cosine similarities need to be converted into
probability distributions before fusion. We use the numerically stable softmax:

$$\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}$$

Subtracting the max prevents overflow.

In [None]:
def _softmax(x: np.ndarray) -> np.ndarray:
    e = np.exp(x - x.max())
    return e / e.sum()

## 2) `MultimodalASR` class

This is the main pipeline class. It holds both models (Whisper + CLIP) and
provides three key methods:

### `generate_nbest(audio_array, sr)`
Runs Whisper with **beam search** (`num_beams=5` by default) and returns up to 5
unique hypotheses, each with a length-normalised log-probability score.
Duplicate beams are filtered out.

### `rescore(hypotheses, image_embedding, alpha)`
The fusion step:
1. Normalise ASR log-probs → probabilities via softmax
2. Encode each hypothesis text through CLIP's text encoder
3. Compute cosine similarity with the image embedding
4. Scale CLIP similarities by temperature τ≈100 (CLIP's learned logit scale), then softmax
5. Fuse: `final = (1-α)·asr_probs + α·clip_probs`
6. Re-sort by fused score

### `transcribe(audio_array, image_id=, image_path=, sr=, alpha=)`
The top-level entry point that chains `generate_nbest` → `rescore` → return best.
Pass an `image_id` (for cached embeddings) or `image_path` (computed on-the-fly).
Omit both for ASR-only transcription.

In [None]:
class MultimodalASR:
    """Whisper ASR with optional CLIP visual-context rescoring."""

    def __init__(
        self,
        whisper_model_id: str = DEFAULT_WHISPER_MODEL,
        clip_model_id: str = DEFAULT_CLIP_MODEL,
        cache_path: Path | str = DEFAULT_CACHE_PATH,
        alpha: float = 0.3,
        num_beams: int = 5,
        device: str | None = None,
    ):
        if device is None:
            device = (
                "cuda" if torch.cuda.is_available()
                else "mps" if torch.backends.mps.is_available()
                else "cpu"
            )
        self.device = device
        self.alpha = alpha
        self.num_beams = num_beams

        # Whisper
        self.whisper_processor = WhisperProcessor.from_pretrained(whisper_model_id)
        self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
            whisper_model_id
        ).to(device)
        self.whisper_model.eval()

        # CLIP
        self.clip_model, self.clip_processor, _ = load_clip(clip_model_id, device)

        # Pre-computed image embeddings
        self.image_embeddings: dict[str, np.ndarray] = {}
        cache_path = Path(cache_path)
        if cache_path.exists():
            self.image_embeddings = load_cached_embeddings(cache_path)
            print(f"Loaded {len(self.image_embeddings)} cached image embeddings")

    # ── n-best generation ────────────────────────────────────────────

    def generate_nbest(
        self, audio_array: np.ndarray, sr: int = 16000
    ) -> list[dict]:
        """Return n-best Whisper hypotheses sorted by score (descending)."""
        input_features = self.whisper_processor(
            audio_array, sampling_rate=sr, return_tensors="pt"
        ).input_features.to(self.device)

        with torch.no_grad():
            outputs = self.whisper_model.generate(
                input_features,
                num_beams=self.num_beams,
                num_return_sequences=self.num_beams,
                return_dict_in_generate=True,
                output_scores=True,
            )

        sequences = outputs.sequences
        seq_scores = outputs.sequences_scores.cpu().numpy()

        hypotheses: list[dict] = []
        seen: set[str] = set()
        for i, seq in enumerate(sequences):
            text = self.whisper_processor.decode(
                seq, skip_special_tokens=True
            ).strip().lower()
            if text in seen:
                continue
            seen.add(text)
            hypotheses.append({"text": text, "score": float(seq_scores[i])})

        hypotheses.sort(key=lambda h: h["score"], reverse=True)
        return hypotheses

    # ── CLIP similarity ──────────────────────────────────────────────

    def clip_similarity(
        self, image_embedding: np.ndarray, texts: list[str]
    ) -> np.ndarray:
        """Cosine similarities between one image embedding and N texts."""
        if not texts:
            return np.array([])
        text_embeddings = encode_texts(
            self.clip_model, self.clip_processor, texts, self.device
        )
        sims = (image_embedding.reshape(1, -1) @ text_embeddings.T).squeeze()
        return np.atleast_1d(sims)

    # ── rescoring ────────────────────────────────────────────────────

    def rescore(
        self,
        hypotheses: list[dict],
        image_embedding: np.ndarray,
        alpha: float | None = None,
        caption_style: bool = True,
    ) -> list[dict]:
        """Fuse ASR log-probs with CLIP cosine similarity."""
        if alpha is None:
            alpha = self.alpha
        if not hypotheses:
            return hypotheses

        texts = [h["text"] for h in hypotheses]
        clip_texts = [
            to_caption_style(t) if caption_style else normalize_transcript(t)
            for t in texts
        ]

        clip_scores = self.clip_similarity(image_embedding, clip_texts)

        asr_logits = np.array([h["score"] for h in hypotheses])
        asr_probs = _softmax(asr_logits)

        clip_probs = _softmax(clip_scores * 100.0)

        fused = (1 - alpha) * asr_probs + alpha * clip_probs

        rescored = [
            {
                "text": h["text"],
                "asr_score": float(asr_probs[i]),
                "clip_score": float(clip_scores[i]),
                "fused_score": float(fused[i]),
            }
            for i, h in enumerate(hypotheses)
        ]
        rescored.sort(key=lambda h: h["fused_score"], reverse=True)
        return rescored

    # ── main entry point ─────────────────────────────────────────────

    def transcribe(
        self,
        audio_array: np.ndarray,
        image_id: str | None = None,
        image_path: str | Path | None = None,
        sr: int = 16000,
        alpha: float | None = None,
    ) -> dict:
        """Full transcription pipeline."""
        hypotheses = self.generate_nbest(audio_array, sr)

        if image_id is None and image_path is None:
            return {
                "transcription": hypotheses[0]["text"] if hypotheses else "",
                "hypotheses": hypotheses,
                "mode": "asr_only",
            }

        img_emb = self._resolve_image_embedding(image_id, image_path)
        if img_emb is None:
            return {
                "transcription": hypotheses[0]["text"] if hypotheses else "",
                "hypotheses": hypotheses,
                "mode": "asr_only",
                "warning": f"Image not found (id={image_id}, path={image_path})",
            }

        rescored = self.rescore(hypotheses, img_emb, alpha)
        return {
            "transcription": rescored[0]["text"] if rescored else "",
            "hypotheses": rescored,
            "mode": "multimodal",
            "alpha": alpha if alpha is not None else self.alpha,
        }

    # ── helpers ───────────────────────────────────────────────────────

    def _resolve_image_embedding(
        self, image_id: str | None, image_path: str | Path | None
    ) -> np.ndarray | None:
        if image_id and image_id in self.image_embeddings:
            return self.image_embeddings[image_id]
        if image_path:
            image_path = Path(image_path)
            if not image_path.exists():
                return None
            image = Image.open(image_path).convert("RGB")
            inputs = self.clip_processor(images=image, return_tensors="pt")
            pixel_values = inputs["pixel_values"].to(self.device)
            with torch.no_grad():
                emb = _get_image_features(self.clip_model, pixel_values)
            return emb.cpu().numpy().squeeze()
        return None

## 3) Initialise the pipeline

This loads **both** Whisper Small and CLIP ViT-B/32, plus the cached image embeddings.
It takes a moment the first time (model weights are downloaded and cached by HuggingFace).

In [None]:
pipeline = MultimodalASR(
    whisper_model_id=DEFAULT_WHISPER_MODEL,
    alpha=0.3,
    num_beams=5,
)

## 4) Demo: ASR-only vs multimodal transcription

Pick a test audio file and an image. We run the pipeline twice — once without an
image (pure ASR) and once with the image (multimodal rescoring) — and compare the
hypotheses and their scores.

If CLIP helps, the multimodal ranking will differ from the ASR-only ranking, and the
top hypothesis will be more semantically aligned with the image.

In [None]:
audio_path = "../audio/torgo/processed/test/sample_00007.wav"
image_id = "img_001.png"

audio_array, sr = sf.read(audio_path)
print(f"Audio: {audio_path}  ({len(audio_array)/sr:.1f}s @ {sr} Hz)")

# ASR-only
asr_result = pipeline.transcribe(audio_array, sr=sr)
print(f"\n--- ASR-only (mode: {asr_result['mode']}) ---")
print(f"Transcription: {asr_result['transcription']}")
for h in asr_result["hypotheses"]:
    print(f"  {h['score']:.4f}  {h['text']}")

# Multimodal
mm_result = pipeline.transcribe(audio_array, image_id=image_id, sr=sr)
print(f"\n--- Multimodal (mode: {mm_result['mode']}, \u03b1={mm_result['alpha']}) ---")
print(f"Transcription: {mm_result['transcription']}")
for h in mm_result["hypotheses"]:
    print(f"  fused={h['fused_score']:.4f}  asr={h['asr_score']:.4f}  clip={h['clip_score']:.4f}  {h['text']}")

## 5) Examining the fusion in detail

To understand how α affects the ranking, let's run the same hypotheses through
`rescore` with several different α values and observe which hypothesis wins.

In [None]:
hypotheses = pipeline.generate_nbest(audio_array, sr)
img_emb = pipeline.image_embeddings.get(image_id)

if img_emb is not None:
    for alpha in [0.0, 0.1, 0.3, 0.5, 0.7, 1.0]:
        rescored = pipeline.rescore(hypotheses, img_emb, alpha=alpha)
        best = rescored[0]
        print(f"  \u03b1={alpha:.1f}  best=\"{best['text']}\"  fused={best['fused_score']:.4f}")
else:
    print(f"Image {image_id} not found in cache. Run clip_embeddings notebook first.")

## Summary

This notebook built the `MultimodalASR` class, the central piece of Phase 3:

| Method | Purpose |
|--------|--------|
| `generate_nbest` | Whisper beam search → *n* unique hypotheses with log-prob scores |
| `clip_similarity` | Cosine similarity between image embedding and text candidates |
| `rescore` | Fuses ASR + CLIP probabilities with linear interpolation (α) |
| `transcribe` | End-to-end: audio in → best transcription out |

The fusion coefficient **α** controls the balance:
- α = 0 → pure ASR (CLIP ignored)
- α = 1 → pure CLIP (ASR confidence ignored)
- α ≈ 0.2–0.4 is typically optimal (determined by grid search in `fusion_tuning`)