In [None]:
import torch
import torchaudio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import time
import os

  from .autonotebook import tqdm as notebook_tqdm


### Load and store model in the dir

In [None]:
import os

save_dir = "models"
os.makedirs(save_dir, exist_ok=True)

models = {
    "base_960h": "facebook/wav2vec2-base-960h",
    "large_960h": "facebook/wav2vec2-large-960h",
    "large_960h_lv60": "facebook/wav2vec2-large-960h-lv60-self",
    "base_100h_lm": "patrickvonplaten/wav2vec2-base-100h-with-lm"
}

for name, model_id in models.items():
    print(f"Loading and saving {model_id}...")
    processor = Wav2Vec2Processor.from_pretrained(model_id)
    model = Wav2Vec2ForCTC.from_pretrained(model_id)

    # Create subfolder per model inside models folder
    model_save_path = os.path.join(save_dir, name)
    os.makedirs(model_save_path, exist_ok=True)

    # Save processor and model locally
    processor.save_pretrained(model_save_path)
    model.save_pretrained(model_save_path)

print("All models loaded and saved successfully!")


In [3]:
from jiwer import wer

models_dir = "models"
test_data_dir = "test_data"
transcribs_dir = "transcribs"

### Load model from the dir

In [4]:
model_names = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))]

### load wav file and resample if needed

In [5]:
def load_audio(path, target_sr = 16000):
    waveform, sr = torchaudio.load(path)
    if sr != target_sr:
        resampler = torchaudio.transforms.Resample(sr, target_sr)
        waveform = resampler(waveform)
    return waveform.squeeze()


### Transcribe audio with given model and processor

In [6]:
def transcribe(waveform, processor, model):
        # Ensure mono audio: get only one channel if multiple
        if waveform.ndim == 2:
            waveform = waveform[0]  # take the first channel
        
        # Ensure it's a tensor of shape (1, time)
        input_values = processor(waveform, return_tensors="pt", sampling_rate=16000).input_values

        with torch.no_grad():
            logits = model(input_values).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.decode(predicted_ids[0])

        return transcription.lower()


### Results

In [7]:
results = []

for model_name in model_names:
    print(f"\n--- Using model: {model_name} ---")
    model_path = os.path.join(models_dir, model_name)
    
    # Load processor and model
    processor = Wav2Vec2Processor.from_pretrained(model_path)
    model = Wav2Vec2ForCTC.from_pretrained(model_path)
    model.eval()

    # Loop through all wav files
    for wav_file in sorted(os.listdir(test_data_dir)):
        if not wav_file.endswith(".wav"):
            continue
        
        wav_path = os.path.join(test_data_dir, wav_file)
        txt_file = wav_file.replace(".wav", ".txt")
        txt_path = os.path.join(transcribs_dir, txt_file)

        if not os.path.isfile(txt_path):
            print(f"Transcription file missing for {wav_file}, skipping...")
            continue
        
        # Load audio & ground truth text
        waveform = load_audio(wav_path)
        with open(txt_path, "r", encoding="utf-8") as f:
            ground_truth = f.read().strip().lower()

        # Transcribe & measure time
        start_time = time.time()
        pred_text = transcribe(waveform, processor, model)
        duration = time.time() - start_time

        # Calculate WER
        error_rate = wer(ground_truth, pred_text)

        # Print results
        print(f"{model_name} | {wav_file} | Duration: {duration:.2f}s | WER: {error_rate:.3f}")
        print(f"Transcription: {pred_text}")
        print(f"Ground Truth : {ground_truth}\n")

        # Store result
        results.append({
            "model": model_name,
            "audio_file": wav_file,
            "duration_sec": duration,
            "wer": error_rate,
            "prediction": pred_text,
            "ground_truth": ground_truth
        })



--- Using model: base_100h_lm ---
base_100h_lm | 1.wav | Duration: 3.56s | WER: 1.000
Transcription: this is dist and t now there is anoyse silr levio es finished
Ground Truth : this is test now there is noise finised

base_100h_lm | 10.wav | Duration: 0.78s | WER: 0.778
Transcription: thence disano ts plareaty eats lord is easy hav pronounce and understanding but itis tar outslo
Ground Truth : sentence designed to be spoken clearly each word is easy to pronounce and understand when read aloud slowly

base_100h_lm | 2.wav | Duration: 1.74s | WER: 0.921
Transcription: sometimes life disegas bank if his tdes laps an dit tat eveny demdin of con disi a his was if his tad even disto once  sadinein disiistroge an micor
Ground Truth : sometimes life does not go as planned we face struggles doubts and delays but even in the middle of confusion your path has purpose every step even the slow ones is shaping you into someone stronger and wiser

base_100h_lm | 3.wav | Duration: 2.72s | WER: 0.923

### Average Error per word of each model

In [None]:
def summarize_model_wer(results):
    from collections import defaultdict

    model_wer_stats = defaultdict(lambda: {"total_wer": 0.0, "count": 0})

    for result in results:
        model = result["model"]
        model_wer_stats[model]["total_wer"] += result["wer"]
        model_wer_stats[model]["count"] += 1

    print("\n=== Average WER per Model ===")
    for model, stats in model_wer_stats.items():
        avg_wer = stats["total_wer"] / stats["count"] if stats["count"] > 0 else 0
        print(f"{model}: Average WER = {avg_wer:.3f} over {stats['count']} files")


In [9]:
summarize_model_wer(results)



=== Average WER per Model ===
base_100h_lm: Average WER = 0.678 over 10 files
base_960h: Average WER = 0.630 over 10 files
large_960h: Average WER = 0.482 over 10 files
large_960h_lv60: Average WER = 0.360 over 10 files


In [None]:
import csv
from tabulate import tabulate  

# Define CSV output path
csv_output_path = "transcription_results.csv"

# Write results to CSV
with open(csv_output_path, mode="w", newline="", encoding="utf-8") as csv_file:
    fieldnames = ["model", "audio_file", "duration_sec", "wer", "prediction", "ground_truth"]
    writer = csv.DictWriter(csv_file, fieldnames=fieldnames)

    writer.writeheader()
    for row in results:
        writer.writerow(row)

print(f"\nResults exported to: {csv_output_path}")

# Print structured table
print("\nSummary Table:")
print(tabulate(results, headers="keys", tablefmt="grid", floatfmt=".3f"))


Results exported to: transcription_results.csv

Summary Table:
+-----------------+--------------+----------------+-------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------