In [14]:
import os
import uuid
import torch
import shutil
import librosa
import difflib
import warnings
import requests
import numpy as np 
from pydub import AudioSegment
import speech_recognition as sr
from datasets import load_dataset, Dataset, Audio
from speechbrain.pretrained import SepformerSeparation as separator
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, pipeline, \
                         WhisperProcessor, WhisperForConditionalGeneration
warnings.filterwarnings("ignore")

In [15]:
denoiser = separator.from_hparams(
                                source="speechbrain/sepformer-wham-enhancement", 
                                savedir='pretrained_models/sepformer-wham-enhancement'
                                )

s2t_processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
s2t_model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")

zsc_pipeline = pipeline(model="facebook/bart-large-mnli")

sin_s2t_processor = WhisperProcessor.from_pretrained("Subhaka/whisper-small-Sinhala-Fine_Tune")
sin_s2t_model = WhisperForConditionalGeneration.from_pretrained("Subhaka/whisper-small-Sinhala-Fine_Tune")
sin_s2t_forced_decoder_ids = sin_s2t_processor.get_decoder_prompt_ids(
                                                                    language="sinhala", 
                                                                    task="transcribe"
                                                                    )

API_URL_DENOISER = "https://api-inference.huggingface.co/models/speechbrain/sepformer-wsj03mix"
headers_DENOISER = {"Authorization": "Bearer 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"}

In [16]:
def mp3toWav(audioFile):
    audioFileNew = audioFile.replace('mp3', 'wav') if audioFile.endswith('.mp3') else audioFile.replace('ogg', 'wav').replace('/mp3/', '/wav/')
    if os.path.exists(audioFileNew):
        os.remove(audioFileNew)

    if audioFile.endswith('.mp3'):
        sound = AudioSegment.from_mp3(audioFile)
        sound.export(audioFileNew, format="wav")
    else:
        sound = AudioSegment.from_file(audioFile)
        sound.export(audioFileNew, format="wav")
    return audioFileNew

def audio_denoising(audioFileNew):
    try:
        denoiser.separate_file(path=audioFileNew) 
        file_path = os.path.split(audioFileNew)[-1]

        enhancedAudioFile = audioFileNew.replace('/wav/', '/denoised_wav/')
        if os.path.exists(enhancedAudioFile):
            os.remove(enhancedAudioFile)
            
        shutil.move(file_path, enhancedAudioFile)
        return enhancedAudioFile
    
    except:
        os.remove(os.path.split(audioFileNew)[-1])
        return audioFileNew 

# def remove_punc(predicted_number):
#     predicted_number = predicted_number.replace('.', ' ')
#     predicted_number = predicted_number.replace(',', ' ')
#     predicted_number = predicted_number.replace('?', ' ')
#     predicted_number = predicted_number.replace('!', ' ')
#     predicted_number = predicted_number.replace('-', ' ')
#     predicted_number = predicted_number.replace('_', ' ')
#     predicted_number = predicted_number.replace(';', ' ')
#     predicted_number = predicted_number.replace(':', ' ')
#     predicted_number = predicted_number.replace('(', ' ')
#     predicted_number = predicted_number.replace(')', ' ')
#     predicted_number = predicted_number.replace('[', ' ')
#     predicted_number = predicted_number.replace(']', ' ')
#     predicted_number = predicted_number.replace('{', ' ')
#     predicted_number = predicted_number.replace('}', ' ')
#     predicted_number = predicted_number.replace('/', ' ')
#     predicted_number = predicted_number.replace('\\', ' ')
#     predicted_number = predicted_number.replace('|', ' ')
#     predicted_number = predicted_number.replace('\'', ' ')
#     predicted_number = predicted_number.replace('\"', ' ')
#     predicted_number = predicted_number.replace('~', ' ')
#     return predicted_number

