In [1]:
import argparse
from transformers import pipeline
from datasets import load_dataset, Audio
import evaluate
from joblib import Parallel, delayed
from tqdm import tqdm
import json
import librosa
import pandas as pd
from torch.utils.data import Dataset
from indicnlp.normalize.indic_normalize import IndicNormalizerFactory
import pyarrow as pa
import soundfile as sf
import jiwer
import os
import string
import re
import time
import torch

from transformers import (
    AutoConfig,
    AutoFeatureExtractor,
    AutoModelForSpeechSeq2Seq,
    AutoProcessor,
    AutoTokenizer,
    HfArgumentParser,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainerCallback,
    set_seed,
)

In [2]:
lang_to_code = {
    'hindi': 'hi',
    'sanskrit': 'sa',
    'bengali': 'bn',
    'tamil': 'ta',
    'telugu': 'te',
    'gujarati': 'gu',
    'kannada': 'kn',
    'malayalam': 'ml',
    'marathi': 'mr',
    'odia': 'or',
    'punjabi': 'pa',
    'urdu': 'ur',
}

def normalize_sentence(sentence, lang_code):
    '''
    Perform NFC -> NFD normalization for a sentence and a given language
    sentence: string
    lang_code: language code in ISO format
    '''
    factory=IndicNormalizerFactory()
    normalizer=factory.get_normalizer(lang_code)
    normalized_sentence = normalizer.normalize(sentence)
    return normalized_sentence

class eval_dataset(Dataset):
    
    def __init__(self):
        self.audios = []
        self.sents = []
        
    def __len__(self):
        return len(self.audios)

    def __getitem__(self, i):
        return {"raw": self.audios[i]['array'], "sampling_rate":self.audios[i]['sampling_rate'], "reference":self.sents[i], 
                "path": self.audios[i]['path'], "duration": self.audios[i]['duration']}
    
    def fill_data(self, aud, sent):
        self.audios.append(aud)
        self.sents.append(sent)

def get_data(split):
    js_data = json.loads(split)
    aud = {}
    aud['path'] = js_data['audio_filepath'].replace('/nlsasfs/home/ai4bharat/ai4bharat-pr/speechteam/asr_datasets', '/workspace/ai4bharat-pr/speechteam/ai4bp_upload/vistaar')
    y, sr = sf.read(aud['path'])
    aud['duration'] = js_data['duration']
    aud['array'] = y
    aud['sampling_rate'] = sr
    
    return (aud, js_data['text'])

In [None]:
# the value for dir_path needs to be your drive folder
dir_path = ""

lang_code = 'hi' # hi for hindi

model_path = "YOUR_MODEL_DRIVE_PATH"

manifest_path = "" # this is the manifest file from kathbath/hindi/test/nemo_manifest folder

batch_size = 16 # depends on the size of the TPU, 16 is probably safe

In [None]:
with open(manifest_path, 'r') as f:
    data = f.read()
    splits = data.split('\n')
    if splits[-1] == '':
        splits = splits[:-1]
whisper_asr = pipeline(
        "automatic-speech-recognition", model=model_path, device='tpu', # tpu because colab has a Tensor Processing Unit
    )



In [None]:
da = Parallel(n_jobs=128)(delayed(get_data)(split) for split in tqdm(splits))

dataset = eval_dataset()
for d in da:
    dataset.fill_data(d[0], d[1])

In [None]:
whisper_asr.model.generation_config.language = f"<|{lang_code}|>"
whisper_asr.model.generation_config.task = 'transcribe'

In [None]:

hypothesis = []
ground_truth = []

os.makedirs(dir_path + '/' + 'predictions', exist_ok=True)

out_name = 'predictions.json'

open(dir_path + '/' + 'predictions/' + out_name, 'w').close()

st = time.time()

for out in tqdm(whisper_asr(dataset, batch_size=batch_size), total=len(dataset)):
    
    hyp = out['text']
    ref = out['reference'][0]
    hyp = hyp.translate(str.maketrans('', '', string.punctuation+"।۔'-॥"))
    ref = ref.translate(str.maketrans('', '', string.punctuation+"।۔'-॥"))
    if lang_code[:2] != 'ur':
        hyp = normalize_sentence(hyp, lang_code[:2])
        ref = normalize_sentence(ref, lang_code[:2])
    hyp = re.sub(' +', ' ', hyp)
    ref = re.sub(' +', ' ', ref)
    
    if ref == '':
        ref = '<empty>'
    hypothesis.append(hyp)
    ground_truth.append(ref)
    res = {
        "audio_filepath":out['path'][0],
        "duration":out['duration'][0],
        "text":ref,
        "pred_text":hyp
    }
    with open(dir_path + '/' + 'predictions/' + out_name, 'a') as f:
        json.dump(res, f)
        f.write('\n')

et = time.time()
    
data = {}
data['model'] = model_path
data['dataset'] = "kathbath_hindi"
data['language'] = lang_code
data['cer'] = jiwer.cer(ground_truth, hypothesis)
data['time'] = (et-st)/60
data['batch_size'] = batch_size
measures = jiwer.compute_measures(ground_truth, hypothesis)
data['wer'] = measures['wer']

print(data)

with open(dir_path + '/' + 'results.csv', 'a') as results_fp:
    print(','.join([str(v) for v in data.values()]), file=results_fp)