In [7]:
import os
import jiwer
from tqdm import tqdm

os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import math
from torch.utils.tensorboard import SummaryWriter
import torch
from torch.optim.lr_scheduler import OneCycleLR
torch.autograd.set_detect_anomaly(True)
from dataset import get_dataset, get_tokenizer
from transcribe_model import TranscribeModel
from torch import nn

vq_initial_loss_weight = 0.01 
vq_warmup_steps = 1000
vq_final_loss_weight = 0.001
num_epochs = 1000
starting_steps = 0
num_examples = None
model_id = "test21"
num_batch_repeats = 1

starting_steps = 0
BATCH_SIZE = 64
LEARNING_RATE = 1e-4

In [8]:
def run_loss_function(log_probs, target, blank_token):
    loss_function = nn.CTCLoss(blank=blank_token)
    
    input_lengths = torch.full((log_probs.shape[0],), log_probs.shape[1], 
                              dtype=torch.long, device=log_probs.device)
    
    # Use torch.ne for element-wise comparison
    target_lengths = torch.ne(target, blank_token).sum(dim=1).to(torch.long)
    
    input_seq_first = log_probs.permute(1, 0, 2)
    loss = loss_function(input_seq_first, target, input_lengths, target_lengths)
    return loss

def safe_mean(losses):
    """Calculate mean safely, handling empty lists"""
    return sum(losses) / len(losses) if len(losses) > 0 else 0.0

In [9]:
def greedy_decoder(log_probs, blank_token=0):
    """Improved greedy decoder for CTC outputs."""
    # Get the most likely token at each timestep
    predictions = torch.argmax(log_probs, dim=-1).cpu().numpy()
    decoded_predictions = []
    
    for pred in predictions:
        # Remove consecutive duplicates and blanks
        previous = -1
        decoded_seq = []
        for p in pred:
            if p != previous and p != blank_token:
                decoded_seq.append(p)
            previous = p
        decoded_predictions.append(decoded_seq)
    
    return decoded_predictions

def calculate_wer(predictions, references):
    """Calculate Word Error Rate between predictions and references."""
    try:
        return jiwer.wer(references, predictions)
    except:
        return 1.0
    
def calculate_cer(predictions, references):
    """Calculate Character Error Rate between predictions and references."""
    total_chars = sum(len(ref) for ref in references)
    total_edits = 0
    
    for pred, ref in zip(predictions, references):
        # Simple Levenshtein distance calculation
        dp = [[0] * (len(ref) + 1) for _ in range(len(pred) + 1)]
        
        for i in range(len(pred) + 1):
            dp[i][0] = i
        for j in range(len(ref) + 1):
            dp[0][j] = j
            
        for i in range(1, len(pred) + 1):
            for j in range(1, len(ref) + 1):
                if pred[i-1] == ref[j-1]:
                    dp[i][j] = dp[i-1][j-1]
                else:
                    dp[i][j] = min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1]) + 1
        
        total_edits += dp[len(pred)][len(ref)]
    
    return total_edits / total_chars if total_chars > 0 else 1.0

