In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Audio processing
import torchaudio

# Data handling
import numpy as np
from datasets import load_dataset, DownloadConfig

# Progress bar
from tqdm import tqdm

class LibriSpeechDataset(Dataset):
    def __init__(self, max_duration=30, max_samples=10):
        """
        max_duration: maximum audio duration in seconds
        max_samples: maximum number of samples to load
        """
        print("Preparing to load LibriSpeech dataset...")
        
        try:
            # Explicitly download tiny subset first
            print("Step 1: Downloading a tiny subset for verification...")
            
            # Configure download settings
            download_config = DownloadConfig(
                delete_extracted=True,
                force_download=True,
                force_extract=True,
                num_proc=1
            )
            
            # Try downloading a very small subset first
            print("Downloading initial test set...")
            self.dataset = load_dataset(
                "patrickvonplaten/librispeech_asr_dummy",
                "clean",
                split="validation",
                download_config=download_config
            )
            
            print("Test download successful!")
            print(f"Found {len(self.dataset)} samples in test set")
            
        except Exception as e:
            print(f"Error during dataset loading: {str(e)}")
            raise
            
        self.max_duration = max_duration
        self.sample_rate = 16000
        self.tokenizer = SimpleTokenizer()
        
        # Setup mel spectrogram transform
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sample_rate,
            n_mels=80,
            n_fft=2048,
            hop_length=160,
            win_length=400
        )
        print("Audio processing setup complete")
        print(f"Dataset ready with {len(self.dataset)} samples")
    
    def __len__(self):
        return len(self.dataset)
    
    def process_audio(self, audio: np.ndarray) -> torch.Tensor:
        waveform = torch.from_numpy(audio).float()
        if len(waveform.shape) > 1:
            waveform = waveform.mean(dim=0)
        mel_spec = self.mel_transform(waveform)
        mel_spec = torch.log(mel_spec + 1e-9)
        return mel_spec
    
    def __getitem__(self, idx: int):
        try:
            sample = self.dataset[idx]
            audio_array = sample['audio']['array']
            audio_features = self.process_audio(audio_array)
            transcript = sample['text']
            text_tokens = self.tokenizer.encode(transcript)
            language_id = self.tokenizer.encode_language('en')
            has_speech = torch.tensor(1.0)
            
            return {
                'audio_features': audio_features,
                'text_tokens': text_tokens,
                'language_id': language_id,
                'has_speech': has_speech,
                'transcript': transcript
            }
        except Exception as e:
            print(f"Error processing sample {idx}: {str(e)}")
            raise

# Add this before your main function
def verify_dataset_access():
    """
    Function to verify dataset can be accessed and downloaded
    """
    print("Verifying dataset access...")
    try:
        from datasets import load_dataset
        
        # Try to load just 2 examples from the dummy dataset
        print("Attempting to download test dataset...")
        dataset = load_dataset(
            "patrickvonplaten/librispeech_asr_dummy",
            "clean",
            split="validation"
        )
        
        print(f"Successfully accessed dataset! Found {len(dataset)} samples")
        return True
        
    except Exception as e:
        print(f"Error accessing dataset: {str(e)}")
        return False

def main():
    print("Initializing Whisper model training...")
    
    # First verify dataset access
    if not verify_dataset_access():
        print("Failed to access dataset. Please check your internet connection and try again.")
        return
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize model and tokenizer
    model = WhisperModel(n_mels=80, n_vocab=51865, n_state=512)
    model = model.to(device)
    
    # Load dataset
    try:
        print("Loading main dataset...")
        dataset = LibriSpeechDataset(max_samples=10)  # Start with just 10 samples
        
        dataloader = DataLoader(
            dataset, 
            batch_size=2,  # Smaller batch size for testing
            shuffle=True,
            collate_fn=collate_batch,
            num_workers=0
        )
        print(f"Successfully created DataLoader with {len(dataset)} samples")
        
        # Rest of your training code...
        
    except Exception as e:
        print(f"Failed to load dataset: {str(e)}")
        raise

if __name__ == "__main__":
    # Make sure you have the required packages
    try:
        import datasets
        import torch
        import torchaudio
        
        print("Required packages found!")
        print(f"datasets version: {datasets.__version__}")
        print(f"torch version: {torch.__version__}")
        print(f"torchaudio version: {torchaudio.__version__}")
        
        main()
    except ImportError as e:
        print(f"Missing required package: {str(e)}")
        print("Please install required packages:")
        print("pip install torch torchaudio datasets transformers")

NameError: name 'Dataset' is not defined

In [5]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
from datasets import load_dataset
from tqdm import tqdm

class SimpleTokenizer:
    def __init__(self, vocab_size=51865):
        self.pad_token = 0
        self.eos_token = 1
        self.bos_token = 2
        
        # Simple character-level tokenization
        self.char_to_idx = {chr(i): i+10 for i in range(ord('a'), ord('z')+1)}
        self.char_to_idx.update({
            ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, "'": 8, '"': 9
        })
        self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
        self.language_codes = {'en': 0}
        
    def encode(self, text: str, max_length: int = 448) -> torch.Tensor:
        tokens = [self.char_to_idx.get(c, self.char_to_idx[' ']) for c in text.lower()]
        tokens = [self.bos_token] + tokens + [self.eos_token]
        
        if len(tokens) < max_length:
            tokens.extend([self.pad_token] * (max_length - len(tokens)))
        else:
            tokens = tokens[:max_length-1] + [self.eos_token]
        return torch.tensor(tokens)
    
    def encode_language(self, language: str) -> torch.Tensor:
        return torch.tensor([self.language_codes.get(language, 0)])

class LibriSpeechDataset(Dataset):
    def __init__(self):
        print("Loading LibriSpeech dataset...")
        
        # Load a small subset of train-clean-100
        self.dataset = load_dataset(
            "librispeech_asr",
            "clean",
            split="train[:10]"  # Only load 10 samples
        )
        
        self.sample_rate = 16000
        self.tokenizer = SimpleTokenizer()
        
        # Setup mel spectrogram transform
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sample_rate,
            n_mels=80,
            n_fft=2048,
            hop_length=160,
            win_length=400
        )
        print(f"Dataset loaded with {len(self.dataset)} samples")
    
    def __len__(self):
        return len(self.dataset)
    
    def process_audio(self, audio: np.ndarray) -> torch.Tensor:
        waveform = torch.from_numpy(audio).float()
        if len(waveform.shape) > 1:
            waveform = waveform.mean(dim=0)
        mel_spec = self.mel_transform(waveform)
        mel_spec = torch.log(mel_spec + 1e-9)
        return mel_spec
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        
        # Process audio
        audio_array = sample['audio']['array']
        audio_features = self.process_audio(audio_array)
        
        # Get transcript
        transcript = sample['text']
        
        # Convert to tokens
        text_tokens = self.tokenizer.encode(transcript)
        
        return {
            'audio_features': audio_features,
            'text_tokens': text_tokens,
            'transcript': transcript
        }

def main():
    print("Testing LibriSpeech dataset loading...")
    
    try:
        # Create dataset
        dataset = LibriSpeechDataset()
        
        # Create dataloader
        dataloader = DataLoader(
            dataset,
            batch_size=2,
            shuffle=True
        )
        
        # Test by accessing a few samples
        print("\nTesting access to samples:")
        for idx in range(min(3, len(dataset))):
            sample = dataset[idx]
            print(f"\nSample {idx}:")
            print(f"Audio shape: {sample['audio_features'].shape}")
            print(f"Text tokens shape: {sample['text_tokens'].shape}")
            print(f"Transcript: {sample['transcript']}")
        
        print("\nTesting batch loading:")
        for batch_idx, batch in enumerate(dataloader):
            print(f"\nBatch {batch_idx}:")
            print(f"Audio batch shape: {batch['audio_features'].shape}")
            print(f"Text tokens batch shape: {batch['text_tokens'].shape}")
            if batch_idx >= 2:  # Only test first few batches
                break
                
        print("\nDataset test completed successfully!")
        
    except Exception as e:
        print(f"Error during testing: {str(e)}")
        raise

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


Testing LibriSpeech dataset loading...
Loading LibriSpeech dataset...


Downloading data:  69%|██████▉   | 234M/338M [04:59<01:36, 1.08MB/s] 

Error during testing: 


FSTimeoutError: 

Downloading data:  69%|██████▉   | 234M/338M [05:19<01:36, 1.08MB/s]

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import time

