In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from os                import walk
from pydub             import AudioSegment
from pydub.utils       import get_array_type
from pydub.utils       import mediainfo
from pydub.silence     import split_on_silence
from datasets          import load_dataset
from torchaudio.utils  import download_asset
from scipy             import signal
from scipy.io          import wavfile
from matplotlib.pyplot import figure
from tqdm              import tqdm
from os                import listdir
from os.path           import isfile, join
from datetime          import datetime
from pyctcdecode       import build_ctcdecoder
from pprint            import pprint
from torch.utils.data  import Dataset, DataLoader

import matplotlib.pyplot as plt
import pandas            as pd
import numpy             as np
import soundfile         as sf

import re
import json
import requests
import whisper
import git
import os
import jiwer
import IPython
import array
import librosa
import torch
import torchaudio

<h1 style="background-color:LightGreen;"> <center> <a id='pipeline_cell'></a> Utils </center></h1>

In [4]:
def convert_to_16sr_file(source_path, dest_path):    
    speech, sr = librosa.load(source_path, sr=16000)
    sf.write(dest_path, speech, sr)


In [5]:
def convert_to_8sr_file(source_path, dest_path):    
    speech, sr = librosa.load(source_path, sr=8000)
    sf.write(dest_path, speech, sr)


In [6]:
def get_sample_rate(file):
    info          = mediainfo(file)
    sampling_rate = info['sample_rate']
    sampling_rate = int(sampling_rate)
    return sampling_rate

In [7]:
def HebrewNormalizer(hebrew_text):
    # --- step 1: remove sign characters
    #
    # ignore_characters = ",~!@#%^&*()-+/|<>[]*'?.{}"
    # for character in ignore_characters:
    #     hebrew_text = hebrew_text.replace(character, '')

        # --- step 2: replace signs
    hebrew_text = hebrew_text.replace('$', " דולר")
    hebrew_text = hebrew_text.replace('₪', " שח")
    hebrew_text = hebrew_text.replace('€', " יורו")
    # hebrew_text = hebrew_text.replace('.', " נקודה")
    hebrew_text = hebrew_text.replace('ת"א', "תל אביב")
    hebrew_text = hebrew_text.replace('ב"ש', "באר שבע")
    hebrew_text = hebrew_text.replace('ע"י', "על ידי")
    hebrew_text = hebrew_text.replace('אח"כ', "אחר כך")
    hebrew_text = hebrew_text.replace('\"', "")

    # for now we will not handle digits, we will have to handle digits if it costs us in the performance of the model
    # TODO: handle dates: 3/7 -> third of july
    # # --- step 3: replace numbers to words
    # dict_nums = {
    #     "0": "אפס",
    #     "1": "אחד",
    #     "2": "שתיים",
    #     "3": "שלוש",
    #     "4": "ארבע",
    #     "5": "חמש",
    #     "6": "שש",
    #     "7": "שבע",
    #     "8": "שמונה",
    #     "9": "תשע",
    #     "10": "עשר",
    # }
    # for digit, word in dict_nums.items():
    #     hebrew_text = hebrew_text.replace(digit, word)
    #
    # # --- step 4: replace female numbers to male numbers
    # dict_male = {
    #     "אחת": "אחד",
    #     "שתיים": "שניים",
    #     "שלושה": "שלוש",
    #     "ארבעה": "ארבע",
    #     "חמישה": "חמש",
    #     "שישה": "שש",
    #     "שבעה": "שבע",
    #     "תשעה": "תשע",
    #
    # }
    # for female, male in dict_male.items():
    #     hebrew_text = hebrew_text.replace(female, male)
    # postproccessing, removing special charcteres after handling and translating them
    valid_tokens = "פ ם ן ו ט א ר ק ף ך ל ח י ע כ ג ד ש ץ ת צ מ נ ה ב ס ז 1 2 3 4 5 6 7 8 9 0"
    valid_tokens = set([x.lower() for x in valid_tokens])
    # The caret in the character class ([^) means match anything but
    invalid_chars_regex = f"[^\s{re.escape(''.join(set(valid_tokens)))}]"


    """ DO ADAPT FOR YOUR USE CASE. this function normalizes the target text. """
    hebrew_text = re.sub(invalid_chars_regex, " ", hebrew_text)
    hebrew_text = re.sub("\s+", " ", hebrew_text).strip()
    # --- return result
    return hebrew_text    
   

