In [None]:
import os
import librosa
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from jiwer import wer

#  Functionality to process NPTEL data from local directory
def load_NPTEL_data(main_directory):
    data = {}
    for folder_name in ["wav"]:
        folder_path = os.path.join(main_directory, folder_name)
        audiofiles = []
        original_transcriptions = []
        corrected_transcriptions = []

        # loop through each file in Folder
        for file_name in os.listdir(folder_path):
            if file_name.endswith(".wav"):
                audio_path = os.path.join(folder_path, file_name)
                original_transcription_path = os.path.join(main_directory, "original_txt", file_name.replace(".wav", ".txt"))
                corrected_transcription_path = os.path.join(main_directory, "corrected_txt", file_name.replace(".wav", ".txt"))

                if os.path.exists(original_transcription_path) and os.path.exists(corrected_transcription_path):
                    with open(original_transcription_path, "r") as f:
                        original_transcription = f.read().strip()
                    with open(corrected_transcription_path, "r") as f:
                        corrected_transcription = f.read().strip()

                    # Skip empty references
                    if not original_transcription or not corrected_transcription:
                        print(f"Skipping {file_name}: Empty reference found")
                        continue

                    audiofiles.append(audio_path)
                    original_transcriptions.append(original_transcription)
                    corrected_transcriptions.append(corrected_transcription)
                else:
                    print(f"Transcription file missing for {file_name}")

        if not audiofiles:
            print(f"No audio files found in directory: {folder_name}")

        data[folder_name] = {
            "audiofiles": audiofiles,
            "original_transcriptions": original_transcriptions,
            "corrected_transcriptions": corrected_transcriptions
        }

    return data

# Transcribe the  Audio using Wav2Vec2 model
def transcribe_audio(audio_path):
    waveform, sample_rate = librosa.load(audio_path, sr=None)
    inputs = processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding=True)
    with torch.no_grad():
        logits = model(inputs.input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    return transcription[0]

# Load Audio files and transcriptions from the NPTEL dataset
main_directory = "C:/Users/tanya/OneDrive/Desktop/pytrch/New folder/nptel-pure"
NPTEL_data = load_NPTEL_data(main_directory)

# Load Pre-trained Wav2Vec2 model and processor
model_name = "facebook/wav2vec2-large-960h"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

Skipping 0000a7be825f70cbe4c49acf9a8b7804d05a8a4701a2b42a343a694e.wav: Empty reference found
Skipping 0000b51ac562056d95d2ab9a1925f69257cc1018e909d8583dc92ed4.wav: Empty reference found


Some weights of the model checkpoint at facebook/wav2vec2-large-960h were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You s

In [6]:

# Calculate CER for each directory
cer_results = {}

for dir_name, data in NPTEL_data.items():
    if not data["audiofiles"]:
        print(f"No audio files found in directory: {dir_name}")
        cer_results[dir_name] = {"CER Original": None, "CER Corrected": None}
        continue

    total_cer_original = 0
    total_cer_corrected = 0
    valid_samples = 0  

    audiofiles = data["audiofiles"]
    original_transcriptions = data["original_transcriptions"]
    corrected_transcriptions = data["corrected_transcriptions"]

    for audio_path, original_transcription, corrected_transcription in zip(audiofiles, original_transcriptions, corrected_transcriptions):
        try:
            # Transcribe audio
            transcription = transcribe_audio(audio_path)

            # Calculate CER against original transcription
            cer_original = wer(original_transcription, transcription)
            total_cer_original += cer_original

            # Calculate CER against corrected transcription
            cer_corrected = wer(corrected_transcription, transcription)
            total_cer_corrected += cer_corrected

            valid_samples += 1  

        except Exception as e:
            print(f"Error processing {audio_path}: {str(e)}")

    # Calculate average CER only if there are valid samples processed
    average_cer_original = total_cer_original / valid_samples if valid_samples > 0 else None
    average_cer_corrected = total_cer_corrected / valid_samples if valid_samples > 0 else None

    cer_results[dir_name] = {
        "CER Original": average_cer_original,
        "CER Corrected": average_cer_corrected
    }

    print(f"Avg CER for {dir_name} (Original): {average_cer_original}")
    print(f"Avg CER for {dir_name} (Corrected): {average_cer_corrected}")

print(cer_results)

Avg CER for wav (Original): 0.47939048968891124
Avg CER for wav (Corrected): 0.416193899455399
{'wav': {'CER Original': 0.47939048968891124, 'CER Corrected': 0.416193899455399}}
