In [1]:
import os

# Environment configuration for performance optimization
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# Launch TensorBoard Session
from torch.utils.tensorboard import SummaryWriter

import torch

# Enable anomaly detection for debugging
torch.autograd.set_detect_anomaly(True)

# Import custom modules
from dataset import get_dataset, get_tokenizer
from transcribe_model import TranscribeModel
from torch import nn

# Vector Quantization Loss Scheduling
vq_initial_loss_weight = 10      # High initial weight for VQ loss
vq_warmup_steps = 1000          # Steps to gradually reduce VQ loss weight
vq_final_loss_weight = 0.5      # Final reduced weight for VQ loss

# Training Configuration
num_epochs = 1000               # Total training epochs
starting_steps = 0              # Resume from step 0 (fresh training)
num_examples = None             # Use entire dataset (no limit)
model_id = "test37"             # Experiment identifier
num_batch_repeats = 1           # Process each batch once

# Optimization Parameters
BATCH_SIZE = 64                 # Number of samples per batch
LEARNING_RATE = 0.005          # Step size for gradient updates


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
def run_loss_function(log_probs, target, blank_token):
    # Add log_softmax to ensure proper probability distribution
    
    loss_function = nn.CTCLoss(blank=blank_token)
    input_lengths = tuple(log_probs.shape[1] for _ in range(log_probs.shape[0]))
    target_lengths = (target != blank_token).sum(dim=1)
    target_lengths = tuple(t.item() for t in target_lengths)
    input_seq_first = log_probs.permute(1, 0, 2)
    loss = loss_function(input_seq_first, target, input_lengths, target_lengths)
    return loss


In [4]:
def greedy_decoder(log_probs, blank_token=0):
    """Improved greedy decoder for CTC outputs."""
    # Get the most likely token at each timestep
    predictions = torch.argmax(log_probs, dim=-1).cpu().numpy()
    decoded_predictions = []
    
    for pred in predictions:
        # Remove consecutive duplicates and blanks
        previous = -1
        decoded_seq = []
        for p in pred:
            if p != previous and p != blank_token:
                decoded_seq.append(p)
            previous = p
        decoded_predictions.append(decoded_seq)
    
    return decoded_predictions

def calculate_wer(predictions, references):
    """Calculate Word Error Rate between predictions and references."""
    try:
        return jiwer.wer(references, predictions)
    except:
        return 1.0
    
def calculate_cer(predictions, references):
    """Calculate Character Error Rate between predictions and references."""
    total_chars = sum(len(ref) for ref in references)
    total_edits = 0
    
    for pred, ref in zip(predictions, references):
        # Simple Levenshtein distance calculation
        dp = [[0] * (len(ref) + 1) for _ in range(len(pred) + 1)]
        
        for i in range(len(pred) + 1):
            dp[i][0] = i
        for j in range(len(ref) + 1):
            dp[0][j] = j
            
        for i in range(1, len(pred) + 1):
            for j in range(1, len(ref) + 1):
                if pred[i-1] == ref[j-1]:
                    dp[i][j] = dp[i-1][j-1]
                else:
                    dp[i][j] = min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1]) + 1
        
        total_edits += dp[len(pred)][len(ref)]
    
    return total_edits / total_chars if total_chars > 0 else 1.0

def evaluate_model(model, dataloader, tokenizer, device, blank_token, max_batches=5):
    model.eval()
    all_predictions = []
    all_references = []
    
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if i >= max_batches:
                break
            
            audio = batch["audio"].to(device)
            target = batch["input_ids"].to(device)
            text = batch["text"]
            
            # Forward pass with error handling
            try:
                output, _ = model(audio)
            except Exception as e:
                print(f"Model forward error: {e}")
                continue
            
            # Decode predictions with improved logic
            decoded_preds = greedy_decoder(output, blank_token=blank_token)
            
            # Convert token IDs to text - FIXED VERSION
            pred_texts = []
            for pred in decoded_preds:
                tokens = []
                for p in pred:
                    if 0 <= p < len(tokenizer.get_vocab()) and p != blank_token:
                        token = tokenizer.id_to_token(p)
                        # Only include actual character tokens
                        if token and len(token) == 1 and token.isalpha():
                            tokens.append(token)
                
                pred_text = "".join(tokens)
                pred_texts.append(pred_text)
            
            all_predictions.extend(pred_texts)
            all_references.extend(text)
    
    model.train()
    
    # Calculate metrics
    wer = calculate_wer(all_predictions, all_references) if all_predictions else 1.0
    cer = calculate_cer(all_predictions, all_references) if all_predictions else 1.0
    
    return {
        'wer': wer,
        'cer': cer,
        'num_samples': len(all_predictions),
        'predictions': all_predictions[:10],  # Show first 10 for debugging
        'references': all_references[:10]
    }
    
