In [2]:
import torch

import whisper
from whisper.model import MultiHeadAttention


def main():
    model = whisper.load_model("tiny.en")
    model.cpu().eval()
    patch(model)

    encoder = model.encoder.cpu()
    decoder = FunctionalDecoder(model.decoder.cpu())

    x_mel = torch.randn(1, 80, 3000)
    x_tokens = torch.zeros((1, 10), dtype=torch.long).cpu()
    x_audio = encoder(x_mel).cpu()

    cache_self_attn = torch.zeros(
        (len(decoder.keys_self_attn), 1, 1, model.dims.n_text_state),
    )
    cache_cross_attn = torch.zeros(
        (len(decoder.keys_cross_attn), 1, 1, model.dims.n_audio_state),
    )

    torch.onnx.export(
        encoder,
        (x_mel,),
        "encoder.onnx",
        input_names=["mel"],
        output_names=["audio"],
        dynamic_axes={
            "mel": {0: "batch", 2: "time"},
            "audio": {0: "batch", 1: "time"},
        },
        opset_version=17,
    )

    torch.onnx.export(
        decoder,
        (x_tokens, x_audio, cache_self_attn, cache_cross_attn),
        "decoder.onnx",
        input_names=["tokens", "audio", "cache_self_attn", "cache_cross_attn"],
        output_names=["logits", "new_cache_self_attn", "new_cache_cross_attn"],
        dynamic_axes={
            # inputs
            "tokens": {0: "batch", 1: "seq"},
            "audio": {0: "batch", 1: "time"},
            "cache_self_attn": {2: "cached_seq"},
            "cache_cross_attn": {2: "cached_time"},
            # outputs
            "logits": {0: "batch", 1: "seq"},
            "new_cache_self_attn": {1: "batch", 2: "new_cached_seq"},
            "new_cache_cross_attn": {1: "batch", 2: "new_cached_time"},
        },
        opset_version=17,
    )


def patch(model):
    for block in model.decoder.blocks:
        block.attn.__class__ = FunctionalMultiHeadAttention
        block.attn.n_ctx = model.dims.n_text_ctx

        block.cross_attn.__class__ = FunctionalMultiHeadAttention
        block.cross_attn.n_ctx = model.dims.n_audio_ctx


class FunctionalDecoder(torch.nn.Module):
    def __init__(self, decoder):
        super().__init__()
        self.decoder = decoder

        self.keys_self_attn = []
        self.keys_cross_attn = []

        for block in decoder.blocks:
            self.keys_self_attn += (block.attn.key, block.attn.value)
            self.keys_cross_attn += (block.cross_attn.key, block.cross_attn.value)

    def forward(self, x, xa, cache_self_attn, cache_cross_attn):
        kv_cache = {
            **dict(zip(self.keys_self_attn, cache_self_attn)),
            **dict(zip(self.keys_cross_attn, cache_cross_attn)),
        }

        logits = self.decoder(x, xa, kv_cache=kv_cache)
        return (
            logits,
            torch.cat([kv_cache[key].unsqueeze(0) for key in self.keys_self_attn], dim=0),
            torch.cat([kv_cache[key].unsqueeze(0) for key in self.keys_cross_attn], dim=0),
        )


class FunctionalMultiHeadAttention(MultiHeadAttention):
    def forward(self, x, xa=None, mask=None, kv_cache=None):
        k, v = self._get_kv(x, xa, kv_cache)

        q = self.query(x)
        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk

    def _get_kv(self, x, xa=None, kv_cache=None):
        xx = x if xa is None else xa
        assert xx is not None

        if kv_cache is None:
            return self.key(xx), self.value(xx)

        key = torch.concat([kv_cache[self.key], self.key(xx).detach()], dim=1)
        key = key[:, -self.n_ctx :, :]
        kv_cache[self.key] = key

        value = torch.concat([kv_cache[self.value], self.value(xx).detach()], dim=1)
        value = value[:, -self.n_ctx :, :]
        kv_cache[self.value] = value

        return kv_cache[self.key], kv_cache[self.value]


if __name__ == "__main__":
    main()

  torch.onnx.export(
W0207 13:52:07.387000 42371 torch/onnx/_internal/exporter/_compat.py:125] Setting ONNX exporter to use operator set version 18 because the requested opset_version 17 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features
W0207 13:52:07.928000 42371 torch/onnx/_internal/exporter/_registration.py:110] torchvision is not installed. Skipping torchvision::nms
W0207 13:52:07.930000 42371 torch/onnx/_internal/exporter/_registration.py:110] torchvision is not installed. Skipping torchvision::roi_align
W0207 13:52:07.933000 42371 torch/onnx/_internal/exporter/_registration.py:110] torchvision is not installed. Skipping torchvision::roi_pool


[torch.onnx] Obtain model graph for `AudioEncoder([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `AudioEncoder([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...


  return cls.__new__(cls, *args)
The model version conversion is not supported by the onnxscript version converter and fallback is enabled. The model will be converted using the onnx C API (target version: 17).


[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 16 of general pattern rewrite rules.


  torch.onnx.export(
  torch.onnx.export(
W0207 13:52:13.737000 42371 torch/onnx/_internal/exporter/_compat.py:125] Setting ONNX exporter to use operator set version 18 because the requested opset_version 17 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features
W0207 13:52:14.263000 42371 torch/onnx/_internal/exporter/_registration.py:110] torchvision is not installed. Skipping torchvision::nms
W0207 13:52:14.265000 42371 torch/onnx/_internal/exporter/_registration.py:110] torchvision is not installed. Skipping torchvision::roi_align
W0207 13:52:14.268000 42371 torch/onnx/_internal/exporter/_registration.py:110] torchvision is not installed. Skipping torchvision::roi_pool


[torch.onnx] Obtain model graph for `FunctionalDecoder([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `FunctionalDecoder([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...


  return cls.__new__(cls, *args)


[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...


The model version conversion is not supported by the onnxscript version converter and fallback is enabled. The model will be converted using the onnx C API (target version: 17).


[torch.onnx] Translate the graph into ONNX... ✅
Applied 37 of general pattern rewrite rules.
