In [1]:
import pandas as pd
import torchaudio 
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer
from jiwer import wer, cer  # Importing WER and CER libraries

print("Number of GPU: ", torch.cuda.device_count())
print("GPU Name: ", torch.cuda.get_device_name())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Number of GPU:  1
GPU Name:  NVIDIA L4
Using device: cuda


In [2]:
model = Wav2Vec2ForCTC.from_pretrained("./trained_models/wav2vec0.8dropout").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("./processor/wav2vec0.8dropout")

In [3]:
# Define the feature extractor and tokenizer
feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1, 
    sampling_rate=16000, 
    padding_value=0.0, 
    do_normalize=True, 
    return_attention_mask=True
)

vocab_path = './input/cleaned-asr-data/data/vocabulary/vocab.json'
tokenizer = Wav2Vec2CTCTokenizer(
    vocab_path, 
    unk_token="[UNK]", 
    pad_token="[PAD]", 
    word_delimiter_token="|"
)

processor = Wav2Vec2Processor(
    feature_extractor=feature_extractor, 
    tokenizer=tokenizer
)

In [4]:
# Define function to handle large audio files (segmentation)
def segmentLargeArray(inputTensor, chunksize=200000):
    list_of_segments = []
    tensor_length = inputTensor.shape[1]
    for i in range(0, tensor_length+1, chunksize):
        list_of_segments.append(inputTensor[:, i:i+chunksize])
    return list_of_segments 
    


In [5]:
# Function to predict speech-to-text from audio
def predict_from_speech(file):
    speech_array, sampling_rate = torchaudio.load(file)
    resampler = torchaudio.transforms.Resample(sampling_rate, 16000)
    resampled_array = resampler(speech_array).squeeze()

    if len(resampled_array.shape) == 1:
        resampled_array = resampled_array.reshape([1, resampled_array.shape[0]])

    if resampled_array.shape[1] >= 200000:
        list_of_segments = segmentLargeArray(resampled_array)
        output = ''
        for segment in list_of_segments:
            logits = model(segment.to("cuda")).logits
            pred_ids = torch.argmax(logits, dim=-1)[0]
            output += processor.decode(pred_ids)
    else:
        logits = model(resampled_array.to("cuda")).logits
        pred_ids = torch.argmax(logits, dim=-1)[0]
    
    return processor.decode(pred_ids)

In [6]:
# Load test data CSV
csv_file_path = "./csv_files/test.csv"
test_data = pd.read_csv(csv_file_path)
test_data

Unnamed: 0,path,labels
0,./training_data/audio/ca2305bc0b.wav,सिद्ध भए छन्।
1,./training_data/audio/fbfe903fb2.wav,आर्टस् एन्ड साइन्सेसमा
2,./training_data/audio/2a1efeef3e.wav,अल्फा एमानाइटिन जस्ता च्याउमा पाइने विषहरूले त...
3,./training_data/audio/9a3fb596b5.wav,अत यो मुख्य
4,./training_data/audio/be310dff43.wav,सूक्ष्मजीव थिए जसको
...,...,...
495,./training_data/audio/3306588e61.wav,सम्बन्ध ४ अर्ब
496,./training_data/audio/dd7e49af93.wav,२९ जुन २००९
497,./training_data/audio/2aa70648ea.wav,सेवा सञ्चालन गर्दा
498,./training_data/audio/e3a96720cc.wav,पल्टिँदा निन्द्रा आउने


In [7]:
# Create a list to hold the results
results = []

In [8]:
def predict_from_speech(file, actual_text):
    try:
        waveform, sample_rate = torchaudio.load(file)
    except Exception as e:
        print(f"Error loading audio file {file}: {e}")
        return None

    if waveform.shape[1] < 10:  # Skip very short audio
        print(f"Skipping short audio file: {file}")
        return None

    # Convert to mono (Wav2Vec2 expects a single channel)
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # Resample if necessary (Wav2Vec2 expects 16kHz)
    if sample_rate != 16000:
        transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = transform(waveform)

    # Ensure correct shape: (batch_size=1, samples)
    input_tensor = waveform.to("cuda")  # Shape should be (1, samples)

    with torch.no_grad():
        logits = model(input_tensor).logits  # Extract logits from output
        pred_ids = torch.argmax(logits, dim=-1)[0]
        predicted_text = processor.decode(pred_ids)

    predicted_text = predicted_text.replace("[UNK]", "")

    # Compute CER & WER
    cer_value = cer(actual_text, predicted_text)
    wer_value = wer(actual_text, predicted_text)

    results.append([actual_text, predicted_text, cer_value, wer_value])


# Process all test samples
for _, row in test_data.iterrows():
    predict_from_speech(row["path"], row["labels"])

In [9]:
# Convert results to a DataFrame
results_df = pd.DataFrame(results, columns=['Actual Text', 'Predicted Text', 'CER', 'WER'])

In [10]:
# Beautify the DataFrame
# 1. Format the CER and WER columns to 2 decimal places for better readability
results_df['CER'] = results_df['CER'].apply(lambda x: round(x, 2))
results_df['WER'] = results_df['WER'].apply(lambda x: round(x, 2))

# 2. Apply some style to highlight CER and WER columns
styled_df = results_df.style.set_table_styles(
    [{'selector': 'th', 'props': [('background-color', 'lightblue'), ('text-align', 'center')]},  # Header styling
     {'selector': 'td', 'props': [('text-align', 'left')]},  # Data cell alignment
     {'selector': '.col0', 'props': [('width', '300px')]},  # Column width adjustment for Actual Text
     {'selector': '.col1', 'props': [('width', '300px')]},  # Column width adjustment for Predicted Text
     {'selector': '.col2', 'props': [('width', '80px')]},   # Column width adjustment for CER
     {'selector': '.col3', 'props': [('width', '80px')]},   # Column width adjustment for WER
    ])
results_df

Unnamed: 0,Actual Text,Predicted Text,CER,WER
0,सिद्ध भए छन्।,सिन्धभएछन्,0.31,1.0
1,आर्टस् एन्ड साइन्सेसमा,आर्टएन्टसएन्सेजमा,0.36,1.0
2,अल्फा एमानाइटिन जस्ता च्याउमा पाइने विषहरूले त...,अल्फाएमानाइटिजस्ताच्याउमापाइनेविसहरूलेताप्टनसँ...,0.27,1.0
3,अत यो मुख्य,बटयोमख्य,0.45,1.0
4,सूक्ष्मजीव थिए जसको,सूचमजीवथिएजसको,0.32,1.0
...,...,...,...,...
495,सम्बन्ध ४ अर्ब,सम्बन्धचारअर्प,0.29,1.0
496,२९ जुन २००९,३९जुन२०९,0.36,1.0
497,सेवा सञ्चालन गर्दा,सेवासञ्चालनगर्दा,0.11,1.0
498,पल्टिँदा निन्द्रा आउने,पल्टितानितआउने,0.45,1.0


In [11]:
# Save the styled dataframe to CSV
results_df.to_csv('./test_result_csv/test_results(s14).csv', index=False)