def print_evaluation_results(eval_results, step):
    """Print evaluation results in a nice format."""
    print("\n" + "="*80)
    print(f"EVALUATION RESULTS AT STEP {step}")
    print("="*80)
    print(f"Word Error Rate (WER): {eval_results['wer']:.4f}")
    print(f"Character Error Rate (CER): {eval_results['cer']:.4f}")
    print(f"Number of samples evaluated: {eval_results['num_samples']}")
    print("\nSAMPLE PREDICTIONS:")
    print("-"*80)
    
    for i, example in enumerate(eval_results['examples']):
        print(f"\nExample {i+1}:")
        print(f"Reference:  '{example['reference']}'")
        print(f"Prediction: '{example['prediction']}'")
        
        # Calculate individual WER for this example
        individual_wer = calculate_wer([example['prediction']], [example['reference']])
        print(f"Individual WER: {individual_wer:.4f}")
    
    print("="*80 + "\n")    

In [None]:
# def main():
#     # Set up logging directory with model ID
#     log_dir = f"runs/speech2text_training/{model_id}"
    
#     # Clean up previous logs if they exist
#     if os.path.exists(log_dir):
#         import shutil
#         shutil.rmtree(log_dir)
    
#     # Initialize TensorBoard writer for logging
#     writer = SummaryWriter(log_dir)
    
#     # Initialize tokenizer and get blank token ID
#     tokenizer = get_tokenizer()
#     blank_token = tokenizer.token_to_id("□")
    
#     # Device selection with fallback priority: CUDA > MPS > CPU
#     device = torch.device(
#         "cuda" 
#         if torch.cuda.is_available() 
#         else "mps" if torch.backends.mps.is_available() else "cpu"
#     )
#     print(f"Using device: {device}")
    
#     # Model loading or initialization
#     if os.path.exists(f"models/{model_id}/model_latest.pth"):
#         print(f"Loading model from models/{model_id}/model_latest.pth")
#         model = TranscribeModel.load(f"models/{model_id}/model_latest.pth").to(device)
#     else:
#         # Initialize new model with specified hyperparameters
#         model = TranscribeModel(
#             num_codebooks=2,
#             codebook_size=32,
#             embedding_dim=16,
#             num_transformer_layers=2,
#             vocab_size=len(tokenizer.get_vocab()),
#             strides=[6, 6, 6],  # Less aggressive downsampling
#             initial_mean_pooling_kernel_size=4,
#             max_seq_length=400,  # Reduced from 400
#         ).to(device)

#     num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
#     print(f"Number of trainable parameters: {num_trainable_params}")

#     optimizer = torch.optim.AdamW(
#         model.parameters(),
#         lr=LEARNING_RATE,
#     )

#     dataloader = get_dataset(
#         batch_size=BATCH_SIZE,
#         num_examples=num_examples,
#         num_workers=1,
#     )

#     # Initialize tracking variables
#     ctc_losses = []
#     vq_losses = []
#     num_batches = len(dataloader)
#     steps = starting_steps

#     # Main training loop
#     for i in range(num_epochs):
#         for idx, batch in enumerate(dataloader):
#             for repeat_batch in range(num_batch_repeats):
#                 # Extract batch components
#                 audio = batch["audio"]
#                 target = batch["input_ids"]
#                 text = batch["text"]
                
#                 # Handle sequence length mismatch
#                 if target.shape[1] > audio.shape[1]:
#                     print(
#                         "Padding audio, target is longer than audio. Audio Shape: ",
#                         audio.shape,
#                         "Target Shape: ",
#                         target.shape,
#                     )
#                     # Pad audio to match target length
#                     audio = torch.nn.functional.pad(
#                         audio, (0, 0, 0, target.shape[1] - audio.shape[1])
#                     )
#                     print("After padding: ", audio.shape)
                
