In [None]:
!git clone https://github.com/sunbirdai/leb.git
!pip install -qr leb/requirements.txt
!pip install -q editdistance
!pip install -q transformers

In [10]:
import yaml
import transformers
import datasets
import torch
import librosa
import editdistance
from IPython import display
from tqdm.notebook import tqdm
import math
import numpy as np

In [21]:
BATCH_SIZE = 64
MODEL_SAMPLE_RATE = 16_000

all_transcriptions = {}

def transcribe_batch(audio_batch, sample_rate_batch):
    audio_resampled_batch = [
        librosa.resample(audio, orig_sr=orig_sr, target_sr=MODEL_SAMPLE_RATE) if orig_sr != MODEL_SAMPLE_RATE else audio
        for audio, orig_sr in zip(audio_batch, sample_rate_batch)
    ]
    inputs_batch = [
        processor(audio, sampling_rate=16_000, return_tensors="pt").to(device) for audio in audio_resampled_batch
    ]
    with torch.no_grad():
        outputs_batch = [
            model(**input).logits for input in inputs_batch
        ]
    return [
        processor.decode(torch.argmax(output, dim=-1)[0]) for output in outputs_batch
    ]

# Loop over languages and splits
languages = ['lug'] #['eng', 'nyn', 'lgg', 'teo', 'ach']

model_id = "Sunbird/sunbird-mms"
processor = transformers.AutoProcessor.from_pretrained(model_id)
model = transformers.Wav2Vec2ForCTC.from_pretrained(model_id)
device = 'cuda:0'
model.to(device)

for language in languages:

  all_transcriptions[language] = {}

  config_name = 'multispeaker-' + language
  model.load_adapter(language)
  processor.tokenizer.set_target_lang(language)

  for split in ['dev', 'test', 'train']:

    print(f'dataset: {config_name}, split: {split}')

    ds = datasets.load_dataset(
        'Sunbird/salt', f'multispeaker-{language}', split=split)
    df = ds.to_pandas()

    # Batching
    n = len(df)
    n_batches = math.ceil(n / BATCH_SIZE)

    # Progress bar
    pbar = tqdm(total=n_batches, desc="Transcribing")

    # Placeholder for results
    transcriptions = []

    for i in range(n_batches):
        start = i * BATCH_SIZE
        end = start + BATCH_SIZE
        sample = df.iloc[start:end]
        audio_batch = sample["audio"].tolist()
        sample_rate_batch = sample["sample_rate"].tolist()
        transcriptions_batch = transcribe_batch(audio_batch, sample_rate_batch)
        transcriptions.extend(transcriptions_batch)

        # Update progress bar
        pbar.update()

    # Assign transcriptions back to dataframe
    df["transcription"] = transcriptions

    # Close progress bar
    pbar.close()

    edit_distances = [editdistance.eval(t1.lower(), t2)
                      for t1, t2 in zip(df['text'], transcriptions)]

    df['edit_distance'] = edit_distances

    del(df['audio'])

    suspicious = np.where(np.array(edit_distances) > 30)[0]
    print(f'Found {len(suspicious)} suspicious entries out of {len(df)} in {config_name}/{split}')

    all_transcriptions[language][split] = df.to_dict()

Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream

adapter.lug.safetensors:   0%|          | 0.00/9.04M [00:00<?, ?B/s]

dataset: multispeaker-lug, split: dev


Downloading data:   0%|          | 0.00/194M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/186M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/234M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/175M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/183M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/19.1M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/19.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5016 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/103 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/97 [00:00<?, ? examples/s]

Transcribing:   0%|          | 0/2 [00:00<?, ?it/s]

Found 1 suspicious entries out of 103 in multispeaker-lug/dev
dataset: multispeaker-lug, split: test


Transcribing:   0%|          | 0/2 [00:00<?, ?it/s]

Found 1 suspicious entries out of 97 in multispeaker-lug/test
dataset: multispeaker-lug, split: train


Transcribing:   0%|          | 0/79 [00:00<?, ?it/s]

Found 19 suspicious entries out of 5016 in multispeaker-lug/train


In [25]:
import json
with open('SALT-transcriptions.json', 'w') as fp:
    json.dump(all_transcriptions, fp)