In [1]:
! pip install git+https://github.com/openai/whisper.git
! pip install jiwer

Collecting git+https://github.com/openai/whisper.git
  Cloning https://github.com/openai/whisper.git to /private/var/folders/h0/lym4t3rj7pgdbfpw64mhfr700000gn/T/pip-req-build-d1s1peqt
  Running command git clone --filter=blob:none --quiet https://github.com/openai/whisper.git /private/var/folders/h0/lym4t3rj7pgdbfpw64mhfr700000gn/T/pip-req-build-d1s1peqt
  Resolved https://github.com/openai/whisper.git to commit c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone


In [2]:
import os
import numpy as np

try:
    import tensorflow  # required in Colab to avoid protobuf compatibility issues
except ImportError:
    pass

import torch
import pandas as pd
import whisper
import torchaudio

from tqdm.notebook import tqdm


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
class LibriSpeech(torch.utils.data.Dataset):
    """
    A simple class to wrap LibriSpeech and trim/pad the audio to 30 seconds.
    It will drop the last few seconds of a very small portion of the utterances.
    """
    def __init__(self, split="test-clean", device=DEVICE):
        self.dataset = torchaudio.datasets.LIBRISPEECH(
            root=os.path.expanduser("~/.cache"),
            url=split,
            download=True,
        )
        self.device = device

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        audio, sample_rate, text, _, _, _ = self.dataset[item]
        assert sample_rate == 16000
        audio = whisper.pad_or_trim(audio.flatten()).to(self.device)
        mel = whisper.log_mel_spectrogram(audio)
        
        return (mel, text)

In [4]:
dataset = LibriSpeech("test-clean")
loader = torch.utils.data.DataLoader(dataset, batch_size=16)

## Running inference on the dataset using a base Whisper model
#### The following will take a few minutes to transcribe all utterances in the dataset.

In [5]:
model = whisper.load_model("base.en")
print(
    f"Model is {'multilingual' if model.is_multilingual else 'English-only'} "
    f"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters."
)

Model is English-only and has 71,825,408 parameters.


In [6]:
# predict without timestamps for short-form transcription
options = whisper.DecodingOptions(language="en", without_timestamps=True)

In [7]:
!pip install ipywidgets



In [8]:
pip install "openai-whisper==20231117" "torch==2.1.*" "torchaudio==2.1.*" soundfile==0.12.1

