In [None]:
import os
import re
import sys

import librosa
import numpy as np
import pandas as pd
import torch
import whisper
import jiwer
from datasets import load_dataset
from TTS.api import TTS
from tqdm.notebook import tqdm

In [None]:
class hide_print:
    """
    Redirect stdout to stdnull
    """

    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, "w")

    def __exit__(self, *args):
        sys.stdout.close()
        sys.stdout = self._original_stdout

In [None]:
noise, sr = librosa.load("restaurant.wav")  # Load background noise


def add_noise(waveform, noise, snr):
    """
    Add noise to an audio signal while respecting a SNR
    """
    L = len(waveform)
    if L > len(noise):
        noise = np.repeat(noise, 1 + L // len(noise))
    noise = noise[:L]
    energy_signal = np.linalg.norm(waveform, ord=2) ** 2
    energy_noise = np.linalg.norm(noise, ord=2) ** 2
    original_snr_db = 10 * (np.log10(energy_signal) - np.log10(energy_noise))
    scale = 10 ** ((original_snr_db - snr) / 20.0)

    # scale noise
    return np.asarray(waveform) + np.asarray(noise) * scale

In [None]:
# Load tacotron
with hide_print():
    tts = TTS("tts_models/en/ljspeech/tacotron2-DDC")

# Load whisper
model = whisper.load_model("small.en")

In [None]:
test_data = load_dataset("silicone", "dyda_da", split="test")

In [None]:
expr = re.compile(r"([^a-zA-Z0-9' ]| +)+")
normalize = lambda x: re.sub(expr, " ", x).lower().strip(" ")


def wer(s1, s2):
    return jiwer.wer(normalize(s1), normalize(s2))

In [None]:
data = test_data[:10]  # test_data[:] to regenerate all of NoDA. It will take several hours

texts, labels = [], []

for i in tqdm(range(len(data["Utterance"]))):
    utterance, label = data["Utterance"][i], data["Label"][i]
    with hide_print():
        audio = tts.tts(utterance)
    audio = add_noise(audio, noise, snr=5)
    stt_utterance = model.transcribe(torch.tensor(audio).float())["text"]
    texts.append(stt_utterance.lower())
    labels.append(label)

In [None]:
df = pd.DataFrame({"text": texts, "label": labels})
df