In [3]:
import torch
import whisper
import whisper.model
from whisper.model import MultiHeadAttention

# Disable SDPA (required)
whisper.model.MultiHeadAttention.use_sdpa = False


# ---------- Functional Self-Attention ----------
class FunctionalSelfAttention(MultiHeadAttention):

    def forward(self, x, xa=None, mask=None, kv_cache=None):

        if kv_cache is None:
            k = self.key(x)
            v = self.value(x)
        else:
            k_prev = kv_cache[self.key]
            v_prev = kv_cache[self.value]

            k_new = self.key(x)
            v_new = self.value(x)

            k = torch.cat([k_prev, k_new], dim=1)
            v = torch.cat([v_prev, v_new], dim=1)

            kv_cache[self.key] = k
            kv_cache[self.value] = v

        q = self.query(x)
        wv, qk = self.qkv_attention(q, k, v, mask)

        return self.out(wv), qk


# ---------- Patch model ----------
def patch_model(model):
    for block in model.decoder.blocks:
        block.attn.__class__ = FunctionalSelfAttention


# ---------- Functional Decoder ----------
class FunctionalDecoder(torch.nn.Module):

    def __init__(self, decoder):
        super().__init__()
        self.decoder = decoder

        self.cache_keys = []
        for block in decoder.blocks:
            self.cache_keys.append(block.attn.key)
            self.cache_keys.append(block.attn.value)

    def forward(self, tokens, audio, cache):

        kv_cache = dict(zip(self.cache_keys, cache))

        offset = kv_cache[self.cache_keys[0]].shape[1]

        x = self.decoder.token_embedding(tokens)
        x = x + self.decoder.positional_embedding[offset: offset + tokens.shape[1]]
        x = x.to(audio.dtype)

        mask = self.decoder.mask[
            offset:offset + tokens.shape[1],
            :offset + tokens.shape[1]
        ]

        for block in self.decoder.blocks:
            x = block(x, audio, mask=mask, kv_cache=kv_cache)

        x = self.decoder.ln(x)

        logits = x @ self.decoder.token_embedding.weight.T

        new_cache = torch.cat(
            [kv_cache[k].unsqueeze(0) for k in self.cache_keys],
            dim=0
        )

        return logits, new_cache



# ---------- Main ----------
def main():

    model = whisper.load_model("tiny.en").cpu().eval()
    patch_model(model)

    decoder = FunctionalDecoder(model.decoder)
    tokenizer = whisper.tokenizer.get_tokenizer(True)

    audio = whisper.load_audio("sample.mp3")
    audio = whisper.pad_or_trim(audio)

    mel = whisper.log_mel_spectrogram(audio).unsqueeze(0)
    audio_features = model.encoder(mel)

    # CRITICAL: start with empty cache
    cache = torch.zeros((8, 1, 0, 384))

    # CRITICAL: correct start sequence
    start_tokens = [tokenizer.sot, tokenizer.no_timestamps]

    result = []

    for t in start_tokens:
        token_tensor = torch.tensor([[t]])
        logits, cache = decoder(token_tensor, audio_features, cache)
        result.append(t)

    for _ in range(200):

        next_token = logits[:, -1].argmax(-1).reshape(1, 1)
        token_id = next_token.item()

        if token_id == tokenizer.eot:
            break

        result.append(token_id)

        logits, cache = decoder(next_token, audio_features, cache)

    text = tokenizer.decode(result)

    print("Decoded:", text)


if __name__ == "__main__":
    main()


Decoded: <|startoftranscript|><|notimestamps|>


In [None]:
import torch

import whisper
from whisper.model import MultiHeadAttention


def main():
    model = whisper.load_model("tiny.en")
    model.cpu().eval()
    patch(model)

    encoder = model.encoder.cpu()
    decoder = FunctionalDecoder(model.decoder.cpu())

    x_mel = torch.randn(1, 80, 3000)
    x_tokens = torch.zeros((1, 10), dtype=torch.long).cpu()
    x_audio = encoder(x_mel).cpu()

    cache_self_attn = torch.zeros(
        (len(decoder.keys_self_attn), 1, 1, model.dims.n_text_state),
    )
    cache_cross_attn = torch.zeros(
        (len(decoder.keys_cross_attn), 1, 1, model.dims.n_audio_state),
    )

    torch.onnx.export(
        encoder,
        (x_mel,),
        "encoder.onnx",
        input_names=["mel"],
        output_names=["audio"],
        dynamic_axes={
            "mel": {0: "batch", 2: "time"},
            "audio": {0: "batch", 1: "time"},
        },
        opset_version=17,
    )

    torch.onnx.export(
        decoder,
        (x_tokens, x_audio, cache_self_attn, cache_cross_attn),
        "decoder.onnx",
        input_names=["tokens", "audio", "cache_self_attn", "cache_cross_attn"],
        output_names=["logits", "new_cache_self_attn", "new_cache_cross_attn"],
        dynamic_axes={
            # inputs
            "tokens": {0: "batch", 1: "seq"},
            "audio": {0: "batch", 1: "time"},
            "cache_self_attn": {2: "cached_seq"},
            "cache_cross_attn": {2: "cached_time"},
            # outputs
            "logits": {0: "batch", 1: "seq"},
            "new_cache_self_attn": {1: "batch", 2: "new_cached_seq"},
            "new_cache_cross_attn": {1: "batch", 2: "new_cached_time"},
        },
        opset_version=17,
    )


