In [1]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import torchaudio
from jiwer import wer
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [2]:
model_name = "google/speech-to-text"
processor = Wav2Vec2Processor.from_pretrained(model_name,use_auth_token=True)
model = Wav2Vec2ForCTC.from_pretrained(model_name,use_auth_token=True)

model.eval()



OSError: google/speech-to-text is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass `use_auth_token=True`.

In [None]:

# Load and preprocess the audio file
audio_path = "your_audio_file.wav"  # Replace with your actual audio file path
waveform, sample_rate = torchaudio.load(audio_path)

# Resample the audio to 16kHz if needed (Conformer typically works best with 16kHz)
if sample_rate != 16000:
    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
    waveform = resampler(waveform)

# Preprocess the audio for the model
input_values = processor(waveform.squeeze().numpy(), return_tensors="pt").input_values

# Perform inference using the Conformer model
with torch.no_grad():
    logits = model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)

# Decode the predicted_ids to text
predicted_text = processor.decode(predicted_ids[0])

# Print the transcription result
print(f"Predicted text: {predicted_text}")

# Define the ground truth text (replace with actual transcription)
ground_truth_text = "This is an example of transcription"

# Calculate WER (Word Error Rate) for evaluation
error_rate = wer(ground_truth_text.lower(), predicted_text.lower())
print(f"Word Error Rate (WER): {error_rate:.4f}")

# Define function for visualization of pronunciation and performance metrics
def visualize_pronunciation(ground_truth, prediction):
    ground_truth_words = ground_truth.split()
    predicted_words = prediction.split()
    
    fig, ax = plt.subplots(figsize=(10, 4))
    for idx, word in enumerate(ground_truth_words):
        color = 'green' if idx < len(predicted_words) and word.lower() == predicted_words[idx].lower() else 'red'
        ax.text(idx * 0.1, 0.5, word, color=color, fontsize=12, ha='center')
    ax.axis('off')
    plt.title("Pronunciation Comparison")
    plt.show()

visualize_pronunciation(ground_truth_text, predicted_text)

# Function to plot performance metrics (Accuracy, Precision, Recall, F1)
def plot_performance_metrics(accuracy, precision, recall, f1):
    metrics = {"Accuracy": accuracy, "Precision": precision, "Recall": recall, "F1 Score": f1}
    
    plt.figure(figsize=(8, 5))
    sns.barplot(x=list(metrics.keys()), y=list(metrics.values()), palette="Blues_d")
    plt.ylim(0, 1)
    plt.title("Pronunciation Prediction Model Performance Metrics")
    plt.ylabel("Score")
    plt.xlabel("Metric")
    plt.show()

# Evaluation using accuracy, precision, recall, and F1 score
def get_labels(ground_truth, prediction):
    ground_truth_words = ground_truth.split()
    predicted_words = prediction.split()
    labels = [1 if gt.lower() == pd.lower() else 0 for gt, pd in zip(ground_truth_words, predicted_words)]
    return labels

labels = get_labels(ground_truth_text, predicted_text)

accuracy = accuracy_score([1]*len(labels), labels)
precision = precision_score([1]*len(labels), labels)
recall = recall_score([1]*len(labels), labels)
f1 = f1_score([1]*len(labels), labels)

plot_performance_metrics(accuracy, precision, recall, f1)

# Function to visualize audio waveform with text overlay
def visualize_audio_with_text_overlay(waveform, ground_truth, prediction, sample_rate=16000, downsample_factor=10):
    waveform = waveform[0, ::downsample_factor]  # downsample for better visualization

    ground_truth_words = ground_truth.split()
    predicted_words = prediction.split()

    total_time = waveform.size(0) / sample_rate
    word_times = np.linspace(0, total_time, len(ground_truth_words) + 1)

    fig, ax = plt.subplots(figsize=(10, 4))  # size of the plot
    ax.plot(np.linspace(0, total_time, waveform.size(0)), waveform.numpy(), label="Audio Signal")
    
    for idx, word in enumerate(ground_truth_words):
        color = 'green' if idx < len(predicted_words) and word.lower() == predicted_words[idx].lower() else 'red'
        ax.text(word_times[idx], waveform.max(), word, color=color, fontsize=9, ha='center', va='bottom')
    
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Amplitude")
    ax.set_xlim(0, total_time)  
    plt.legend()
    plt.title("Audio Waveform with Transcription Overlay")
    plt.show()

# Visualize audio with text overlay
visualize_audio_with_text_overlay(waveform, ground_truth_text, predicted_text)

# Print the evaluation metrics
print(f"Accuracy: {accuracy:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1 Score: {f1:.2f}")
