In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['CUDA_VISIBLE_DEVICES'] = "3"
import json
import numpy as np
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor
from scipy.signal import resample
from tqdm import tqdm
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
import soundfile as sf
from transformers import Wav2Vec2ForCTC
from transformers import TrainingArguments, Trainer
from utils.clean_arabic import clean_arabic


In [2]:
torch.cuda.device_count()

1

In [3]:
train_inputs_path = "/mnt/nfs/stt_project/dataset/reupload/train/"
train_labels_path = "/mnt/nfs/dorten/cleaned_labels/train/"
train_dataset = []
errors = 0
for audio_file in os.listdir(train_inputs_path):
    audio_data, sample_rate = sf.read(os.path.join(train_inputs_path, audio_file))
    label_file = os.path.join(train_labels_path, audio_file.split('.')[0] + ".txt")
    try:
        with open(label_file, "r", encoding="utf-8-sig") as f:
            text = f.read().strip()
            text = clean_arabic(text)

        train_dataset.append(
            {"audio_data": audio_data, "sample_rate": sample_rate, "sentence": text}
        )
    except:
        print(f"Error openning {label_file}")
        errors += 1

Error openning /mnt/nfs/dorten/cleaned_labels/train/00080666-d88f-4107-91bd-21a417aad6a1.txt
Error openning /mnt/nfs/dorten/cleaned_labels/train/0027c38a-f0db-4cb9-bafc-27931e29c4e3.txt
Error openning /mnt/nfs/dorten/cleaned_labels/train/003a2894-a344-4b1f-a17f-b6436ee176e8.txt
Error openning /mnt/nfs/dorten/cleaned_labels/train/00403687-2817-4891-a240-c867bf819c7e.txt
Error openning /mnt/nfs/dorten/cleaned_labels/train/0040b773-61f0-4d12-b5b1-ff1f92d68463.txt
Error openning /mnt/nfs/dorten/cleaned_labels/train/004e4c8d-035d-4286-a983-190fcdd41185.txt
Error openning /mnt/nfs/dorten/cleaned_labels/train/004f9eda-9978-459d-bd02-a85d4e89a527.txt
Error openning /mnt/nfs/dorten/cleaned_labels/train/0059e2a3-7c70-4afa-9624-eb6120f76d29.txt
Error openning /mnt/nfs/dorten/cleaned_labels/train/0077663c-da83-4290-92c1-ee77f537b228.txt
Error openning /mnt/nfs/dorten/cleaned_labels/train/00830b0e-f77b-4234-b449-56cb0ce4659b.txt
Error openning /mnt/nfs/dorten/cleaned_labels/train/0086a288-98f1-4bfc

In [4]:
errors

7089

In [5]:
len(train_dataset)


28181

In [6]:
train_dataset[0]

{'audio_data': array([ 0.18359375, -0.4453125 , -0.4296875 , ...,  0.12890625,
         0.08007812,  0.11523438]),
 'sample_rate': 8000,
 'sentence': 'من المكان ارسالك م'}

In [7]:
test_inputs_path = "/mnt/nfs/stt_project/dataset/reupload/test/"
test_labels_path = "/mnt/nfs/stt_project/dataset/test-txt/"
test_dataset = []
errors = 0
for audio_file in os.listdir(test_inputs_path):
    audio_data, sample_rate = sf.read(os.path.join(test_inputs_path, audio_file))
    label_file = os.path.join(test_labels_path, audio_file.split('.')[0] + ".txt")
    try:
        with open(label_file, "r", encoding="utf-8-sig") as f:
            text = f.read().strip()
            text = clean_arabic(text)
        test_dataset.append(
            {"audio_data": audio_data, "sample_rate": sample_rate, "sentence": text}
        )
    except:
        print(f"Error openning {label_file}")
        errors += 1

Error openning /mnt/nfs/stt_project/dataset/test-txt/00a6d967-6af0-4018-bb31-6633be63dc1d.txt


