In [None]:
import nemo.collections.asr as nemo_asr
import torch
import torchaudio
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from jiwer import wer, mer, wil
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Load QuartzNet model and tokenizer from NeMo
quartznet_model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name="stt_en_quartznet15x5")

# Load audio file
audio_file = os.path.join("timit", "data", "TRAIN", "DR1", "FCJF0", "SA1.WAV")
waveform, sample_rate = torchaudio.load(audio_file)
waveform = waveform.squeeze()

# Resample if necessary
if sample_rate != 16000:
    waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)

# Ground truth text for comparison
ground_truth_text = "She had your dark suit in greasy wash water all year"

# Function to process audio and get predictions
def process_and_evaluate(model, waveform, ground_truth_text, sample_rate=16000):
    # Convert the waveform to a list of numpy arrays
    waveform = waveform.numpy()
    inputs = [waveform]

    # Get prediction
    predicted_text = model.transcribe(inputs)[0]

    # Calculate error rates
    error_rate = wer(ground_truth_text.lower(), predicted_text.lower())
    match_error_rate = mer(ground_truth_text.lower(), predicted_text.lower())
    wil_rate = wil(ground_truth_text.lower(), predicted_text.lower())

    print(f"Predicted Text: {predicted_text}")
    print("Word Error Rate:", error_rate)
    print("Match Error Rate:", match_error_rate)
    print("Word Information Lost Rate:", wil_rate)

    return predicted_text, error_rate, match_error_rate, wil_rate

# Evaluate the QuartzNet model
print("Evaluating QuartzNet...")
predicted_text, wer, mer, wil = process_and_evaluate(quartznet_model, waveform, ground_truth_text)

# Visualization of Pronunciation
def visualize_pronunciation(ground_truth, prediction):
    ground_truth_words = ground_truth.split()
    predicted_words = prediction.split()

    fig, ax = plt.subplots()
    for idx, word in enumerate(ground_truth_words):
        color = 'green' if idx < len(predicted_words) and word.lower() == predicted_words[idx].lower() else 'red'
        ax.text(idx * 0.1, 0.5, word, color=color, fontsize=12, ha='center')
    ax.axis('off')
    plt.show()

# Plot performance metrics for QuartzNet
def plot_performance_metrics(wer, mer, wil):
    metrics = {"WER": wer, "MER": mer, "WIL": wil}
    
    plt.figure(figsize=(8, 5))
    sns.barplot(x=list(metrics.keys()), y=list(metrics.values()), palette="Blues_d")
    plt.ylim(0, 1)
    plt.title("QuartzNet Model Performance Metrics (WER, MER, WIL)")
    plt.ylabel("Score")
    plt.xlabel("Metric")
    plt.show()

# Visualize predictions
visualize_pronunciation(ground_truth_text, predicted_text)

# Plot all model performance metrics
plot_performance_metrics(wer, mer, wil)

# Detailed evaluation (accuracy, precision, recall, F1) for the QuartzNet model
def evaluate_word_level_metrics(ground_truth, prediction):
    ground_truth_words = ground_truth.split()
    predicted_words = prediction.split()
    labels = [1 if gt.lower() == pd.lower() else 0 for gt, pd in zip(ground_truth_words, predicted_words)]
    
    accuracy = accuracy_score([1]*len(labels), labels)
    precision = precision_score([1]*len(labels), labels)
    recall = recall_score([1]*len(labels), labels)
    f1 = f1_score([1]*len(labels), labels)

    return accuracy, precision, recall, f1

accuracy, precision, recall, f1 = evaluate_word_level_metrics(ground_truth_text, predicted_text)

# Print evaluation metrics
print(f"Accuracy: {accuracy:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1 Score: {f1:.2f}")

# Plotting the detailed metrics
def plot_detailed_metrics(accuracy, precision, recall, f1):
    metrics = {"Accuracy": accuracy, "Precision": precision, "Recall": recall, "F1 Score": f1}
    
    plt.figure(figsize=(8, 5))
    sns.barplot(x=list(metrics.keys()), y=list(metrics.values()), palette="Blues_d")
    plt.ylim(0, 1)
    plt.title(f"QuartzNet Model Performance Metrics")
    plt.ylabel("Score")
    plt.xlabel("Metric")
    plt.show()

plot_detailed_metrics(accuracy, precision, recall, f1)


ImportError: cannot import name 'WavLMProcessor' from 'transformers' (d:\Anaconda\Lib\site-packages\transformers\__init__.py)