class SimpleTokenizer:
    def __init__(self, vocab_size=51865):
        self.pad_token = 0
        self.eos_token = 1
        self.bos_token = 2
        self.char_to_idx = {chr(i): i+10 for i in range(ord('a'), ord('z')+1)}
        self.char_to_idx.update({' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, "'": 8, '"': 9})
        self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
        self.language_codes = {'en': 0}
        
    def encode(self, text: str, max_length: int = 448) -> torch.Tensor:
        tokens = [self.char_to_idx.get(c, self.char_to_idx[' ']) for c in text.lower()]
        tokens = [self.bos_token] + tokens + [self.eos_token]
        if len(tokens) < max_length:
            tokens.extend([self.pad_token] * (max_length - len(tokens)))
        else:
            tokens = tokens[:max_length-1] + [self.eos_token]
        return torch.tensor(tokens)
    
    def encode_language(self, language: str) -> torch.Tensor:
        return torch.tensor([self.language_codes.get(language, 0)])

class TinyLibriSpeechDataset(Dataset):
    def __init__(self):
        print("Loading Tiny LibriSpeech dataset...")
        
        # Use the tiny dummy dataset with 'clean' config
        self.dataset = load_dataset(
            "patrickvonplaten/librispeech_asr_dummy",
            "clean",  # Specify the config
            split="validation"
        )
        
        self.sample_rate = 16000
        self.tokenizer = SimpleTokenizer()
        
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sample_rate,
            n_mels=80,
            n_fft=2048,
            hop_length=160,
            win_length=400
        )
        print(f"Dataset loaded with {len(self.dataset)} samples")
    
    def process_audio(self, audio: np.ndarray) -> torch.Tensor:
        waveform = torch.from_numpy(audio).float()
        if len(waveform.shape) > 1:
            waveform = waveform.mean(dim=0)
        mel_spec = self.mel_transform(waveform)
        mel_spec = torch.log(mel_spec + 1e-9)
        return mel_spec
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        
        # Process audio
        audio_array = sample['audio']['array']
        audio_features = self.process_audio(audio_array)
        
        # Get transcript
        transcript = sample['text']
        
        # Convert to tokens
        text_tokens = self.tokenizer.encode(transcript)
        
        return {
            'audio_features': audio_features,
            'text_tokens': text_tokens,
            'transcript': transcript
        }

def try_download_dataset(max_retries=3, delay=5):
    """Try to download the dataset with retries"""
    for attempt in range(max_retries):
        try:
            print(f"Download attempt {attempt + 1}/{max_retries}")
            dataset = load_dataset(
                "patrickvonplaten/librispeech_asr_dummy",
                "clean",  # Specify the config
                split="validation"
            )
            print("Download successful!")
            return dataset
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {str(e)}")
            if attempt < max_retries - 1:
                print(f"Waiting {delay} seconds before retrying...")
                time.sleep(delay)
                delay *= 2  # Exponential backoff
    raise Exception("Failed to download dataset after all attempts")

def collate_fn(batch):
    """Custom collate function to handle variable length sequences"""
    
    # Find max lengths in the batch
    max_audio_len = max(b['audio_features'].shape[1] for b in batch)
    
    # Initialize tensors
    batch_size = len(batch)
    audio_features = torch.zeros(batch_size, 80, max_audio_len)
    text_tokens = torch.stack([b['text_tokens'] for b in batch])
    
    # Fill in the tensors with padded data
    for i, sample in enumerate(batch):
        audio = sample['audio_features']
        audio_len = audio.shape[1]
        audio_features[i, :, :audio_len] = audio
    
    return {
        'audio_features': audio_features,
        'text_tokens': text_tokens,
        'transcript': [b['transcript'] for b in batch]
    }

def main():
    print("Testing LibriSpeech dataset loading...")
    
    try:
        # First try just downloading
        print("Testing dataset download...")
        test_dataset = try_download_dataset()
        print(f"Successfully downloaded dataset with {len(test_dataset)} samples")
        
        # Create dataset
        dataset = TinyLibriSpeechDataset()
        
        # Create dataloader
        dataloader = DataLoader(
            dataset,
            batch_size=2,
            shuffle=True,
            collate_fn=collate_fn
        )
        
        # Test by accessing a few samples
        print("\nTesting access to samples:")
        for idx in range(min(3, len(dataset))):
            sample = dataset[idx]
            print(f"\nSample {idx}:")
            print(f"Audio shape: {sample['audio_features'].shape}")
            print(f"Text tokens shape: {sample['text_tokens'].shape}")
            print(f"Transcript: {sample['transcript']}")
        
        print("\nTesting batch loading:")
        for batch_idx, batch in enumerate(dataloader):
            print(f"\nBatch {batch_idx}:")
            print(f"Audio batch shape: {batch['audio_features'].shape}")
            print(f"Text tokens batch shape: {batch['text_tokens'].shape}")
            if batch_idx >= 2:
                break
                
        print("\nDataset test completed successfully!")
        
    except Exception as e:
        print(f"Error during testing: {str(e)}")
        raise

if __name__ == "__main__":
    # Set longer timeout for downloads
    import datasets.config as config
    config.HF_DATASETS_HTTP_TIMEOUT = 1000  # 1000 seconds timeout
    
    # Run main
    main()

Testing LibriSpeech dataset loading...
Testing dataset download...
Download attempt 1/3
Download successful!
Successfully downloaded dataset with 73 samples
Loading Tiny LibriSpeech dataset...
Dataset loaded with 73 samples

Testing access to samples:

Sample 0:
Audio shape: torch.Size([80, 466])
Text tokens shape: torch.Size([448])
Transcript: A MAN SAID TO THE UNIVERSE SIR I EXIST

Sample 1:
Audio shape: torch.Size([80, 654])
Text tokens shape: torch.Size([448])
Transcript: SWEAT COVERED BRION'S BODY TRICKLING INTO THE TIGHT LOINCLOTH THAT WAS THE ONLY GARMENT HE WORE

Sample 2:
Audio shape: torch.Size([80, 1334])
Text tokens shape: torch.Size([448])
Transcript: THE CUT ON HIS CHEST STILL DRIPPING BLOOD THE ACHE OF HIS OVERSTRAINED EYES EVEN THE SOARING ARENA AROUND HIM WITH THE THOUSANDS OF SPECTATORS WERE TRIVIALITIES NOT WORTH THINKING ABOUT

Testing batch loading:

Batch 0:
Audio batch shape: torch.Size([2, 80, 925])
Text tokens batch shape: torch.Size([2, 448])

Batch 1:
Audio b

In [9]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import math
import time

class SimpleTokenizer:
    def __init__(self, vocab_size=51865):
        self.pad_token = 0
        self.eos_token = 1
        self.bos_token = 2
        self.char_to_idx = {chr(i): i+10 for i in range(ord('a'), ord('z')+1)}
        self.char_to_idx.update({' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, "'": 8, '"': 9})
        self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
        self.language_codes = {'en': 0}
        
        # Add task tokens
        self.task_tokens = {
            'transcribe': vocab_size - 5,
            'translate': vocab_size - 4,
            'language_id': vocab_size - 3,
            'no_timestamps': vocab_size - 2,
            'no_speech': vocab_size - 1
        }
        
    def encode(self, text: str, max_length: int = 448) -> torch.Tensor:
        tokens = [self.char_to_idx.get(c, self.char_to_idx[' ']) for c in text.lower()]
        tokens = [self.bos_token] + tokens + [self.eos_token]
        if len(tokens) < max_length:
            tokens.extend([self.pad_token] * (max_length - len(tokens)))
        else:
            tokens = tokens[:max_length-1] + [self.eos_token]
        return torch.tensor(tokens)
    
    def encode_language(self, language: str) -> torch.Tensor:
        return torch.tensor([self.language_codes.get(language, 0)])
    
    def decode(self, tokens: torch.Tensor) -> str:
        text = []
        for token in tokens:
            if token == self.eos_token:
                break
            if token.item() in self.idx_to_char:
                text.append(self.idx_to_char[token.item()])
        return ''.join(text)

class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)

    def forward(self, x: torch.Tensor, kv: torch.Tensor = None, 
                mask: torch.Tensor = None) -> torch.Tensor:
        if kv is None:
            kv = x

        q = self.query(x)
        k = self.key(kv)
        v = self.value(kv)

        head_dim = q.size(-1) // self.n_head
        q = q.view(*q.shape[:-1], self.n_head, head_dim).transpose(-3, -2)
        k = k.view(*k.shape[:-1], self.n_head, head_dim).transpose(-3, -2)
        v = v.view(*v.shape[:-1], self.n_head, head_dim).transpose(-3, -2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        out = out.transpose(-3, -2).contiguous()
        out = out.view(*out.shape[:-2], -1)
        
        return self.out(out)

class TransformerBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        super().__init__()
        self.cross_attention = cross_attention
        self.attn = MultiHeadAttention(n_state, n_head)
        if cross_attention:
            self.cross_attn = MultiHeadAttention(n_state, n_head)
            self.ln_cross = nn.LayerNorm(n_state)
        self.ln1 = nn.LayerNorm(n_state)
        self.ln2 = nn.LayerNorm(n_state)
        self.mlp = nn.Sequential(
            nn.Linear(n_state, 4 * n_state),
            nn.GELU(),
            nn.Linear(4 * n_state, n_state)
        )

    def forward(self, x: torch.Tensor, encoder_out: torch.Tensor = None, 
                mask: torch.Tensor = None) -> torch.Tensor:
        x = x + self.attn(self.ln1(x), mask=mask)
        if self.cross_attention and encoder_out is not None:
            x = x + self.cross_attn(self.ln_cross(x), encoder_out)
        x = x + self.mlp(self.ln2(x))
        return x

class WhisperEncoder(nn.Module):
    def __init__(self, n_mels: int = 80, n_ctx: int = 1500, 
                 n_state: int = 512, n_head: int = 8, n_layer: int = 6):
        super().__init__()
        self.conv1 = nn.Conv1d(n_mels, n_state, 3, padding=1)
        self.conv2 = nn.Conv1d(n_state, n_state, 3, stride=2, padding=1)
        self.gelu = nn.GELU()
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
        self.blocks = nn.ModuleList([
            TransformerBlock(n_state, n_head) for _ in range(n_layer)
        ])
        self.ln = nn.LayerNorm(n_state)
        
        # Initialize weights
        torch.nn.init.normal_(self.positional_embedding, mean=0.0, std=0.02)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.gelu(x)
        x = self.conv2(x)
        x = self.gelu(x)

        x = x.transpose(1, 2)
        x = x + self.positional_embedding[:x.shape[1], :]

        for block in self.blocks:
            x = block(x)

        x = self.ln(x)
        return x

class WhisperDecoder(nn.Module):
    def __init__(self, n_vocab: int = 51865, n_ctx: int = 448,
                 n_state: int = 512, n_head: int = 8, n_layer: int = 6):
        super().__init__()
        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
        self.blocks = nn.ModuleList([
            TransformerBlock(n_state, n_head, cross_attention=True)
            for _ in range(n_layer)
        ])
        self.ln = nn.LayerNorm(n_state)
        self.fc = nn.Linear(n_state, n_vocab, bias=False)
        
        # Tie weights with embedding
        self.fc.weight = self.token_embedding.weight
        
        # Initialize weights
        torch.nn.init.normal_(self.positional_embedding, mean=0.0, std=0.02)

    def forward(self, x: torch.Tensor, encoder_out: torch.Tensor,
                mask: torch.Tensor = None) -> torch.Tensor:
        # Ensure input sequence length doesn't exceed positional embedding size
        x = self.token_embedding(x[:, :self.positional_embedding.size(0)])
        x = x + self.positional_embedding[:x.shape[1], :]

        for block in self.blocks:
            x = block(x, encoder_out, mask)

        x = self.ln(x)
        x = self.fc(x)
        return x

class WhisperModel(nn.Module):
    def __init__(self, n_mels: int = 80, n_vocab: int = 51865, 
                 n_state: int = 512, n_head: int = 8, n_layer: int = 6):
        super().__init__()
        self.encoder = WhisperEncoder(n_mels, n_state=n_state, 
                                    n_head=n_head, n_layer=n_layer)
        self.decoder = WhisperDecoder(n_vocab, n_state=n_state, 
                                    n_head=n_head, n_layer=n_layer)
        
        # Task tokens are handled by the tokenizer
        self.tokenizer = SimpleTokenizer(n_vocab)

    def forward(self, audio_features: torch.Tensor, text_tokens: torch.Tensor,
                task_type: str = 'transcribe') -> torch.Tensor:
        encoder_out = self.encoder(audio_features)
        
        # Add task token to the beginning of text tokens and ensure proper length
        task_token = self.tokenizer.task_tokens[task_type]
        task_tensor = torch.tensor([[task_token]]).expand(text_tokens.shape[0], 1).to(text_tokens.device)
        
        # Truncate text_tokens if needed to ensure final sequence length is 448
        max_text_len = 447  # 448 - 1 (task token)
        if text_tokens.size(1) > max_text_len:
            text_tokens = text_tokens[:, :max_text_len]
        
        decoder_input = torch.cat([task_tensor, text_tokens], dim=1)
        logits = self.decoder(decoder_input, encoder_out)
        return logits

def collate_fn(batch):
    """Custom collate function to handle variable length sequences"""
    
    # Find max lengths in the batch
    max_audio_len = max(b['audio_features'].shape[1] for b in batch)
    max_text_len = min(448, max(b['text_tokens'].shape[0] for b in batch))  # Limit to 448
    
    # Initialize tensors
    batch_size = len(batch)
    audio_features = torch.zeros(batch_size, 80, max_audio_len)
    text_tokens = torch.zeros(batch_size, max_text_len).long()
    
    # Fill in the tensors with padded data
    for i, sample in enumerate(batch):
        # Audio features
        audio = sample['audio_features']
        audio_len = audio.shape[1]
        audio_features[i, :, :audio_len] = audio
        
        # Text tokens (truncate if needed)
        text = sample['text_tokens'][:max_text_len]
        text_len = text.shape[0]
        text_tokens[i, :text_len] = text
    
    return {
        'audio_features': audio_features,
        'text_tokens': text_tokens,
        'transcript': [b['transcript'] for b in batch]
    }

class WhisperTrainer:
    def __init__(self, model: WhisperModel, learning_rate: float = 1e-4):
        self.model = model
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=100000, eta_min=learning_rate/10
        )
        
    def train_step(self, batch: dict) -> float:
        self.optimizer.zero_grad()
        
        # Move batch to same device as model
        device = next(self.model.parameters()).device
        audio_features = batch['audio_features'].to(device)
        text_tokens = batch['text_tokens'].to(device)
        
        # Forward pass
        logits = self.model(audio_features, text_tokens, 'transcribe')
        
        # Calculate loss (ignore task token in target)
        loss = F.cross_entropy(
            logits[:, 1:].reshape(-1, logits.size(-1)),
            text_tokens.reshape(-1),
            ignore_index=0
        )
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        self.scheduler.step()
        
        return loss.item()


In [10]:

class WhisperTrainer:
    def __init__(self, model: WhisperModel, learning_rate: float = 1e-4):
        self.model = model
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=100000, eta_min=learning_rate/10
        )
        
    def calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor, 
                      task_type: str) -> torch.Tensor:
        if task_type in ['transcribe', 'translate']:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=0  # Ignore padding tokens
            )
        elif task_type == 'language_id':
            loss = F.cross_entropy(logits[:, 0, :], targets[:, 0])
        elif task_type == 'no_speech':
            loss = F.binary_cross_entropy_with_logits(
                logits[:, 0, self.model.tokenizer.task_tokens['no_speech']],
                targets.float()
            )
        else:
            raise ValueError(f"Unknown task type: {task_type}")
        return loss
        
    def train_step(self, batch: dict) -> float:
        self.optimizer.zero_grad()
        total_loss = 0
        
        # Process each task in the batch
        task_type = 'transcribe'  # For this example, we'll focus on transcription
        
        # Move batch to same device as model
        audio_features = batch['audio_features'].to(next(self.model.parameters()).device)
        text_tokens = batch['text_tokens'].to(next(self.model.parameters()).device)
        
        # Forward pass
        logits = self.model(audio_features, text_tokens, task_type)
        loss = self.calculate_loss(logits, text_tokens, task_type)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        self.scheduler.step()
        
        return loss.item()