def evaluate_model(model, dataloader, tokenizer, device, blank_token, max_batches=5):
    """Evaluate the model and return metrics with sample predictions."""
    model.eval()
    all_predictions = []
    all_references = []
    sample_examples = []
    
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if i >= max_batches:
                break
                
            audio = batch["audio"].to(device)
            target = batch["input_ids"].to(device)
            text = batch["text"]
            
            # Forward pass
            if audio.dim() == 2:
                audio = audio.unsqueeze(1)
            
            output, _ = model(audio)
            blank_token_id = blank_token
            # Decode predictions - use the correct blank token
            decoded_preds = greedy_decoder(output, blank_token=blank_token)
            
            # Convert token IDs to text - FIXED VERSION
            pred_texts = []
            for pred in decoded_preds:
                tokens = []
                for p in pred:
                    if p < len(tokenizer.get_vocab()) and p != blank_token_id:
                        token = tokenizer.id_to_token(p)
                        # Filter out ALL special tokens
                        if token and token not in ["<pad>", "<unk>", "<s>", "</s>", "<□>"]:
                            tokens.append(token)
                pred_text = "".join(tokens)
                pred_texts.append(pred_text)
                
            all_predictions.extend(pred_texts)     
            all_references.extend(text) 
            # Store first few examples for display
            if i < 3:
                for j, (pred, ref) in enumerate(zip(pred_texts, text)):
                    if len(sample_examples) < 6:
                        sample_examples.append({
                            'reference': ref,
                            'prediction': pred,
                            'batch': i,
                            'sample': j
                        })
    
    model.train()
    
    # Calculate metrics
    wer = calculate_wer(all_predictions, all_references)
    cer = calculate_cer(all_predictions, all_references)
    
    return {
        'wer': wer,
        'cer': cer,
        'num_samples': len(all_predictions),
        'examples': sample_examples
    }
    
def print_evaluation_results(eval_results, step):
    """Print evaluation results in a nice format."""
    print("\n" + "="*80)
    print(f"EVALUATION RESULTS AT STEP {step}")
    print("="*80)
    print(f"Word Error Rate (WER): {eval_results['wer']:.4f}")
    print(f"Character Error Rate (CER): {eval_results['cer']:.4f}")
    print(f"Number of samples evaluated: {eval_results['num_samples']}")
    print("\nSAMPLE PREDICTIONS:")
    print("-"*80)
    
    for i, example in enumerate(eval_results['examples']):
        print(f"\nExample {i+1}:")
        print(f"Reference:  '{example['reference']}'")
        print(f"Prediction: '{example['prediction']}'")
        
        # Calculate individual WER for this example
        individual_wer = calculate_wer([example['prediction']], [example['reference']])
        print(f"Individual WER: {individual_wer:.4f}")
    
    print("="*80 + "\n")    