#                 # Move tensors to device
#                 audio = audio.to(device)
#                 target = target.to(device)
# # Clear gradients from previous step
#                 optimizer.zero_grad()

#                 # Forward pass through the model
#                 output, vq_loss = model(audio)

#                 # Compute CTC loss for sequence alignment
#                 ctc_loss = run_loss_function(output, target, blank_token)

#                 # Calculate VQ loss weight using linear warmup schedule
#                 vq_loss_weight = max(
#                     vq_final_loss_weight,
#                     vq_initial_loss_weight
#                     - (vq_initial_loss_weight - vq_final_loss_weight)
#                     * (steps / vq_warmup_steps),
#                 )

#                 # Combine losses based on VQ availability
#                 if vq_loss is None:
#                     loss = ctc_loss
#                 else:
#                     loss = ctc_loss + vq_loss_weight * vq_loss

#                 # Skip training step if loss is infinite (numerical instability)
#                 if torch.isinf(loss):
#                     print("Loss is inf, skipping step", audio.shape, target.shape)
#                     continue

#                 # Backpropagation
#                 loss.backward()

#                 # Gradient clipping to prevent exploding gradients
#                 torch.nn.utils.clip_grad_norm_(
#                     model.parameters(), max_norm=10.0
#                 )
#                 optimizer.step()
                
#                 ctc_losses.append(ctc_loss.item())
#                 vq_losses.append(vq_loss.item())
#                 steps += 1
#                                 # Periodic logging and evaluation
#                 if steps % 20 == 0:
#                     avg_ctc_loss = sum(ctc_losses) / len(ctc_losses)
#                     avg_vq_loss = sum(vq_losses) / len(vq_losses)
#                     avg_loss = avg_ctc_loss + vq_loss_weight * avg_vq_loss
                    
#                     print(
#                         f"Epoch {i}, Batch {idx}, Step {steps}: "
#                         f"CTC Loss: {avg_ctc_loss:.4f}, "
#                         f"VQ Loss: {avg_vq_loss:.4f}, "
#                         f"Total Loss: {avg_loss:.4f}, "
#                         f"VQ Weight: {vq_loss_weight:.4f}"
#                     )
                    
#                     # Log to TensorBoard
#                     writer.add_scalar("Loss/CTC", avg_ctc_loss, steps)
#                     writer.add_scalar("Loss/VQ", avg_vq_loss, steps)
#                     writer.add_scalar("Loss/Total", avg_loss, steps)
#                     writer.add_scalar("Loss/VQ_Weight", vq_loss_weight, steps)
                    
#                     # Clear loss lists for next period
#                     ctc_losses = []
#                     vq_losses = []
                
#                 # Model checkpointing
#                 if steps % 500 == 0:
#                     checkpoint_dir = f"models/{model_id}"
#                     os.makedirs(checkpoint_dir, exist_ok=True)
                    
#                     # Save latest model
#                     model.save(f"{checkpoint_dir}/model_latest.pth")
                    
#                     # Save step-specific checkpoint
#                     model.save(f"{checkpoint_dir}/model_step_{steps}.pth")
                    
#                     print(f"Model saved at step {steps}")
                    
#     checkpoint_dir = f"models/{model_id}"
#     os.makedirs(checkpoint_dir, exist_ok=True)
#     model.save(f"{checkpoint_dir}/model_final.pth")
    
#     # Close TensorBoard writer
#     writer.close()
    
#     print("Training completed!")