def collate_fn(batch):
    """Custom collate function to handle variable length sequences"""
    
    # Find max lengths in the batch
    max_audio_len = max(b['audio_features'].shape[1] for b in batch)
    
    # Initialize tensors
    batch_size = len(batch)
    audio_features = torch.zeros(batch_size, 80, max_audio_len)
    text_tokens = torch.stack([b['text_tokens'] for b in batch])
    
    # Fill in the tensors with padded data
    for i, sample in enumerate(batch):
        audio = sample['audio_features']
        audio_len = audio.shape[1]
        audio_features[i, :, :audio_len] = audio
    
    return {
        'audio_features': audio_features,
        'text_tokens': text_tokens,
        'transcript': [b['transcript'] for b in batch]
    }

def train_whisper():
    print("Initializing Whisper training...")
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize model
    model = WhisperModel(
        n_mels=80,
        n_vocab=51865,
        n_state=512,
        n_head=8,
        n_layer=6
    ).to(device)
    
    # Create dataset and dataloader
    dataset = TinyLibriSpeechDataset()
    dataloader = DataLoader(
        dataset,
        batch_size=4,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=0
    )
    
    # Initialize trainer
    trainer = WhisperTrainer(model)
    
    # Training loop
    num_epochs = 10
    best_loss = float('inf')
    
    print("Starting training...")
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for batch_idx, batch in enumerate(progress_bar):
            loss = trainer.train_step(batch)
            epoch_loss += loss
            
            # Update progress bar
            progress_bar.set_postfix({'loss': f'{loss:.4f}'})
            
            # Save checkpoint if loss improved
            if batch_idx % 10 == 0:
                avg_loss = epoch_loss / (batch_idx + 1)
                if avg_loss < best_loss:
                    best_loss = avg_loss
                    print(f"\nSaving best model (loss: {best_loss:.4f})...")
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': trainer.optimizer.state_dict(),
                        'loss': best_loss,
                    }, 'whisper_best.pt')
        
        avg_epoch_loss = epoch_loss / len(dataloader)
        print(f'\nEpoch {epoch+1} completed, Average Loss: {avg_epoch_loss:.4f}')
    
    print("\nTraining completed!")
    return model

def inference(model: WhisperModel, audio_features: torch.Tensor, 
             tokenizer: SimpleTokenizer, device: torch.device):
    model.eval()
    with torch.no_grad():
        # Prepare initial tokens
        decoder_input = torch.tensor([[
            tokenizer.task_tokens['transcribe'],
            tokenizer.bos_token
        ]]).to(device)
        
        # Move audio features to device
        audio_features = audio_features.unsqueeze(0).to(device)
        
        # Encode audio
        encoder_out = model.encoder(audio_features)
        
        # Decode sequence
        generated_tokens = []
        max_length = 448
        
        for _ in range(max_length):
            logits = model.decoder(decoder_input, encoder_out)
            next_token = torch.argmax(logits[:, -1, :], dim=-1)
            
            if next_token.item() == tokenizer.eos_token:
                break
                
            generated_tokens.append(next_token.item())
            decoder_input = torch.cat([
                decoder_input,
                next_token.unsqueeze(0).unsqueeze(0)
            ], dim=1)
        
        # Decode tokens to text
        transcription = tokenizer.decode(torch.tensor(generated_tokens))
        return transcription

def main():
    try:
        # Train the model
        model = train_whisper()
        
        # Test inference
        print("\nTesting inference...")
        test_dataset = TinyLibriSpeechDataset()
        test_sample = test_dataset[0]
        
        device = next(model.parameters()).device
        transcription = inference(
            model,
            test_sample['audio_features'],
            model.tokenizer,
            device
        )
        
        print("\nTest Results:")
        print(f"Original: {test_sample['transcript']}")
        print(f"Predicted: {transcription}")
        
    except Exception as e:
        print(f"\nError during execution: {str(e)}")
        raise

if __name__ == "__main__":
    # Set longer timeout for downloads
    import datasets.config as config
    config.HF_DATASETS_HTTP_TIMEOUT = 1000
    
    main()


Initializing Whisper training...
Using device: cpu
Loading Tiny LibriSpeech dataset...
Dataset loaded with 73 samples
Starting training...


Epoch 1/10:   0%|          | 0/19 [00:00<?, ?it/s]


Error during execution: The size of tensor a (449) must match the size of tensor b (448) at non-singleton dimension 1





RuntimeError: The size of tensor a (449) must match the size of tensor b (448) at non-singleton dimension 1

In [11]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import math
import time

class SimpleTokenizer:
    def __init__(self, vocab_size=51865):
        self.vocab_size = vocab_size
        self.pad_token = 0
        self.eos_token = 1
        self.bos_token = 2
        
        # Task tokens start from end of vocabulary
        self.task_tokens = {
            'transcribe': vocab_size - 5,
            'translate': vocab_size - 4,
            'language_id': vocab_size - 3,
            'no_timestamps': vocab_size - 2,
            'no_speech': vocab_size - 1
        }
        
        # Simple character-level tokenization
        self.char_to_idx = {chr(i): i+10 for i in range(ord('a'), ord('z')+1)}
        self.char_to_idx.update({
            ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, "'": 8, '"': 9
        })
        self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
        self.language_codes = {'en': 0}
        
    def encode(self, text: str, max_length: int = 448) -> torch.Tensor:
        tokens = [self.char_to_idx.get(c, self.char_to_idx[' ']) for c in text.lower()]
        tokens = [self.bos_token] + tokens + [self.eos_token]
        
        if len(tokens) < max_length:
            tokens.extend([self.pad_token] * (max_length - len(tokens)))
        else:
            tokens = tokens[:max_length-1] + [self.eos_token]
        return torch.tensor(tokens)
    
    def decode(self, tokens: torch.Tensor) -> str:
        text = []
        for token in tokens:
            if token == self.eos_token:
                break
            if token.item() in self.idx_to_char:
                text.append(self.idx_to_char[token.item()])
        return ''.join(text)
    
    def encode_language(self, language: str) -> torch.Tensor:
        return torch.tensor([self.language_codes.get(language, 0)])

class TinyLibriSpeechDataset(Dataset):
    def __init__(self):
        print("Loading Tiny LibriSpeech dataset...")
        
        # Use the tiny dummy dataset with 'clean' config
        self.dataset = load_dataset(
            "patrickvonplaten/librispeech_asr_dummy",
            "clean",
            split="validation"
        )
        
        self.sample_rate = 16000
        self.tokenizer = SimpleTokenizer()
        
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sample_rate,
            n_mels=80,
            n_fft=2048,
            hop_length=160,
            win_length=400
        )
        print(f"Dataset loaded with {len(self.dataset)} samples")
    
    def process_audio(self, audio: np.ndarray) -> torch.Tensor:
        waveform = torch.from_numpy(audio).float()
        if len(waveform.shape) > 1:
            waveform = waveform.mean(dim=0)
        mel_spec = self.mel_transform(waveform)
        mel_spec = torch.log(mel_spec + 1e-9)
        return mel_spec
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        
        # Process audio
        audio_array = sample['audio']['array']
        audio_features = self.process_audio(audio_array)
        
        # Get transcript
        transcript = sample['text']
        
        # Convert to tokens
        text_tokens = self.tokenizer.encode(transcript)
        
        return {
            'audio_features': audio_features,
            'text_tokens': text_tokens,
            'transcript': transcript
        }

class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)

    def forward(self, x: torch.Tensor, kv: torch.Tensor = None,
                mask: torch.Tensor = None) -> torch.Tensor:
        if kv is None:
            kv = x

        q = self.query(x)
        k = self.key(kv)
        v = self.value(kv)

        head_dim = q.size(-1) // self.n_head
        q = q.view(*q.shape[:-1], self.n_head, head_dim).transpose(-3, -2)
        k = k.view(*k.shape[:-1], self.n_head, head_dim).transpose(-3, -2)
        v = v.view(*v.shape[:-1], self.n_head, head_dim).transpose(-3, -2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        out = out.transpose(-3, -2).contiguous()
        out = out.view(*out.shape[:-2], -1)
        
        return self.out(out)

class TransformerBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        super().__init__()
        self.cross_attention = cross_attention
        self.attn = MultiHeadAttention(n_state, n_head)
        if cross_attention:
            self.cross_attn = MultiHeadAttention(n_state, n_head)
            self.ln_cross = nn.LayerNorm(n_state)
        self.ln1 = nn.LayerNorm(n_state)
        self.ln2 = nn.LayerNorm(n_state)
        self.mlp = nn.Sequential(
            nn.Linear(n_state, 4 * n_state),
            nn.GELU(),
            nn.Linear(4 * n_state, n_state)
        )

    def forward(self, x: torch.Tensor, encoder_out: torch.Tensor = None,
                mask: torch.Tensor = None) -> torch.Tensor:
        x = x + self.attn(self.ln1(x), mask=mask)
        if self.cross_attention and encoder_out is not None:
            x = x + self.cross_attn(self.ln_cross(x), encoder_out)
        x = x + self.mlp(self.ln2(x))
        return x


In [16]:
class WhisperEncoder(nn.Module):
    def __init__(self, n_mels: int = 80, n_ctx: int = 1500, 
                 n_state: int = 512, n_head: int = 8, n_layer: int = 6):
        super().__init__()
        self.conv1 = nn.Conv1d(n_mels, n_state, 3, padding=1)
        self.conv2 = nn.Conv1d(n_state, n_state, 3, stride=2, padding=1)
        self.gelu = nn.GELU()
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
        self.blocks = nn.ModuleList([
            TransformerBlock(n_state, n_head) for _ in range(n_layer)
        ])
        self.ln = nn.LayerNorm(n_state)
        
        # Initialize weights
        torch.nn.init.normal_(self.positional_embedding, mean=0.0, std=0.02)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.gelu(x)
        x = self.conv2(x)
        x = self.gelu(x)

        x = x.transpose(1, 2)
        x = x + self.positional_embedding[:x.shape[1], :]

        for block in self.blocks:
            x = block(x)

        x = self.ln(x)
        return x

class WhisperDecoder(nn.Module):
    def __init__(self, n_vocab: int = 51865, n_ctx: int = 448,
                 n_state: int = 512, n_head: int = 8, n_layer: int = 6):
        super().__init__()
        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
        self.blocks = nn.ModuleList([
            TransformerBlock(n_state, n_head, cross_attention=True)
            for _ in range(n_layer)
        ])
        self.ln = nn.LayerNorm(n_state)
        self.fc = nn.Linear(n_state, n_vocab, bias=False)
        
        # Tie weights with embedding
        self.fc.weight = self.token_embedding.weight
        
        # Initialize weights
        torch.nn.init.normal_(self.positional_embedding, mean=0.0, std=0.02)

    def forward(self, x: torch.Tensor, encoder_out: torch.Tensor,
                mask: torch.Tensor = None) -> torch.Tensor:
        # Ensure input sequence length doesn't exceed positional embedding size
        x = self.token_embedding(x[:, :self.positional_embedding.size(0)])
        x = x + self.positional_embedding[:x.shape[1], :]

        for block in self.blocks:
            x = block(x, encoder_out, mask)

        x = self.ln(x)
        x = self.fc(x)
        return x

class WhisperModel(nn.Module):
    def __init__(self, n_mels: int = 80, n_vocab: int = 51865,
                 n_state: int = 512, n_head: int = 8, n_layer: int = 6):
        super().__init__()
        self.encoder = WhisperEncoder(n_mels, n_state=n_state,
                                    n_head=n_head, n_layer=n_layer)
        self.decoder = WhisperDecoder(n_vocab, n_state=n_state,
                                    n_head=n_head, n_layer=n_layer)
        
        # Task tokens are handled by the tokenizer
        self.tokenizer = SimpleTokenizer(n_vocab)
        
        # Initialize task embeddings
        self.task_tokens = nn.Parameter(torch.randn(len(self.tokenizer.task_tokens), n_state))

    def forward(self, audio_features: torch.Tensor, text_tokens: torch.Tensor,
                task_type: str = 'transcribe') -> torch.Tensor:
        encoder_out = self.encoder(audio_features)
        
        # Get task token embedding
        task_idx = self.tokenizer.task_tokens[task_type]
        task_embedding = self.task_tokens[task_idx - (self.tokenizer.vocab_size - len(self.tokenizer.task_tokens))]
        
        # Create decoder input sequence with task token
        batch_size = text_tokens.size(0)
        task_emb = task_embedding.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1)
        
        # Get text embeddings
        text_emb = self.decoder.token_embedding(text_tokens)
        
        # Combine task and text embeddings
        decoder_input = torch.cat([task_emb, text_emb], dim=1)
        
        # Add positional embeddings and decode
        decoder_input = decoder_input + self.decoder.positional_embedding[:decoder_input.size(1), :]
        
        # Run through decoder blocks
        for block in self.decoder.blocks:
            decoder_input = block(decoder_input, encoder_out)
            
        decoder_input = self.decoder.ln(decoder_input)
        logits = self.decoder.fc(decoder_input)
        
        return logits

