In [3]:
import json
from jiwer import wer

def load_json(file_path):
    """Load JSON data from a file."""
    with open(file_path, 'r') as f:
        return json.load(f)

def calculate_wer(gt_json_path, pred_json_path):
    # Load the ground truth and predictions JSON files
    gt_data = load_json(gt_json_path)
    pred_data = load_json(pred_json_path)
    
    # Convert predictions to a dictionary for quick lookup by video_id
    pred_dict = {item['filename']: item['transcript'] for item in pred_data}
    
    # Initialize variables to calculate WER
    wer_scores = []
    gt_texts = []
    pred_texts = []
    
    for gt_item in gt_data:
        video_id = gt_item['video_id']
        if "subtitle" not in gt_item or len(gt_item['subtitle'])<=1:
            continue
        gt_caption = gt_item['subtitle']
        
        # Find the corresponding prediction
        pred_caption = pred_dict.get(video_id)
        
        if pred_caption:
            # Calculate WER for this pair
            error_rate = wer(gt_caption, pred_caption)
            wer_scores.append(error_rate)
            gt_texts.append(gt_caption)
            pred_texts.append(pred_caption)
        else:
            print(f"No prediction found for video ID: {video_id}")
    
    # Calculate the average WER across all pairs
    average_wer = sum(wer_scores) / len(wer_scores) if wer_scores else None
    
    return {
        "average_wer": average_wer,
        "wer_scores": wer_scores,
        "gt_texts": gt_texts,
        "pred_texts": pred_texts
    }

# Specify the paths to your GT and Prediction JSON files
gt_json_path = '../data/mix120/mix120.json'
pred_json_path = '../output/whisper_small_transcriptions.json'

# Calculate WER and print results
results = calculate_wer(gt_json_path, pred_json_path)
print("Average WER:", results["average_wer"])
for i, (gt, pred, wer_score) in enumerate(zip(results["gt_texts"], results["pred_texts"], results["wer_scores"])):
    print(f"\nPair {i + 1}:")
    print("Ground Truth:", gt)
    print("Prediction:", pred)
    print("WER:", wer_score)


Average WER: 0.5909673952943849

Pair 1:
Ground Truth: What Im going to have to do is I want to have to take and cut to trying to a stress relief cut in here and here on both sides.
Prediction:  Well, what I'm gonna have to do is I'm gonna have to take and cut, try to do a stress relief cut in here, in here.
WER: 0.5666666666666667

Pair 2:
Ground Truth: The way the game is being played, the Golden State Warriors have mastered it, theyve mastered it from the threepoint line; they mastered it from the defensive end, so they have put together a full package.
Prediction:  way the game is being played. The Golden State Warriors have mastered it. They've mastered it from the three-point line. They've mastered it from the defensive end. So they have put together a full package
WER: 0.3055555555555556

Pair 3:
Ground Truth: That likely means that some of the batteries were damaged, but it doesn't matter now.
Prediction:  after a few minutes, which was all the battery power the ship had. That 