# def speech2text(audioFile):
#     r = sr.Recognizer()
#     with sr.AudioFile(audioFile) as source:
#         audio = r.record(source)
#     text = r.recognize_google(audio)
#     return text

# def speech2number(
#                 audioFile,
#                 use_hf = True
#                 ):
    
#     word2num = {
#                 'one': 1,
#                 'two or to': 2,
#                 'three': 3,
#                 'four': 4,
#                 'five': 5,
#                 'six': 6,
#                 'seven': 7,
#                 'eight': 8,
#                 'nine': 9
#                 }

#     if not use_hf:
#         r = sr.Recognizer()
#         with sr.AudioFile(audioFile) as source:
#             audio = r.record(source)
#         text = r.recognize_google(audio)
#         return text
    
#     else:
#         speech_array, _ = librosa.load(audioFile, sr=16_000)
#         inputs = s2t_processor(
#                                 speech_array, 
#                                 sampling_rate=16_000, 
#                                 return_tensors="pt", 
#                                 padding=True
#                                 )
#         with torch.no_grad():
#             logits = s2t_model(inputs.input_values, attention_mask=inputs.attention_mask).logits
#             predicted_ids = torch.argmax(logits, dim=-1)
#             predicted_number = s2t_processor.batch_decode(predicted_ids)[0]
        
#         predicted_number = remove_punc(predicted_number)
#         predicted_number = predicted_number.split(' ')
#         if len(predicted_number) in [9, 10]:
#                 if len(predicted_number) == 10:
#                     predicted_number = predicted_number[1:]
#                 predicted_number = [p.strip() for p in predicted_number]
#                 pred_json = zsc_pipeline(
#                                             predicted_number, 
#                                             candidate_labels = [
#                                                                 'one',
#                                                                 'two or to',
#                                                                 'three',
#                                                                 'four',
#                                                                 'five',
#                                                                 'six',
#                                                                 'seven',
#                                                                 'eight',
#                                                                 'nine'
#                                                                 ])
#                 pred_numbers = []
#                 for p in pred_json:
#                     labels = p['labels']
#                     scores = p['scores']
#                     max_score = max(scores)
#                     label = labels[scores.index(max_score)]
#                     pred_numbers.append(word2num[label])

#                 pred_numbers = '0' + ''.join([str(p) for p in pred_numbers])
#                 pred_numbers = pred_numbers.replace(' ', '')
#                 return pred_numbers
#         else:
#             print("Invalid number. only contains {} digits".format(len(predicted_number)))
#             return None
        
def speech2number(phone_number_text):
    phone_number = ''
    for word in phone_number_text.split(' '):
        word = difflib.get_close_matches(word, ['බින්දුව', 'බින්දුවයි',
                                            'එක' , 'එකයි' ,
                                            'දෙක' ,  'දෙකයි' ,
                                            'තුන' , 'තුනයි' ,
                                            'හතර' , 'හතරයි' ,
                                            'පහ' , 'පහයි' ,
                                            'හය', 'හයයි',
                                            'හත', 'හතයි',
                                            'අට', 'අටයි',
                                            'නවය', 'නවයයි'])[0]
        if word in ['බින්දුව', 'බින්දුවයි']:
            phone_number += '0'
        elif word in ['එක' , 'එකයි']:
            phone_number += '1'
        elif word in ['දෙක' ,  'දෙකයි']:
            phone_number += '2'
        elif word in ['තුන' , 'තුනයි']:
            phone_number += '3'
        elif word in ['හතර' , 'හතරයි']:
            phone_number += '4'
        elif word in ['පහ' , 'පහයි']:
            phone_number += '5'
        elif word in ['හය', 'හයයි']:
            phone_number += '6'
        elif word in ['හත', 'හතයි']:
            phone_number += '7'
        elif word in ['අට', 'අටයි']:
            phone_number += '8'
        elif word in ['නවය', 'නවයයි']:
            phone_number += '9'
    return phone_number
        