class WhisperTrainer:
    def __init__(self, model: WhisperModel, learning_rate: float = 1e-4):
        self.model = model
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=100000, eta_min=learning_rate/10
        )
    
    def prepare_targets(self, text_tokens: torch.Tensor, task_type: str) -> torch.Tensor:
        """Prepare target sequence including task token"""
        batch_size = text_tokens.size(0)
        device = text_tokens.device
        
        # Get task token
        task_token = self.model.tokenizer.task_tokens[task_type]
        
        # Create targets with task token at the beginning
        targets = torch.cat([
            torch.full((batch_size, 1), task_token, device=device),
            text_tokens
        ], dim=1)
        
        return targets
        
    def train_step(self, batch: dict) -> float:
        self.optimizer.zero_grad()
        
        # Move batch to same device as model
        device = next(self.model.parameters()).device
        audio_features = batch['audio_features'].to(device)
        text_tokens = batch['text_tokens'].to(device)
        
        # Prepare targets with task token
        targets = self.prepare_targets(text_tokens, 'transcribe')
        
        # Forward pass
        logits = self.model(audio_features, text_tokens, 'transcribe')
        
        # Calculate loss including task token prediction
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            targets.view(-1),
            ignore_index=0  # Only ignore padding tokens
        )
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        self.scheduler.step()
        
        return loss.item()

def collate_fn(batch):
    """Custom collate function to handle variable length sequences"""
    
    # Find max lengths in the batch
    max_audio_len = max(b['audio_features'].shape[1] for b in batch)
    max_text_len = min(447, max(b['text_tokens'].shape[0] for b in batch))  # 447 to leave room for task token
    
    # Initialize tensors
    batch_size = len(batch)
    audio_features = torch.zeros(batch_size, 80, max_audio_len)
    text_tokens = torch.zeros(batch_size, max_text_len).long()
    
    # Fill in the tensors with padded data
    for i, sample in enumerate(batch):
        # Audio features
        audio = sample['audio_features']
        audio_len = audio.shape[1]
        audio_features[i, :, :audio_len] = audio
        
        # Text tokens (truncate if needed)
        text = sample['text_tokens'][:max_text_len]
        text_len = text.shape[0]
        text_tokens[i, :text_len] = text
    
    return {
        'audio_features': audio_features,
        'text_tokens': text_tokens,
        'transcript': [b['transcript'] for b in batch]
    }
import os
import tempfile
from pathlib import Path

def create_checkpoint_dir():
    """Create checkpoint directory in temporary directory"""
    # Use system's temp directory
    temp_dir = tempfile.gettempdir()
    checkpoint_dir = Path(temp_dir) / "whisper_checkpoints"
    checkpoint_dir.mkdir(exist_ok=True)
    return checkpoint_dir

def main():
    print("Initializing Whisper training...")
    
    # Create checkpoint directory
    checkpoint_dir = create_checkpoint_dir()
    print(f"Saving checkpoints to: {checkpoint_dir}")
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize model
    model = WhisperModel(
        n_mels=80,
        n_vocab=51865,
        n_state=512,
        n_head=8,
        n_layer=6
    ).to(device)
    
    # Create dataset and dataloader
    dataset = TinyLibriSpeechDataset()
    dataloader = DataLoader(
        dataset,
        batch_size=4,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=0
    )
    
    # Initialize trainer
    trainer = WhisperTrainer(model)
    
    # Training loop
    num_epochs = 10
    best_loss = float('inf')
    
    print("Starting training...")
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for batch_idx, batch in enumerate(progress_bar):
            try:
                loss = trainer.train_step(batch)
                epoch_loss += loss
                
                progress_bar.set_postfix({'loss': f'{loss:.4f}'})
                
                # Save checkpoint if loss improved
                if batch_idx % 10 == 0:
                    avg_loss = epoch_loss / (batch_idx + 1)
                    if avg_loss < best_loss:
                        best_loss = avg_loss
                        checkpoint_path = checkpoint_dir / f'whisper_best_epoch_{epoch}_batch_{batch_idx}.pt'
                        print(f"\nSaving best model (loss: {best_loss:.4f}) to {checkpoint_path}")
                        try:
                            torch.save({
                                'epoch': epoch,
                                'batch_idx': batch_idx,
                                'model_state_dict': model.state_dict(),
                                'optimizer_state_dict': trainer.optimizer.state_dict(),
                                'loss': best_loss,
                            }, checkpoint_path)
                        except Exception as e:
                            print(f"Error saving checkpoint: {str(e)}")
                            continue
                            
            except Exception as e:
                print(f"Error in training step: {str(e)}")
                continue
        
        avg_epoch_loss = epoch_loss / len(dataloader)
        print(f'\nEpoch {epoch+1} completed, Average Loss: {avg_epoch_loss:.4f}')
        
        # Save epoch checkpoint
        try:
            checkpoint_path = checkpoint_dir / f'whisper_epoch_{epoch}.pt'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': trainer.optimizer.state_dict(),
                'loss': avg_epoch_loss,
            }, checkpoint_path)
            print(f"Saved epoch checkpoint to {checkpoint_path}")
        except Exception as e:
            print(f"Error saving epoch checkpoint: {str(e)}")
    
    print("\nTraining completed!")
    return model

if __name__ == "__main__":
    try:
        # Set longer timeout for downloads
        import datasets.config as config
        config.HF_DATASETS_HTTP_TIMEOUT = 1000
        
        main()
    except Exception as e:
        print(f"Error during execution: {str(e)}")

Initializing Whisper training...
Saving checkpoints to: /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints
Using device: cpu
Loading Tiny LibriSpeech dataset...
Dataset loaded with 73 samples
Starting training...


Epoch 1/10:   0%|          | 0/19 [00:05<?, ?it/s, loss=0.6520]


Saving best model (loss: 0.6520) to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_0_batch_0.pt


Epoch 1/10:  53%|█████▎    | 10/19 [00:30<00:20,  2.28s/it, loss=0.0318]


Saving best model (loss: 0.1465) to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_0_batch_10.pt


Epoch 1/10: 100%|██████████| 19/19 [00:52<00:00,  2.78s/it, loss=0.0000]



Epoch 1 completed, Average Loss: 0.0848
Saved epoch checkpoint to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_epoch_0.pt


Epoch 2/10:  53%|█████▎    | 10/19 [00:24<00:17,  1.97s/it, loss=0.0000]


Saving best model (loss: 0.0410) to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_1_batch_10.pt


Epoch 2/10: 100%|██████████| 19/19 [00:40<00:00,  2.13s/it, loss=0.0000]



Epoch 2 completed, Average Loss: 0.0238
Saved epoch checkpoint to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_epoch_1.pt


Epoch 3/10:   0%|          | 0/19 [00:02<?, ?it/s, loss=0.0043]


Saving best model (loss: 0.0043) to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_2_batch_0.pt


Epoch 3/10:  53%|█████▎    | 10/19 [00:24<00:16,  1.84s/it, loss=0.0000]


Saving best model (loss: 0.0004) to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_2_batch_10.pt


Epoch 3/10: 100%|██████████| 19/19 [00:44<00:00,  2.35s/it, loss=0.0000]



Epoch 3 completed, Average Loss: 0.0002
Saved epoch checkpoint to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_epoch_2.pt


Epoch 4/10:   0%|          | 0/19 [00:01<?, ?it/s, loss=0.0000]


Saving best model (loss: 0.0000) to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_3_batch_0.pt


Epoch 4/10: 100%|██████████| 19/19 [00:38<00:00,  2.04s/it, loss=0.0000]



Epoch 4 completed, Average Loss: 0.0052
Saved epoch checkpoint to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_epoch_3.pt


Epoch 5/10:   0%|          | 0/19 [00:01<?, ?it/s, loss=0.0000]


Saving best model (loss: 0.0000) to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_4_batch_0.pt


Epoch 5/10: 100%|██████████| 19/19 [00:37<00:00,  2.00s/it, loss=0.0000]



Epoch 5 completed, Average Loss: 0.0000
Saved epoch checkpoint to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_epoch_4.pt


Epoch 6/10: 100%|██████████| 19/19 [00:39<00:00,  2.06s/it, loss=0.0000]



Epoch 6 completed, Average Loss: 0.0000
Saved epoch checkpoint to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_epoch_5.pt


Epoch 7/10: 100%|██████████| 19/19 [00:37<00:00,  1.98s/it, loss=0.0000]



Epoch 7 completed, Average Loss: 0.0000
Saved epoch checkpoint to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_epoch_6.pt


Epoch 8/10: 100%|██████████| 19/19 [00:38<00:00,  2.04s/it, loss=0.0000]



Epoch 8 completed, Average Loss: 0.0000
Saved epoch checkpoint to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_epoch_7.pt


Epoch 9/10: 100%|██████████| 19/19 [00:40<00:00,  2.12s/it, loss=0.0000]



Epoch 9 completed, Average Loss: 0.0000
Saved epoch checkpoint to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_epoch_8.pt


Epoch 10/10: 100%|██████████| 19/19 [00:39<00:00,  2.06s/it, loss=0.0000]



