In [1]:
import os
import sys
import torch

# Configure the parent directory
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

device = 'cuda' if torch.cuda.is_available else 'cpu'

In [2]:
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from app.utils import sinusoids
from dataclasses import dataclass
from app.whisper import (
    AudioEncoder, 
    TextDecoder,
    ModelDimensions
)

from app.utils import (
    load_original_whisper_weights,
    get_whisper_encoder_keys, 
    get_whisper_encoder_weigths, 
    get_whisper_decoder_keys, 
    get_whisper_decoder_weigths,
    compute_features
)

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)

    def forward(self, x: Tensor):
        # multi head attention used in the encoder
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        wv = self.qkv_attention(q, k, v)
        return self.out(wv)

    def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor):
        n_batch, n_ctx, n_state = q.size()
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(q.size(0), q.size(1), self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(k.size(0), k.size(1), self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(v.size(0), v.size(1), self.n_head, -1).permute(0, 2, 1, 3)
        qk = q @ k
        w = F.softmax(qk, dim=-1)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)

In [4]:
class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = nn.LayerNorm(n_state)
        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            nn.Linear(n_state, n_mlp),
            nn.GELU(),
            nn.Linear(n_mlp, n_state),
        )
        self.mlp_ln = nn.LayerNorm(n_state)

    def forward(self, x: Tensor):
        # standard encoder attention block with skip connection
        x = x + self.attn(self.attn_ln(x))
        x = x + self.mlp(self.mlp_ln(x))
        return x

In [5]:
class CachedMultiHeadAttentionDecoderSelf(nn.Module):
    def __init__(self, n_state: int, n_head: int, n_layer: int):
        super().__init__()
        self.n_layer = n_layer
        self.n_head = n_head
        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)

    def forward(
        self,
        x: Tensor,
        kv_cache: Tensor,
    ):
        # q will always come from the bottom (from previous decoder)
        q = self.query(x)
        
        # It is essential to define the batch_indices for the case where the offset is not unique 
        # batch_indices = torch.arange(x.size(0), device=x.device, dtype=torch.int32)
        
        key = self.key(x)
        value = self.value(x)
        
        # print(f"key: {key.shape}")
        # print(f"value: {value.shape}")
        
        key_cache = torch.cat([kv_cache[:, self.n_layer, 0, ...], key], dim=1)
        value_cache = torch.cat([kv_cache[:, self.n_layer, 1, ...], value], dim=1)

        k = key_cache
        v = value_cache

        wv = self.masked_qkv_attention(q, k, v)
        return self.out(wv), key, value

    def masked_qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor,
    ):
        n_batch, n_ctx, n_state = q.size()
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(q.size(0), q.size(1), self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(k.size(0), k.size(1), self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(v.size(0), v.size(1), self.n_head, -1).permute(0, 2, 1, 3)
        
        qk = q @ k

        # Mask padded tokens, they deserve 0 attention score
        # padding_mask = (qk == 0)
        # qk.masked_fill_(padding_mask, float('-inf')) # -- more advanced, ONNX warnings
        
        # mask = padding_mask * -65504 # Smallest value for float16
        # qk = qk + mask
        
        # the model expects one token at a time
        # if mask is not None:
        #     print("qk.shape, mask.shape, n_ctx", qk.shape, mask.shape, n_ctx)
        #     qk = qk + mask[:n_ctx, :n_ctx]
        
        w = F.softmax(qk, dim=-1)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)

In [6]:
class CachedMultiHeadAttentionDecoderCross(nn.Module):
    def __init__(self, n_state: int, n_head: int, n_layer: int):
        super().__init__()
        self.n_layer = n_layer
        self.n_head = n_head
        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)

    def forward(
        self,
        x: Tensor,
        n_layer_cross_k: Tensor,
        n_layer_cross_v: Tensor,
    ):
        # q will always come from the bottom (from previous decoder)
        q = self.query(x)
        
        # for corss-attention
        k = n_layer_cross_k[self.n_layer, ...]
        v = n_layer_cross_v[self.n_layer, ...]
        
        wv = self.masked_qkv_attention(q, k, v)
        return self.out(wv)

    def masked_qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor,
    ):
        n_batch, n_ctx, n_state = q.size()
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(q.size(0), q.size(1), self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(k.size(0), k.size(1), self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(v.size(0), v.size(1), self.n_head, -1).permute(0, 2, 1, 3)
        
        qk = q @ k
        
        # the model expects one token at a time
        # if mask is not None:
        #     print("qk.shape, mask.shape, n_ctx", qk.shape, mask.shape, n_ctx)
        #     qk = qk + mask[:n_ctx, :n_ctx]
        
        w = F.softmax(qk, dim=-1)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)

In [7]:
class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = nn.LayerNorm(n_state)
        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            nn.Linear(n_state, n_mlp),
            nn.GELU(),
            nn.Linear(n_mlp, n_state),
        )
        self.mlp_ln = nn.LayerNorm(n_state)

    def forward(self, x: Tensor):
        # standard encoder attention block with skip connection
        x = x + self.attn(self.attn_ln(x))
        x = x + self.mlp(self.mlp_ln(x))
        return x

