In [1]:
import os
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
import librosa
from scipy.signal import resample
from tqdm import tqdm
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
from transformers import WhisperForConditionalGeneration
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

In [2]:
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Arabic", task="transcribe")
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Arabic", task="transcribe")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
inputs_path = "../dataset/test"
labels_path = "../dataset/test-txt"
dataset = []
errors = 0
for audio_file in os.listdir(inputs_path):
    audio_data, sample_rate = librosa.load(os.path.join(inputs_path, audio_file))
    label_file = os.path.join(labels_path, audio_file.split('.')[0] + ".txt")
    try:
        with open(label_file, "r", encoding="utf-8-sig") as f:
            text = f.read().strip()

        dataset.append(
            {"audio_data": audio_data, "sample_rate": sample_rate, "sentence": text}
        )
    except:
        print(f"Error openning {label_file}")
        errors += 1

Error openning ../dataset/test-txt/00a6d967-6af0-4018-bb31-6633be63dc1d.txt


In [4]:
errors

1

In [5]:
# resample data - to remove once data is already in 16k sr
for record in dataset:
    data = record['audio_data']
    origin_sr = record['sample_rate']
    expected_sr = 16000
    data_resampled = resample(data, int(len(data) * expected_sr / origin_sr), axis=0)
    record['sample_rate'] = expected_sr
    record['audio_data'] = data_resampled

In [6]:
dataset[0]

{'audio_data': array([-0.03481539,  0.02195274,  0.04504649, ...,  0.17630008,
         0.14643314,  0.03439737], dtype=float32),
 'sample_rate': 16000,
 'sentence': 'جيد استلمت لك افضل التحيات'}

In [7]:
def prepare_record(record):
    # compute log-Mel input features from input audio array 
    record["input_features"] = feature_extractor(record["audio_data"], sampling_rate=record["sample_rate"]).input_features[0]

    # encode target text to label ids 
    record["labels"] = tokenizer(record["sentence"]).input_ids
    return record

In [8]:
prepared_records = []
for record in tqdm(dataset):
    prepared_records.append(prepare_record(record))

100%|██████████| 1348/1348 [01:11<00:00, 18.75it/s]


In [9]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [10]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [11]:
metric = evaluate.load("wer")

In [12]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [13]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

In [14]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

In [15]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small-ar",  # change to a repo name of your choice
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=100,
    max_steps=4000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=1,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
)

In [16]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=prepared_records[:1000],
    eval_dataset=prepared_records[1000:],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

In [None]:
trainer.train()

`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
