<a href="https://colab.research.google.com/github/SunbirdAI/salt/blob/main/notebooks/Test_ASR_data_labels.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/sunbirdai/salt.git
!pip install -qr salt/requirements.txt

In [None]:
import yaml
import io
import soundfile as sf
import transformers
import datasets
import torch
import librosa
import numpy as np
import editdistance
from IPython import display
from huggingface_hub  import login
from tqdm.notebook import tqdm
import math
import salt.utils

In [None]:
login()

In [None]:
# The repository to check
repository_path = 'Sunbird/salt-corrected'
splits = ['train', 'dev', 'test']
languages = ['lug', 'eng', 'nyn', 'lgg', 'teo', 'ach']
config_names = [f'corrected-{language}' for language in languages]

# Where to store the ASR transcriptions of all audio
transcriptions_path = 'evie-8/salt-corrected-asr-data-transcriptions'
# Set to true to upload the transcriptions to HuggingFace
upload_transcriptions = False

# The model used to transcribe the audio
asr_model = 'Sunbird/sunbird-mms'

# Apply an ASR model to the audio recordings

We try to detect any speech in the corresponding language, and then check for discrepancies with the label. This only needs to be run once.

In [None]:
BATCH_SIZE = 64
MODEL_SAMPLE_RATE = 16_000

def bytes_to_audio(byte_data):

    audio_file = io.BytesIO(byte_data)

    audio, sample_rate = sf.read(audio_file)
    return audio, sample_rate


def transcribe_batch(audio_batch, sample_rate_batch):
    audio_resampled_batch = [
        librosa.resample(
            audio, orig_sr=orig_sr, target_sr=MODEL_SAMPLE_RATE)
            if orig_sr != MODEL_SAMPLE_RATE else audio
            for audio, orig_sr in zip(audio_batch, sample_rate_batch)
    ]
    inputs_batch = [
        processor(
            audio, sampling_rate=16_000, return_tensors="pt").to(device)
        for audio in audio_resampled_batch
    ]
    with torch.no_grad():
        outputs_batch = [
            model(**input).logits for input in inputs_batch
        ]
    return [
        processor.decode(torch.argmax(output, dim=-1)[0])
        for output in outputs_batch
    ]

# Loop over languages and splits
processor = transformers.AutoProcessor.from_pretrained(asr_model)
model = transformers.Wav2Vec2ForCTC.from_pretrained(asr_model)
device = 'cuda:0'
model.to(device)

for language, config_name in zip(languages, config_names):

  if language == 'eng':
    model.load_adapter('eng')
  else:
    model.load_adapter(f'{language}+eng')

  processor.tokenizer.set_target_lang(language)

  for split in splits:

    print(f'dataset: {config_name}, split: {split}')

    ds = datasets.load_dataset(
        repository_path, config_name, split=split)
    df = ds.to_pandas()

    # Batching
    n = len(df)
    n_batches = math.ceil(n / BATCH_SIZE)

    # Progress bar
    pbar = tqdm(total=n_batches, desc="Transcribing")

    # Placeholder for results
    transcriptions = []

    for i in range(n_batches):
        start = i * BATCH_SIZE
        end = start + BATCH_SIZE
        sample = df.iloc[start:end]

        audio_batch = []
        sample_rate_batch = []

        for audio_bytes in sample['audio']:
            audio, sample_rate = bytes_to_audio(audio_bytes['bytes'])
            audio_batch.append(audio)
            sample_rate_batch.append(sample_rate)

        transcriptions_batch = transcribe_batch(audio_batch, sample_rate_batch)
        transcriptions.extend(transcriptions_batch)

        # Update progress bar
        pbar.update()

    # Assign transcriptions back to dataframe
    df["transcription"] = transcriptions

    # Close progress bar
    pbar.close()

    edit_distances = [editdistance.eval(t1.lower(), t2)
                      for t1, t2 in zip(df['text'], transcriptions)]

    df['edit_distance'] = edit_distances
    del(df['audio'])

    suspicious = np.where(np.array(edit_distances) > 30)[0]
    print(f'Found {len(suspicious)} suspicious entries out of '
          f'{len(df)} in {config_name}/{split}')

    ds = datasets.Dataset.from_pandas(df)
    if upload_transcriptions:
      ds.push_to_hub(
          transcriptions_path,
          config_name=config_name, private=False, split=split)

# Investigate misaligments

Examine the 'suspicious' cases where there seems to be a high edit distance between the ASR-derived transcription and the label.

In [None]:
# Threshold at which we decide a label is 'suspicious', as it deviates too
# much from the ASR transcription. Make this lower to be more strict.
edit_distance_threshold = 30

In [None]:
for language, config_name in zip(languages, config_names):
  print(f"Language: {language}")
  for split in splits:

    ds_transcriptions = datasets.load_dataset(
        transcriptions_path,
        name=config_name,
        split='validation' if split == 'dev' else split
    )
    # TODO: change the split name in the transcription repo
    df_transcriptions = ds_transcriptions.to_pandas()

    ds_source = datasets.load_dataset(
        repository_path, name=config_name, split=split)
    df_source = ds_source.to_pandas()

    suspicious_ids = []
    for id, row in df_transcriptions.iterrows():
      if row['edit_distance'] > edit_distance_threshold:
        suspicious_ids.append(row['id'])

    print(
        f"   {language}/{split}: {len(suspicious_ids)} / {len(ds_transcriptions)} "
        "audio recordings don't match the text.")
    print(f"   suspicious: {suspicious_ids}\n\n" )

    #sampling mismatched audios
    if len(suspicious_ids):

      '''
      suspicious_file = df_org.loc[df_org['id'] == ids[0]]
      file_transcribed = df.loc[df['id'] == ids[0]]
      audio_data = suspicious_file['audio'].values[0]

      audio, sample_rate = bytes_to_audio(audio_data['bytes'])
      display.display(display.Audio(audio, rate=sample_rate))
      print(f"Text : {file_transcribed['text'].values[0]}")
      print(f"Transcription: {file_transcribed['transcription'].values[0]}")
      '''
      df_suspicious = df_transcriptions[df_transcriptions['id'].isin(suspicious_ids)]

      # TODO: remove these debugging break statements, which stop the execution
      # after processing the first split.
      break
    break
  break


In [None]:
df_transcriptions['audio'] = df_source['audio']
df_suspicious = df_transcriptions[
    df_transcriptions['edit_distance'] > edit_distance_threshold]
df_suspicious = df_suspicious[
    ['id', 'text', 'transcription', 'audio']]
df_suspicious['audio'] = [bytes_to_audio(audio['bytes'])[0].astype(np.float32)
                          for audio in list(df_suspicious['audio'])]
df_suspicious['id'] = np.array(
    [int(id) for id in list(df_suspicious['id'])],
    dtype=np.int32)

In [None]:
salt.utils.show_dataset(df_suspicious, audio_features=['audio'], N=5)

Next steps: find a way of nicely presenting the suspicious entries, so that we can step through them one by one and decide on the action (remove from dataset for invalid entries, or change the ID for misaligned entries).

Bonus: a nice table showing the number of suspicious entries for each language/split.