Mean Absolute Error


In [1]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import librosa

# Load processor & trained model
model_path = "D:/Speech_recognition/Model_Training/wav2vec2_finetuned"
processor = Wav2Vec2Processor.from_pretrained(model_path)
model = Wav2Vec2ForCTC.from_pretrained(model_path)

# Move model to GPU 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Wav2Vec2ForCTC(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=1024, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder

In [2]:
def transcribe_audio(audio_path):
    # Load audio file
    waveform, _ = librosa.load(audio_path, sr=16000)
    
    # Convert audio to tensor
    input_values = processor(waveform, sampling_rate=16000, return_tensors="pt", padding=True).input_values
    input_values = input_values.to(device)

    # Perform inference
    with torch.no_grad():
        logits = model(input_values).logits

    # Decode prediction
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)[0]

    return transcription

In [3]:
import time

# New test sentences for ASR evaluation (Ground Truth)
actual_transcripts = [
    "GET THE TRUST FUND TO THE BANK EARLY",
    "THE GRASS AND BUSHES WERE WET WITH DEW",
    "SMALL CHILDREN CAME TO SEE HIM",
    "A PINK SHELL WAS FOUND ON THE SANDY BEACH",
    "SHE SAW A CAT IN THE NEIGHBOUR'S HOUSE",
]

# List of test audio files (Replace with real file paths)
audio_files = [
    r"D:\Speech_recognition\Data\test_audio01.wav",
    r"D:\Speech_recognition\Data\test_audio02.wav",
    r"D:\Speech_recognition\Data\test_audio03.wav",
    r"D:\Speech_recognition\Data\test_audio04.wav",
    r"D:\Speech_recognition\Data\test_audio05.wav"
]

# Function to transcribe multiple audio files using ASR model
def generate_asr_predictions(audio_files):
    predicted_texts = []
    
    for audio in audio_files:
        transcript = transcribe_audio(audio)  # Your ASR function
        predicted_texts.append(transcript)

    return predicted_texts

# Generate ASR outputs dynamically
predicted_transcripts = generate_asr_predictions(audio_files)
print(predicted_transcripts)

# Function to calculate simple accuracy (exact matches)
def calculate_accuracy(actual, predicted):
    correct = sum(1 for a, p in zip(actual, predicted) if a.lower() == p.lower())
    accuracy = (correct / len(actual)) * 100
    return accuracy

# Run accuracy test
accuracy = calculate_accuracy(actual_transcripts, predicted_transcripts)
print(f"✅ ASR Accuracy: {accuracy:.2f}%")

# Measure inference speed
def measure_inference_time(audio_files):
    times = []
    
    for audio in audio_files:
        start_time = time.time()
        
        _ = transcribe_audio(audio)  # Run ASR
        
        end_time = time.time()
        times.append(end_time - start_time)

    avg_time = sum(times) / len(times)
    return avg_time

# Run inference time test
avg_inference_time = measure_inference_time(audio_files)
print(f"⏳ Average Inference Time: {avg_inference_time:.4f} seconds per file")


['GET THE TRUST FUN TO THE BANK EARLY', 'THE GRASS AND BUSHES WERE WHET WITH DEW', 'SMALL CHILDREN CAME TO SEE HIM', 'A PINK SHELL WAS FOUND ON THE SANDY BEACH', "SHE SAW A CAT IN THE NEIGHBOUR'S HOUSE"]
✅ ASR Accuracy: 60.00%
⏳ Average Inference Time: 3.4395 seconds per file


In [4]:
import os
from langchain_groq import ChatGroq
from dotenv import load_dotenv

# Load API Key
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")

# ✅ Initialize Groq LLM
groq_llm = ChatGroq(
    model_name="mistral-saba-24b",  # Choose supported model
    groq_api_key=groq_api_key,
    temperature=0.1
)

def batch_refine_transcriptions(asr_outputs, history_file="transcriptions_history.txt"):
    if not asr_outputs:
        return []
    
    # Load previous transcriptions as context if file exists
    previous_transcriptions = []
    if os.path.exists(history_file):
        with open(history_file, "r", encoding="utf-8") as file:
            previous_transcriptions = file.readlines()
    
    # Format transcriptions
    formatted_transcriptions = "\n".join([f"{i+1}. {text}" for i, text in enumerate(asr_outputs)])

    # Create the prompt with history for better corrections
    prompt = f"""
    You are an advanced ASR correction assistant that fixes phonetic and contextual mistakes in transcriptions.
    - Correct phonetic errors 
    - Correct name recognition 
    - Ensure proper grammar while preserving original meaning.
    - Refer to previous corrected transcriptions when necessary.

    Previous transcriptions:
    {"".join(previous_transcriptions)}

    Here are multiple ASR outputs. Correct each one and return ONLY the corrected texts, numbered accordingly:
    {formatted_transcriptions}
    """
    
    # Get response from Groq API
    response = groq_llm.invoke(prompt)
    if not response or not hasattr(response, "content"):
        return ["❌ Error: No response from Groq API"]
    
    corrected_texts = response.content.strip().split("\n")
    corrected_texts = [text.split(". ", 1)[1] if ". " in text else text for text in corrected_texts]

    # Save new transcriptions
    with open(history_file, "a", encoding="utf-8") as file:
        for text in corrected_texts:
            file.write(text + "\n")

    return corrected_texts

In [5]:
import os
Refined_transcripts = batch_refine_transcriptions(predicted_transcripts)

def calculate_accuracy(actual, predicted):
    correct = sum(1 for a, p in zip(actual, predicted) if a.lower() == p.lower())
    accuracy = (correct / len(actual)) * 100
    return accuracy

print(Refined_transcripts)

# Run accuracy test
accuracy = calculate_accuracy(actual_transcripts, Refined_transcripts)
print(f"✅ ASR Accuracy: {accuracy:.2f}%")

['GET THE TRUST FUND TO THE BANK EARLY', 'THE GRASS AND BUSHES WERE WET WITH DEW', 'SMALL CHILDREN CAME TO SEE HIM', 'A PINK SHELL WAS FOUND ON THE SANDY BEACH', "SHE SAW A CAT IN THE NEIGHBOR'S HOUSE"]
✅ ASR Accuracy: 80.00%


WER = Word Error Rate


In [6]:
from jiwer import wer

# Calculate WER for each transcript
wer_scores = [wer(actual, pred) for actual, pred in zip(actual_transcripts, predicted_transcripts)]

# Compute Average WER
avg_wer = sum(wer_scores) / len(wer_scores)

print(f"\n🔍 WER Before LLM Correction: {avg_wer:.2%}")



🔍 WER Before LLM Correction: 5.00%


In [7]:
# Apply LLM Refinement
refined_transcripts = batch_refine_transcriptions(predicted_transcripts)

# Compute WER After LLM Correction
wer_after = sum(wer(actual, refined) for actual, refined in zip(actual_transcripts, refined_transcripts)) / len(actual_transcripts)

print(f"\n🚀 WER After LLM Correction: {wer_after:.2%} (Lower is better!)")



🚀 WER After LLM Correction: 2.50% (Lower is better!)