In [10]:
def main():
    log_dir = f"runs/speech2text_training/{model_id}"
    if os.path.exists(log_dir):
        import shutil
        shutil.rmtree(log_dir)
    writer = SummaryWriter(log_dir)

    tokenizer = get_tokenizer()
    blank_token = tokenizer.token_to_id("<□>")

    device = torch.device(
        "cuda"
        if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available() else "cpu"
    )
    print(f"Using device: {device}")

    # Load or create model
    if os.path.exists(r"C:\Users\Kamil\Desktop\Coding\models\test21\model_step_3000.pth"):
        print(f"Loading model from models/{model_id}/model_step_3000.pth")
        model = TranscribeModel.load(r"C:\Users\Kamil\Desktop\Coding\models\test21\model_step_3000.pth").to(device)
    else:
        model = TranscribeModel(
            num_codebooks=4,        # Zwiększ z 2
            codebook_size=64,       # Zwiększ z 32
            embedding_dim=256,      # Zwiększ z 128
            num_transformer_layers=6, # Zwiększ z 3
            vocab_size=len(tokenizer.get_vocab()),  # DODANE - wymagane
            strides=[8, 8, 4],      # Bardziej agresywne downsampling
            initial_mean_pooling_kernel_size=2,     # DODANE - wymagane
            max_seq_length=400,     # Zmniejsz dla pamięci
        ).to(device)

    num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of trainable parameters: {num_trainable_params}")

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=0.01,
        betas=(0.9, 0.98)
    )

    dataloader = get_dataset(
        batch_size=BATCH_SIZE,
        num_examples=num_examples,
        num_workers=0,
    )

    scheduler = OneCycleLR(
        optimizer,
        max_lr=LEARNING_RATE,
        steps_per_epoch=len(dataloader),
        epochs=num_epochs,
        pct_start=0.1
    )

    # Create evaluation dataloader (smaller batch size for evaluation)
    eval_dataloader = get_dataset(
        batch_size=16,
        num_examples=100,  # Evaluate on 100 samples
        num_workers=0,
    )

    # Training configuration
    vq_initial_loss_weight = 0.1
    vq_warmup_steps = 2000
    vq_final_loss_weight = 0.01
    gradient_accumulation_steps = 4
    
    # Initialize loss tracking
    ctc_losses = []
    vq_losses = []
    steps = starting_steps
    
    # Create directory for saving models
    os.makedirs(f"models/{model_id}", exist_ok=True)

    print("Starting training...")
    print(f"Total steps per epoch: {len(dataloader)}")
    print(f"Evaluation every 1000 steps")
    
    for i in range(num_epochs):
        model.train()
        epoch_start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
        epoch_end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
        
        if epoch_start_time:
            epoch_start_time.record()
        
        for idx, batch in enumerate(dataloader):
            audio = batch["audio"].to(device)
            target = batch["input_ids"].to(device)

            if audio.dim() == 2:
                audio = audio.unsqueeze(1)

            output, vq_loss = model(audio)
            ctc_loss = run_loss_function(output, target, blank_token)

            # Improved loss weighting with cosine annealing
            progress = steps / (num_epochs * len(dataloader))
            vq_weight = vq_final_loss_weight + (vq_initial_loss_weight - vq_final_loss_weight) * \
                       (1 + math.cos(math.pi * min(progress * 2, 1))) / 2

            if vq_loss is not None:
                total_loss = ctc_loss + vq_weight * vq_loss
            else:
                total_loss = ctc_loss

            total_loss = total_loss / gradient_accumulation_steps
            total_loss.backward()

            ctc_losses.append(ctc_loss.item())
            vq_losses.append(vq_loss.item() if vq_loss is not None else 0.0)

            if (idx + 1) % gradient_accumulation_steps == 0 or (idx + 1) == len(dataloader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            steps += 1

            # Regular logging every 20 steps
            if steps % 20 == 0:
                if len(ctc_losses) > 0:
                    avg_ctc_loss = safe_mean(ctc_losses)
                    avg_vq_loss = safe_mean(vq_losses)
                    avg_loss = avg_ctc_loss + vq_weight * avg_vq_loss

                    print(
                        f"Epoch {i}, Batch {idx}/{len(dataloader)}, Step {steps}, "
                        f"Loss: {avg_loss:.4f}, CTC Loss: {avg_ctc_loss:.4f}, "
                        f"VQ Loss: {avg_vq_loss:.4f}, VQ Weight: {vq_weight:.4f}"
                    )

                    writer.add_scalar("Loss/train", avg_loss, steps)
                    writer.add_scalar("Loss/ctc", avg_ctc_loss, steps)
                    writer.add_scalar("Loss/vq", avg_vq_loss, steps)
                    writer.add_scalar("Loss/vq_weight", vq_weight, steps)

                    ctc_losses = []
                    vq_losses = []

            # Evaluation every 1000 steps
            if steps % 1000 == 0:
                print(f"\nRunning evaluation at step {steps}...")
                eval_results = evaluate_model(model, eval_dataloader, tokenizer, device, blank_token)
                
                # Log metrics to tensorboard
                writer.add_scalar("Metrics/WER", eval_results['wer'], steps)
                writer.add_scalar("Metrics/CER", eval_results['cer'], steps)
                
                # Print detailed results
                print_evaluation_results(eval_results, steps)

            # Save model periodically
            if steps % 500 == 0:
                model_path = f"models/{model_id}/model_step_{steps}.pth"
                os.makedirs(os.path.dirname(model_path), exist_ok=True)

                try:
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'step': steps,
                        'epoch': i,
                        'vq_weight': vq_weight,
                    }, model_path)
                    print(f"Model saved to {model_path}")

                    latest_path = f"models/{model_id}/model_latest.pth"
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'step': steps,
                        'epoch': i,
                        'vq_weight': vq_weight,
                    }, latest_path)

                except Exception as e:
                    print(f"Error saving model: {e}")

        if epoch_end_time:
            epoch_end_time.record()
            torch.cuda.synchronize()
            epoch_time = epoch_start_time.elapsed_time(epoch_end_time) / 1000.0
            print(f"Epoch {i} completed in {epoch_time:.2f} seconds")

    # Final evaluation
    print("\n" + "="*80)
    print("FINAL EVALUATION")
    print("="*80)
    final_eval_results = evaluate_model(model, eval_dataloader, tokenizer, device, blank_token, max_batches=10)
    print_evaluation_results(final_eval_results, steps)

    # Save final model
    try:
        final_path = f"models/{model_id}/model_final.pth"
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'step': steps,
            'epoch': num_epochs,
            'final_wer': final_eval_results['wer'],
            'final_cer': final_eval_results['cer'],
        }, final_path)
        print(f"Final model saved to {final_path}")
    except Exception as e:
        print(f"Error saving final model: {e}")

    writer.close()
    print("Training completed!")