<h1 style="background-color:LightGreen;"> <center> Preprocessing </center></h1>

In [11]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda', index=0)

In [8]:
ROBO_SOURCE = "/home/amitli/Datasets/Roboshaul/saspeech_gold_standard_v1.0/saspeech_gold_standard/wavs/"
ROBO_SR_16  = "/home/amitli/Datasets/Roboshaul/saspeech_gold_standard_v1.0/saspeech_gold_standard/WAV_SR_16K/"
ROBO_SR_8   = "/home/amitli/Datasets/Roboshaul/saspeech_gold_standard_v1.0/saspeech_gold_standard/WAV_SR_8K/"

In [9]:
convert = False

if convert is True:
    list_file_name      = [entry for entry in os.listdir(ROBO_SOURCE) if os.path.isfile(os.path.join(ROBO_SOURCE, entry))]

    for file_name in tqdm(list_file_name):

        full_file_name = f"{ROBO_SOURCE}/{file_name}"
        full_sr_16     = f"{ROBO_SR_16}/{file_name}"
        full_sr_8      = f"{ROBO_SR_8}/{file_name}"

        current_sr     = get_sample_rate(full_file_name)
        if 44100 != current_sr:
            print(f"Source Sample Rate: {current_sr}")
            continue

        convert_to_16sr_file(full_file_name, full_sr_16)
        convert_to_8sr_file(full_file_name, full_sr_8)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2986/2986 [15:33<00:00,  3.20it/s]


In [10]:
df_meta = pd.read_csv("/home/amitli/Datasets/Roboshaul/saspeech_gold_standard_v1.0/saspeech_gold_standard/metadata.csv", sep="|", names=["file", "gt", "gt2"], header=None)

In [12]:
df_meta.head()

Unnamed: 0,file,gt,gt2
0,gold_000_line_000,"שָׁלוֹם, צְלִיל אַבְרָהָם.","שָׁלוֹם, צְלִיל אַבְרָהָם."
1,gold_000_line_001,"לְגַמְרֵי, מַדְהִים, לֹא?","לְגַמְרֵי, מַדְהִים, לֹא?"
2,gold_000_line_002,וְדַוְוקָא בִּגְלַל שֶׁכּוּלָּנוּ הָיִינוּ עֲס...,וְדַוְוקָא בִּגְלַל שֶׁכּוּלָּנוּ הָיִינוּ עֲס...
3,gold_000_line_003,אָז הַיּוֹם אֲנַחְנוּ נְדַבֵּר עַל הַחֲקִירָה ...,אָז הַיּוֹם אֲנַחְנוּ נְדַבֵּר עַל הַחֲקִירָה ...
4,gold_000_line_004,הָרָמַת מָסַךְ מֵעַל סְבַךְ שֶׁל אִינְטֶרֶסִים...,הָרָמַת מָסַךְ מֵעַל סְבַךְ שֶׁל אִינְטֶרֶסִים...


In [33]:
class RoboshaulGoldDataset(Dataset):
        
    def __init__(self, source_folder, df_gt, device):        
        
        self.list_file_names = [entry for entry in os.listdir(ROBO_SOURCE) if os.path.isfile(os.path.join(ROBO_SOURCE, entry))]        
        self.l_file_name     = []
        self.l_full_path     = []
        self.l_gt            = []
        self.device          = device
        
        for file_name in tqdm(list_file_name):

            # --- get GT text            
            gt         = df_gt[df_gt['file'] == file_name[:-4]]['gt'].values[0]            
            full_src   = f"{source_folder}/{file_name}"       
            
            self.l_file_name.append(file_name)
            self.l_full_path.append(full_src)                
            self.l_gt.append(gt)
        

    def __len__(self):
        return len(self.l_gt)

    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        audio_file_path = self.l_full_path[idx]        
        audio           = whisper.load_audio(str(audio_file_path))
        audio           = whisper.pad_or_trim(audio)
        
        mel             = whisper.log_mel_spectrogram(audio).to(self.device)        
        
        sample          = {'mel':      mel, 
                          'text':      HebrewNormalizer(self.l_gt[idx]), 
                          'file':      self.l_file_name[idx],                          
                          'full_path': self.l_full_path[idx]}
        
        return sample