In [8]:
# resample data - to remove once data is already in 16k sr
for record in train_dataset + test_dataset:
    data = record['audio_data']
    origin_sr = record['sample_rate']
    expected_sr = 16000
    if origin_sr == expected_sr:
        data_resampled = data
    else:
        data_resampled = resample(data, int(len(data) * expected_sr / origin_sr), axis=0)
    record['sample_rate'] = expected_sr
    record['audio_data'] = data_resampled

In [9]:
train_dataset[0]

{'audio_data': array([ 0.18359375, -0.01122654, -0.4453125 , ...,  0.11501621,
         0.11523437,  0.14891549]),
 'sample_rate': 16000,
 'sentence': 'من المكان ارسالك م'}

In [10]:
test_dataset[0]

{'audio_data': array([ 0.00186157,  0.06555176,  0.00747681, ...,  0.29241943,
        -0.09805298, -0.30621338]),
 'sample_rate': 16000,
 'sentence': 'جيد استلمت الو نسر اثنين ناصر ثلاث وعشرون'}

In [11]:
def extract_all_chars(record):
  all_text = " ".join(record["sentence"])
  vocab = list(set(all_text))
  return {"vocab": vocab, "all_text": [all_text]}

In [12]:
vocab_parts = []
for record in train_dataset + test_dataset:
    vocab_parts.append(extract_all_chars(record))

In [13]:
vocab_parts[0]

{'vocab': ['ن', ' ', 'ا', 'ر', 'ل', 'س', 'ك', 'م'],
 'all_text': ['م ن   ا ل م ك ا ن   ا ر س ا ل ك   م']}

In [14]:
vocab_list = []
for v in vocab_parts:
    vocab_list.extend(v['vocab'])

vocab_set = set(vocab_list)

In [15]:
vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))}
vocab_dict

{' ': 0,
 'A': 1,
 'B': 2,
 'D': 3,
 'E': 4,
 'F': 5,
 'G': 6,
 'J': 7,
 'K': 8,
 'M': 9,
 'N': 10,
 'O': 11,
 'P': 12,
 'R': 13,
 'S': 14,
 'T': 15,
 'X': 16,
 'Y': 17,
 'Z': 18,
 'a': 19,
 'd': 20,
 'e': 21,
 'g': 22,
 'h': 23,
 'i': 24,
 'k': 25,
 'l': 26,
 'm': 27,
 'n': 28,
 'o': 29,
 'p': 30,
 'r': 31,
 's': 32,
 't': 33,
 'u': 34,
 'v': 35,
 'w': 36,
 'y': 37,
 'א': 38,
 'ל': 39,
 'ם': 40,
 'ס': 41,
 'ء': 42,
 'ا': 43,
 'ب': 44,
 'ت': 45,
 'ث': 46,
 'ج': 47,
 'ح': 48,
 'خ': 49,
 'د': 50,
 'ذ': 51,
 'ر': 52,
 'ز': 53,
 'س': 54,
 'ش': 55,
 'ص': 56,
 'ض': 57,
 'ط': 58,
 'ظ': 59,
 'ع': 60,
 'غ': 61,
 'ف': 62,
 'ق': 63,
 'ك': 64,
 'ل': 65,
 'م': 66,
 'ن': 67,
 'ه': 68,
 'و': 69,
 'ي': 70,
 '٧': 71,
 '\u200f': 72}

In [16]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

In [17]:
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)

75

In [18]:
target_lang="ara"

In [19]:
new_vocab_dict = {target_lang: vocab_dict}

In [20]:
with open('vocab.json', 'w') as vocab_file:
    json.dump(new_vocab_dict, vocab_file)

In [21]:
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|", target_lang=target_lang)

In [22]:
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

In [23]:
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [24]:
def prepare_record(record):
    datum = {}
    # compute log-Mel input features from input audio array 
    datum["input_values"] = feature_extractor(record["audio_data"], sampling_rate=record["sample_rate"]).input_values[0]

    # encode target text to label ids 
    datum["labels"] = tokenizer(record["sentence"]).input_ids
    return datum

In [25]:
train_prepared_records = []
for record in tqdm(train_dataset):
    train_prepared_records.append(prepare_record(record))

100%|██████████| 28181/28181 [00:20<00:00, 1407.95it/s]


