In [None]:
import os
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, AutoProcessor
from scipy.signal import resample
from tqdm import tqdm
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
import soundfile as sf
from transformers import Wav2Vec2ForCTC
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
import re

In [2]:
torch.cuda.is_available()

  return torch._C._cuda_getDeviceCount() > 0


False

In [14]:
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/mms-1b-all", language="ara", task="transcribe")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/mms-1b-all")
processor = AutoProcessor.from_pretrained("facebook/mms-1b-all", language="ara", task="transcribe")

In [22]:
train_inputs_path = "/mnt/nfs/stt_project/dataset/reupload/train/"
train_labels_path = "/mnt/nfs/stt_project/dataset/train-txt/"
train_dataset = []
errors = 0
for audio_file in os.listdir(train_inputs_path):
    audio_data, sample_rate = sf.read(os.path.join(train_inputs_path, audio_file))
    label_file = os.path.join(train_labels_path, audio_file.split('.')[0] + ".txt")
    try:
        with open(label_file, "r", encoding="utf-8-sig") as f:
            text = f.read().strip()

        train_dataset.append(
            {"audio_data": audio_data, "sample_rate": sample_rate, "sentence": text}
        )
    except:
        print(f"Error openning {label_file}")
        errors += 1

Error openning /mnt/nfs/stt_project/dataset/train-txt/00080666-d88f-4107-91bd-21a417aad6a1.txt
Error openning /mnt/nfs/stt_project/dataset/train-txt/003a2894-a344-4b1f-a17f-b6436ee176e8.txt
Error openning /mnt/nfs/stt_project/dataset/train-txt/004e4c8d-035d-4286-a983-190fcdd41185.txt
Error openning /mnt/nfs/stt_project/dataset/train-txt/004f9eda-9978-459d-bd02-a85d4e89a527.txt
Error openning /mnt/nfs/stt_project/dataset/train-txt/0059e2a3-7c70-4afa-9624-eb6120f76d29.txt
Error openning /mnt/nfs/stt_project/dataset/train-txt/0077663c-da83-4290-92c1-ee77f537b228.txt
Error openning /mnt/nfs/stt_project/dataset/train-txt/00830b0e-f77b-4234-b449-56cb0ce4659b.txt
Error openning /mnt/nfs/stt_project/dataset/train-txt/0086a288-98f1-4bfc-bdf6-eb056c4a8052.txt
Error openning /mnt/nfs/stt_project/dataset/train-txt/008c5f45-aa06-42d1-9e24-137ee4e85fef.txt
Error openning /mnt/nfs/stt_project/dataset/train-txt/00961615-9a6b-464e-9914-50a174d90ccf.txt
Error openning /mnt/nfs/stt_project/dataset/train-

In [23]:
errors

4255

In [24]:
len(train_dataset)

31015

In [25]:
train_dataset[0]

{'audio_data': array([ 0.18359375, -0.4453125 , -0.4296875 , ...,  0.12890625,
         0.08007812,  0.11523438]),
 'sample_rate': 8000,
 'sentence': 'من المكان ارسالك م-'}

In [26]:
test_inputs_path = "/mnt/nfs/stt_project/dataset/reupload/test/"
test_labels_path = "/mnt/nfs/stt_project/dataset/test-txt/"
test_dataset = []
errors = 0
for audio_file in os.listdir(test_inputs_path):
    audio_data, sample_rate = sf.read(os.path.join(test_inputs_path, audio_file))
    label_file = os.path.join(test_labels_path, audio_file.split('.')[0] + ".txt")
    try:
        with open(label_file, "r", encoding="utf-8-sig") as f:
            text = f.read().strip()

        test_dataset.append(
            {"audio_data": audio_data, "sample_rate": sample_rate, "sentence": text}
        )
    except:
        print(f"Error openning {label_file}")
        errors += 1

Error openning /mnt/nfs/stt_project/dataset/test-txt/00a6d967-6af0-4018-bb31-6633be63dc1d.txt


In [27]:
# resample data - to remove once data is already in 16k sr
for record in train_dataset + test_dataset:
    data = record['audio_data']
    origin_sr = record['sample_rate']
    expected_sr = 16000
    data_resampled = resample(data, int(len(data) * expected_sr / origin_sr), axis=0)
    record['sample_rate'] = expected_sr
    record['audio_data'] = data_resampled

In [28]:
train_dataset[0]

{'audio_data': array([ 0.18359375, -0.01122654, -0.4453125 , ...,  0.11501621,
         0.11523437,  0.14891549]),
 'sample_rate': 16000,
 'sentence': 'من المكان ارسالك م-'}

In [29]:
test_dataset[0]

{'audio_data': array([ 0.00186157,  0.06555176,  0.00747681, ...,  0.29241943,
        -0.09805298, -0.30621338]),
 'sample_rate': 16000,
 'sentence': 'جيد استلمت الو نسر اثنين ناصر ثلاث وعشرون'}

In [30]:
def prepare_record(record):
    datum = {}
    # compute log-Mel input features from input audio array 
    datum["input_values"] = feature_extractor(record["audio_data"], sampling_rate=record["sample_rate"]).input_values[0]

    # encode target text to label ids 
    datum["labels"] = tokenizer(record["sentence"]).input_ids
    return datum

In [31]:
train_prepared_records = []
for record in tqdm(train_dataset):
    train_prepared_records.append(prepare_record(record))

100%|██████████| 31015/31015 [00:21<00:00, 1415.59it/s]


In [32]:
test_prepared_records = []
for record in tqdm(test_dataset):
    test_prepared_records.append(prepare_record(record))

100%|██████████| 1348/1348 [00:00<00:00, 1369.33it/s]


In [33]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        model_input_name = self.processor.model_input_names[0]
        input_features = [{"input_values": feature[model_input_name]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [34]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [36]:
metric = evaluate.load("wer")

In [37]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [38]:
model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all")

Downloading (…)lve/main/config.json: 100%|██████████| 2.04k/2.04k [00:00<00:00, 182kB/s]
Downloading model.safetensors: 100%|██████████| 3.86G/3.86G [00:55<00:00, 70.0MB/s]
Some weights of the model checkpoint at facebook/mms-1b-all were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/mms-1b-all and are newly initialized: ['wav2vec2.encoder.pos_conv_e

In [54]:
model.init_adapter_layers()
adapter_weights = model._get_adapters()
for param in adapter_weights.values():
    param.requires_grad = True

In [67]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./mms-1b-all-adapter",  # change to a repo name of your choice
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    group_by_length=False,
    warmup_steps=100,
    gradient_checkpointing=True,
    evaluation_strategy="steps",
    predict_with_generate=True,
    num_train_epochs=3,
    save_steps=2584,
    eval_steps=2584,
    logging_steps=25,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
)

In [68]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_prepared_records,
    eval_dataset=test_prepared_records,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

In [69]:
len(trainer.get_train_dataloader())

7754

False

In [70]:
trainer.train()



Step,Training Loss,Validation Loss


KeyboardInterrupt: 