In [5]:
pip install datasets transformers torchaudio librosa

Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import torchaudio
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from datasets import load_metric
from jiwer import cer

# Set paths to the datasets
dataset_dir = os.path.expanduser("~/Desktop/nptel-pure")
audio_dir = os.path.join(dataset_dir, "wav")
original_txt_dir = os.path.join(dataset_dir, "original_txt")
corrected_txt_dir = os.path.join(dataset_dir, "corrected_txt")

# Function to preprocess audio files
def speech_file_to_array_fn(file_path):
    speech_array, sampling_rate = torchaudio.load(file_path)
    resampler = torchaudio.transforms.Resample(sampling_rate, 16000)
    speech = resampler(speech_array).squeeze().numpy()
    return speech

# Load pretrained Wav2Vec2 model and processor
model_name = "facebook/wav2vec2-large-960h-lv60-self"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
model.to("cuda" if torch.cuda.is_available() else "cpu")

# Function to transcribe audio using the model
def predict(speech):
    inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
    with torch.no_grad():
        logits = model(inputs.input_values.to(model.device)).logits
    pred_ids = torch.argmax(logits, dim=-1)
    predicted_text = processor.batch_decode(pred_ids)
    return predicted_text[0]

# Function to read transcripts from a file
def read_transcript(file_path):
    with open(file_path, 'r') as f:
        return f.read().strip()

# Function to process the dataset
def process_dataset(audio_dir, transcript_dir):
    audio_files = [os.path.join(audio_dir, file) for file in os.listdir(audio_dir) if file.endswith(".wav")]
    audio_files.sort()

    predictions = []
    references = []

    for audio_file in audio_files:
        transcript_file_name = os.path.splitext(os.path.basename(audio_file))[0] + ".txt"
        transcript_path = os.path.join(transcript_dir, transcript_file_name)
        
        if not os.path.exists(transcript_path):
            print(f"Transcript file not found for audio: {audio_file}")
            continue

        # Read ground truth transcription
        reference_text = read_transcript(transcript_path)

        # Skip empty reference texts
        if not reference_text:
            print(f"Empty reference text found for audio: {audio_file}")
            continue

        references.append(reference_text)

        # Transcribe the audio file
        speech = speech_file_to_array_fn(audio_file)
        predicted_text = predict(speech)
        predictions.append(predicted_text)

        # Debugging: Print the transcript and prediction
        print(f"Transcript for {audio_file}: {reference_text}")
        print(f"Prediction for {audio_file}: {predicted_text}")

    return predictions, references

# Process the datasets
predictions_original, references_original = process_dataset(audio_dir, original_txt_dir)
predictions_corrected, references_corrected = process_dataset(audio_dir, corrected_txt_dir)

# Ensure predictions and references are not empty
if not predictions_original or not references_original:
    raise ValueError("Predictions or references for original transcripts are empty.")
if not predictions_corrected or not references_corrected:
    raise ValueError("Predictions or references for corrected transcripts are empty.")

# Load the CER metric
cer_metric = load_metric("cer")

# Calculate CER for both datasets
cer_original = cer_metric.compute(predictions=predictions_original, references=references_original)
cer_corrected = cer_metric.compute(predictions=predictions_corrected, references=references_corrected)

print(f"CER for original transcripts: {cer_original:.4f}")
print(f"CER for corrected transcripts: {cer_corrected:.4f}")


Some weights of the model checkpoint at facebook/wav2vec2-large-960h-lv60-self were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h-lv60-self and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.maske

Transcript for C:\Users\ASUS/Desktop/nptel-pure\wav\0000003b8fd9bc22877135b42b04c49d4860312b001be688723ecc5d.wav: IN THIS PARTICULAR SESSION WE WILL BE DISCUSSING HOW TO BENCHMARK SOMETHING THAT I GOT TO
Prediction for C:\Users\ASUS/Desktop/nptel-pure\wav\0000003b8fd9bc22877135b42b04c49d4860312b001be688723ecc5d.wav: THIS PARTICULAR SESSION WILL BE DISCUSSING HOW TO BENCH ONE SOMETHING THAT 'VE GOT
Transcript for C:\Users\ASUS/Desktop/nptel-pure\wav\00000682f31904acc560fa359512e7bdd487b11efe36145a56874e30.wav: ELECTRIC FIELD LIKE THIS WHY BECAUSE FIRST OF ALL THIS ELECTRIC FIELD INTERN CREATES
Prediction for C:\Users\ASUS/Desktop/nptel-pure\wav\00000682f31904acc560fa359512e7bdd487b11efe36145a56874e30.wav: FIELD LIKE THIS WHY ECAS FIRST OF ALL THIS ELECTRIC FIELD IN TURN CREATE
Transcript for C:\Users\ASUS/Desktop/nptel-pure\wav\0000068bfcd8e252fd9ec8225bd1fdb47378a009f8afa00d4e998df5.wav: TWO Z INVARIANT MAP AH J TILDE SO THAT THE FOLLOWING DIAGRAMS
Prediction for C:\Users\ASUS/Desktop/

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


CER for original transcripts: 0.2066
CER for corrected transcripts: 0.1325
