# Fusion Coefficient Tuning — Grid Search for Optimal α

The multimodal rescoring pipeline combines ASR confidence with CLIP visual similarity
using a **fusion coefficient α**:

$$\text{score} = (1 - \alpha) \cdot P_{\text{ASR}} + \alpha \cdot P_{\text{CLIP}}$$

| α | Behaviour |
|---|----------|
| 0.0 | Pure ASR — CLIP is ignored entirely |
| 0.5 | Equal weight to both signals |
| 1.0 | Pure CLIP — ASR confidence is ignored |

The optimal α depends on how informative the visual context is relative to the
acoustic signal. This notebook finds it via **grid search** over a validation set.

## Strategy

Whisper beam search is the expensive part. We run it **once** per sample to
pre-compute all *n*-best hypotheses, then sweep α values cheaply by just
re-running the fusion arithmetic.

## 0) Imports and prerequisite code

This notebook is **self-contained** — it includes the full `MultimodalASR` pipeline
(CLIP helpers, transcript normalization, and the pipeline class) inline so it can
run independently. We also import `jiwer` for computing Word Error Rate at each α.

In [None]:
import re
import json
from pathlib import Path

import numpy as np
import torch
import soundfile as sf
from PIL import Image
from jiwer import wer as compute_wer
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    CLIPModel,
    CLIPProcessor,
)

DEFAULT_ALPHA_RANGE = np.round(np.arange(0.0, 1.05, 0.05), 2)

# ── Prerequisite code (defined in clip_embeddings & multimodal_asr notebooks) ──
# Included inline so this notebook is self-contained and runnable on its own.

DEFAULT_CLIP_MODEL = "openai/clip-vit-base-patch32"
DEFAULT_WHISPER_MODEL = "openai/whisper-small"
DEFAULT_CACHE_PATH = Path("cache/clip_image_embeddings.npz")

def _pick_device():
    if torch.cuda.is_available():
        return "cuda"
    if torch.backends.mps.is_available():
        return "mps"
    return "cpu"

def load_clip(model_name=DEFAULT_CLIP_MODEL, device=None):
    if device is None:
        device = _pick_device()
    processor = CLIPProcessor.from_pretrained(model_name)
    model = CLIPModel.from_pretrained(model_name).to(device)
    model.eval()
    return model, processor, device

def _get_image_features(model, pixel_values):
    vision_out = model.vision_model(pixel_values=pixel_values)
    features = model.visual_projection(vision_out.pooler_output)
    return features / features.norm(dim=-1, keepdim=True)

def _get_text_features(model, input_ids, attention_mask):
    text_out = model.text_model(input_ids=input_ids, attention_mask=attention_mask)
    features = model.text_projection(text_out.pooler_output)
    return features / features.norm(dim=-1, keepdim=True)