In [8]:
class CachedResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, n_layer: int):
        super().__init__()
        self.attn = CachedMultiHeadAttentionDecoderSelf(n_state, n_head, n_layer)
        self.attn_ln = nn.LayerNorm(n_state)
        self.cross_attn = CachedMultiHeadAttentionDecoderCross(n_state, n_head, n_layer)
        self.cross_attn_ln = nn.LayerNorm(n_state)
        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            nn.Linear(n_state, n_mlp),
            nn.GELU(),
            nn.Linear(n_mlp, n_state),
        )
        self.mlp_ln = nn.LayerNorm(n_state)

    def forward(
        self,
        x: Tensor,
        kv_cache: Tensor,
        n_layer_cross_k,
        n_layer_cross_v
    ):
        # decoder attn and cross-attn block with skip connection
        x1, k, v = self.attn(self.attn_ln(x), kv_cache)
        x = x + x1
        x = x + self.cross_attn(self.cross_attn_ln(x), n_layer_cross_k, n_layer_cross_v)
        x = x + self.mlp(self.mlp_ln(x))
        return x, k, v

In [9]:
class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layers: int, encoder_x: bool = False
    ):
        super().__init__()
        self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
        self.encoder_x = encoder_x

        # encoder
        self.blocks = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layers)]
        )
        self.ln_post = nn.LayerNorm(n_state)
        
        # decoder
        self.decoder = nn.ModuleList(
            [
                CachedResidualAttentionBlock(n_state, n_head, n_layer)
                for n_layer in range(n_layers)
            ]
        )

    def forward(self, x: Tensor):
        """
        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
            the mel spectrogram of the audio
        """
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)

        # assert x[0].size() == self.positional_embedding.size(), "incorrect audio shape"
        # x = x + self.positional_embedding
        
        x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)

        for block in self.blocks:
            x = block(x)
        x = self.ln_post(x)
        
        if self.encoder_x:
            return x
        
        ###   DECODER   ###
        n_layer_cross_k_list = []
        n_layer_cross_v_list = []
        for block in self.decoder:
            n_layer_cross_k_list.append(block.cross_attn.key(x))
            n_layer_cross_v_list.append(block.cross_attn.value(x))
        audio_features = torch.stack(n_layer_cross_k_list), torch.stack(n_layer_cross_v_list)
        return (audio_features[0].permute(1, 0, 2, 3), audio_features[1].permute(1, 0, 2, 3))

In [10]:
class TextDecoder(nn.Module):
    def __init__(
        self,
        n_vocab: int,
        n_ctx: int,
        n_state: int,
        n_head: int,
        n_layers: int,
    ):
        super().__init__()

        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

        self.blocks = nn.ModuleList(
            [
                CachedResidualAttentionBlock(n_state, n_head, n_layer)
                for n_layer in range(n_layers)
            ]
        )
        self.ln = nn.LayerNorm(n_state)

        # mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        # self.register_buffer("mask", mask, persistent=False)

    def forward(self, x: Tensor, kv_cache: Tensor, n_layer_cross_k: Tensor, n_layer_cross_v: Tensor, offset: Tensor):
        """
        x : torch.LongTensor, shape = (batch_size, <= n_ctx)
            the text tokens
        """

        # (b_size, n_layers, audio_lenght, d_model)
        n_layer_cross_k = n_layer_cross_k.permute(1, 0, 2, 3)
        n_layer_cross_v = n_layer_cross_v.permute(1, 0, 2, 3)
        
        # offset = kv_cache[0].size(1) if len(kv_cache) > 0 else 0
        
        x = (
            self.token_embedding(x)
            + self.positional_embedding[offset] # We always expect a single token at a time in the batch 
        )

        keys = []
        values = []
        for block in self.blocks:
            x, k, v = block(x, kv_cache, n_layer_cross_k, n_layer_cross_v)
            keys.append(k)
            values.append(v)
        
        x = self.ln(x)
        logits = x @ torch.transpose(self.token_embedding.weight, 0, 1)
        keys, values = torch.stack((keys), dim=0), torch.stack((values), dim=0)
        return logits, keys.permute(1, 0, 2, 3), values.permute(1, 0, 2, 3)

