<a href="https://colab.research.google.com/github/Aditya-Shandilya1182/SpeculativeWhisper/blob/main/whisper_speculative_decoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Downloads And Imports

In [None]:
!pip install -q -U openai-whisper
!pip install -q -U jiwer
!wget -q https://www.openslr.org/resources/12/dev-clean.tar.gz
!tar -xf dev-clean.tar.gz

In [None]:
import whisper
import time
import torch
import os
import re
import torch.nn.functional as F
from dataclasses import dataclass
from jiwer import wer

# Speculative Decoding

In [None]:
class SpeculativeWhisper:
    def __init__(self, config):
        self.device = config.device
        self.draft = whisper.load_model(config.draft_model).to(self.device)
        self.final = whisper.load_model(config.final_model).to(self.device)
        self.tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True)
        self.k = config.k
        self.max_tokens = config.max_tokens
        self.mel_dim_tiny = config.mel_dim_tiny
        self.mel_dim_large = config.mel_dim_large
        self.beam_search = config.beam_search
        self.beam_size = getattr(config, "beam_size", 5)
        self.top_p = getattr(config, "top_p", None)

    def topp_sample(self, logits):
        if self.top_p is not None:
            probs = F.softmax(logits, dim=-1)
            sorted_probs, sorted_idx = torch.sort(probs, descending=True)
            cumsum = torch.cumsum(sorted_probs, dim=-1)
            mask = cumsum > self.top_p
            mask[..., 1:] = mask[..., :-1].clone()
            mask[..., 0] = False
            sorted_probs[mask] = 0.0
            sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
            idx = torch.multinomial(sorted_probs, 1)
            return sorted_idx.gather(-1, idx).squeeze(-1)
        return torch.argmax(logits, dim=-1)

    def _beam_search(self, tmp, draft_encoder, active_indices):
        beams = [(tmp, torch.zeros(tmp.size(0), device=self.device))]
        for _ in range(self.k):
            new_beams = []
            for seq, score in beams:
                logits = self.draft.decoder(seq, draft_encoder[active_indices])[:, -1]
                logp = F.log_softmax(logits, dim=-1)
                topk_logp, topk_idx = torch.topk(logp, self.beam_size, dim=-1)
                for b in range(self.beam_size):
                    nt = topk_idx[:, b]
                    ns = torch.cat([seq, nt[:, None]], dim=1)
                    new_beams.append((ns, score + topk_logp[:, b]))
            beams = sorted(new_beams, key=lambda x: x[1].sum().item(), reverse=True)[:self.beam_size]
        best = beams[0][0]
        draft = []
        for i in range(self.k):
            draft.append(best[:, -(self.k - i)])
        return torch.stack(draft, dim=1), best

    def decode(self, draft_encoder, final_encoder, max_tokens=None):
        max_tokens = max_tokens or self.max_tokens
        batch = final_encoder.size(0)
        tokens = [torch.full((1,), self.tokenizer.sot, device=self.device, dtype=torch.long) for _ in range(batch)]
        done = torch.zeros(batch, dtype=torch.bool, device=self.device)

        for _ in range(max_tokens):
            active_indices = (~done).nonzero(as_tuple=True)[0]
            if len(active_indices) == 0:
                break

            tmp = torch.nn.utils.rnn.pad_sequence(
                [tokens[i] for i in active_indices],
                batch_first=True,
                padding_value=self.tokenizer.sot,
            )

            with torch.no_grad():
                if self.beam_search:
                    draft, tmp = self._beam_search(tmp, draft_encoder, active_indices)
                else:
                    draft_list = []
                    for _ in range(self.k):
                        logits = self.draft.decoder(tmp, draft_encoder[active_indices])[:, -1]
                        next_tok = self.topp_sample(logits)
                        draft_list.append(next_tok)
                        tmp = torch.cat([tmp, next_tok[:, None]], dim=1)
                    draft = torch.stack(draft_list, dim=1)

            verify = torch.cat([tmp[:, :-draft.size(1)], draft[:, :-1]], dim=1)

            with torch.no_grad():
                logits = self.final.decoder(verify, final_encoder[active_indices])
                logp = F.log_softmax(logits, dim=-1)

            for idx, seq_idx in enumerate(active_indices):
                accepted = 0
                base = tokens[seq_idx].size(0) - 1
                for i in range(draft.size(1)):
                    pred = torch.argmax(logp[idx, base + i], dim=-1)
                    if pred == draft[idx, i]:
                        accepted += 1
                    else:
                        break

                if accepted > 0:
                    tokens[seq_idx] = torch.cat([tokens[seq_idx], draft[idx, :accepted]], dim=0)

                if accepted < draft.size(1):
                    pos = tokens[seq_idx].size(0) - 1
                    fb = self.topp_sample(logp[idx, pos])
                    tokens[seq_idx] = torch.cat([tokens[seq_idx], fb.unsqueeze(0)], dim=0)

                done[seq_idx] = tokens[seq_idx][-1] == self.tokenizer.eot

        return tokens

    def transcribe(self, audio_files, max_tokens=None):
        max_tokens = max_tokens or self.max_tokens
        audios = []

        for p in audio_files:
            a = whisper.load_audio(p)
            a = whisper.pad_or_trim(a)
            audios.append(torch.from_numpy(a))

        audios = torch.stack(audios).to(self.device)
        mel_tiny = torch.stack([whisper.log_mel_spectrogram(a, self.mel_dim_tiny) for a in audios]).to(self.device)
        mel_large = torch.stack([whisper.log_mel_spectrogram(a, self.mel_dim_large) for a in audios]).to(self.device)

        with torch.no_grad():
            draft_encoder = self.draft.encoder(mel_tiny)
            final_encoder = self.final.encoder(mel_large)

        batch_tokens = self.decode(draft_encoder, final_encoder, max_tokens)

        return [self.tokenizer.decode(t.tolist()) for t in batch_tokens]