def encode_texts(model, processor, texts, device):
    inputs = processor(text=texts, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        features = _get_text_features(model, inputs["input_ids"].to(device), inputs["attention_mask"].to(device))
    return features.cpu().numpy()

def load_cached_embeddings(cache_path):
    return dict(np.load(str(cache_path)))

FILLER_WORDS = frozenset({
    "um", "uh", "uh-huh", "hmm", "hm", "ah", "er", "oh",
    "like", "you know", "i mean", "okay", "ok", "so", "well",
})

def normalize_transcript(text):
    text = text.strip().lower()
    if not text:
        return text
    text = re.sub(r"\b(\w+)( \1\b)+", r"\1", text)
    words = text.split()
    cleaned, skip_next = [], False
    for i, w in enumerate(words):
        if skip_next:
            skip_next = False
            continue
        bigram = f"{w} {words[i + 1]}" if i + 1 < len(words) else ""
        if bigram in FILLER_WORDS:
            skip_next = True
            continue
        if w not in FILLER_WORDS:
            cleaned.append(w)
    return re.sub(r"\s+", " ", " ".join(cleaned)).strip()

def to_caption_style(text):
    text = normalize_transcript(text)
    if not text:
        return text
    return f"an image showing {text}" if len(text.split()) <= 2 else text

def _softmax(x):
    e = np.exp(x - x.max())
    return e / e.sum()


class MultimodalASR:
    """Whisper ASR with optional CLIP visual-context rescoring."""

    def __init__(self, whisper_model_id=DEFAULT_WHISPER_MODEL, clip_model_id=DEFAULT_CLIP_MODEL,
                 cache_path=DEFAULT_CACHE_PATH, alpha=0.3, num_beams=5, device=None):
        if device is None:
            device = _pick_device()
        self.device = device
        self.alpha = alpha
        self.num_beams = num_beams
        self.whisper_processor = WhisperProcessor.from_pretrained(whisper_model_id)
        self.whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_id).to(device)
        self.whisper_model.eval()
        self.clip_model, self.clip_processor, _ = load_clip(clip_model_id, device)
        self.image_embeddings = {}
        cache_path = Path(cache_path)
        if cache_path.exists():
            self.image_embeddings = load_cached_embeddings(cache_path)
            print(f"Loaded {len(self.image_embeddings)} cached image embeddings")

    def generate_nbest(self, audio_array, sr=16000):
        input_features = self.whisper_processor(audio_array, sampling_rate=sr, return_tensors="pt").input_features.to(self.device)
        with torch.no_grad():
            outputs = self.whisper_model.generate(input_features, num_beams=self.num_beams,
                num_return_sequences=self.num_beams, return_dict_in_generate=True, output_scores=True)
        seq_scores = outputs.sequences_scores.cpu().numpy()
        hypotheses, seen = [], set()
        for i, seq in enumerate(outputs.sequences):
            text = self.whisper_processor.decode(seq, skip_special_tokens=True).strip().lower()
            if text not in seen:
                seen.add(text)
                hypotheses.append({"text": text, "score": float(seq_scores[i])})
        hypotheses.sort(key=lambda h: h["score"], reverse=True)
        return hypotheses

    def clip_similarity(self, image_embedding, texts):
        if not texts:
            return np.array([])
        text_embs = encode_texts(self.clip_model, self.clip_processor, texts, self.device)
        return np.atleast_1d((image_embedding.reshape(1, -1) @ text_embs.T).squeeze())

    def rescore(self, hypotheses, image_embedding, alpha=None, caption_style=True):
        if alpha is None:
            alpha = self.alpha
        if not hypotheses:
            return hypotheses
        texts = [h["text"] for h in hypotheses]
        clip_texts = [to_caption_style(t) if caption_style else normalize_transcript(t) for t in texts]
        clip_scores = self.clip_similarity(image_embedding, clip_texts)
        asr_probs = _softmax(np.array([h["score"] for h in hypotheses]))
        clip_probs = _softmax(clip_scores * 100.0)
        fused = (1 - alpha) * asr_probs + alpha * clip_probs
        rescored = [{"text": h["text"], "asr_score": float(asr_probs[i]),
                      "clip_score": float(clip_scores[i]), "fused_score": float(fused[i])}
                     for i, h in enumerate(hypotheses)]
        rescored.sort(key=lambda h: h["fused_score"], reverse=True)
        return rescored

    def transcribe(self, audio_array, image_id=None, image_path=None, sr=16000, alpha=None):
        hypotheses = self.generate_nbest(audio_array, sr)
        if image_id is None and image_path is None:
            return {"transcription": hypotheses[0]["text"] if hypotheses else "", "hypotheses": hypotheses, "mode": "asr_only"}
        img_emb = self._resolve_image_embedding(image_id, image_path)
        if img_emb is None:
            return {"transcription": hypotheses[0]["text"] if hypotheses else "", "hypotheses": hypotheses, "mode": "asr_only",
                    "warning": f"Image not found (id={image_id}, path={image_path})"}
        rescored = self.rescore(hypotheses, img_emb, alpha)
        return {"transcription": rescored[0]["text"] if rescored else "", "hypotheses": rescored,
                "mode": "multimodal", "alpha": alpha if alpha is not None else self.alpha}

    def _resolve_image_embedding(self, image_id, image_path):
        if image_id and image_id in self.image_embeddings:
            return self.image_embeddings[image_id]
        if image_path:
            image_path = Path(image_path)
            if not image_path.exists():
                return None
            image = Image.open(image_path).convert("RGB")
            inputs = self.clip_processor(images=image, return_tensors="pt")
            with torch.no_grad():
                emb = _get_image_features(self.clip_model, inputs["pixel_values"].to(self.device))
            return emb.cpu().numpy().squeeze()
        return None

## 1) `tune_alpha` — the grid search function

### Input format

Each test sample is a dictionary with:

| Key | Type | Description |
|-----|------|------------|
| `audio_array` | `np.ndarray` | Raw waveform |
| `sr` | `int` | Sampling rate (default 16 000) |
| `reference` | `str` | Ground-truth transcription |
| `image_id` | `str` | Cached embedding key, e.g. `"img_001.png"` |

### Algorithm

1. **Pre-compute** (one-time cost): for each sample, run Whisper beam search to
   get *n*-best hypotheses, and look up the image embedding from cache.
2. **Sweep** α from 0.0 to 1.0 in steps of 0.05 (21 values). For each α:
   - Re-run the fusion formula on all pre-computed hypotheses
   - Pick the top hypothesis for each sample
   - Compute corpus-level WER against references
3. **Select** the α with the lowest WER.