In [12]:
def main():
    # Set up logging directory with model ID
    log_dir = f"runs/speech2text_training/{model_id}"
    
    # Clean up previous logs if they exist
    if os.path.exists(log_dir):
        import shutil
        shutil.rmtree(log_dir)
    
    # Initialize TensorBoard writer for logging
    writer = SummaryWriter(log_dir)
    
    # Initialize tokenizer and get blank token ID
    tokenizer = get_tokenizer()
    blank_token = tokenizer.token_to_id("□")
    
    # Device selection with fallback priority: CUDA > MPS > CPU
    device = torch.device(
        "cuda" 
        if torch.cuda.is_available() 
        else "mps" if torch.backends.mps.is_available() else "cpu"
    )
    print(f"Using device: {device}")
    
    # Model loading or initialization
    if os.path.exists(f"models/{model_id}/model_latest.pth"):
        print(f"Loading model from models/{model_id}/model_latest.pth")
        model = TranscribeModel.load(f"models/{model_id}/model_latest.pth").to(device)
    else:
        # Initialize new model with specified hyperparameters
        model = TranscribeModel(
            num_codebooks=2,
            codebook_size=32,
            embedding_dim=16,
            num_transformer_layers=2,
            vocab_size=len(tokenizer.get_vocab()),
            strides=[6, 6, 6],  # Less aggressive downsampling
            initial_mean_pooling_kernel_size=4,
            max_seq_length=400,  # Reduced from 400
        ).to(device)

    num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of trainable parameters: {num_trainable_params}")

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
    )
    num_examples = None 
    dataloader = get_dataset(
        batch_size=BATCH_SIZE,
        num_examples=num_examples,
        num_workers=1,
    )

    # Initialize tracking variables
    ctc_losses = []
    vq_losses = []
    num_batches = len(dataloader)
    steps = starting_steps

    # Main training loop
    for i in range(num_epochs):
        for idx, batch in enumerate(dataloader):
            for repeat_batch in range(num_batch_repeats):
                # Extract batch components
                audio = batch["audio"]
                target = batch["input_ids"]
                text = batch["text"]
                
                # Handle sequence length mismatch
                if target.shape[1] > audio.shape[1]:
                    print(
                        "Padding audio, target is longer than audio. Audio Shape: ",
                        audio.shape,
                        "Target Shape: ",
                        target.shape,
                    )
                    # Pad audio to match target length
                    audio = torch.nn.functional.pad(
                        audio, (0, 0, 0, target.shape[1] - audio.shape[1])
                    )
                    print("After padding: ", audio.shape)
                
                # Move tensors to device
                audio = audio.to(device)
                target = target.to(device)
# Clear gradients from previous step
                optimizer.zero_grad()

                # Forward pass through the model
                output, vq_loss = model(audio)

                # Compute CTC loss for sequence alignment
                ctc_loss = run_loss_function(output, target, blank_token)

                # Calculate VQ loss weight using linear warmup schedule
                vq_loss_weight = max(
                    vq_final_loss_weight,
                    vq_initial_loss_weight
                    - (vq_initial_loss_weight - vq_final_loss_weight)
                    * (steps / vq_warmup_steps),
                )

                # Combine losses based on VQ availability
                if vq_loss is None:
                    loss = ctc_loss
                else:
                    loss = ctc_loss + vq_loss_weight * vq_loss

                # Skip training step if loss is infinite (numerical instability)
                if torch.isinf(loss):
                    print("Loss is inf, skipping step", audio.shape, target.shape)
                    continue

                # Backpropagation
                loss.backward()

                # Gradient clipping to prevent exploding gradients
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), max_norm=10.0
                )
                optimizer.step()
                
                ctc_losses.append(ctc_loss.item())
                vq_losses.append(vq_loss.item())
                steps += 1
                                # Periodic logging and evaluation
                if steps % 20 == 0:
                    avg_ctc_loss = sum(ctc_losses) / len(ctc_losses)
                    avg_vq_loss = sum(vq_losses) / len(vq_losses)
                    avg_loss = avg_ctc_loss + vq_loss_weight * avg_vq_loss
                    
                    print(
                        f"Num Steps: {steps}, Batch: {idx}/{num_batches}, "
                        f"ctc_loss: {avg_ctc_loss:.3f}, vq_loss: {avg_vq_loss:.3f}, "
                        f"total_loss: {avg_loss:.3f}"
                    )
                    
                    # Log to TensorBoard
                    writer.add_scalar("Loss/CTC", avg_ctc_loss, steps)
                    writer.add_scalar("Loss/VQ", avg_vq_loss, steps)
                    writer.add_scalar("Loss/Total", avg_loss, steps)
                    writer.add_scalar("Loss/VQ_Weight", vq_loss_weight, steps)
                    
                    # Clear loss lists for next period
                    ctc_losses = []
                    vq_losses = []
                
                # **ADD EVALUATION AND TRANSCRIPTION EXAMPLES**
                if steps % 40 == 0:  # Every 40 steps, show transcription examples
                    print("\n" + "="*60)
                    print("Transcription Examples")
                    print("="*60)
                    
                    # Run evaluation on current batch
                    model.eval()
                    with torch.no_grad():
                        # Get model output for current batch
                        output, _ = model(audio)
                        
                        # Decode predictions using greedy decoder
                        decoded_preds = greedy_decoder(output, blank_token=blank_token)
                        
                        # Show first few examples
                        num_examples = min(4, len(text))
                        for ex_idx in range(num_examples):
                            # Convert prediction tokens to text
                            pred_tokens = []
                            for token_id in decoded_preds[ex_idx]:
                                if 0 <= token_id < len(tokenizer.get_vocab()) and token_id != blank_token:
                                    token = tokenizer.id_to_token(token_id)
                                    if token and len(token) == 1 and (token.isalpha() or token == " "):
                                        pred_tokens.append(token)
                            
                            pred_text = "".join(pred_tokens)
                            ground_truth = text[ex_idx]
                            
                            print(f"Example {ex_idx}:")
                            print(f"Model Output: {pred_text}")
                            print(f"Ground Truth: {ground_truth}")
                            print("-" * 40)
                    
                    model.train()  # Switch back to training mode
                    print("="*60 + "\n")
                
                # Model checkpointing
                if steps % 500 == 0:
                    checkpoint_dir = f"models/{model_id}"
                    os.makedirs(checkpoint_dir, exist_ok=True)
                    
                    # Save latest model
                    model.save(f"{checkpoint_dir}/model_latest.pth")
                    
                    # Save step-specific checkpoint
                    model.save(f"{checkpoint_dir}/model_step_{steps}.pth")
                    
                    print(f"Model saved at step {steps}")
                    
    checkpoint_dir = f"models/{model_id}"
    os.makedirs(checkpoint_dir, exist_ok=True)
    model.save(f"{checkpoint_dir}/model_final.pth")
    
    # Close TensorBoard writer
    writer.close()
    
    print("Training completed!")

