In [69]:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import os
from pathlib import Path
import pandas as pd
from datasets import load_dataset, Dataset, Audio,concatenate_datasets
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
import transformers
import evaluate
import torchaudio
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift
import numpy as np
import pandas as pd

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [18]:
from huggingface_hub import notebook_login

notebook_login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Importing data

In [2]:
# Code generated by Gemini to get the paths

def get_audio_file_paths(base_path_str: str) -> dict:
    
    base_path = Path(base_path_str)
    processed_dir = base_path / "processed data"
    
    audio_paths = {}
    audio_extensions = {'.wav', '.mp3', '.flac', '.m4a', '.ogg', '.opus'}

    if not processed_dir.is_dir():
        return audio_paths

    for lang_dir in processed_dir.iterdir():
        if not lang_dir.is_dir():
            continue
        
        lang_name = lang_dir.name
        audio_paths[lang_name] = {}
        
        for sub_dir in lang_dir.iterdir():
            if not sub_dir.is_dir():
                continue
            
            sub_name = sub_dir.name
            
            files = [
                str(f.resolve()) for f in sub_dir.glob('*') 
                if f.is_file() and f.suffix.lower() in audio_extensions
            ]
            audio_paths[lang_name][sub_name] = files
            
    return audio_paths

In [3]:

text_a_english = "zero five twelve ninety-nine one hundred and five 2 plus 7 18 minus 4 6 times 3 20 divided by 5 ten plus thirty minus eight negative fifteen plus nine three to the power of two square root of sixteen clear equals repeat"
text_a_arabic = "احسب خمسة زائد اثنين عشرة ناقص ثلاثة ستة ضرب أربعة عشرون قسمة خمسة سالب سبعة زائد واحد خمسة أس اثنين الجذر التربيعي لأربعة وعشرين امسح تأكيد أعِد calculate 37 plus خمسة اطرح twelve من عشرة اضرب ثلاثة في twenty eighty divided by ثمانية اجمع ١٢ و ١٣ سبعة زائد ١٩ 45 minus تسعة 3.5 plus اثنين ونصف واحد فاصلة خمسة ضرب أربعة مية واثنا عشر ناقص ستة 1000 minus 250 999 plus 1 قل اللون: أزرق"

text_b_english = "one eight seventeen sixty-four one hundred and twenty 4 plus 9 22 minus 7 9 times 5 81 divided by 9 thirty plus fifty negative six minus ten plus three two to the power of five cube root of twenty-seven start stop undo"
text_b_arabic = "اجمع سبعة و تلاتين مع 12 خمسة وأربعون ناقص عشرين تسعة ضرب ستة أربعة وستون قسمة ثمانية سالب ثلاثة زائد خمسة اثنان أس ثلاثة الجذر التكعيبي لسبعة وعشرين امسح الشاشة تم كرر آخر عملية calculate twelve times خمسة اقسم 36 على ستة اطرح خمسة من twenty fifty plus سبعة اجمع ١٠٠ و ٢٥ مئتان ناقص ٩٩ 14 minus أربعة اثنين فاصلة خمسة زائد 0.5 7.25 divided by خمسة أربع مية وخمسة ناقص عشرة 500 plus 500 1234 minus 234 قل اللون: أخضر"

text_c_english = "two nine eleven seventy-three two hundred and three 8 plus 6 40 minus 12 7 times 7 90 divided by 10 twenty plus fifteen negative nine minus twenty plus eight five to the power of three square root of one hundred confirm repeat last slower please"
text_c_arabic = "احسب 23 زائد 15 سبعة ناقص اثنين ثلاثة ضرب تسعة ستة وثلاثون قسمة أربعة سالب اثنا عشر زائد عشرة عشرة أس اثنين الجذر التربيعي لتسعة افتح رجوع أعد الحساب calculate twenty minus ثلاثة اجمع five و خمسة اضرب 8 في twenty-one thirty divided by ثلاثة اجمع ٧ و ١١ أربعون ناقص ١٨ 16 plus سبعة واحد فاصلة خمسة ناقص 0.25 2.2 times اثنين تسعمية وتسعة وتسعين زائد واحد 1500 minus 300 333 plus 667 قل اللون: أحمر"

