In [1]:
import torch
import numpy as np
import datasets
from dataset import get_tokenizer
from transcribe_model import TranscribeModel

def predict_on_idrak_dataset(model_path, num_samples=5, device='cuda'):
    """
    Test the model on the idrak_timit_subsample1 dataset.
    
    Args:
        model_path: Path to your trained model
        num_samples: Number of samples to test
        device: Device to use
    """
    # Load dataset
    print("Loading dataset...")
    dataset = datasets.load_dataset("m-aliabbas/idrak_timit_subsample1", split="train")
    
    # Load tokenizer and model
    tokenizer = get_tokenizer()
    blank_token = tokenizer.token_to_id("<□>")
    
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load model
    try:
        model = TranscribeModel.load(model_path).to(device)
        model.eval()
        print(f"Model loaded from {model_path}")
    except Exception as e:
        print(f"Error loading model: {e}")
        return
    
    print(f"Testing on {num_samples} samples from the dataset...")
    print("="*80)
    
    results = []
    
    for i in range(min(num_samples, len(dataset))):
        # Get data sample
        data = dataset[i]
        audio = data["audio"]["array"]
        sample_rate = data["audio"]["sampling_rate"]
        transcription = data["transcription"].upper()  # Convert to uppercase
        
        print(f"\nSample {i+1}:")
        print(f"Sample rate: {sample_rate} Hz")
        print(f"Audio length: {len(audio)} samples ({len(audio)/sample_rate:.2f} seconds)")
        print(f"Reference: '{transcription}'")
        
        # Preprocess audio
        audio_tensor = torch.tensor(audio).float().unsqueeze(0).to(device)
        
        # Limit audio length to prevent memory issues
        max_length = 80000  # 5 seconds at 16kHz
        if audio_tensor.shape[1] > max_length:
            audio_tensor = audio_tensor[:, :max_length]
            print(f"Audio truncated to {max_length} samples")
        
        # Forward pass
        try:
            with torch.no_grad():
                output, vq_loss = model(audio_tensor)
            
            # Decode prediction
            decoded_preds = greedy_decoder(output, blank_token=blank_token)
            
            # Convert tokens to text
            if len(decoded_preds) > 0:
                pred = decoded_preds[0]
                tokens = []
                for p in pred:
                    if p < len(tokenizer.get_vocab()):
                        token = tokenizer.id_to_token(p)
                        # Use FIXED special token filtering
                        if token and token not in ["<pad>", "<unk>", "<s>", "</s>", "<□>"]:
                            tokens.append(token)
                
                prediction = "".join(tokens)
                
                # Calculate simple accuracy metrics
                wer = calculate_simple_wer(prediction, transcription)
                
                print(f"Prediction: '{prediction}'")
                print(f"VQ Loss: {vq_loss.item():.4f}")
                print(f"Simple WER: {wer:.4f}")
                
                # Debug info
                print(f"Raw tokens (first 10): {pred[:10]}")
                print(f"Decoded tokens (first 10): {tokens[:10]}")
                
                results.append({
                    'sample_id': i,
                    'reference': transcription,
                    'prediction': prediction,
                    'wer': wer,
                    'vq_loss': vq_loss.item(),
                    'audio_length': len(audio)
                })
                
            else:
                print("Prediction: (empty)")
                results.append({
                    'sample_id': i,
                    'reference': transcription,
                    'prediction': "",
                    'wer': 1.0,
                    'vq_loss': vq_loss.item(),
                    'audio_length': len(audio)
                })
                
        except Exception as e:
            print(f"Error during prediction: {e}")
            continue
        
        print("-" * 40)
    
    # Summary
    print("\n" + "="*80)
    print("SUMMARY")
    print("="*80)
    
    if results:
        avg_wer = sum(r['wer'] for r in results) / len(results)
        avg_vq_loss = sum(r['vq_loss'] for r in results) / len(results)
        
        print(f"Average WER: {avg_wer:.4f}")
        print(f"Average VQ Loss: {avg_vq_loss:.4f}")
        print(f"Samples processed: {len(results)}")
        
        # Show best and worst predictions
        results.sort(key=lambda x: x['wer'])
        
        print(f"\nBest prediction (WER: {results[0]['wer']:.4f}):")
        print(f"  Reference: '{results[0]['reference']}'")
        print(f"  Prediction: '{results[0]['prediction']}'")
        
        print(f"\nWorst prediction (WER: {results[-1]['wer']:.4f}):")
        print(f"  Reference: '{results[-1]['reference']}'")
        print(f"  Prediction: '{results[-1]['prediction']}'")
    
    return results