Epoch 10 completed, Average Loss: 0.0000
Saved epoch checkpoint to /var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_epoch_9.pt

Training completed!


In [None]:
/private/var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_4_batch_0.pt

Starting Whisper model testing...
Loading test dataset...


KeyboardInterrupt: 

In [20]:

def test_model_features():
    """Test different features of the model"""
    checkpoint_path = "/private/var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_4_batch_0.pt"
    tester = WhisperTester(checkpoint_path)
    
    print("Testing model features...")
    
    # Load a single test sample
    dataset = load_dataset("facebook/voxpopuli", "en", split="test[:1]")
    sample = dataset[0]
    audio_features = tester.process_audio(sample['audio']['array'])
    
    print("\n1. Basic Transcription Test")
    transcription = tester.transcribe(audio_features)
    print(f"Original: {sample['text']}")
    print(f"Predicted: {transcription}")
    
    print("\n2. Testing Attention Patterns")
    with torch.no_grad():
        # Get encoder-decoder attention weights
        audio_features = audio_features.unsqueeze(0).to(tester.device)
        encoder_out = tester.model.encoder(audio_features)
        
        # Get initial decoder input
        decoder_input = torch.tensor([[
            tester.model.tokenizer.task_tokens['transcribe'],
            tester.model.tokenizer.bos_token
        ]]).to(tester.device)
        
        # Get attention patterns from first decoder block
        first_block = tester.model.decoder.blocks[0]
        _ = first_block(
            tester.model.decoder.token_embedding(decoder_input),
            encoder_out
        )
        
        print("Encoder output shape:", encoder_out.shape)
        print("Decoder input shape:", decoder_input.shape)
    
    print("\n3. Testing Task Token Impact")
    # Test transcription with different task tokens
    tasks = ['transcribe', 'translate', 'language_id']
    for task in tasks:
        print(f"\nTesting {task} task:")
        try:
            with torch.no_grad():
                # Prepare input with different task token
                decoder_input = torch.tensor([[
                    tester.model.tokenizer.task_tokens[task],
                    tester.model.tokenizer.bos_token
                ]]).to(tester.device)
                
                # Get initial output
                logits = tester.model.decoder(decoder_input, encoder_out)
                print(f"Output logits shape: {logits.shape}")
                
        except Exception as e:
            print(f"Error testing {task}: {str(e)}")
    
    print("\nFeature testing completed!")

if __name__ == "__main__":
    # Run basic tests
    print("Running basic transcription tests...")
    test_model()
    
    # Run feature tests
    print("\nRunning feature tests...")
    test_model_features()


Running basic transcription tests...
Starting Whisper model testing...
Loading test dataset...


KeyboardInterrupt: 

In [19]:

import torch
import torchaudio
from datasets import load_dataset
from pathlib import Path
import numpy as np
from tqdm import tqdm

class WhisperTester:
    def __init__(self, checkpoint_path, device=None):
        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
            
        print(f"Using device: {self.device}")
        
        self.model = self.load_model(checkpoint_path)
        self.model.eval()
        
        self.sample_rate = 16000
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sample_rate,
            n_mels=80,
            n_fft=2048,
            hop_length=160,
            win_length=400
        )
    
    def load_model(self, checkpoint_path):
        print(f"Loading model from {checkpoint_path}")
        model = WhisperModel(
            n_mels=80,
            n_vocab=51865,
            n_state=512,
            n_head=8,
            n_layer=6
        ).to(self.device)
        
        checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded model from epoch {checkpoint['epoch']}, loss: {checkpoint['loss']:.4f}")
        return model
    
    def process_audio(self, audio: np.ndarray) -> torch.Tensor:
        waveform = torch.from_numpy(audio).float()
        if len(waveform.shape) > 1:
            waveform = waveform.mean(dim=0)
        mel_spec = self.mel_transform(waveform)
        mel_spec = torch.log(mel_spec + 1e-9)
        return mel_spec
    
    def transcribe(self, audio_features: torch.Tensor) -> str:
        self.model.eval()
        with torch.no_grad():
            # Ensure audio features have batch dimension
            if len(audio_features.shape) == 2:
                audio_features = audio_features.unsqueeze(0)
            audio_features = audio_features.to(self.device)
            
            # Get task token embedding
            task_idx = self.model.tokenizer.task_tokens['transcribe']
            task_embedding = self.model.task_tokens[task_idx - (self.model.tokenizer.vocab_size - 
                                                              len(self.model.tokenizer.task_tokens))]
            
            # Initialize decoder input with task token
            batch_size = audio_features.size(0)
            task_emb = task_embedding.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1)
            
            # Encode audio
            encoder_out = self.model.encoder(audio_features)
            
            # Initialize decoder sequence
            decoder_input = self.model.decoder.token_embedding(
                torch.tensor([[self.model.tokenizer.bos_token]], device=self.device)
            )
            decoder_input = torch.cat([task_emb, decoder_input], dim=1)
            
            # Generate sequence
            generated_tokens = []
            max_length = 100
            
            for _ in range(max_length):
                # Add positional embeddings
                pos_emb = self.model.decoder.positional_embedding[:decoder_input.size(1), :].unsqueeze(0)
                decoder_input = decoder_input + pos_emb
                
                # Run through decoder blocks
                x = decoder_input
                for block in self.model.decoder.blocks:
                    x = block(x, encoder_out)
                
                # Get next token
                x = self.model.decoder.ln(x)
                logits = self.model.decoder.fc(x)
                next_token = torch.argmax(logits[:, -1, :], dim=-1)
                
                if next_token.item() == self.model.tokenizer.eos_token:
                    break
                
                generated_tokens.append(next_token.item())
                
                # Prepare next iteration
                next_embedding = self.model.decoder.token_embedding(next_token).unsqueeze(1)
                decoder_input = torch.cat([decoder_input, next_embedding], dim=1)
            
            return self.model.tokenizer.decode(torch.tensor(generated_tokens))

def quick_test():
    print("Starting quick Whisper model test...")
    
    print("Loading minimal test dataset...")
    dataset = load_dataset(
        "patrickvonplaten/librispeech_asr_dummy",
        "clean",
        split="validation[:5]"
    )
    
    checkpoint_path = "/private/var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_4_batch_0.pt"
    tester = WhisperTester(checkpoint_path)
    
    print("\nTesting transcription on samples:")
    for idx, sample in enumerate(dataset):
        print(f"\nSample {idx + 1}/5:")
        try:
            audio_features = tester.process_audio(sample['audio']['array'])
            transcription = tester.transcribe(audio_features)
            
            print(f"Original: {sample['text']}")
            print(f"Predicted: {transcription}")
            
            original_words = set(sample['text'].lower().split())
            predicted_words = set(transcription.lower().split())
            common_words = original_words.intersection(predicted_words)
            wer = 1 - (len(common_words) / len(original_words))
            print(f"Word Error Rate: {wer:.2%}")
            
        except Exception as e:
            print(f"Error processing sample {idx}: {str(e)}")
            import traceback
            traceback.print_exc()
        
        print("-" * 50)
    
    print("\nQuick test completed!")

if __name__ == "__main__":
    quick_test()


Starting quick Whisper model test...
Loading minimal test dataset...
Using device: cpu
Loading model from /private/var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_4_batch_0.pt
Loaded model from epoch 4, loss: 0.0000

Testing transcription on samples:

Sample 1/5:
Original: A MAN SAID TO THE UNIVERSE SIR I EXIST
Predicted: 
Word Error Rate: 100.00%
--------------------------------------------------

Sample 2/5:
Original: SWEAT COVERED BRION'S BODY TRICKLING INTO THE TIGHT LOINCLOTH THAT WAS THE ONLY GARMENT HE WORE
Predicted: 
Word Error Rate: 100.00%
--------------------------------------------------

Sample 3/5:
Original: THE CUT ON HIS CHEST STILL DRIPPING BLOOD THE ACHE OF HIS OVERSTRAINED EYES EVEN THE SOARING ARENA AROUND HIM WITH THE THOUSANDS OF SPECTATORS WERE TRIVIALITIES NOT WORTH THINKING ABOUT
Predicted: 
Word Error Rate: 100.00%
--------------------------------------------------

Sample 4/5:
Original: HIS INSTANT OF PANIC WAS FOLLOWED

In [22]:

def test_model_features():
    """Quick test of different model features on small dataset"""
    print("Testing model features...")
    
    # Load checkpoint
    checkpoint_path = "/private/var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_4_batch_0.pt"
    tester = WhisperTester(checkpoint_path)
    
    # Load small test dataset
    print("\nLoading small test dataset...")
    dataset = load_dataset(
        "patrickvonplaten/librispeech_asr_dummy",
        "clean",
        split="validation[:10]"  # Use 10 samples
    )
    
    print("\n1. Testing Basic Transcription")
    for idx, sample in enumerate(dataset):
        print(f"\nSample {idx + 1}/10:")
        try:
            # Process audio
            audio_features = tester.process_audio(sample['audio']['array'])
            
            # Test transcription
            transcription = tester.transcribe(audio_features)
            print(f"Original: {sample['text']}")
            print(f"Predicted: {transcription}")
            
            # Simple WER
            original_words = set(sample['text'].lower().split())
            predicted_words = set(transcription.lower().split())
            common_words = original_words.intersection(predicted_words)
            wer = 1 - (len(common_words) / len(original_words))
            print(f"WER: {wer:.2%}")
            
            # Test model internals on first sample only
            if idx == 0:
                print("\n2. Testing Model Architecture")
                with torch.no_grad():
                    # Check encoder output
                    audio_features = audio_features.unsqueeze(0).to(tester.device)
                    encoder_out = tester.model.encoder(audio_features)
                    print(f"Encoder output shape: {encoder_out.shape}")
                    
                    # Check task token embeddings
                    task_idx = tester.model.tokenizer.task_tokens['transcribe']
                    task_emb = tester.model.task_tokens[task_idx - (tester.model.tokenizer.vocab_size - 
                                                                   len(tester.model.tokenizer.task_tokens))]
                    print(f"Task embedding shape: {task_emb.shape}")
                    
                    # Check attention patterns
                    print("\n3. Testing Attention Patterns")
                    first_block = tester.model.decoder.blocks[0]
                    decoder_input = tester.model.decoder.token_embedding(
                        torch.tensor([[tester.model.tokenizer.bos_token]], device=tester.device)
                    )
                    decoder_out = first_block(decoder_input, encoder_out)
                    print(f"Decoder block output shape: {decoder_out.shape}")
                    
                    # Test different tasks
                    print("\n4. Testing Different Tasks")
                    tasks = ['transcribe', 'translate', 'language_id']
                    for task in tasks:
                        print(f"\nTesting {task} task:")
                        try:
                            task_token = torch.tensor([[
                                tester.model.tokenizer.task_tokens[task]
                            ]], device=tester.device)
                            task_emb = tester.model.decoder.token_embedding(task_token)
                            out = first_block(task_emb, encoder_out)
                            print(f"Task {task} output shape: {out.shape}")
                        except Exception as e:
                            print(f"Error testing {task}: {str(e)}")
                
        except Exception as e:
            print(f"Error processing sample {idx}: {str(e)}")
        print("-" * 50)
    
    print("\nFeature testing completed!")