text_a_digits = "احسب 5 زائد 2 10 ناقص 3 6 ضرب 4 20 قسمة 5 سالب 7 زائد 1 5 أس 2 الجذر التربيعي لـ 24 امسح تأكيد أعِد calculate 37 plus 5 اطرح 12 من 10 اضرب 3 في 20 eighty divided by 8 اجمع 12 و 13 7 زائد 19 45 minus 9 3.5 plus 2.5 1.5 ضرب 4 112 ناقص 6 1000 minus 250 999 plus 1 قل اللون: أزرق"
text_b_digits = "اجمع 37 و 12 45 ناقص 20 9 ضرب 6 64 قسمة 8 سالب 3 زائد 5 2 أس 3 الجذر التكعيبي لـ 27 امسح الشاشة تم كرر آخر عملية calculate 12 times 5 اقسم 36 على 6 اطرح 5 من 20 50 plus 7 اجمع 100 و 25 200 ناقص 99 14 minus 4 2.5 زائد 0.5 7.25 divided by 5 405 ناقص 10 500 plus 500 1234 minus 234 قل اللون: أخضر"
text_c_digits = "احسب 23 زائد 15 7 ناقص 2 3 ضرب 9 36 قسمة 4 سالب 12 زائد 10 10 أس 2 الجذر التربيعي لـ 9 افتح رجوع أعد الحساب calculate 20 minus 3 اجمع 5 و 5 اضرب 8 في 21 30 divided by 3 اجمع 7 و 11 40 ناقص 18 16 plus 7 1.5 ناقص 0.25 2.2 times 2 999 زائد 1 1500 minus 300 333 plus 667 قل اللون: أحمر"


In [9]:
data_path = r"C:\Users\lucar-work\Documents\GitHub\whisper-math\data"
all_files = get_audio_file_paths(data_path)
arabic_files = all_files['arabic']
english_files = all_files['english']

In [10]:
transcriptions_arabic = {
'A': text_a_arabic,
'B': text_b_arabic,
'C': text_c_arabic
}

transcriptions_english = {
'A': text_a_english,
'B': text_b_english,
'C': text_c_english
}

transcriptions_arabic_digits = {
'A': text_a_digits,
'B': text_b_digits,
'C': text_c_digits
}

In [None]:
def dataset_g(transcriptions, digits_transcription, files, language: str):
    rows = []
    for label, file_list in files.items():
        n = len(file_list)
        for i, file_path in enumerate(file_list):
            if language.lower() == "arabic" and i >= n // 2:
                text = digits_transcription.get(label, "")
            else:
                text = transcriptions.get(label, "")
            rows.append({
                'audio': file_path,
                'transcription': text,
                'Language': language
            })
    df = pd.DataFrame(rows)
    return df

df_arabic = dataset_g(transcriptions_arabic, transcriptions_arabic_digits, arabic_files, language='arabic')
df_english = dataset_g(transcriptions_english, None, english_files, language='english')


In [56]:
df_arabic

Unnamed: 0,audio,transcription,Language
0,C:\Users\lucar-work\Documents\GitHub\whisper-m...,احسب خمسة زائد اثنين عشرة ناقص ثلاثة ستة ضرب أ...,arabic
1,C:\Users\lucar-work\Documents\GitHub\whisper-m...,احسب خمسة زائد اثنين عشرة ناقص ثلاثة ستة ضرب أ...,arabic
2,C:\Users\lucar-work\Documents\GitHub\whisper-m...,احسب خمسة زائد اثنين عشرة ناقص ثلاثة ستة ضرب أ...,arabic
3,C:\Users\lucar-work\Documents\GitHub\whisper-m...,احسب خمسة زائد اثنين عشرة ناقص ثلاثة ستة ضرب أ...,arabic
4,C:\Users\lucar-work\Documents\GitHub\whisper-m...,احسب خمسة زائد اثنين عشرة ناقص ثلاثة ستة ضرب أ...,arabic
5,C:\Users\lucar-work\Documents\GitHub\whisper-m...,احسب خمسة زائد اثنين عشرة ناقص ثلاثة ستة ضرب أ...,arabic
6,C:\Users\lucar-work\Documents\GitHub\whisper-m...,احسب خمسة زائد اثنين عشرة ناقص ثلاثة ستة ضرب أ...,arabic
7,C:\Users\lucar-work\Documents\GitHub\whisper-m...,احسب خمسة زائد اثنين عشرة ناقص ثلاثة ستة ضرب أ...,arabic
8,C:\Users\lucar-work\Documents\GitHub\whisper-m...,احسب 5 زائد 2 10 ناقص 3 6 ضرب 4 20 قسمة 5 سالب...,arabic
9,C:\Users\lucar-work\Documents\GitHub\whisper-m...,احسب 5 زائد 2 10 ناقص 3 6 ضرب 4 20 قسمة 5 سالب...,arabic


