In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
import re
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from tqdm.auto import tqdm
import torchaudio
from dataclasses import dataclass, field
import evaluate
import os
import soundfile as sf
import numpy as np

device = 'cuda'
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
wer_metric = evaluate.load("wer")

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [2]:
MODEL = 'facebook/wav2vec2-base-960h'

In [3]:
origin_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h").to(device)
origin_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")

Some weights of the model checkpoint at facebook/wav2vec2-large-960h were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You s

Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize


In [4]:
ROOT_DIR = '../data'
TEST_INDICES = pd.read_csv(f'{ROOT_DIR}/cv-valid-test.csv')
TEST_DATA_BASE = f"{ROOT_DIR}/cv-valid-test"

In [5]:
def transcribe(file, eval_model, processor):
    # try:
    # Load audio file with soundfile
    audio_array, sample_rate = sf.read(file)
    
    # Convert to mono if stereo
    if len(audio_array.shape) > 1:
        audio_array = audio_array.mean(axis=1)
    
    # Resample to 16kHz if needed
    if sample_rate != 16000:
        # Calculate new length for 16kHz
        new_length = int(len(audio_array) * 16000 / sample_rate)
        audio_array = np.interp(
            np.linspace(0, len(audio_array), new_length),
            np.arange(len(audio_array)),
            audio_array
        )
        sample_rate = 16000
    
    # Get duration
    
    # Normalize audio array
    audio_array = audio_array / np.max(np.abs(audio_array))
    
        # Process audio with Wav2Vec2
    input_values = processor(
        audio_array, 
        return_tensors="pt", 
        padding="longest",
        sampling_rate=sample_rate
    ).input_values.to(device)

    # Get model predictions
    with torch.no_grad():
        logits = eval_model(input_values).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]


    return transcription


In [6]:
def evaluate_model(eval_model, eval_processor, data_base, data_indices):
    references, predictions = [], []
    for i in tqdm(range(len(data_indices))):
        row = data_indices.iloc[i]
        file_path = f"{data_base}/{row['filename']}"
        
        if not os.path.exists(file_path):
            print(f"File {file_path} does not exist.")
            continue
        
        transcription = transcribe(file_path, eval_model, eval_processor)
        
        # print(transcription)
        # print(row)
        references.append(row['text'].upper())
        predictions.append(transcription)
        # break

    performance = wer_metric.compute(predictions=predictions, references=references)
    print(f"Word Error Rate: {performance:.4f}")    
    return performance


In [7]:
FINTUNED_MODEL = '../models/wav2vec2-large-960h-cv/best_model'
cv_model = Wav2Vec2ForCTC.from_pretrained(FINTUNED_MODEL).to(device)
cv_processor = Wav2Vec2Processor.from_pretrained(FINTUNED_MODEL)

# Test eval

In [8]:
wer_origin = evaluate_model(origin_model, origin_processor, TEST_DATA_BASE, TEST_INDICES)
wer_finetuned = evaluate_model(cv_model, cv_processor, TEST_DATA_BASE, TEST_INDICES)

  audio_array = audio_array / np.max(np.abs(audio_array))
100%|██████████| 3995/3995 [01:38<00:00, 40.49it/s]


Word Error Rate: 0.1041


100%|██████████| 3995/3995 [01:37<00:00, 40.88it/s]


Word Error Rate: 0.0826


In [9]:
print(f"WER for original model: {wer_origin:.4f}")
print(f"WER for finetuned model: {wer_finetuned:.4f}")

WER for original model: 0.1041
WER for finetuned model: 0.0826


# Dev eval

In [10]:
DEV_INDICES = pd.read_csv(f'{ROOT_DIR}/cv-valid-dev.csv')
DEV_DATA_BASE = f"{ROOT_DIR}/cv-valid-dev"

In [11]:
wer_origin = evaluate_model(origin_model, origin_processor, DEV_DATA_BASE, DEV_INDICES)
wer_finetuned = evaluate_model(cv_model, cv_processor, DEV_DATA_BASE, DEV_INDICES)

  audio_array = audio_array / np.max(np.abs(audio_array))
100%|██████████| 4076/4076 [01:46<00:00, 38.15it/s]


Word Error Rate: 0.1100


100%|██████████| 4076/4076 [02:47<00:00, 24.29it/s]


Word Error Rate: 0.0861