In [None]:
def tune_alpha(
    pipeline: MultimodalASR,
    test_samples: list[dict],
    alpha_range: np.ndarray = DEFAULT_ALPHA_RANGE,
) -> dict:
    """Grid-search alpha on paired image-audio test samples."""

    # Step 1: pre-compute n-best hypotheses + image embeddings
    print(f"Pre-computing n-best hypotheses for {len(test_samples)} samples \u2026")
    precomputed: list[dict] = []

    for i, sample in enumerate(test_samples):
        hyps = pipeline.generate_nbest(
            sample["audio_array"], sample.get("sr", 16000)
        )
        img_emb = pipeline._resolve_image_embedding(
            sample.get("image_id"), sample.get("image_path")
        )
        precomputed.append({
            "hypotheses": hyps,
            "image_embedding": img_emb,
            "reference": sample["reference"].strip().lower(),
        })
        if (i + 1) % 10 == 0:
            print(f"  {i + 1}/{len(test_samples)} done")

    # Step 2: sweep alpha
    print(f"\nSweeping {len(alpha_range)} alpha values \u2026")
    grid_results: list[dict] = []

    for alpha in alpha_range:
        alpha = float(alpha)
        refs, hyps_texts = [], []

        for item in precomputed:
            ref = item["reference"]
            if item["image_embedding"] is not None:
                rescored = pipeline.rescore(
                    item["hypotheses"], item["image_embedding"], alpha
                )
                best = rescored[0]["text"] if rescored else ""
            else:
                best = (
                    item["hypotheses"][0]["text"] if item["hypotheses"] else ""
                )
            refs.append(ref)
            hyps_texts.append(best)

        alpha_wer = compute_wer(refs, hyps_texts)
        grid_results.append({"alpha": alpha, "wer": alpha_wer})
        print(f"  \u03b1 = {alpha:.2f}   WER = {alpha_wer * 100:.1f}%")

    best = min(grid_results, key=lambda r: r["wer"])
    baseline_wer = next(r["wer"] for r in grid_results if r["alpha"] == 0.0)

    summary = {
        "grid_results": grid_results,
        "best_alpha": best["alpha"],
        "best_wer": best["wer"],
        "baseline_wer": baseline_wer,
        "wer_reduction": baseline_wer - best["wer"],
        "num_samples": len(test_samples),
    }

    print(f"\n{'\u2500' * 50}")
    print(f"Baseline WER (\u03b1=0):  {baseline_wer * 100:.1f}%")
    print(f"Best WER:            {best['wer'] * 100:.1f}%  (\u03b1 = {best['alpha']:.2f})")
    print(f"Absolute reduction:  {(baseline_wer - best['wer']) * 100:.1f}%")

    return summary

## 2) Prepare test data

To run the grid search, you need **paired image-audio test data** — audio samples
where you know both the ground-truth transcription and which image was being described.

Below is a template that you can adapt to your test set. Replace the placeholder
with your actual paired samples.

```python
import soundfile as sf

test_samples = []
for audio_file, image_id, reference in your_test_pairs:
    audio_array, sr = sf.read(audio_file)
    test_samples.append({
        "audio_array": audio_array,
        "sr": sr,
        "reference": reference,
        "image_id": image_id,
    })
```

In [None]:
# Placeholder — replace with your actual paired test data
test_samples = []

# Example: uncomment and adapt
# import soundfile as sf
# pairs = [
#     ("../audio/torgo/processed/test/sample_00000.wav", "img_001.png", "a cat sleeping on a windowsill"),
#     ("../audio/torgo/processed/test/sample_00001.wav", "img_002.png", "a dog catching a ball in a park"),
# ]
# for audio_file, image_id, reference in pairs:
#     audio_array, sr = sf.read(audio_file)
#     test_samples.append({
#         "audio_array": audio_array, "sr": sr,
#         "reference": reference, "image_id": image_id,
#     })

## 3) Run the grid search

Initialise the pipeline and run `tune_alpha`. The output shows WER at each α
and highlights the optimum.

> **Note**: This cell requires `test_samples` to be populated above.

In [None]:
if test_samples:
    pipeline = MultimodalASR(alpha=0.0, num_beams=5)
    results = tune_alpha(pipeline, test_samples)

    # Save results
    output_path = Path("cache/fusion_tuning_results.json")
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nResults saved \u2192 {output_path}")
else:
    print("No test samples provided. Populate test_samples in the cell above to run tuning.")

## 4) Interpreting results

After running the grid search, you'll see output like:

```
α = 0.00   WER = 61.8%     ← baseline (pure ASR)
α = 0.05   WER = 60.2%
α = 0.10   WER = 58.5%
  ...
α = 0.25   WER = 55.1%     ← best
  ...
α = 1.00   WER = 89.3%     ← pure CLIP (too aggressive)
```

**What to look for:**
- The best α is usually in the **0.1–0.4 range** — visual context helps but
  shouldn't override strong ASR confidence.
- If the best α is 0.0, CLIP isn't helping. Consider:
  - Is the image-text alignment strong enough?
  - Does the transcript normalization need tuning?
  - Would fine-tuning CLIP on your domain help?
- Very high optimal α (>0.5) suggests the ASR model is struggling badly and
  visual context is doing most of the disambiguation work.

## Summary

This notebook:

1. Pre-computes Whisper beam hypotheses once (expensive)
2. Sweeps 21 α values from 0.0 to 1.0 (cheap)
3. Selects the α that minimises corpus-level WER
4. Saves the full grid results to JSON

Use the optimal α when deploying the `MultimodalASR` pipeline:

```python
pipeline = MultimodalASR(alpha=results["best_alpha"])
```