def patch(model):
    for block in model.decoder.blocks:
        block.attn.__class__ = FunctionalMultiHeadAttention
        block.attn.n_ctx = model.dims.n_text_ctx

        block.cross_attn.__class__ = FunctionalMultiHeadAttention
        block.cross_attn.n_ctx = model.dims.n_audio_ctx


class FunctionalDecoder(torch.nn.Module):
    def __init__(self, decoder):
        super().__init__()
        self.decoder = decoder

        self.keys_self_attn = []
        self.keys_cross_attn = []

        for block in decoder.blocks:
            self.keys_self_attn += (block.attn.key, block.attn.value)
            self.keys_cross_attn += (block.cross_attn.key, block.cross_attn.value)

    def forward(self, x, xa, cache_self_attn, cache_cross_attn):
        kv_cache = {
            **dict(zip(self.keys_self_attn, cache_self_attn)),
            **dict(zip(self.keys_cross_attn, cache_cross_attn)),
        }

        logits = self.decoder(x, xa, kv_cache=kv_cache)
        return (
            logits,
            torch.cat([kv_cache[key].unsqueeze(0) for key in self.keys_self_attn], dim=0),
            torch.cat([kv_cache[key].unsqueeze(0) for key in self.keys_cross_attn], dim=0),
        )


class FunctionalMultiHeadAttention(MultiHeadAttention):
    def forward(self, x, xa=None, mask=None, kv_cache=None):
        k, v = self._get_kv(x, xa, kv_cache)

        q = self.query(x)
        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk

    def _get_kv(self, x, xa=None, kv_cache=None):
        xx = x if xa is None else xa
        assert xx is not None

        if kv_cache is None:
            return self.key(xx), self.value(xx)

        key = torch.concat([kv_cache[self.key], self.key(xx).detach()], dim=1)
        key = key[:, -self.n_ctx :, :]
        kv_cache[self.key] = key

        value = torch.concat([kv_cache[self.value], self.value(xx).detach()], dim=1)
        value = value[:, -self.n_ctx :, :]
        kv_cache[self.value] = value

        return kv_cache[self.key], kv_cache[self.value]


if __name__ == "__main__":
    main()

corrected enhanced

In [None]:
from gtts import gTTS

text = "Hello, this is a Whisper ONNX test."
tts = gTTS(text=text, lang="en")

tts.save("sample.mp3")


In [None]:
import whisper
import whisper.model
whisper.model.MultiHeadAttention.use_sdpa = False

model = whisper.load_model("tiny")

result = model.transcribe("sample.mp3")


print(result["text"])


In [None]:
import numpy as np
import whisper
import onnxruntime as ort

# tokenizer
tokenizer = whisper.tokenizer.get_tokenizer(True)

# load audio â†’ mel
audio = whisper.load_audio("sample.mp3")
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).unsqueeze(0)
mel = mel.cpu().numpy().astype(np.float32)

# load ONNX
enc = ort.InferenceSession("encoder.onnx")
dec = ort.InferenceSession("decoder.onnx")

# encoder forward
audio_features = enc.run(None, {"mel": mel})[0].astype(np.float32)

# KV cache init
cache_self = np.zeros((8,1,1,384), np.float32)
cache_cross = np.zeros((8,1,1,384), np.float32)

# PREFILL cross-attn cache (CRITICAL)
dummy = np.array([[tokenizer.sot]], dtype=np.int64)
_, _, cache_cross = dec.run(
    None,
    {
        "tokens": dummy,
        "audio": audio_features,
        "cache_self_attn": cache_self,
        "cache_cross_attn": cache_cross,
    },
)

# reset self cache only
cache_self[:] = 0

# start decoding
tokens_list = [tokenizer.sot]
tokens = np.array([[tokenizer.sot]], dtype=np.int64)

logits, cache_self, cache_cross = dec.run(
    None,
    {
        "tokens": tokens,
        "audio": audio_features,
        "cache_self_attn": cache_self,
        "cache_cross_attn": cache_cross,
    },
)

# incremental decode
for _ in range(200):
    next_token = logits[:, -1, :].argmax(-1).astype(np.int64).reshape(1,1)

    if next_token[0,0] == tokenizer.eot:
        break

    tokens_list.append(int(next_token[0,0]))

    logits, cache_self, cache_cross = dec.run(
        None,
        {
            "tokens": next_token,
            "audio": audio_features,
            "cache_self_attn": cache_self,
            "cache_cross_attn": cache_cross,
        },
    )

# decode text
print(tokenizer.decode(tokens_list))