def greedy_decoder(log_probs, blank_token=0):
    """Greedy decoder for CTC outputs."""
    predictions = torch.argmax(log_probs, dim=-1).cpu().numpy()
    decoded_predictions = []
    
    for pred in predictions:
        previous = -1
        decoded_seq = []
        for p in pred:
            if p != previous and p != blank_token:
                decoded_seq.append(p)
            previous = p
        decoded_predictions.append(decoded_seq)
    
    return decoded_predictions

def calculate_simple_wer(prediction, reference):
    """Simple WER calculation."""
    pred_words = prediction.split()
    ref_words = reference.split()
    
    if len(ref_words) == 0:
        return 1.0 if len(pred_words) > 0 else 0.0
    
    # Simple edit distance
    dp = [[0] * (len(ref_words) + 1) for _ in range(len(pred_words) + 1)]
    
    for i in range(len(pred_words) + 1):
        dp[i][0] = i
    for j in range(len(ref_words) + 1):
        dp[0][j] = j
    
    for i in range(1, len(pred_words) + 1):
        for j in range(1, len(ref_words) + 1):
            if pred_words[i-1] == ref_words[j-1]:
                dp[i][j] = dp[i-1][j-1]
            else:
                dp[i][j] = min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1]) + 1
    
    return dp[len(pred_words)][len(ref_words)] / len(ref_words)

# Usage example
if __name__ == "__main__":
    model_path = r"C:\Users\Kamil\Desktop\Coding\WUM_PROJECT\models\test21\model_step_1500.pth"
    
    # Test on 10 samples
    results = predict_on_idrak_dataset(model_path, num_samples=10)


Loading dataset...
Using device: cuda
Loading model from C:\Users\Kamil\Desktop\Coding\WUM_PROJECT\models\test21\model_step_1500.pth
Model loaded from C:\Users\Kamil\Desktop\Coding\WUM_PROJECT\models\test21\model_step_1500.pth
Testing on 10 samples from the dataset...

Sample 1:
Sample rate: 16000 Hz
Audio length: 35840 samples (2.24 seconds)
Reference: 'DON T ASK ME TO CARRY AN OILY RAG LIKE THAT'


  return F.conv1d(


Prediction: ''
VQ Loss: 2.0000
Simple WER: 1.0000
Raw tokens (first 10): []
Decoded tokens (first 10): []
----------------------------------------

Sample 2:
Sample rate: 16000 Hz
Audio length: 33485 samples (2.09 seconds)
Reference: 'BY EATING YOGURT  YOU MAY LIVE LONGER'
Prediction: ''
VQ Loss: 2.0000
Simple WER: 1.0000
Raw tokens (first 10): []
Decoded tokens (first 10): []
----------------------------------------

Sample 3:
Sample rate: 16000 Hz
Audio length: 48845 samples (3.05 seconds)
Reference: 'THE OVERWEIGHT CHARMER COULD SLIP POISON INTO ANYONE S TEA'
Prediction: ''
VQ Loss: 2.0000
Simple WER: 1.0000
Raw tokens (first 10): []
Decoded tokens (first 10): []
----------------------------------------

Sample 4:
Sample rate: 16000 Hz
Audio length: 58061 samples (3.63 seconds)
Reference: 'HE PICKED UP NINE PAIRS OF SOCKS FOR EACH BROTHER'
Prediction: ''
VQ Loss: 2.0000
Simple WER: 1.0000
Raw tokens (first 10): []
Decoded tokens (first 10): []
---------------------------------------