def main():
    try:
        print("Starting quick Whisper model test...")
        test_model_features()
    except Exception as e:
        print(f"Error during testing: {str(e)}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()


Starting quick Whisper model test...
Testing model features...
Using device: cpu
Loading model from /private/var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_4_batch_0.pt
Loaded model from epoch 4, loss: 0.0000

Loading small test dataset...

1. Testing Basic Transcription

Sample 1/10:
Original: A MAN SAID TO THE UNIVERSE SIR I EXIST
Predicted: 
WER: 100.00%

2. Testing Model Architecture
Encoder output shape: torch.Size([1, 233, 512])
Task embedding shape: torch.Size([512])

3. Testing Attention Patterns
Decoder block output shape: torch.Size([1, 1, 512])

4. Testing Different Tasks

Testing transcribe task:
Task transcribe output shape: torch.Size([1, 1, 512])

Testing translate task:
Task translate output shape: torch.Size([1, 1, 512])

Testing language_id task:
Task language_id output shape: torch.Size([1, 1, 512])
--------------------------------------------------

Sample 2/10:
Original: SWEAT COVERED BRION'S BODY TRICKLING INTO THE TIGHT LOI

In [23]:

def test_multitask_features():
    """Test different tasks and features of the model"""
    print("Testing model features and multi-task capabilities...")
    
    checkpoint_path = "/private/var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_4_batch_0.pt"
    tester = WhisperTester(checkpoint_path)
    
    # Load test dataset
    print("\nLoading test dataset...")
    dataset = load_dataset(
        "patrickvonplaten/librispeech_asr_dummy",
        "clean",
        split="validation[:10]"
    )
    
    # Test different tasks on each sample
    tasks = ['transcribe', 'translate', 'language_id']
    
    for idx, sample in enumerate(dataset):
        print(f"\nSample {idx + 1}/10:")
        print(f"Original text: {sample['text']}")
        
        try:
            # Process audio
            audio_features = tester.process_audio(sample['audio']['array'])
            
            # Test each task
            for task in tasks:
                print(f"\nTesting {task.upper()} task:")
                try:
                    with torch.no_grad():
                        # Prepare input
                        audio_features_batch = audio_features.unsqueeze(0).to(tester.device)
                        encoder_out = tester.model.encoder(audio_features_batch)
                        
                        # Get task token embedding
                        task_idx = tester.model.tokenizer.task_tokens[task]
                        task_embedding = tester.model.task_tokens[task_idx - (tester.model.tokenizer.vocab_size - 
                                                                           len(tester.model.tokenizer.task_tokens))]
                        
                        # Initialize decoder with task token
                        batch_size = 1
                        task_emb = task_embedding.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1)
                        
                        # Initialize decoder sequence
                        decoder_input = tester.model.decoder.token_embedding(
                            torch.tensor([[tester.model.tokenizer.bos_token]], device=tester.device)
                        )
                        decoder_input = torch.cat([task_emb, decoder_input], dim=1)
                        
                        # Generate sequence
                        generated_tokens = []
                        max_length = 100
                        
                        for _ in range(max_length):
                            # Add positional embeddings
                            pos_emb = tester.model.decoder.positional_embedding[:decoder_input.size(1), :].unsqueeze(0)
                            current_input = decoder_input + pos_emb
                            
                            # Run through decoder
                            x = current_input
                            for block in tester.model.decoder.blocks:
                                x = block(x, encoder_out)
                            
                            x = tester.model.decoder.ln(x)
                            logits = tester.model.decoder.fc(x)
                            next_token = torch.argmax(logits[:, -1, :], dim=-1)
                            
                            if next_token.item() == tester.model.tokenizer.eos_token:
                                break
                            
                            generated_tokens.append(next_token.item())
                            next_embedding = tester.model.decoder.token_embedding(next_token).unsqueeze(1)
                            decoder_input = torch.cat([decoder_input, next_embedding], dim=1)
                        
                        # Decode output based on task
                        output = tester.model.tokenizer.decode(torch.tensor(generated_tokens))
                        
                        if task == 'transcribe':
                            print(f"Transcription: {output}")
                            # Calculate WER
                            original_words = set(sample['text'].lower().split())
                            predicted_words = set(output.lower().split())
                            common_words = original_words.intersection(predicted_words)
                            wer = 1 - (len(common_words) / len(original_words))
                            print(f"WER: {wer:.2%}")
                            
                        elif task == 'translate':
                            print(f"Translation: {output}")
                            
                        elif task == 'language_id':
                            print(f"Detected language: {output}")
                            
                        # Show attention visualization for first sample
                        if idx == 0:
                            print(f"Output sequence length: {len(generated_tokens)}")
                            print(f"Encoder attention shape: {encoder_out.shape}")
                            print(f"Decoder attention shape: {x.shape}")
                            
                except Exception as e:
                    print(f"Error in {task} task: {str(e)}")
                    import traceback
                    traceback.print_exc()
            
        except Exception as e:
            print(f"Error processing sample {idx}: {str(e)}")
            import traceback
            traceback.print_exc()
        
        print("-" * 50)
    
    print("\nMulti-task testing completed!")