In [7]:
# # Add to train.py to test the tokenizer fix:
# def test_tokenizer_filtering():
#     tokenizer = get_tokenizer()
    
#     # Test the filtering logic
#     test_tokens = ["<pad>", "<unk>", "H", "E", "L", "L", "O", "<□>"]
#     filtered = []
    
#     for token in test_tokens:
#         if token and token not in ["<pad>", "<s>", "</s>", "<unk>", "<mask>", "<blank>","<□>"]:
#             filtered.append(token)
    
#     print(f"Original tokens: {test_tokens}")
#     print(f"Filtered tokens: {filtered}")
#     print(f"Expected: ['H', 'E', 'L', 'L', 'O']")
    
#     if filtered == ['H', 'E', 'L', 'L', 'O']:
#         print("✅ Token filtering is working correctly!")
#     else:
#         print("❌ Token filtering still has issues")

# # Call this before training starts
# test_tokenizer_filtering()


In [None]:
if __name__ == "__main__":
    main()

Using device: cuda
Number of trainable parameters: 7532
Num Steps: 20, Batch: 19/21, ctc_loss: 5.318, vq_loss: 0.836, total_loss: 13.526
Num Steps: 40, Batch: 18/21, ctc_loss: 4.207, vq_loss: 1.383, total_loss: 17.524

Transcription Examples
Example 0:
Model Output: 
Ground Truth: IN MOST DISCUSSIONS OF THIS PHENOMENON  THE FIGURES ARE SUBSTANTIALLY INFLATED
----------------------------------------
Example 1:
Model Output: 
Ground Truth: BY EATING YOGURT  YOU MAY LIVE LONGER
----------------------------------------
Example 2:
Model Output: 
Ground Truth: HE CHUCKLED  THE MEMORY VIVID
----------------------------------------
Example 3:
Model Output: 
Ground Truth: ANY RETALIATORY GAS ATTACK WOULD BE AIRBORNE
----------------------------------------

Num Steps: 60, Batch: 17/21, ctc_loss: 3.101, vq_loss: 0.520, total_loss: 8.005
Num Steps: 80, Batch: 16/21, ctc_loss: 2.982, vq_loss: 0.232, total_loss: 5.126

Transcription Examples
Example 0:
Model Output: 
Ground Truth: DON T ASK ME TO C