In [11]:
# Add to train.py to test the tokenizer fix:
def test_tokenizer_filtering():
    tokenizer = get_tokenizer()
    
    # Test the filtering logic
    test_tokens = ["<pad>", "<unk>", "H", "E", "L", "L", "O", "<□>"]
    filtered = []
    
    for token in test_tokens:
        if token and token not in ["<pad>", "<unk>", "<s>", "</s>", "<□>"]:
            filtered.append(token)
    
    print(f"Original tokens: {test_tokens}")
    print(f"Filtered tokens: {filtered}")
    print(f"Expected: ['H', 'E', 'L', 'L', 'O']")
    
    if filtered == ['H', 'E', 'L', 'L', 'O']:
        print("✅ Token filtering is working correctly!")
    else:
        print("❌ Token filtering still has issues")

# Call this before training starts
test_tokenizer_filtering()


Original tokens: ['<pad>', '<unk>', 'H', 'E', 'L', 'L', 'O', '<□>']
Filtered tokens: ['H', 'E', 'L', 'L', 'O']
Expected: ['H', 'E', 'L', 'L', 'O']
✅ Token filtering is working correctly!


In [12]:
if __name__ == "__main__":
    main()

Using device: cpu
Number of trainable parameters: 2877489
Starting training...
Total steps per epoch: 21
Evaluation every 1000 steps


  return F.conv1d(


Epoch 0, Batch 19/21, Step 20, Loss: 6.9521, CTC Loss: 6.6293, VQ Loss: 3.2285, VQ Weight: 0.1000
Epoch 1, Batch 18/21, Step 40, Loss: 6.9502, CTC Loss: 6.6266, VQ Loss: 3.2353, VQ Weight: 0.1000
Epoch 2, Batch 17/21, Step 60, Loss: 6.9208, CTC Loss: 6.5959, VQ Loss: 3.2488, VQ Weight: 0.1000
Epoch 3, Batch 16/21, Step 80, Loss: 6.9807, CTC Loss: 6.6540, VQ Loss: 3.2673, VQ Weight: 0.1000
Epoch 4, Batch 15/21, Step 100, Loss: 6.9327, CTC Loss: 6.6038, VQ Loss: 3.2901, VQ Weight: 0.1000
Epoch 5, Batch 14/21, Step 120, Loss: 6.9425, CTC Loss: 6.6108, VQ Loss: 3.3181, VQ Weight: 0.1000
Epoch 6, Batch 13/21, Step 140, Loss: 6.9188, CTC Loss: 6.5838, VQ Loss: 3.3519, VQ Weight: 0.1000
Epoch 7, Batch 12/21, Step 160, Loss: 6.9299, CTC Loss: 6.5913, VQ Loss: 3.3884, VQ Weight: 0.0999
Epoch 8, Batch 11/21, Step 180, Loss: 6.9299, CTC Loss: 6.5870, VQ Loss: 3.4317, VQ Weight: 0.0999
Epoch 9, Batch 10/21, Step 200, Loss: 6.9295, CTC Loss: 6.5817, VQ Loss: 3.4809, VQ Weight: 0.0999
Epoch 10, Batc

KeyboardInterrupt: 