In [11]:
medium_weights = load_original_whisper_weights(file_path='../app/model/medium.pt', device='cpu', )
dims = ModelDimensions(**medium_weights['dims'])

encoder = AudioEncoder(
    n_mels=dims.n_mels,
    n_ctx=dims.n_audio_ctx,
    n_state=dims.n_audio_state,
    n_head=dims.n_audio_head,
    n_layers=dims.n_audio_layer,
)

decoder = TextDecoder(
    n_vocab=dims.n_vocab,
    n_ctx=dims.n_text_ctx,
    n_state=dims.n_text_state,
    n_head=dims.n_text_head,
    n_layers=dims.n_text_layer 
)

encoder_keys = encoder.state_dict()

encoder_needed_keys = get_whisper_encoder_keys(encoder_keys)
encoder_weights = get_whisper_encoder_weigths(
    encoder_keys,
    encoder_needed_keys, 
    medium_weights
)
encoder.load_state_dict(encoder_weights)
encoder = encoder.to(device).half()
encoder = encoder.eval()

decoder_keys = decoder.state_dict()

decoder_needed_keys = get_whisper_decoder_keys(decoder_keys)
decoder_weights = get_whisper_decoder_weigths(
    decoder_keys, 
    decoder_needed_keys, 
    medium_weights
)
decoder.load_state_dict(decoder_weights)
decoder = decoder.to(device).half()
decoder = decoder.eval()

### Run

In [12]:
# suppressed tokens, see SuppressBlank and SuppressTokens class
suppress_blanks = [220, 50257]
suppress_nonspeech = [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 
    93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 
    3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 
    14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 
    32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362]

In [13]:
import pandas as pd
import soundfile as sf

from transformers import WhisperTokenizerFast

tokenizer = WhisperTokenizerFast.from_pretrained("openai/whisper-medium")

In [14]:
df = pd.read_csv('./fleurs_metadata.csv')

In [15]:
wave, sr = sf.read(df.absolute_path[0])
mel = compute_features(wave, sample_rate=sr)
mel = mel.unsqueeze(0)
b_size = mel.size(0)

In [16]:
def update_kv_cache(
    kv_cache: torch.tensor, 
    keys: torch.tensor, 
    values: torch.tensor, 
    offset: int
) -> [torch.tensor, int]:
    kv_cache[..., 0, offset, :] = keys.squeeze(2)
    kv_cache[..., 1, offset, :] = values.squeeze(2)
    offset += 1
    return kv_cache, offset


def get_token(
    logits: torch.tensor
) -> torch.tensor:
    last = logits[:, -1]
    last[:, suppress_nonspeech] = -torch.inf
    last = last.argmax(-1, keepdim=True)
    return last

In [19]:
offset = 0
max_token_sequence = 50
tokens = torch.tensor([[50258, 50259, 50359, 50363]] * b_size, dtype=torch.int32).to(mel.device)
kv_cache = torch.zeros((b_size, 24, 2, max_token_sequence, 1024), dtype=torch.half, device=mel.device)


with torch.no_grad():
    n_layer_cross_k, n_layer_cross_v = encoder(mel)
    
    # Spetial tokens
    for token in tokens[0]:
        logits, keys, values = decoder(
            token.unsqueeze(0).unsqueeze(0),
            kv_cache,
            n_layer_cross_k,
            n_layer_cross_v,
            offset
        )
        kv_cache, offset = update_kv_cache(kv_cache, keys, values, offset)
        last = get_token(logits)
    tokens = torch.cat([tokens, last], dim=-1)
    
    
    # Start of auto-regressiveness
    for i in range(max_token_sequence-len(tokens)):
        logits, keys, values = decoder(
            last,
            kv_cache,
            n_layer_cross_k, 
            n_layer_cross_v, 
            offset
        )
        kv_cache, offset = update_kv_cache(kv_cache, keys, values, offset)
        last = get_token(logits)
        tokens = torch.cat([tokens, last], dim=-1)
        
        # when to stop
        if last.item() == 50257:
            break
    

In [20]:
tokenizer.decode(tokens[0])

'<|startoftranscript|><|en|><|transcribe|><|notimestamps|> However, due to the slow communication channels styles styles in the West could lag behind by 25 to 30 years.<|endoftext|>'