# Utils

In [None]:
def load_references(audio_files):
    refs = []
    for path in audio_files:
        base = os.path.basename(path).replace(".flac", "")
        chapter_dir = os.path.dirname(path)
        speaker_id = os.path.basename(os.path.dirname(chapter_dir))
        chapter_id = os.path.basename(chapter_dir)
        txt_path = os.path.join(chapter_dir, f"{speaker_id}-{chapter_id}.trans.txt")
        with open(txt_path) as f:
            for line in f:
                if line.startswith(base):
                    refs.append(line.strip().split(" ", 1)[1].lower())
                    break
    return refs

In [None]:
def clean_whisper_output(text):
    text = re.sub(r"<\|.*?\|>", "", text)
    return text.strip()


# Config

In [None]:
@dataclass
class Config:
    draft_model: str = "tiny"
    final_model: str = "large-v3"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    k: int = 5
    beam_search: bool = False
    beam_size: int = 4
    top_p: float | None = 0.9
    max_tokens: int = 200
    mel_dim_tiny: int = 80
    mel_dim_large: int = 128

# Inference

In [None]:
audio_files = [
    "LibriSpeech/dev-clean/84/121123/84-121123-0000.flac",
    "LibriSpeech/dev-clean/84/121123/84-121123-0001.flac",
    "LibriSpeech/dev-clean/84/121123/84-121123-0002.flac",
    "LibriSpeech/dev-clean/84/121123/84-121123-0003.flac",
    "LibriSpeech/dev-clean/84/121123/84-121123-0004.flac",
]

references = load_references(audio_files)

In [None]:
sw = SpeculativeWhisper(Config())

spec_times = []
spec_outputs = []

for path in audio_files:
    start = time.time()
    out = sw.transcribe([path], max_tokens=100)[0]
    torch.cuda.synchronize()
    t = time.time() - start

    spec_times.append(t)
    spec_outputs.append(out)

spec_preds = [clean_whisper_output(o) for o in spec_outputs]
spec_wer = wer(references, spec_preds)

print("\n--- Speculative Transcription Outputs ---")

for i, out in enumerate(spec_preds):
    print(f"\n[{i}] {out}")

print("\nSpeculative Whisper")
print(f"Avg latency/sample: {sum(spec_times)/len(spec_times):.4f}s")
print(f"Min latency: {min(spec_times):.4f}s")
print(f"Max latency: {max(spec_times):.4f}s")
print(f"Total time: {sum(spec_times):.2f}s")
print(f"WER: {spec_wer:.4f}")

In [None]:
del sw
torch.cuda.empty_cache()
torch.cuda.synchronize()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
vanilla_model = whisper.load_model("large-v3").to(device)

vanilla_times = []
vanilla_outputs = []

for path in audio_files:
    start = time.time()
    r = vanilla_model.transcribe(path, language="en")
    torch.cuda.synchronize()
    t = time.time() - start

    vanilla_times.append(t)
    vanilla_outputs.append(r["text"])

vanilla_wer = wer(references, vanilla_outputs)

print("\n--- Vanilla Transcription Outputs ---")

for i, out in enumerate(vanilla_outputs):
    print(f"\n[{i}] {out}")

print("\nVanilla Whisper Large-V3")
print(f"Avg latency/sample: {sum(vanilla_times)/len(vanilla_times):.4f}s")
print(f"Min latency: {min(vanilla_times):.4f}s")
print(f"Max latency: {max(vanilla_times):.4f}s")
print(f"Total time: {sum(vanilla_times):.2f}s")
print(f"WER: {vanilla_wer:.4f}")