In [16]:
import os, io
import torch
from whisper.model import AudioEncoder
from whisper import _MODELS, _ALIGNMENT_HEADS, _download, available_models
from whisper import ModelDimensions
from typing import Optional, Union

def load_model(
    name: str,
    device: Optional[Union[str, torch.device]] = None,
    download_root: str = None,
    in_memory: bool = False,
) -> AudioEncoder:
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    if download_root is None:
        default = os.path.join(os.path.expanduser("~"), ".cache")
        download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")

    if name in _MODELS:
        checkpoint_file = _download(_MODELS[name], download_root, in_memory)
        alignment_heads = _ALIGNMENT_HEADS[name]
    elif os.path.isfile(name):
        checkpoint_file = open(name, "rb").read() if in_memory else name
        alignment_heads = None
    else:
        raise RuntimeError(
            f"Model {name} not found; available models = {available_models()}"
        )

    with (
        io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
    ) as fp:
        checkpoint = torch.load(fp, map_location=device)
    del checkpoint_file

    dims = ModelDimensions(**checkpoint["dims"])
    model = AudioEncoder(
        dims.n_mels,
        dims.n_audio_ctx,
        dims.n_audio_state,
        dims.n_audio_head,
        dims.n_audio_layer,
    )

    print("\n".join([f for f in checkpoint["model_state_dict"].keys() if "encoder" in f]))
    encoder_keys = [f for f in checkpoint["model_state_dict"].keys() if "encoder" in f]
    model.load_state_dict({f.split("encoder.")[-1]: checkpoint["model_state_dict"][f] for f in encoder_keys})

    # if alignment_heads is not None:
    #     model.set_alignment_heads(alignment_heads)

    return model.to(device)

In [17]:
load_model("tiny")

  checkpoint = torch.load(fp, map_location=device)


encoder.positional_embedding
encoder.conv1.weight
encoder.conv1.bias
encoder.conv2.weight
encoder.conv2.bias
encoder.blocks.0.mlp_ln.weight
encoder.blocks.0.mlp_ln.bias
encoder.blocks.0.mlp.0.weight
encoder.blocks.0.mlp.0.bias
encoder.blocks.0.mlp.2.weight
encoder.blocks.0.mlp.2.bias
encoder.blocks.0.attn_ln.weight
encoder.blocks.0.attn_ln.bias
encoder.blocks.0.attn.query.weight
encoder.blocks.0.attn.query.bias
encoder.blocks.0.attn.key.weight
encoder.blocks.0.attn.value.weight
encoder.blocks.0.attn.value.bias
encoder.blocks.0.attn.out.weight
encoder.blocks.0.attn.out.bias
encoder.blocks.1.mlp_ln.weight
encoder.blocks.1.mlp_ln.bias
encoder.blocks.1.mlp.0.weight
encoder.blocks.1.mlp.0.bias
encoder.blocks.1.mlp.2.weight
encoder.blocks.1.mlp.2.bias
encoder.blocks.1.attn_ln.weight
encoder.blocks.1.attn_ln.bias
encoder.blocks.1.attn.query.weight
encoder.blocks.1.attn.query.bias
encoder.blocks.1.attn.key.weight
encoder.blocks.1.attn.value.weight
encoder.blocks.1.attn.value.bias
encoder.block

AudioEncoder(
  (conv1): Conv1d(80, 384, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv2): Conv1d(384, 384, kernel_size=(3,), stride=(2,), padding=(1,))
  (blocks): ModuleList(
    (0-3): 4 x ResidualAttentionBlock(
      (attn): MultiHeadAttention(
        (query): Linear(in_features=384, out_features=384, bias=True)
        (key): Linear(in_features=384, out_features=384, bias=False)
        (value): Linear(in_features=384, out_features=384, bias=True)
        (out): Linear(in_features=384, out_features=384, bias=True)
      )
      (attn_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=384, out_features=1536, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1536, out_features=384, bias=True)
      )
      (mlp_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
    )
  )
  (ln_post): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
)