In [36]:
def run_whisper_on_rambo(loader, res_file_name, lang):
    df = pd.DataFrame()
    for batch in tqdm(loader):
        
        languages        = []
        if lang is not None:
            languages = [lang]        
        mel              = batch['mel']
        audio_data       = {'wav': json.dumps(mel.tolist()), 'languages': languages}
        gt               = batch['text']

        #res              = requests.get('http://10.53.140.33:80/batch_inference/', json=audio_data)
        res              = requests.get('http://10.53.140.33:80/batch_inference_beam/', json=audio_data)
        res_list         = res.json()[0]

        l_wer            = []
        l_whisper        = []    
        l_res_lang       = []
        l_avg_logprob    = []
        l_no_speech_prob = []
        l_compres_ratio  = []

        for i, res in enumerate(res_list):
            whisper_text = res['text']
            whisper_text = HebrewNormalizer(whisper_text)
            l_whisper.append(whisper_text)                
            if whisper_text == '':
                l_wer.append(1)
            else:
                l_wer.append(jiwer.wer(whisper_text, gt[i]))
            l_res_lang.append(res['language'])
            l_avg_logprob.append(res['avg_logprob'])
            l_no_speech_prob.append(res['no_speech_prob'])
            l_compres_ratio.append(res['compression_ratio'])


        df_tmp     = pd.DataFrame({
            "whisper":           l_whisper,
            "gt":                gt,
            "wer":               l_wer,
            "file":              batch['file'],            
            "full_path":         batch['full_path'],
            "detect_lang":       l_res_lang,
            "avg_logprob":       l_avg_logprob,
            "no_speech_prob":    l_no_speech_prob,
            "compression_ratio": l_compres_ratio,

        })
        df = pd.concat([df, df_tmp], ignore_index=True)        
        df.to_csv(res_file_name)
        
            
                        

<h1 style="background-color:LightGreen;"> <center> Test With SR=16K </center></h1>

In [37]:
dataset_16k = RoboshaulGoldDataset(source_folder = "/home/amitli/Datasets/Roboshaul/saspeech_gold_standard_v1.0/saspeech_gold_standard/WAV_SR_16K",
                               df_gt         = df_meta,
                               device        = DEVICE)

loader_16k = DataLoader(dataset_16k, batch_size=20, shuffle=False)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2986/2986 [00:01<00:00, 2984.51it/s]


In [None]:
run_whisper_on_sr_16 = False
if run_whisper_on_sr_16 is True:
    run_whisper_on_rambo(loader_16k, "/home/amitli/Datasets/Roboshaul/saspeech_gold_standard_v1.0/saspeech_gold_standard/robo_whisper_16k.csv"):

<h1 style="background-color:LightGreen;"> <center> Test With SR=8K </center></h1>

In [38]:
dataset_8k = RoboshaulGoldDataset(source_folder = "/home/amitli/Datasets/Roboshaul/saspeech_gold_standard_v1.0/saspeech_gold_standard/WAV_SR_8K",
                               df_gt         = df_meta,
                               device        = DEVICE)

loader_8k = DataLoader(dataset_8k, batch_size=20, shuffle=False)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2986/2986 [00:00<00:00, 3053.25it/s]


In [None]:
run_whisper_on_sr_8 = False
if run_whisper_on_sr_8 is True:
    run_whisper_on_rambo(loader_8k, "/home/amitli/Datasets/Roboshaul/saspeech_gold_standard_v1.0/saspeech_gold_standard/robo_whisper_8k.csv"):

<h1 style="background-color:LightGreen;"> <center> Compare Results </center></h1>