In [57]:
def generate_audio_dataset(df_arabic, df_english, augment_factor=1):
  

    # Define as augmentations
    augment = Compose([
        AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.005, p=0.3),
        TimeStretch(min_rate=0.95, max_rate=1.05, p=0.3),
        PitchShift(min_semitones=-1, max_semitones=1, p=0.3),
    ])


    # Cria o dataset base (lazy load)
    dataset_arabic = Dataset.from_pandas(df_arabic).cast_column("audio", Audio(sampling_rate=16000))
    dataset_english = Dataset.from_pandas(df_english).cast_column("audio", Audio(sampling_rate=16000))


    # Função que aplica augment em cada item
    def augment_audio(batch):
        audio_array = batch["audio"]["array"]
        if isinstance(audio_array, np.ndarray):
            augmented = augment(samples=audio_array, sample_rate=16000)
            batch["audio"] = {"array": augmented, "sampling_rate": 16000}
        return batch

    # Lista com todas as versões
    datasets_all = [dataset_arabic, dataset_english]

    
    # Cria as versões aumentadas
    for _ in range(augment_factor):
        ds_aug = dataset_arabic.map(augment_audio)
        datasets_all.append(ds_aug)

    # Concatena tudo corretamente
    full_dataset = concatenate_datasets(datasets_all)

    return full_dataset


## Generating dataset

In [58]:
df_final = generate_audio_dataset(df_arabic, df_english, augment_factor=2)
df_final

Map:   0%|          | 0/40 [00:00<?, ? examples/s]

Map:   0%|          | 0/40 [00:00<?, ? examples/s]

Dataset({
    features: ['audio', 'transcription', 'Language'],
    num_rows: 160
})

In [13]:
df_final

Dataset({
    features: ['audio', 'transcription', 'Language'],
    num_rows: 160
})

## Fine-Tuning

### Processing dataset

In [59]:
df_final = df_final.train_test_split(test_size=0.25)

In [60]:
model_name = "openai/whisper-medium"

processor = WhisperProcessor.from_pretrained(model_name, task="transcribe")

model = WhisperForConditionalGeneration.from_pretrained(model_name)

#model.freeze_encoder() # because our dataset is small

# Desativar idioma fixo (importantíssimo)
model.config.forced_decoder_ids = None

preprocessor_config.json: 0.00B [00:00, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

normalizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/3.06G [00:00<?, ?B/s]

generation_config.json: 0.00B [00:00, ?B/s]

In [61]:
def preprocess_function(batch):
    audio = batch["audio"]

    batch["input_features"] = processor.feature_extractor(
        audio["array"], 
        sampling_rate=16000
        ).input_features[0]

    batch["labels"] = processor.tokenizer(batch["transcription"]).input_ids
    return batch

In [62]:
dataset = df_final.map(preprocess_function, remove_columns=df_final["train"].column_names)

Map:   0%|          | 0/120 [00:00<?, ? examples/s]

Map:   0%|          | 0/40 [00:00<?, ? examples/s]

In [63]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


In [64]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)


### Defining evaluation metrics

In [65]:
metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}


Downloading builder script: 0.00B [00:00, ?B/s]

### Training

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-medium-finetuned",

    per_device_train_batch_size=2,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=1,
    learning_rate=2e-5, 
    num_train_epochs=10, # best number of epochs!
    warmup_steps=30,

    gradient_checkpointing=False,
    #bf16=True, -> turn this line on if you have a GPU that supports it
    fp16=True

    eval_strategy="epoch",
    save_strategy="best",
    logging_strategy="steps",
    logging_steps=10,

    predict_with_generate=True,           
    generation_max_length=150,         

    dataloader_num_workers=0,             
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False
)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
    callbacks=[transformers.EarlyStoppingCallback(2, 0.0)]
)

