In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm_notebook
import evaluate
import syllables
from audiomentations import (
    AddBackgroundNoise, 
    AddGaussianNoise,
    AddGaussianSNR, 
    LoudnessNormalization, 
    PitchShift,
    Shift,
    TimeStretch
)
from datasets import load_dataset
from transformers import (
    Speech2TextForConditionalGeneration,
    Speech2TextProcessor
)

In [2]:
device = torch.device("cuda")
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")
model.to(device)

Speech2TextForConditionalGeneration(
  (model): Speech2TextModel(
    (encoder): Speech2TextEncoder(
      (conv): Conv1dSubsampler(
        (conv_layers): ModuleList(
          (0): Conv1d(80, 1024, kernel_size=(5,), stride=(2,), padding=(2,))
          (1): Conv1d(512, 512, kernel_size=(5,), stride=(2,), padding=(2,))
        )
      )
      (embed_positions): Speech2TextSinusoidalPositionalEmbedding()
      (layers): ModuleList(
        (0-11): 12 x Speech2TextEncoderLayer(
          (self_attn): Speech2TextAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (activation_fn): ReLU()
          (fc1): Linear(in_features=2

In [3]:
wer = evaluate.load("wer")

In [4]:
ds = load_dataset("google/fleurs", 'en_us', split="validation")


Found cached dataset fleurs (/home/dcek/.cache/huggingface/datasets/google___fleurs/en_us/2.0.0/af82dbec419a815084fa63ebd5d5a9f24a6e9acdf9887b9e3b8c6bbd64e0b7ac)


In [5]:
gender_map = {'female':1, 'male':0}
def get_wer_scores(dataset, transcriptions=None, sampling_rates=None, is_hg_ds=False):
    all_wer_score = []

    for idx, audio_data in tqdm_notebook(enumerate(dataset), total=len(dataset)):
        inputs = processor(
            audio_data["audio"]["array"] if is_hg_ds else audio_data, 
            sampling_rate=audio_data["audio"]["sampling_rate"] if is_hg_ds else sampling_rates[idx],
            return_tensors="pt"
        )
        generated_ids = model.generate(
            inputs["input_features"].to(device), 
            attention_mask=inputs["attention_mask"].to(device)
        )
        transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
        wer_score = wer.compute(
            predictions=transcription, 
            references=[audio_data['transcription'] if is_hg_ds else transcriptions[idx]]
        )
        all_wer_score.append(wer_score)
    all_wer_score = np.array(all_wer_score)
    wer_score_results = {}
    for gender in gender_map.keys():
        gender_idx = np.where(all_gender == gender_map[gender])[0]
        wer_score_results[gender + '_wer_score'] = all_wer_score[gender_idx].mean()
    wer_score_results['wer_score'] = all_wer_score.mean()
    return wer_score_results



In [6]:
def get_augmented_samples_wer_results(
    all_baseline_samples, augment, transcriptions, all_sampling_rates
):
    all_augmented_samples = []
    for idx, audio_sample in enumerate(all_baseline_samples):
        augmented_samples = augment(samples=audio_sample, sample_rate=all_sampling_rates[idx])
        all_augmented_samples.append(augmented_samples)
    results = get_wer_scores(
        all_augmented_samples, transcriptions, sampling_rates=all_sampling_rates, is_hg_ds=False
    )
    return results

In [7]:
all_syllables_per_second = []
for audio_data in ds:
    num_syllables = syllables.estimate(audio_data['transcription'])
    syllables_per_second = num_syllables / (audio_data['num_samples'] / audio_data['audio']['sampling_rate'])
    all_syllables_per_second.append(syllables_per_second)
    average_syllables_per_second = np.mean(all_syllables_per_second)

In [8]:
all_baseline_speed_audio_samples = []
transcriptions = []
all_sampling_rates = []
for idx, audio_data in tqdm_notebook(enumerate(ds), total=len(ds)):
    rate = average_syllables_per_second / all_syllables_per_second[idx]
    augment = TimeStretch(min_rate=rate, max_rate=rate, p=1.0)
    augmented_samples = augment(
        samples=audio_data['audio']['array'], 
        sample_rate=audio_data['audio']['sampling_rate']
    )
    transcriptions.append(audio_data['transcription'])
    all_sampling_rates.append(audio_data['audio']['sampling_rate'])
    all_baseline_speed_audio_samples.append(augmented_samples)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for idx, audio_data in tqdm_notebook(enumerate(ds), total=len(ds)):


  0%|          | 0/394 [00:00<?, ?it/s]



In [9]:
all_gender = np.array(ds['gender'])

In [None]:
rates = np.linspace(0.1, 1, 9).tolist() + list(range(1, 11))
wer_results_by_rate = []
for rate_to_change in tqdm_notebook(rates): 
    augment = TimeStretch(min_rate=rate_to_change, max_rate=rate_to_change, p=1.0)
    results = get_augmented_samples_wer_results(
        all_baseline_speed_audio_samples, augment, transcriptions, all_sampling_rates
    )
    wer_results_by_rate.append(results)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for rate_to_change in tqdm_notebook(rates):


  0%|          | 0/19 [00:00<?, ?it/s]

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for idx, audio_data in tqdm_notebook(enumerate(dataset), total=len(dataset)):


  0%|          | 0/394 [00:00<?, ?it/s]

  0%|          | 0/394 [00:00<?, ?it/s]

In [None]:
labels = ["female", "male", "overall"]
plt.xlabel("Syllables per second")
plt.ylabel("WER")
for idx, gender in enumerate(["female_", "male_", ""]):
    plt.plot(
        [average_syllables_per_second * i for i in rates],
        [wr[gender + 'wer_score'] for wr in 
        wer_results_by_rate], 
        label=labels[idx]
    )
    plt.legend()

In [None]:
baseline_samples = [audio_data['audio']['array'] for 
audio_data in ds]
snr_rates = np.linspace(1, 100, 25)
wer_results_by_snr = []
for snr_rate in tqdm_notebook(snr_rates): 
    all_augmented_samples = []
    augment = AddGaussianSNR(
        min_snr_in_db=snr_rate,
        max_snr_in_db=snr_rate,
        p=1.0
    )
    results = get_augmented_samples_wer_results(
        baseline_samples, augment, transcriptions, all_sampling_rates
    )
    wer_results_by_snr.append(results)

In [None]:
plt.xlabel(“SNR (dB)”)
plt.ylabel(“WER”)
for idx, gender in enumerate(["female_", "male_", ""]):
    plt.plot(
        snr_rates,
        [wr[gender + ‘wer_score’] for wr in wer_results_by_snr], 
        label=labels[idx]
    )
    plt.legend()

In [None]:
wer_results_by_loudness = []
loudness_db = np.linspace(-31, 100, 25)
for db in tqdm_notebook(loudness_db): 
    augment = LoudnessNormalization(
        min_lufs_in_db=db,
        max_lufs_in_db=db,
        p=1.0
    )
    results = get_augmented_samples_wer_results(
        baseline_samples, augment, transcriptions, all_sampling_rates
    )
    wer_results_by_loudness.append(results)

In [None]:
labels = ["female", "male", "overall"]
plt.xlabel("SNR (dB)")
plt.ylabel("WER")
for idx, gender in enumerate(["female_", "male_", ""]):
    plt.plot(
        loudness_db,
        [wr[gender + 'wer_score'] for wr in 
         wer_results_by_loudness], 
        label=labels[idx]
    )
    plt.legend()

In [None]:
snrs = np.linspace(-50, 50, 20)
wer_results_by_background_noise_snr = []
for snr in tqdm_notebook(snrs): 
    augment = AddBackgroundNoise(
        sounds_path="motorbikes",
        min_snr_in_db=snr,
        max_snr_in_db=snr,
        p=1.0
    )
    results = get_augmented_samples_wer_results(
        baseline_samples, augment, transcriptions, all_sampling_rates
    )
    wer_results_by_background_noise_snr.append(results)


In [None]:
plt.xlabel("SNR (dB)")
plt.ylabel("WER")
for idx, gender in enumerate(["female_", "male_", ""]):
    plt.plot(
        snrs,
        [wr[gender + ‘wer_score’] for wr in wer_results_by_background_noise_snr], 
        label=labels[idx]
    )
    plt.legend()
