In [32]:
import whisper
import torch

from whisper.audio import log_mel_spectrogram, load_audio, pad_or_trim
from whisper.decoding import DecodingOptions

model = whisper.load_model("tiny")
model.to(torch.float32)

Whisper(
  (encoder): AudioEncoder(
    (conv1): Conv1d(80, 384, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(384, 384, kernel_size=(3,), stride=(2,), padding=(1,))
    (blocks): ModuleList(
      (0-3): 4 x ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(in_features=384, out_features=384, bias=True)
          (key): Linear(in_features=384, out_features=384, bias=False)
          (value): Linear(in_features=384, out_features=384, bias=True)
          (out): Linear(in_features=384, out_features=384, bias=True)
        )
        (attn_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=1536, out_features=384, bias=True)
        )
        (mlp_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      )
    )
    (ln_post): LayerNorm((384,), eps=1e-05,

In [38]:
import torch.nn.functional as F

mel = log_mel_spectrogram(load_audio('./audio.mp3'), padding=30 * 16000)

N_FRAMES = 30 * 16000 // 160  # 3000 frames in a mel spectrogram input

mel_segment = pad_or_trim(mel, N_FRAMES).unsqueeze(0)

In [39]:
model.decode(mel_segment, DecodingOptions(fp16=False))

[DecodingResult(audio_features=tensor([[ 0.1039,  0.0522,  0.2074,  ..., -0.1066,  0.1678,  0.0533],
         [ 0.5284,  2.0875, -0.3823,  ..., -1.8898, -0.4868,  0.3108],
         [-0.9325,  1.3111, -1.5403,  ..., -1.6143,  1.0681, -0.4109],
         ...,
         [ 0.7557, -1.7804,  0.2187,  ..., -0.1072, -0.5025,  0.5058],
         [-0.0753, -0.4694,  0.1547,  ...,  0.6741,  0.0419,  0.3270],
         [ 0.1310, -0.0847, -1.4486,  ...,  0.0707, -0.5199, -0.2022]]), language='en', language_probs=None, tokens=[50364, 34439, 278, 11, 294, 264, 787, 2020, 365, 597, 321, 366, 412, 1974, 5922, 11, 37761, 490, 881, 50644, 50644, 498, 406, 490, 439, 264, 8609, 293, 27831, 10379, 294, 264, 14414, 13, 50844], text='Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the exhibition.', avg_logprob=-0.182177291976081, no_speech_prob=0.0037185135297477245, temperature=0.0, compression_ratio=1.345132743362832)]

In [54]:
encoder_logits = model.encoder(mel_segment)
tokens = torch.randint(0, 50000, (1,15))
decoder_logits = model.decoder(tokens, encoder_logits)

In [55]:
# Lets write a fucking model.
from dataclasses import dataclass


@dataclass
class GPTConfig:
    n_mels: int
    n_audio_ctx: int
    n_audio_state: int
    n_audio_head: int
    n_audio_layer: int
    n_vocab: int
    n_text_ctx: int
    n_text_state: int
    n_text_head: int
    n_text_layer: int


In [56]:
import torch.nn as nn
import numpy as np
# nn.Conv1d, nn.Linear, nn.LayerNorm

def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)


In [63]:
from whisper.model import MultiHeadAttention

mha = MultiHeadAttention(128, 8)
logits = torch.randn((2,10,128))
o1, o2 = mha(logits)

In [99]:
import math
class MultiHeadAttentionAlex(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)
        
    def forward(self, x, xa=None, mask=None):
        q = self.query(x)
        k = self.key(x if xa is None else xa)
        v = self.value(x if xa is None else xa)
        
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) # BS, n_head, seq_len, head_size
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask).permute(0, 2, 1, 3)
        
        return self.out(y.flatten(start_dim=2))

mhaalex = MultiHeadAttentionAlex(128, 8)
mhaalex.load_state_dict(mha.state_dict())

o1_alex = mhaalex(logits)

torch.allclose(o1, o1_alex)

True

In [101]:
class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        super().__init__()

        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = nn.LayerNorm(n_state)

        self.cross_attn = (
            MultiHeadAttention(n_state, n_head) if cross_attention else None
        )
        self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None

        self.mlp = nn.Sequential(
            nn.Linear(n_state, n_state * 4), nn.GELU(), nn.Linear(n_state * 4, n_state)
        )
        self.mlp_ln = nn.LayerNorm(n_state)

    def forward(self, x, xa = None, mask = None):
        x = x + self.attn(self.attn_ln(x), mask=mask)
    
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa)
    
        return x + self.mlp(self.mlp_ln(x))

class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()
        self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))

        self.blocks = nn.ModuleList([ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)])
        self.ln_post = nn.LayerNorm(n_state)

    def forward(self, x):
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)

        x = (x + self.positional_embedding)

        for block in self.blocks:
            x = block(x)

        return self.ln_post(x)