trainer.train()


  trainer = Seq2SeqTrainer(
You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


## Testing


In [46]:
audio_arabic = torchaudio.load(r"C:\Users\lucar-work\Documents\GitHub\whisper-math\data\processed data\arabic\C\arabic 6.wav")

In [47]:
resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=16000)
audio_arabic = resampler(audio_arabic[0])

In [48]:
audio_arabic = audio_arabic.numpy()

In [49]:
model.to(device='cpu')

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 1024, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(1024, 1024, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 1024)
      (layers): ModuleList(
        (0-23): 24 x WhisperEncoderLayer(
          (self_attn): WhisperSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias

In [50]:
# load model and processor
forced_decoder_ids = processor.get_decoder_prompt_ids(language="ar", task="transcribe")

# load streaming dataset and read first audio sample
input_features = processor(audio_arabic[0], sampling_rate=16000, return_tensors="pt").input_features

# generate token ids
predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids)

transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

In [51]:
transcription

['احسب 23 زائد 15 سبعة ناقص اثنين ثلاثة ضرب تسعة ستة وثلاثون قسمة أربعة سالب اثنا عشر زائد عشرة عشرة أس اثنين الجذر التربيعي لتسعة افتح رجوع أعد الحساب calculate twenty minus ثلاثة اجمع five و خمسة اضرب 8 في twenty-one thirty divided by ثلاثة اجمع ٧ و ١١ أربعون ناقص ١٨ 16 plus سبعة واحد فاصلة خمسة ناقص 0.25 2.2 times اثنين تسعمية وتسعة وتسعين زائد واحد 1500 minus 300 333 plus 667 قل اللون أحمر']

In [None]:
text_c_arabic = "احسب 23 زائد 15 سبعة ناقص اثنين ثلاثة ضرب تسعة ستة وثلاثون قسمة أربعة سالب اثنا عشر زائد عشرة عشرة أس اثنين الجذر التربيعي لتسعة افتح رجوع أعد الحساب calculate twenty minus ثلاثة اجمع five و خمسة اضرب 8 في twenty-one thirty divided by ثلاثة اجمع ٧ و ١١ أربعون ناقص ١٨ 16 plus سبعة واحد فاصلة خمسة ناقص 0.25 2.2 times اثنين تسعمية وتسعة وتسعين زائد واحد 1500 minus 300 333 plus 667 قل اللون: أحمر"

In [26]:
trainer.push_to_hub()


Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...-medium-finetuned/model.safetensors:   0%|          | 99.6kB / 3.06GB            

  ...-medium-finetuned/training_args.bin:  12%|#2        |   722B / 5.91kB            

CommitInfo(commit_url='https://huggingface.co/manushya-ai/whisper-medium-finetuned/commit/cc031347d094f55501a6ff094ee522777291f6bf', commit_message='End of training', commit_description='', oid='cc031347d094f55501a6ff094ee522777291f6bf', pr_url=None, repo_url=RepoUrl('https://huggingface.co/manushya-ai/whisper-medium-finetuned', endpoint='https://huggingface.co', repo_type='model', repo_id='manushya-ai/whisper-medium-finetuned'), pr_revision=None, pr_num=None)

In [28]:
processor.push_to_hub(repo_id="manushya-ai/whisper-medium-finetuned")

README.md: 0.00B [00:00, ?B/s]

CommitInfo(commit_url='https://huggingface.co/manushya-ai/whisper-medium-finetuned/commit/fc7b0397be7e88803eed0565b4a39ee6f890926d', commit_message='Upload processor', commit_description='', oid='fc7b0397be7e88803eed0565b4a39ee6f890926d', pr_url=None, repo_url=RepoUrl('https://huggingface.co/manushya-ai/whisper-medium-finetuned', endpoint='https://huggingface.co', repo_type='model', repo_id='manushya-ai/whisper-medium-finetuned'), pr_revision=None, pr_num=None)