def main():
    try:
        print("Starting Whisper multi-task test...")
        test_multitask_features()
    except Exception as e:
        print(f"Error during testing: {str(e)}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()


Starting Whisper multi-task test...
Testing model features and multi-task capabilities...
Using device: cpu
Loading model from /private/var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_4_batch_0.pt
Loaded model from epoch 4, loss: 0.0000

Loading test dataset...

Sample 1/10:
Original text: A MAN SAID TO THE UNIVERSE SIR I EXIST

Testing TRANSCRIBE task:
Transcription: 
WER: 100.00%
Output sequence length: 100
Encoder attention shape: torch.Size([1, 233, 512])
Decoder attention shape: torch.Size([1, 101, 512])

Testing TRANSLATE task:
Translation: 
Output sequence length: 100
Encoder attention shape: torch.Size([1, 233, 512])
Decoder attention shape: torch.Size([1, 101, 512])

Testing LANGUAGE_ID task:
Detected language: 
Output sequence length: 100
Encoder attention shape: torch.Size([1, 233, 512])
Decoder attention shape: torch.Size([1, 101, 512])
--------------------------------------------------

Sample 2/10:
Original text: SWEAT COVERED BRION'

KeyboardInterrupt: 

In [24]:

def debug_model_output(tester, audio_features, task='transcribe'):
    """Debug function to inspect model's internal states"""
    print(f"\nDebugging {task} task:")
    
    with torch.no_grad():
        # 1. Check audio features
        audio_features = audio_features.unsqueeze(0).to(tester.device)
        print(f"Audio features shape: {audio_features.shape}")
        print(f"Audio features range: [{audio_features.min():.2f}, {audio_features.max():.2f}]")
        
        # 2. Check encoder output
        encoder_out = tester.model.encoder(audio_features)
        print(f"Encoder output shape: {encoder_out.shape}")
        print(f"Encoder output range: [{encoder_out.min():.2f}, {encoder_out.max():.2f}]")
        
        # 3. Check task token
        task_idx = tester.model.tokenizer.task_tokens[task]
        print(f"Task token index: {task_idx}")
        
        # 4. Check task embedding
        task_embedding = tester.model.task_tokens[task_idx - (tester.model.tokenizer.vocab_size - 
                                                           len(tester.model.tokenizer.task_tokens))]
        print(f"Task embedding shape: {task_embedding.shape}")
        print(f"Task embedding range: [{task_embedding.min():.2f}, {task_embedding.max():.2f}]")
        
        # 5. Check decoder input
        decoder_input = torch.tensor([[
            task_idx,
            tester.model.tokenizer.bos_token
        ]], device=tester.device)
        print(f"Initial decoder input: {decoder_input}")
        
        # 6. Check decoder embedding
        decoder_emb = tester.model.decoder.token_embedding(decoder_input)
        print(f"Decoder embedding shape: {decoder_emb.shape}")
        print(f"Decoder embedding range: [{decoder_emb.min():.2f}, {decoder_emb.max():.2f}]")
        
        # 7. Check decoder output
        logits = tester.model.decoder(decoder_input, encoder_out)
        print(f"Decoder logits shape: {logits.shape}")
        print(f"Logits range: [{logits.min():.2f}, {logits.max():.2f}]")
        
        # 8. Check top predictions
        top_probs, top_tokens = torch.topk(logits[0, -1], k=5)
        print("\nTop 5 predictions:")
        for prob, token in zip(top_probs, top_tokens):
            if token.item() in tester.model.tokenizer.idx_to_char:
                char = tester.model.tokenizer.idx_to_char[token.item()]
                print(f"Token {token.item()}: '{char}' (prob: {torch.softmax(prob, dim=0):.4f})")
            else:
                print(f"Token {token.item()} (prob: {torch.softmax(prob, dim=0):.4f})")
        
        return encoder_out, logits

def test_model():
    print("Starting model debugging...")
    
    checkpoint_path = "/private/var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_4_batch_0.pt"
    tester = WhisperTester(checkpoint_path)
    
    # Load a single sample for debugging
    dataset = load_dataset(
        "patrickvonplaten/librispeech_asr_dummy",
        "clean",
        split="validation[:1]"
    )
    
    sample = dataset[0]
    print(f"\nOriginal text: {sample['text']}")
    
    # Process audio
    audio_features = tester.process_audio(sample['audio']['array'])
    
    # Debug each task
    for task in ['transcribe', 'translate', 'language_id']:
        encoder_out, logits = debug_model_output(tester, audio_features, task)
        
        print(f"\nAttempting {task} generation:")
        try:
            # Try to generate a few tokens
            current_input = torch.tensor([[
                tester.model.tokenizer.task_tokens[task],
                tester.model.tokenizer.bos_token
            ]], device=tester.device)
            
            generated_tokens = []
            for _ in range(10):  # Try to generate first 10 tokens
                logits = tester.model.decoder(current_input, encoder_out)
                next_token = torch.argmax(logits[:, -1, :], dim=-1)
                
                if next_token.item() == tester.model.tokenizer.eos_token:
                    break
                    
                generated_tokens.append(next_token.item())
                current_input = torch.cat([
                    current_input,
                    next_token.unsqueeze(0).unsqueeze(0)
                ], dim=1)
                
                # Print each generated token
                if next_token.item() in tester.model.tokenizer.idx_to_char:
                    print(f"Generated token: {next_token.item()} -> '{tester.model.tokenizer.idx_to_char[next_token.item()]}'")
                else:
                    print(f"Generated token: {next_token.item()} (special token)")
            
        except Exception as e:
            print(f"Error in generation: {str(e)}")
            import traceback
            traceback.print_exc()
            
        print("-" * 50)

if __name__ == "__main__":
    test_model()


Starting model debugging...
Using device: cpu
Loading model from /private/var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_4_batch_0.pt
Loaded model from epoch 4, loss: 0.0000

Original text: A MAN SAID TO THE UNIVERSE SIR I EXIST

Debugging transcribe task:
Audio features shape: torch.Size([1, 80, 466])
Audio features range: [-13.49, 8.98]
Encoder output shape: torch.Size([1, 233, 512])
Encoder output range: [-3.29, 4.09]
Task token index: 51860
Task embedding shape: torch.Size([512])
Task embedding range: [-2.69, 2.90]
Initial decoder input: tensor([[51860,     2]])
Decoder embedding shape: torch.Size([1, 2, 512])
Decoder embedding range: [-3.29, 3.10]
Decoder logits shape: torch.Size([1, 2, 51865])
Logits range: [-107.76, 330.92]

Top 5 predictions:
Token 2 (prob: 1.0000)
Token 28525 (prob: 1.0000)
Token 25907 (prob: 1.0000)
Token 20667 (prob: 1.0000)
Token 50180 (prob: 1.0000)

Attempting transcribe generation:
Error in generation: Tensors must

Traceback (most recent call last):
  File "/var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/ipykernel_49321/2688510538.py", line 95, in test_model
    current_input = torch.cat([
                    ^^^^^^^^^^^
RuntimeError: Tensors must have same number of dimensions: got 2 and 3
Traceback (most recent call last):
  File "/var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/ipykernel_49321/2688510538.py", line 95, in test_model
    current_input = torch.cat([
                    ^^^^^^^^^^^
RuntimeError: Tensors must have same number of dimensions: got 2 and 3


Error in generation: Tensors must have same number of dimensions: got 2 and 3
--------------------------------------------------

Debugging language_id task:
Audio features shape: torch.Size([1, 80, 466])
Audio features range: [-13.49, 8.98]
Encoder output shape: torch.Size([1, 233, 512])
Encoder output range: [-3.29, 4.09]
Task token index: 51862
Task embedding shape: torch.Size([512])
Task embedding range: [-2.77, 3.63]
Initial decoder input: tensor([[51862,     2]])
Decoder embedding shape: torch.Size([1, 2, 512])
Decoder embedding range: [-3.08, 3.51]
Decoder logits shape: torch.Size([1, 2, 51865])
Logits range: [-99.45, 325.62]

Top 5 predictions:
Token 2 (prob: 1.0000)
Token 21385 (prob: 1.0000)
Token 28525 (prob: 1.0000)
Token 50180 (prob: 1.0000)
Token 1565 (prob: 1.0000)

Attempting language_id generation:
Error in generation: Tensors must have same number of dimensions: got 2 and 3
--------------------------------------------------


Traceback (most recent call last):
  File "/var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/ipykernel_49321/2688510538.py", line 95, in test_model
    current_input = torch.cat([
                    ^^^^^^^^^^^
RuntimeError: Tensors must have same number of dimensions: got 2 and 3


In [25]:

def debug_model_output(tester, audio_features, task='transcribe'):
    """Debug function with weight verification"""
    print(f"\nDebugging {task} task:")
    
    with torch.no_grad():
        # Verify model weights
        print("\nChecking model weights:")
        encoder_weights = sum(p.sum() for p in tester.model.encoder.parameters())
        decoder_weights = sum(p.sum() for p in tester.model.decoder.parameters())
        print(f"Encoder weights sum: {encoder_weights:.4f}")
        print(f"Decoder weights sum: {decoder_weights:.4f}")
        
        # 1. Check audio features
        audio_features = audio_features.unsqueeze(0).to(tester.device)
        print(f"\nAudio features:")
        print(f"Shape: {audio_features.shape}")
        print(f"Range: [{audio_features.min():.4f}, {audio_features.max():.4f}]")
        print(f"Mean: {audio_features.mean():.4f}")
        
        # 2. Check encoder
        encoder_out = tester.model.encoder(audio_features)
        print(f"\nEncoder output:")
        print(f"Shape: {encoder_out.shape}")
        print(f"Range: [{encoder_out.min():.4f}, {encoder_out.max():.4f}]")
        print(f"Mean: {encoder_out.mean():.4f}")
        
        # 3. Verify tokenizer and embeddings
        print(f"\nTokenizer check:")
        vocab_size = len(tester.model.tokenizer.char_to_idx) + 10  # Special tokens
        print(f"Vocabulary size: {vocab_size}")
        print(f"Task token index: {tester.model.tokenizer.task_tokens[task]}")
        
        # 4. Check decoder embeddings
        decoder_input = torch.tensor([[
            tester.model.tokenizer.task_tokens[task],
            tester.model.tokenizer.bos_token
        ]], device=tester.device)
        
        print(f"\nDecoder embeddings:")
        emb = tester.model.decoder.token_embedding(decoder_input)
        print(f"Shape: {emb.shape}")
        print(f"Range: [{emb.min():.4f}, {emb.max():.4f}]")
        print(f"Mean: {emb.mean():.4f}")
        
        # 5. Check attention weights
        for i, block in enumerate(tester.model.decoder.blocks):
            print(f"\nDecoder block {i} attention:")
            if hasattr(block.attn, 'query'):
                q_weights = block.attn.query.weight
                print(f"Query weights range: [{q_weights.min():.4f}, {q_weights.max():.4f}]")
        
        # 6. Final output check
        logits = tester.model.decoder(decoder_input, encoder_out)
        print(f"\nFinal logits:")
        print(f"Shape: {logits.shape}")
        print(f"Range: [{logits.min():.4f}, {logits.max():.4f}]")
        print(f"Mean: {logits.mean():.4f}")
        
        # 7. Token distribution
        probs = F.softmax(logits[0, -1], dim=-1)
        top_probs, top_tokens = torch.topk(probs, k=5)
        print("\nTop 5 token probabilities:")
        for prob, token in zip(top_probs, top_tokens):
            if token.item() in tester.model.tokenizer.idx_to_char:
                char = tester.model.tokenizer.idx_to_char[token.item()]
                print(f"Token {token.item()} ('{char}'): {prob:.4f}")
            else:
                print(f"Token {token.item()}: {prob:.4f}")
        
        return encoder_out, logits

def test_model():
    print("Starting detailed model debugging...")
    
    checkpoint_path = "/private/var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_0_batch_0.pt"
    
    # Load checkpoint directly to inspect
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    print("\nCheckpoint contents:")
    for key, value in checkpoint.items():
        if isinstance(value, torch.Tensor):
            print(f"{key}: tensor of shape {value.shape}")
        elif isinstance(value, dict):
            print(f"{key}: dictionary with {len(value)} items")
        else:
            print(f"{key}: {value}")
    
    # Initialize tester
    tester = WhisperTester(checkpoint_path)
    
    # Load test sample
    dataset = load_dataset(
        "patrickvonplaten/librispeech_asr_dummy",
        "clean",
        split="validation[:1]"
    )
    
    sample = dataset[0]
    print(f"\nTest sample text: {sample['text']}")
    
    # Process and debug
    audio_features = tester.process_audio(sample['audio']['array'])
    
    for task in ['transcribe']:  # Start with just transcription
        encoder_out, logits = debug_model_output(tester, audio_features, task)
        
        print(f"\nTesting token generation for {task}:")
        current_input = torch.tensor([[
            tester.model.tokenizer.task_tokens[task],
            tester.model.tokenizer.bos_token
        ]], device=tester.device)
        
        print("\nToken generation sequence:")
        for i in range(10):
            with torch.no_grad():
                logits = tester.model.decoder(current_input, encoder_out)
                probs = F.softmax(logits[0, -1], dim=-1)
                
                # Print probability distribution
                print(f"\nStep {i+1}:")
                top_probs, top_tokens = torch.topk(probs, k=3)
                for prob, token in zip(top_probs, top_tokens):
                    if token.item() in tester.model.tokenizer.idx_to_char:
                        char = tester.model.tokenizer.idx_to_char[token.item()]
                        print(f"  Token {token.item()} ('{char}'): {prob:.4f}")
                    else:
                        print(f"  Token {token.item()}: {prob:.4f}")
                
                next_token = torch.argmax(logits[0, -1])
                if next_token.item() == tester.model.tokenizer.eos_token:
                    print("  Generated EOS token, stopping.")
                    break
                
                current_input = torch.cat([
                    current_input,
                    next_token.unsqueeze(0).unsqueeze(0)
                ], dim=1)

if __name__ == "__main__":
    test_model()


Starting detailed model debugging...


  checkpoint = torch.load(checkpoint_path, map_location='cpu')



Checkpoint contents:
epoch: 0
batch_idx: 0
model_state_dict: dictionary with 265 items
optimizer_state_dict: dictionary with 2 items
loss: 0.651986837387085
Using device: cpu
Loading model from /private/var/folders/nt/mw9vwj4s4d341b0fpfl28ykr0000gn/T/whisper_checkpoints/whisper_best_epoch_0_batch_0.pt
Loaded model from epoch 0, loss: 0.6520

Test sample text: A MAN SAID TO THE UNIVERSE SIR I EXIST

Debugging transcribe task:

Checking model weights:
Encoder weights sum: 6553.4707
Decoder weights sum: 8322.0371

Audio features:
Shape: torch.Size([1, 80, 466])
Range: [-13.4864, 8.9791]
Mean: -4.2057

Encoder output:
Shape: torch.Size([1, 233, 512])
Range: [-3.6061, 4.4856]
Mean: -0.0000

Tokenizer check:
Vocabulary size: 43
Task token index: 51860

Decoder embeddings:
Shape: torch.Size([1, 2, 512])
Range: [-3.2855, 3.1024]
Mean: -0.0240

Decoder block 0 attention:
Query weights range: [-0.0443, 0.0443]

Decoder block 1 attention:
Query weights range: [-0.0443, 0.0443]

Decoder block 2 a

In [28]:

import torch
import torchaudio
from datasets import load_dataset
import numpy as np
from tqdm import tqdm



def process_audio(audio_array: np.ndarray) -> torch.Tensor:
    """Process audio array into mel spectrogram"""
    # Convert to tensor
    waveform = torch.from_numpy(audio_array).float()
    if len(waveform.shape) > 1:
        waveform = waveform.mean(dim=0)
        
    # Setup mel spectrogram transform
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=16000,
        n_mels=80,
        n_fft=2048,
        hop_length=160,
        win_length=400
    )
    
    # Compute mel spectrogram
    mel_spec = mel_transform(waveform)
    mel_spec = torch.log(mel_spec + 1e-9)
    
    return mel_spec

def test_multitask_features():
    """Demonstrate how Whisper handles different tasks with the same architecture"""
    print("Testing Whisper's multi-task capabilities...")
    
    # Initialize model
    model = WhisperModel(
        n_mels=80,
        n_vocab=51865,
        n_state=512,
        n_head=8,
        n_layer=6
    )
    
    # Load a single test sample
    dataset = load_dataset(
        "patrickvonplaten/librispeech_asr_dummy",
        "clean",
        split="validation[:1]"
    )
    sample = dataset[0]
    print(f"\nTest sample text: {sample['text']}")
    
    # Process audio
    audio_features = process_audio(sample['audio']['array'])
    
    # Test each task's processing pipeline
    tasks = ['transcribe', 'translate', 'language_id']
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    for task in tasks:
        print(f"\n=== Testing {task.upper()} task ===")
        
        try:
            with torch.no_grad():
                # 1. Encode audio (shared for all tasks)
                audio_features_batch = audio_features.unsqueeze(0).to(device)
                print(f"Audio features shape: {audio_features_batch.shape}")
                
                encoder_out = model.encoder(audio_features_batch)
                print(f"Encoder output shape: {encoder_out.shape}")
                
                # 2. Task-specific processing
                task_idx = model.tokenizer.task_tokens[task]
                task_embedding = model.task_tokens[task_idx - (model.tokenizer.vocab_size - 
                                                           len(model.tokenizer.task_tokens))]
                print(f"Task token index: {task_idx}")
                print(f"Task embedding shape: {task_embedding.shape}")
                
                # 3. Initialize decoder with task token
                task_emb = task_embedding.unsqueeze(0).unsqueeze(0)
                decoder_input = model.decoder.token_embedding(
                    torch.tensor([[model.tokenizer.bos_token]], device=device)
                )
                decoder_input = torch.cat([task_emb, decoder_input], dim=1)
                print(f"Initial decoder input shape: {decoder_input.shape}")
                
                # 4. Generate sequence with task-specific behavior
                print("\nGeneration process:")
                generated_tokens = []
                for step in range(5):  # Show first 5 steps
                    # Add positional embeddings
                    pos_emb = model.decoder.positional_embedding[:decoder_input.size(1), :].unsqueeze(0)
                    current_input = decoder_input + pos_emb
                    
                    # Run through decoder blocks
                    x = current_input
                    for block_idx, block in enumerate(model.decoder.blocks):
                        x = block(x, encoder_out)
                        if step == 0:  # Show shapes for first step
                            print(f"Block {block_idx + 1} output shape: {x.shape}")
                    
                    # Get next token
                    x = model.decoder.ln(x)
                    logits = model.decoder.fc(x)
                    next_token = torch.argmax(logits[:, -1, :], dim=-1)
                    
                    # Show token info
                    print(f"\nStep {step + 1}:")
                    print(f"Logits shape: {logits.shape}")
                    print(f"Selected token: {next_token.item()}")
                    
                    if next_token.item() == model.tokenizer.eos_token:
                        print("Generated EOS token, stopping.")
                        break
                    
                    generated_tokens.append(next_token.item())
                    next_embedding = model.decoder.token_embedding(next_token).unsqueeze(1)
                    decoder_input = torch.cat([decoder_input, next_embedding], dim=1)
                
                # 5. Task-specific output processing
                output = model.tokenizer.decode(torch.tensor(generated_tokens))
                print(f"\nTask output: {output}")
                
                if task == 'transcribe':
                    print("Calculating WER for transcription:")
                    original_words = set(sample['text'].lower().split())
                    predicted_words = set(output.lower().split())
                    common_words = original_words.intersection(predicted_words)
                    wer = 1 - (len(common_words) / len(original_words))
                    print(f"WER: {wer:.2%}")
                    
        except Exception as e:
            print(f"Error in {task} task: {str(e)}")
            import traceback
            traceback.print_exc()
        
        print("-" * 50)
    
    print("\nMulti-task demonstration completed!")

def main():
    test_multitask_features()

if __name__ == "__main__":
    main()

Testing Whisper's multi-task capabilities...

Test sample text: A MAN SAID TO THE UNIVERSE SIR I EXIST

=== Testing TRANSCRIBE task ===
Audio features shape: torch.Size([1, 80, 466])
Encoder output shape: torch.Size([1, 233, 512])
Task token index: 51860
Task embedding shape: torch.Size([512])
Initial decoder input shape: torch.Size([1, 2, 512])

Generation process:
Block 1 output shape: torch.Size([1, 2, 512])
Block 2 output shape: torch.Size([1, 2, 512])
Block 3 output shape: torch.Size([1, 2, 512])
Block 4 output shape: torch.Size([1, 2, 512])
Block 5 output shape: torch.Size([1, 2, 512])
Block 6 output shape: torch.Size([1, 2, 512])

Step 1:
Logits shape: torch.Size([1, 2, 51865])
Selected token: 2

Step 2:
Logits shape: torch.Size([1, 3, 51865])
Selected token: 2

Step 3:
Logits shape: torch.Size([1, 4, 51865])
Selected token: 2

Step 4:
Logits shape: torch.Size([1, 5, 51865])
Selected token: 2

Step 5:
Logits shape: torch.Size([1, 6, 51865])
Selected token: 2

Task output: 
Calcu

In [30]:
!pip install openai-whisper

[0mCollecting openai-whisper
  Downloading openai-whisper-20240930.tar.gz (800 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m800.5/800.5 kB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting more-itertools (from openai-whisper)
  Downloading more_itertools-10.5.0-py3-none-any.whl.metadata (36 kB)
Downloading more_itertools-10.5.0-py3-none-any.whl (60 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: openai-whisper
  Building wheel for openai-whisper (pyproject.toml) ... [?25ldone
[?25h  Created wheel for openai-whisper: filename=openai_whisper-20240930-py3-none-any.whl size=803319 sha256=cfc687a400cf0267ac3b402e2e2649dc11f01896bf6dd9fdb456e0ab5876d519
  Stor

In [32]:

import torch
import whisper
from datasets import load_dataset
import numpy as np

def test_whisper_multitask():
    """Demonstrate Whisper's multi-task capabilities using pretrained model"""
    print("Testing Whisper's multi-task capabilities using pretrained model...")
    
    # Load pretrained Whisper model
    model = whisper.load_model("small")
    print(f"Loaded Whisper model: {model.dims}")
    
    # Load test sample
    dataset = load_dataset(
        "patrickvonplaten/librispeech_asr_dummy",
        "clean",
        split="validation[:1]"
    )
    sample = dataset[0]
    print(f"\nTest sample text: {sample['text']}")
    
    # Convert audio to float32 numpy array and normalize
    audio = sample['audio']['array'].astype(np.float32)
    
    # Test different tasks
    print("\n1. TRANSCRIPTION TASK:")
    result = model.transcribe(audio)
    print(f"Transcription: {result['text']}")
    
    print("\n2. TRANSLATION TASK:")
    result = model.transcribe(audio, task="translate")
    print(f"Translation: {result['text']}")
    
    print("\n3. LANGUAGE DETECTION:")
    # Process audio using Whisper's utilities
    audio_input = whisper.pad_or_trim(audio)
    mel = whisper.log_mel_spectrogram(audio_input).to(model.device)
    
    # Detect language
    _, probs = model.detect_language(mel)
    detected_lang = max(probs, key=probs.get)
    print(f"Detected language: {detected_lang} (confidence: {probs[detected_lang]:.2%})")
    
    # Show detailed model operation
    print("\nDetailed Model Operation:")
    print("1. Audio Processing:")
    print(f"Mel spectrogram shape: {mel.shape}")
    
    # Properly prepare audio features
    audio_features = mel.unsqueeze(0)
    print(f"Audio features shape: {audio_features.shape}")
    
    print("\n2. Generation Process:")
    options = whisper.DecodingOptions(
        language="en",
        task="transcribe",
        fp16=False  # Ensure we use FP32 on CPU
    )
    
    with torch.no_grad():
        # Use the prepared mel spectrogram
        encoder_output = model.encoder(audio_features)
        print(f"Encoder output shape: {encoder_output.shape}")
        
        # Decode and show tokens
        result = whisper.decode(model, encoder_output, options)
        
        # Show token-by-token generation
        print("\nToken-by-token generation:")
        tokens = result.tokens
        token_probs = result.token_probs
        
        # Show first few tokens and their probabilities
        print("\nFirst few tokens:")
        for i in range(min(5, len(tokens))):
            token = model.tokenizer.decode_with_timestamps([tokens[i]])
            prob = token_probs[i] if i < len(token_probs) else 0.0
            print(f"Token: {token:15} Probability: {prob:.4f}")

def main():
    test_whisper_multitask()

if __name__ == "__main__":
    main()


Testing Whisper's multi-task capabilities using pretrained model...
Loaded Whisper model: ModelDimensions(n_mels=80, n_audio_ctx=1500, n_audio_state=768, n_audio_head=12, n_audio_layer=12, n_vocab=51865, n_text_ctx=448, n_text_state=768, n_text_head=12, n_text_layer=12)

Test sample text: A MAN SAID TO THE UNIVERSE SIR I EXIST

1. TRANSCRIPTION TASK:
Transcription:  A man said to the universe, Sir, I exist.

2. TRANSLATION TASK:
Translation:  A man said to the universe, Sir, I exist.

3. LANGUAGE DETECTION:
Detected language: en (confidence: 99.24%)

Detailed Model Operation:
1. Audio Processing:
Mel spectrogram shape: torch.Size([80, 3000])
Audio features shape: torch.Size([1, 80, 3000])

2. Generation Process:
Encoder output shape: torch.Size([1, 1500, 768])

Token-by-token generation:


AttributeError: 'list' object has no attribute 'tokens'

In [34]:

import torch
import whisper
from datasets import load_dataset
import numpy as np

def test_whisper_multitask():
    """Demonstrate Whisper's multi-task capabilities using pretrained model"""
    print("Testing Whisper's multi-task capabilities using pretrained model...")
    
    # Load pretrained Whisper model
    model = whisper.load_model("small")
    print(f"Loaded Whisper model: {model.dims}")
    
    # Load test sample
    dataset = load_dataset(
        "patrickvonplaten/librispeech_asr_dummy",
        "clean",
        split="validation[:1]"
    )
    sample = dataset[0]
    print(f"\nTest sample text: {sample['text']}")
    
    # Convert audio to float32 numpy array
    audio = sample['audio']['array'].astype(np.float32)
    
    # 1. Transcription
    print("\n=== TRANSCRIPTION TASK ===")
    transcribe_result = model.transcribe(audio)
    print(f"Transcription: {transcribe_result['text']}")
    print(f"Language: {transcribe_result['language']}")
    
    # Show segments
    print("\nSegments:")
    for idx, segment in enumerate(transcribe_result["segments"]):
        print(f"Segment {idx + 1}:")
        print(f"Text: {segment['text']}")
        print(f"Timestamp: {segment['start']:.2f}s - {segment['end']:.2f}s")
    
    # 2. Translation
    print("\n=== TRANSLATION TASK ===")
    translate_result = model.transcribe(audio, task="translate")
    print(f"Translation: {translate_result['text']}")
    
    # 3. Language Detection
    print("\n=== LANGUAGE DETECTION ===")
    audio_input = whisper.pad_or_trim(audio)
    mel = whisper.log_mel_spectrogram(audio_input).to(model.device)
    _, probs = model.detect_language(mel)
    detected_lang = max(probs, key=probs.get)
    print(f"Detected language: {detected_lang}")
    print("\nLanguage probabilities:")
    for lang, prob in sorted(probs.items(), key=lambda x: x[1], reverse=True)[:5]:
        print(f"{lang}: {prob:.2%}")
    
    # Show model internals
    print("\n=== MODEL INTERNALS ===")
    print("\n1. Audio Processing:")
    print(f"Input audio shape: {len(audio)}")
    print(f"Mel spectrogram shape: {mel.shape}")
    print(f"Model dimensions: {model.dims}")
    
    # Demonstrate multi-task processing
    print("\n2. Multi-task Processing:")
    tasks = ["transcribe", "translate"]
    
    for task in tasks:
        print(f"\nProcessing with task: {task}")
        result = model.transcribe(
            audio,
            task=task,
            temperature=0,  # Use greedy decoding
            verbose=True    # Show processing details
        )
        
        print(f"\nTask output:")
        print(f"Text: {result['text']}")
        print(f"Language: {result['language']}")
        
        # Show timing information
        if len(result['segments']) > 0:
            print("\nTiming analysis:")
            segment = result['segments'][0]
            print(f"Start time: {segment['start']:.2f}s")
            print(f"End time: {segment['end']:.2f}s")
            print(f"Average log probability: {segment['avg_logprob']:.2f}")
            if 'temperature' in segment:
                print(f"Temperature used: {segment['temperature']:.2f}")
            if 'compression_ratio' in segment:
                print(f"Compression ratio: {segment['compression_ratio']:.2f}")
            if 'no_speech_prob' in segment:
                print(f"No speech probability: {segment['no_speech_prob']:.2f}")

def main():
    test_whisper_multitask()

if __name__ == "__main__":
    main()


Testing Whisper's multi-task capabilities using pretrained model...
Loaded Whisper model: ModelDimensions(n_mels=80, n_audio_ctx=1500, n_audio_state=768, n_audio_head=12, n_audio_layer=12, n_vocab=51865, n_text_ctx=448, n_text_state=768, n_text_head=12, n_text_layer=12)

Test sample text: A MAN SAID TO THE UNIVERSE SIR I EXIST

=== TRANSCRIPTION TASK ===
Transcription:  A man said to the universe, Sir, I exist.
Language: en

Segments:
Segment 1:
Text:  A man said to the universe, Sir, I exist.
Timestamp: 0.00s - 4.24s

=== TRANSLATION TASK ===
Translation:  A man said to the universe, Sir, I exist.

=== LANGUAGE DETECTION ===
Detected language: en

Language probabilities:
en: 99.24%
ar: 0.20%
nn: 0.10%
la: 0.09%
cy: 0.09%

=== MODEL INTERNALS ===

1. Audio Processing:
Input audio shape: 74400
Mel spectrogram shape: torch.Size([80, 3000])
Model dimensions: ModelDimensions(n_mels=80, n_audio_ctx=1500, n_audio_state=768, n_audio_head=12, n_audio_layer=12, n_vocab=51865, n_text_ctx=448, n_