In [26]:
test_prepared_records = []
for record in tqdm(test_dataset):
    test_prepared_records.append(prepare_record(record))

100%|██████████| 1348/1348 [00:00<00:00, 1351.84it/s]


In [27]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        model_input_name = self.processor.model_input_names[0]
        input_features = [{"input_values": feature[model_input_name]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [28]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [29]:
wer_metric = evaluate.load("wer")

In [30]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [33]:
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/mms-1b-all",
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    layerdrop=0.0,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    ignore_mismatched_sizes=True,
)
model.config.ctc_zero_infinity = True

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/mms-1b-all and are newly initialized because the shapes did not match:
- lm_head.bias: found shape torch.Size([154]) in the checkpoint and torch.Size([77]) in the model instantiated
- lm_head.weight: found shape torch.Size([154, 1280]) in the checkpoint and torch.Size([77, 1280]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [32]:
model.freeze_base_model()
model.init_adapter_layers()
adapter_weights = model._get_adapters()
for param in adapter_weights.values():
    param.requires_grad = True

In [33]:
training_args = TrainingArguments(
    output_dir="./models/mms-1b-all-adapters",  # change to a repo name of your choice
    group_by_length=True,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-3,
    warmup_steps=100,
    fp16=True,
    gradient_checkpointing=True,
    evaluation_strategy="steps",
    num_train_epochs=4,
    save_steps=293,
    eval_steps=293,
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    save_total_limit=9
)

In [34]:
trainer = Trainer(
    args=training_args,
    model=model,
    train_dataset=train_prepared_records,
    eval_dataset=test_prepared_records,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

In [35]:
torch.cuda.device_count()

1

In [36]:
trainer.train()

Step,Training Loss,Validation Loss,Wer
293,2.0239,1.356219,0.796817
586,1.7137,1.204746,0.741351
879,1.5866,1.132704,0.708487
1172,1.5057,1.075981,0.695111
1465,1.4497,1.036485,0.687154
1758,1.4228,1.009812,0.677006
2051,1.3844,0.982294,0.664437
2344,1.3484,0.964068,0.665129
2637,1.3114,0.961575,0.659363
2930,1.2812,0.942114,0.64322


### Continue training

In [31]:
model = Wav2Vec2ForCTC.from_pretrained(
    "models/mms-1b-all-adapters-cont-2/checkpoint-5274",
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    layerdrop=0.0,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    ignore_mismatched_sizes=True,
)
model.config.ctc_zero_infinity = True

In [32]:
model.freeze_base_model()
model.init_adapter_layers()
adapter_weights = model._get_adapters()
for param in adapter_weights.values():
    param.requires_grad = True

In [35]:
training_args = TrainingArguments(
    output_dir="./models/mms-1b-all-adapters-cont-3",  # change to a repo name of your choice
    group_by_length=True,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-3,
    warmup_steps=0,
    fp16=True,
    gradient_checkpointing=True,
    evaluation_strategy="steps",
    num_train_epochs=9,
    save_steps=293,
    eval_steps=293,
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    save_total_limit=27
)

In [36]:
trainer = Trainer(
    args=training_args,
    model=model,
    train_dataset=train_prepared_records,
    eval_dataset=test_prepared_records,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

In [37]:
trainer.train()

Step,Training Loss,Validation Loss,Wer
293,1.1841,0.880709,0.582911
586,1.1812,0.850915,0.575185
879,1.1721,0.851799,0.581873
1172,1.1356,0.841264,0.563307
1465,1.1261,0.834595,0.576684
1758,1.112,0.835884,0.565959
2051,1.0932,0.822016,0.568842
2344,1.0994,0.819848,0.570457
2637,1.0589,0.809865,0.549124
2930,1.0689,0.792242,0.546817


TrainOutput(global_step=7929, training_loss=1.098443360508488, metrics={'train_runtime': 13471.0841, 'train_samples_per_second': 18.828, 'train_steps_per_second': 0.589, 'total_flos': 7.691592727893066e+19, 'train_loss': 1.098443360508488, 'epoch': 9.0})