In [None]:
from transformer_layers import PositionalEncoding, Encoder
import torch.nn as nn
import torch.nn.functional as F

class AudioEncoder(nn.Module):
    def __init__(self, n_mels=80, n_ctx=1500, d_model=512, n_head=8, 
                 n_layer=6, d_ff=2048, dropout=0.1):
        super().__init__()
        self.n_mels = n_mels
        self.n_ctx = n_ctx  # context length
        self.d_model = d_model
        
        # 兩層1D卷積用於特徵提取
        self.conv1 = nn.Conv1d(n_mels, d_model, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1)
        
        self.positional_encoding = PositionalEncoding(d_model, max_len=n_ctx)
        
        # Transformer編碼器層
        self.encoder = Encoder(d_model, n_head, d_ff, n_layer, dropout)
        
    def forward(self, x):
        # 通過卷積層提取特徵
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        
        # 轉換維度順序: (batch, d_model, time) -> (batch, time, d_model)
        x = x.transpose(1, 2)
        
        # 添加位置編碼
        x = self.positional_encoding(x)
        
        # 通過Transformer編碼器
        x = self.encoder(x)
        
        return x

In [None]:
from transformer_layers import Decoder
import torch

class TextDecoder(nn.Module):
    def __init__(self, vocab_size=51865, n_ctx=448, d_model=512, n_head=8,
                 n_layer=6, d_ff=2048, dropout=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.n_ctx = n_ctx
        self.d_model = d_model
        
        # Token嵌入層
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # Whisper解碼器使用可學習的位置嵌入
        self.positional_embedding = nn.Parameter(torch.randn(n_ctx, d_model) * 0.01)
        
        # Transformer解碼器層
        self.decoder = Decoder(d_model, n_head, d_ff, n_layer, dropout)
        
        # 最終的線性層，輸出詞彙表大小的分佈
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        
    def forward(self, tokens, audio_features, mask=None):
        # Token嵌入
        x = self.token_embedding(tokens)
        
        # 添加可學習的位置嵌入
        seq_len = tokens.size(1)
        x = x + self.positional_embedding[:seq_len]
        
        # 生成因果遮罩（causal mask）
        if mask is None:
            mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
            if tokens.is_cuda:
                mask = mask.cuda()
        
        # 通過Transformer解碼器
        x = self.decoder(x, audio_features, tgt_mask=mask)
        
        # 最終層正規化和線性變換
        x = self.ln_f(x)
        logits = self.head(x)
        
        return logits

In [None]:
from transformers import WhisperTokenizer

class Whisper(nn.Module):
    def __init__(self, 
                 # 音頻編碼器參數
                 n_mels=80, audio_ctx=1500, 
                 # 文本解碼器參數  
                 vocab_size=51865, text_ctx=448,
                 # 模型參數
                 d_model=512, n_head=8, n_layer=6, d_ff=2048, dropout=0.1):
        super().__init__()
        
        # 音頻編碼器
        self.encoder = AudioEncoder(
            n_mels=n_mels,
            n_ctx=audio_ctx,
            d_model=d_model,
            n_head=n_head,
            n_layer=n_layer,
            d_ff=d_ff,
            dropout=dropout
        )
        
        # 文本解碼器
        self.decoder = TextDecoder(
            vocab_size=vocab_size,
            n_ctx=text_ctx,
            d_model=d_model,
            n_head=n_head,
            n_layer=n_layer,
            d_ff=d_ff,
            dropout=dropout
        )
        
        # 特殊token的定義
        self.sot_token = 50258  # start of transcript
        self.eot_token = 50257  # end of transcript
        self.no_speech_token = 50362
        
    def forward(self, mel, tokens=None, attention_mask=None):
        # 編碼音頻
        audio_features = self.encoder(mel)
        

        logits = self.decoder(tokens, audio_features, mask=attention_mask)
        return logits
 
    
    def generate(self, mel, max_length=448, temperature=1.0, top_p=0.9):
        self.eval()
        batch_size = mel.size(0)
        device = mel.device
        
        # 編碼音頻
        with torch.no_grad():
            audio_features = self.encoder(mel)
        
        # 初始化解碼序列與 attention mask
        tokens = torch.tensor([[self.sot_token]] * batch_size, device=device)
        attn_mask = torch.ones_like(tokens, dtype=torch.long)
        
        for i in range(max_length - 1):
            with torch.no_grad():
                logits = self.decoder(tokens, audio_features, attention_mask=attn_mask)
                
                # 取最後一個位置的logits
                next_token_logits = logits[:, -1, :] / temperature
                
                # Greedy decoding
                next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
                
                # 添加到序列中
                tokens = torch.cat([tokens, next_token], dim=1)
                
                # 更新 attention mask
                next_mask = torch.ones((batch_size, 1), dtype=torch.long, device=device)
                attn_mask = torch.cat([attn_mask, next_mask], dim=1)
                
                # 檢查是否遇到結束token
                if (next_token == self.eot_token).all():
                    break
        
        return tokens




# 使用範例
if __name__ == "__main__":
    # 初始化 Whisper Tokenizer
    tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base")
    
    # 創建一個base規模的Whisper模型
    model =  Whisper(
        d_model=1280, n_head=20, n_layer=32, d_ff=5120,
        audio_ctx=1500, text_ctx=448
    )
    # 模擬輸入
    batch_size = 2
    n_mels = 80
    time_steps = 3000
    text_list = ["hello world", "this is whisper"]
    
    # 創建模擬的mel-spectrogram
    mel = torch.randn(batch_size, n_mels, time_steps)
    
    # 使用 tokenizer 編碼文字
    encoded = tokenizer(text_list, return_tensors='pt', padding=True)
    tokens = encoded['input_ids']
    attention_mask = encoded['attention_mask']
    
    # 前向傳播
    logits = model(mel, tokens, attention_mask=attention_mask)
    print(f"Logits shape: {logits.shape}")  # (batch_size, seq_len, vocab_size)
    
    # 推理模式
    with torch.no_grad():
        generated_tokens = model.generate(mel, max_length=100)
        print(f"Generated tokens shape: {generated_tokens.shape}")
    
    # 解碼生成的 tokens
    decoded_texts = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    for text in decoded_texts:
        print(text)
    
    # 模型參數數量
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")    