def enhance_audio(
                 audio_file,
                 decible_increment = 10
                 ):
    audio_file = audio_file.replace('\\', '/')
    try:
        audio = AudioSegment.from_wav(audio_file)
    except:
        print("Error in reading audio file: {}".format(audio_file))

    audio = audio.low_pass_filter(1000)
    audio = audio.high_pass_filter(1000)
    audio = audio + decible_increment
    
    if ('/denoised_wav/' in audio_file):
        audioFileEnhanced = audio_file.replace('/denoised_wav/', '/enhanced_wav/')
        if os.path.exists(audioFileEnhanced):
            os.remove(audioFileEnhanced)
    else:
        audioFileEnhanced = audio_file

    file_name = os.path.split(audio_file)[-1].split('.')[0]
    file_name_enhnaced = file_name.split('_')[0] + '_' + str(uuid.uuid4())
    audioFileEnhanced = audioFileEnhanced.replace(file_name, file_name_enhnaced)
    audio.export(audioFileEnhanced, format="wav")
    return audioFileEnhanced

In [17]:
def load_audio(audio_file):

    audio_data = Dataset.from_dict(
                                    {"audio": [audio_file]}
                                    ).cast_column("audio", Audio())
    audio_data = audio_data.cast_column(
                                        "audio", 
                                        Audio(sampling_rate=16000)
                                        )
    audio_data = audio_data[0]['audio']['array']
    return audio_data

def speech2text(audio_file):
    audio_data = load_audio(audio_file)
    input_features = sin_s2t_processor(
                                audio_data, 
                                sampling_rate=16000, 
                                return_tensors="pt"
                                ).input_features
    predicted_ids = sin_s2t_model.generate(
                                    input_features, 
                                    forced_decoder_ids=sin_s2t_forced_decoder_ids
                                    )
    
    # transcription = s2t_processor.batch_decode(predicted_ids)
    transcription = sin_s2t_processor.batch_decode(
                                                predicted_ids, 
                                                skip_special_tokens=True
                                                )
    return transcription[0]

In [18]:
def preprocessing_number_pipeline(audioFile):
    if audioFile.endswith('.mp3') or audioFile.endswith('.ogg'):
        audioFileNew = mp3toWav(audioFile)
    else:
        audioFileNew = audioFile
    enhancedAudioFile = audio_denoising(audioFileNew)
    transcription = speech2text(enhancedAudioFile)
    number = speech2number(transcription)
    return number

def preprocessing_speech_pipeline(audioFile):
    if audioFile.endswith('.mp3') or audioFile.endswith('.ogg'):
        audioFileNew = mp3toWav(audioFile)
    else:
        audioFileNew = audioFile
    enhancedAudioFile = audio_denoising(audioFileNew)
    text = speech2text(enhancedAudioFile)
    return text

In [19]:
number = preprocessing_number_pipeline('data/audio_store/mp3/7.ogg')
number

'0705082391'

In [20]:
def preprocessing_answer_recognition(audioFile):
    text = preprocessing_speech_pipeline(audioFile)
    word = difflib.get_close_matches(
                                    text, [
                                            'ඔව්', 'ඕඕ', 'හරි',
                                            'නෑ', 'නැහැ', 'නැත', 'එපා'
                                            ])
    if len(word) > 0:
        word = word[0]
        if word in ['ඔව්', 'ඕඕ', 'හරි']:
            return 'ඔව්'
        elif word in ['නෑ', 'නැහැ', 'නැත', 'එපා']:
            return 'නැත'
        
    else:
        if ('ඔ' in text) or ('ඕ' in text) or ('හ' in text):
            return 'ඔව්'
        elif ('නෑ' in text) or ('නැ' in text) or ('එ' in text):
            return 'නැත'
        
    return np.random.choice(['ඔව්', 'නැත'])
        

In [21]:
preprocessing_answer_recognition('data/audio_store/wav/OW2.wav')

'ඔව්'