Collecting openai-whisper==20231117
  Using cached openai-whisper-20231117.tar.gz (798 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[31mERROR: Could not find a version that satisfies the requirement torch==2.1.* (from versions: 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0, 2.4.1, 2.5.0, 2.5.1, 2.6.0, 2.7.0, 2.7.1, 2.8.0)[0m[31m
[0m[31mERROR: No matching distribution found for torch==2.1.*[0m[31m
[0m[?25hNote: you may need to restart the kernel to use updated packages.


In [9]:
hypotheses = []
references = []

for mels, texts in tqdm(loader):
    results = model.decode(mels, options)
    hypotheses.extend([result.text for result in results])
    references.extend(texts)

  0%|          | 0/164 [00:00<?, ?it/s]



In [10]:
data = pd.DataFrame(dict(hypothesis=hypotheses, reference=references))
data

Unnamed: 0,hypothesis,reference
0,"He hoped there would be stew for dinner, turni...",HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP...
1,"Stuffered into you, his belly counseled him.",STUFF IT INTO YOU HIS BELLY COUNSELLED HIM
2,After early nightfall the yellow lamps would l...,AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L...
3,"Hello Bertie, any good in your mind?",HELLO BERTIE ANY GOOD IN YOUR MIND
4,Number 10. Fresh Nelly is waiting on you. Good...,NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ...
...,...,...
2615,"Oh, to shoot my soul's full meaning into futur...",OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE...
2616,"Then I, long tried by natural ills, received t...",THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE...
2617,I love thee freely as men strive for right. I ...,I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L...
2618,"I love thee with the passion put to use, in my...",I LOVE THEE WITH THE PASSION PUT TO USE IN MY ...


# Calculating the word error rate
#### Now, we use our English normalizer implementation to standardize the transcription and calculate the WER.

In [11]:
import jiwer
from whisper.normalizers import EnglishTextNormalizer

normalizer = EnglishTextNormalizer()

In [12]:
data["hypothesis_clean"] = [normalizer(text) for text in data["hypothesis"]]
data["reference_clean"] = [normalizer(text) for text in data["reference"]]
data

Unnamed: 0,hypothesis,reference,hypothesis_clean,reference_clean
0,"He hoped there would be stew for dinner, turni...",HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP...,he hoped there would be stew for dinner turnip...,he hoped there would be stew for dinner turnip...
1,"Stuffered into you, his belly counseled him.",STUFF IT INTO YOU HIS BELLY COUNSELLED HIM,stuffered into you his belly counseled him,stuff it into you his belly counseled him
2,After early nightfall the yellow lamps would l...,AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L...,after early nightfall the yellow lamps would l...,after early nightfall the yellow lamps would l...
3,"Hello Bertie, any good in your mind?",HELLO BERTIE ANY GOOD IN YOUR MIND,hello bertie any good in your mind,hello bertie any good in your mind
4,Number 10. Fresh Nelly is waiting on you. Good...,NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ...,number 10 fresh nelly is waiting on you good n...,number 10 fresh nelly is waiting on you good n...
...,...,...,...,...
2615,"Oh, to shoot my soul's full meaning into futur...",OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE...,0 to shoot my soul is full meaning into future...,0 to shoot my soul is full meaning into future...
2616,"Then I, long tried by natural ills, received t...",THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE...,then i long tried by natural ills received the...,then i long tried by natural ills received the...
2617,I love thee freely as men strive for right. I ...,I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L...,i love thee freely as men strive for right i l...,i love thee freely as men strive for right i l...
2618,"I love thee with the passion put to use, in my...",I LOVE THEE WITH THE PASSION PUT TO USE IN MY ...,i love thee with the passion put to use in my ...,i love thee with the passion put to use in my ...


In [13]:
wer = jiwer.wer(list(data["reference_clean"]), list(data["hypothesis_clean"]))

print(f"WER: {wer * 100:.2f} %")

WER: 4.28 %


# Loading the Fleurs dataset
#### Select the language of the Fleur dataset to download. Please note that the transcription and translation performance varies widely depending on the language. Appendix D.2 in the paper contains the performance breakdown by language.


In [14]:
import io
import os
import numpy as np

try:
    import tensorflow  # required in Colab to avoid protobuf compatibility issues
except ImportError:
    pass

import torch
import pandas as pd
import urllib
import tarfile
import whisper
import torchaudio

from scipy.io import wavfile
from tqdm.notebook import tqdm


pd.options.display.max_rows = 100
pd.options.display.max_colwidth = 1000
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [42]:
import ipywidgets as widgets

languages = {"af_za": "Afrikaans", "am_et": "Amharic", "ar_eg": "Arabic", "as_in": "Assamese", "az_az": "Azerbaijani", "be_by": "Belarusian", "bg_bg": "Bulgarian", "bn_in": "Bengali", "bs_ba": "Bosnian", "ca_es": "Catalan", "cmn_hans_cn": "Chinese", "cs_cz": "Czech", "cy_gb": "Welsh", "da_dk": "Danish", "de_de": "German", "el_gr": "Greek", "en_us": "English", "es_419": "Spanish", "et_ee": "Estonian", "fa_ir": "Persian", "fi_fi": "Finnish", "fil_ph": "Tagalog", "fr_fr": "French", "gl_es": "Galician", "gu_in": "Gujarati", "ha_ng": "Hausa", "he_il": "Hebrew", "hi_in": "Hindi", "hr_hr": "Croatian", "hu_hu": "Hungarian", "hy_am": "Armenian", "id_id": "Indonesian", "is_is": "Icelandic", "it_it": "Italian", "ja_jp": "Japanese", "jv_id": "Javanese", "ka_ge": "Georgian", "kk_kz": "Kazakh", "km_kh": "Khmer", "kn_in": "Kannada", "ko_kr": "Korean", "lb_lu": "Luxembourgish", "ln_cd": "Lingala", "lo_la": "Lao", "lt_lt": "Lithuanian", "lv_lv": "Latvian", "mi_nz": "Maori", "mk_mk": "Macedonian", "ml_in": "Malayalam", "mn_mn": "Mongolian", "mr_in": "Marathi", "ms_my": "Malay", "mt_mt": "Maltese", "my_mm": "Myanmar", "nb_no": "Norwegian", "ne_np": "Nepali", "nl_nl": "Dutch", "oc_fr": "Occitan", "pa_in": "Punjabi", "pl_pl": "Polish", "ps_af": "Pashto", "pt_br": "Portuguese", "ro_ro": "Romanian", "ru_ru": "Russian", "sd_in": "Sindhi", "sk_sk": "Slovak", "sl_si": "Slovenian", "sn_zw": "Shona", "so_so": "Somali", "sr_rs": "Serbian", "sv_se": "Swedish", "sw_ke": "Swahili", "ta_in": "Tamil", "te_in": "Telugu", "tg_tj": "Tajik", "th_th": "Thai", "tr_tr": "Turkish", "uk_ua": "Ukrainian", "ur_pk": "Urdu", "uz_uz": "Uzbek", "vi_vn": "Vietnamese", "yo_ng": "Yoruba"}
selection = widgets.Dropdown(
    options=[("Select language", None), ("----------", None)] + sorted([(f"{v} ({k})", k) for k, v in languages.items()]),
    value="hi_in",
    description='Language:',
    disabled=False,
)

selection

Dropdown(description='Language:', index=29, options=(('Select language', None), ('----------', None), ('Afrika…

In [43]:
lang = selection.value
language = languages[lang]

assert lang is not None, "Please select a language"
print(f"Selected language: {language} ({lang})")

Selected language: Hindi (hi_in)


In [44]:
def download(url: str, target_path: str):
    with urllib.request.urlopen(url) as source, open(target_path, "wb") as output:
        with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))


class Fleurs(torch.utils.data.Dataset):
    """
    A simple class to wrap Fleurs and subsample a portion of the dataset as needed.
    """
    def __init__(self, lang, split="test", subsample_rate=1, device=DEVICE):
        url = f"https://storage.googleapis.com/xtreme_translations/FLEURS102/{lang}.tar.gz"
        tar_path = os.path.expanduser(f"~/.cache/fleurs/{lang}.tgz")
        os.makedirs(os.path.dirname(tar_path), exist_ok=True)

        if not os.path.exists(tar_path):
            download(url, tar_path)

        all_audio = {}
        with tarfile.open(tar_path, "r:gz") as tar:
            for member in tar.getmembers():
                name = member.name
                if name.endswith(f"{split}.tsv"):
                    labels = pd.read_table(tar.extractfile(member), names=("id", "file_name", "raw_transcription", "transcription", "_", "num_samples", "gender"))

                if f"/{split}/" in name and name.endswith(".wav"):
                    audio_bytes = tar.extractfile(member).read()
                    all_audio[os.path.basename(name)] = wavfile.read(io.BytesIO(audio_bytes))[1]                    

        self.labels = labels.to_dict("records")[::subsample_rate]
        self.all_audio = all_audio
        self.device = device

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, item):
        record = self.labels[item]
        audio = torch.from_numpy(self.all_audio[record["file_name"]].copy())
        text = record["transcription"]
        
        return (audio, text)

In [45]:
dataset = Fleurs(lang, subsample_rate=10)  # subsample 10% of the dataset for a quick demo

## Running inference on the dataset using a medium Whisper model
#### The following will take a few minutes to transcribe and translate utterances in the dataset.

In [46]:
model = whisper.load_model("medium")
print(
    f"Model is {'multilingual' if model.is_multilingual else 'English-only'} "
    f"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters."
)

Model is multilingual and has 762,321,920 parameters.


In [47]:
options = dict(language=language, beam_size=5, best_of=5)
transcribe_options = dict(task="transcribe", **options)
translate_options = dict(task="translate", **options)

In [48]:
references = []
transcriptions = []
translations = []

for audio, text in tqdm(dataset):
    transcription = model.transcribe(audio, **transcribe_options)["text"]
    translation = model.transcribe(audio, **translate_options)["text"]
    
    transcriptions.append(transcription)
    translations.append(translation)
    references.append(text)

  0%|          | 0/42 [00:00<?, ?it/s]



In [49]:
data = pd.DataFrame(dict(reference=references, transcription=transcriptions, translation=translations))
data

Unnamed: 0,reference,transcription,translation
0,स्कीइंग मार्ग को एक हाईकिंग लंबी पैदल यात्रा मार्ग जैसा ही सोचें।,अस्कीन्मार्क को एक हाईकिन् लंबी पैदल यात्रा मार्ग जैसा ही सोचें।,Think of skiing as a long hiking trail.
1,अधिकांश छोटे द्वीप स्वतंत्र राष्ट्र हैं या फ़्रांस से संबंधित हैं और लग्ज़री बीच रिसॉर्ट के रूप में जाने जाते हैं,अधिकाश चोटे दे पस्वतंत्र राष्टर हैं या फ्रांस से सम्बंदित हैं और लक्स्री बीट्स रिसॉर्ट के रूप में जाने जाते हैं.,Most of them are related to the small and independent state or France and are known as Luxury Beach Resort.
2,तूफान और बवंडर की तरह आंधी ओले भारी बारिश और जंगल की आग तीव्र मौसम का हिस्सा और असर हैं,"तूफान और भवंदर की तरा आंधी, आूले, भारी, बारिश और जंगल की आक तीवर मौसम का हिस्सा और असर है।","Like storms and storms, winds, storms, heavy rains and forest fires are part and parcel of severe weather."
3,महिलाएं यह अनुशंसा की जाती है कि कोई भी महिला यात्री वास्तविक वैवाहिक स्थिति के बावजूद कहती है कि वह विवाहित है,महिलाये या अनुश्रणशा की जाती है कि कोई भी महिलायाइादरी वास्तविक विवाई क्षिति के बावजुट कहती है कि वो विवाईत है।,Women are said to be married despite the fact that they are married.
4,वाइल्डलाइफ़ हैबिटेंट्स के रूप में काम करने वाली रेती और तटों को बनाने के लिए गाद ज़रूरी थी,Wildlife habitants के रूप में काम करने वाली रेती और तटों को बनाने के लिए गाद जुरूडी थी।,It was necessary to build a guard to work as wildlife habitants.
5,बाली में इस एजेंडे के अन्य विषयों में दुनिया के बचे हुए जंगलों को बचाने और ऐसी तकनीकों का आदान प्रदान करना का विषय शामिल है जिससे कि विकासशील देशों को कम प्रदूषणकारी तरीकों से आगे बढ़ने में मदद मिले,"बाली में इस अजेंडे के अन्य विष्यों में दुन्या के बचे हुए जंगलों को बचाने और ऐसी तेक्नीकों का आदान पर्दान करने का विष्य शामिल हैं, जिससे कि विकास जीर देशों को कम परदुषन कारी तरीकों से आगे बरने में मदद मेले.","In this agenda, Bali is involved in protecting the remaining forests of the world and providing solutions to such techniques which will help developing countries to move ahead in less polluting ways."
6,1889 में यह बंदरगाह कुख्यात नौसैनिक गतिरोध का ठिकाना था उस समय जर्मनी अमेरिका और ब्रिटेन के सात जहाजों ने इस बंदरगाह से जाने से इनकार कर दिया था,"1889 में यह बंदर्गाख युक्यात नोसेनिक गतिरोत का थिकाना था। उस समय जर्मनी, अमेरिका और बिटन के साथ जहाजों ने इस बंदर्गाख से जाने से इनकार कर दिया था।","In 1889, this was the place of the 9th World War, during which Germany, America and Britain refused to leave this place."
7,सन 1976 तक माचू पिचू के तीस प्रतिशत हिस्से का जीर्णोद्धार कर दिया गया था और जीर्णोद्धार का कार्य आज तक जारी है,सं 1976 तक माचु पिचु के 30% हिस्से का जिर्नोदार कर दिया घया था और जिर्नोदार का कारे आस तक जारी है।,"In 1976, 30% of Machu Picchu's share of Jirnodhar was done and the work of Jirnodhar is still going on."
8,ms बीमारी केंद्रीय तंत्रिका तंत्र पर असर करती है जिसमें दिमाग स्पाइनल कॉर्ड और ऑप्टिक नर्व शामिल हैं,"MS बिमारी केंद्रिय तंत्र का तंत्र पर असर करती है, जिसमें दिमाक, स्पाइनल कौड और अप्टिक नर्व शामिल है।","The MS disease center affects the immune system, which includes the brain, spinal cord and optic nerve."
9,समाजीकरण के महत्व को स्पष्ट करने के लिए इस्तेमाल किए जाने वाले सबसे आम तरीकों में से कुछ बच्चों के दुर्भाग्यपूर्ण मामलों को आकर्षित करना है जो बड़े होने के दौरान वयस्कों द्वारा उपेक्षित नहीं बल्कि उपेक्षा दुर्भाग्य या दुर्व्यवहार के माध्यम से होते थे,"समाजी करन के महत्तो को इस्पष्ट करने के लिए इस्तमाल किया जाने वाले सबसे आम तरीकों में से कुछ बच्चों के दृर्भाग्य पुर्न मामलों को आकर्षित करना हैं, जो बड़े होने के दौरान वैसको दौरा उपेक्छित नहीं बल्कि उपेक्छा दृर्भाग्य या दृर्वेवार के माद्यम से होते थे।","The most common methods used to clarify the importance of socialization are to attract the unfortunate events of some children, which were not avoided by adults, but were avoided through misfortune or misbehavior."


# Word-level timestamps using attention weights
### Below, we use the cross-attention weights to determine more granular, word-level timestamps.It uses a set of heuristics and dynamic time warping (DTW) to find the alignment between the audio and the transcript.

In [50]:
!pip install dtw-python



In [52]:
!pip install matplotlib

Collecting matplotlib
  Downloading matplotlib-3.10.5-cp312-cp312-macosx_11_0_arm64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.3-cp312-cp312-macosx_11_0_arm64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib)
  Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.59.2-cp312-cp312-macosx_10_13_universal2.whl.metadata (109 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Downloading kiwisolver-1.4.9-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.3 kB)
Collecting pillow>=8 (from matplotlib)
  Downloading pillow-11.3.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (9.0 kB)
Collecting pyparsing>=2.3.1 (from matplotlib)
  Downloading pyparsing-3.2.3-py3-none-any.whl.metadata (5.0 kB)
Downloading matplotlib-3.10.5-cp312-cp312-macosx_11_0_arm64.whl (8.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.1/8.1 MB[0m [31m62.8 MB/s[

In [53]:
import string
import matplotlib as plt
import matplotlib.font_manager as fm
import matplotlib.ticker as ticker

from IPython.display import display, HTML
from whisper.tokenizer import get_tokenizer
from dtw import dtw
from scipy.ndimage import median_filter

%matplotlib inline
%config InlineBackend.figure_format = "retina"

Matplotlib is building the font cache; this may take a moment.


Importing the dtw module. When using in academic works please cite:
  T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.
  J. Stat. Soft., doi:10.18637/jss.v031.i07.



In [54]:
AUDIO_SAMPLES_PER_TOKEN = whisper.audio.HOP_LENGTH * 2
AUDIO_TIME_PER_TOKEN = AUDIO_SAMPLES_PER_TOKEN / whisper.audio.SAMPLE_RATE

medfilt_width = 7
qk_scale = 1.0

tokenizer = get_tokenizer(model.is_multilingual, language=languages[lang])

In [74]:
# This part downloads a repackaged version of the Noto Sans font (either CJK or non-CJK)
# to render various languages in Matplotlib figures.

if languages[lang] in {"Hindi", "Japanese", "Korean"}:
    font = "GoNotoCJKCore.ttf"
else:
    font = "GoNotoCurrent.ttf"

font_release = "https://github.com/satbyy/go-noto-universal/releases/download/v5.2"
if not os.path.exists(font):
    download(f"{font_release}/{font}", font)

prop = fm.FontProperties(fname=font)
props = {'fontproperties': prop}

  0%|                                              | 0.00/17.0M [00:00<?, ?iB/s]

In [75]:
def split_tokens_on_unicode(tokens: torch.Tensor):
    words = []
    word_tokens = []
    current_tokens = []
    
    for token in tokens.tolist():
        current_tokens.append(token)
        decoded = tokenizer.decode_with_timestamps(current_tokens)
        if "\ufffd" not in decoded:
            words.append(decoded)
            word_tokens.append(current_tokens)
            current_tokens = []
    
    return words, word_tokens

In [76]:
def split_tokens_on_spaces(tokens: torch.Tensor):
    subwords, subword_tokens_list = split_tokens_on_unicode(tokens)
    words = []
    word_tokens = []
    
    for subword, subword_tokens in zip(subwords, subword_tokens_list):
        special = subword_tokens[0] >= tokenizer.eot
        with_space = subword.startswith(" ")
        punctuation = subword.strip() in string.punctuation
        if special or with_space or punctuation:
            words.append(subword)
            word_tokens.append(subword_tokens)
        else:
            words[-1] = words[-1] + subword
            word_tokens[-1].extend(subword_tokens)
    
    return words, word_tokens

In [77]:
if languages[lang] in {"Hindi", "Japanese", "Thai", "Lao", "Myanmar"}:
    # These languages don't typically use spaces, so it is difficult to split words
    # without morpheme analysis. Here, we instead split words at any
    # position where the tokens are decoded as valid unicode points
    split_tokens = split_tokens_on_unicode
else:
    split_tokens = split_tokens_on_spaces

In [78]:
# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer

for i, block in enumerate(model.decoder.blocks):
    block.cross_attn.register_forward_hook(
        lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1])
    )

In [85]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(audio)).to(DEVICE)
# ...
with torch.no_grad():
    logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))

model = whisper.load_model("medium", device=DEVICE)   # or "small"/"base" for faster CPU runs


In [86]:
import whisper

# 20 ms per timestamp token
AUDIO_TIME_PER_TOKEN = 0.02

# Back-compatible way to get samples per token
try:
    # older/internal builds sometimes expose this or HOP_LENGTH
    from whisper.audio import SAMPLES_PER_TOKEN as AUDIO_SAMPLES_PER_TOKEN
except Exception:
    # 16 kHz * 0.02 s = 320 samples per timestamp token
    AUDIO_SAMPLES_PER_TOKEN = int(whisper.audio.SAMPLE_RATE * AUDIO_TIME_PER_TOKEN)

In [91]:
import torch
import numpy as np
from transformers import WhisperProcessor, WhisperForConditionalGeneration

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
hf_name = "openai/whisper-medium"

processor = WhisperProcessor.from_pretrained(hf_name)
model_hf  = WhisperForConditionalGeneration.from_pretrained(hf_name).to(DEVICE).eval()

# If you already have the log-mel spectrogram from OpenAI whisper as `mel` (np.float32, shape [80, T]):
# otherwise you can get it from the raw audio with: processor(audio, sampling_rate=16000, return_tensors="pt").input_features
input_features = torch.from_numpy(mel).unsqueeze(0).to(DEVICE)   # [1, 80, T]

# Build decoder inputs (prompt + the text we want to align)
forced_prompt = processor.get_decoder_prompt_ids(language="en", task="transcribe")
text_ids      = processor.tokenizer(transcription, add_special_tokens=False)["input_ids"]
decoder_input_ids = torch.tensor([tid for _, tid in forced_prompt] + text_ids)[None, :].to(DEVICE)

with torch.no_grad():
    out = model_hf(
        input_features=input_features,
        decoder_input_ids=decoder_input_ids,
        output_attentions=True,
        return_dict=True,
        use_cache=False,   # avoid caching to save memory
    )

cross = out.cross_attentions  # list of length n_layers; each tensor: [batch=1, heads, tgt_len, src_len]
print(len(cross), cross[0].shape)


TypeError: expected np.ndarray (got Tensor)

In [92]:
import torch, numpy as np
from transformers import WhisperProcessor, WhisperForConditionalGeneration

# ----- config -----
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "openai/whisper-medium"   # same as you loaded
SAMPLE_RATE = 16000                     # whisper expects 16 kHz
AUDIO_HOP_S = 0.02                      # ~20 ms per frame for the HF Whisper frontend

processor = WhisperProcessor.from_pretrained(MODEL_NAME)
model_hf   = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME).to(DEVICE).eval()

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

with torch.no_grad():
    out = model(
        input_features=input_features.to(DEVICE),
        decoder_input_ids=decoder_input_ids.to(DEVICE),
        output_attentions=True,
        return_dict=True,
    )

# Collect cross-attentions
cross = out.cross_attentions                # list of [batch, heads, tgt, src]
if cross is None:
    raise RuntimeError("No attentions returned — check output_attentions=True")

# Stack and move to CPU
cross = torch.stack(cross, dim=0)[:, 0]     # [layers, heads, tgt, src]
weights = cross.reshape(-1, cross.size(-2), cross.size(-1)).cpu()

# audio: 1-D float32 np.ndarray at 16 kHz (your variable)
# transcription: the text string you want to align against (from your decode step)

# 1) Build log-mel via the processor
#    (If you already have mel= [80,T] float32, you can pass it via "input_features")
inputs = processor.feature_extractor(
    audio, sampling_rate=SAMPLE_RATE, return_tensors="pt"
)
input_features = inputs["input_features"].to(DEVICE)   # [B=1, 80, T]

# 2) Build decoder input tokens that include timestamp window around this clip.
#    A simple, language-agnostic prompt: <|startoftranscript|><|notimestamps|>
forced_prompt = processor.get_decoder_prompt_ids(language=None, task="transcribe")
# Pack decoder inputs: prompt + the text we want to align
with processor.as_target_tokenizer():
    text_tokens = processor.tokenizer(transcription, add_special_tokens=False)["input_ids"]

decoder_input_ids = torch.tensor(
    [[tid for _, tid in forced_prompt] + text_tokens],
    dtype=torch.long, device=DEVICE
)

# 3) Forward with attentions ON
with torch.no_grad():
    out = model_hf(
        input_features=input_features,
        decoder_input_ids=decoder_input_ids,
        output_attentions=True,
        return_dict=True,
    )

# 4) Cross-attentions: list (n_layers) of [batch, n_heads, tgt_len, src_len]
cross = out.cross_attentions                    # length = n_layers
# stack -> [layers, batch, heads, tgt, src] -> select batch 0 -> [layers, heads, tgt, src]
cross = torch.stack(cross, dim=0)[:, 0]
# merge layers and heads: [layers*heads, tgt, src]
weights = cross.reshape(-1, cross.size(-2), cross.size(-1)).cpu()  # float32

# (Optional) smooth a bit along heads*layers
def median_filter(x, k=1):
    if k <= 0: 
        return x
    # simple 1D median filter over the first dim (layers*heads)
    x_np = x.numpy()
    import scipy.ndimage as ndi
    x_np = ndi.median_filter(x_np, size=(1,1,1))  # change if you want different smoothing
    return torch.from_numpy(x_np)

weights = median_filter(weights, k=1)

# 5) Map token indices -> frame indices with a simple "argmax over source frames"
#    weights: [L*H, T_tgt, T_src]; average over heads/layers:
w = weights.mean(dim=0)                          # [T_tgt, T_src]
# normalize over src (frames)
w = w / (w.sum(dim=-1, keepdim=True) + 1e-8)
# For each token, the most-attended source frame
token_to_frame = torch.argmax(w, dim=-1).numpy()  # [T_tgt] ints

# Convert frames to seconds (roughly 20 ms per frame)
frame_times_s = token_to_frame * AUDIO_HOP_S

# Now you have a per-token time; to get word-level times:
tokens = processor.tokenizer.convert_ids_to_tokens(text_tokens)
# group tokens into words (BPE merges). Below is a naive example:
words, begin_times, end_times = [], [], []
current_word, start_t = [], None
for i, tk in enumerate(tokens):
    piece = tk.lstrip("Ġ")  # 'Ġ' indicates a word boundary in Whisper BPE
    boundary = (tk.startswith("Ġ") or i == 0)
    t = frame_times_s[i]
    if boundary and current_word:
        words.append("".join(current_word))
        begin_times.append(start_t)
        end_times.append(prev_t)
        current_word, start_t = [], t
    if boundary and not current_word:
        start_t = t
    current_word.append(piece)
    prev_t = t

# close last word
if current_word:
    words.append("".join(current_word))
    begin_times.append(start_t)
    end_times.append(prev_t)

import pandas as pd
pd.DataFrame({"word": words, "begin": begin_times, "end": end_times}).head()


TypeError: Whisper.forward() got an unexpected keyword argument 'input_features'