<a href="https://colab.research.google.com/github/BrunoSader/WhisperTraining/blob/main/fine_tune_whisper.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers

In this Colab, we present a step-by-step guide on how to fine-tune Whisper
for any multilingual ASR dataset using Hugging Face 🤗 Transformers. This is a
more "hands-on" version of the accompanying [blog post](https://huggingface.co/blog/fine-tune-whisper).
For a more in-depth explanation of Whisper, the Common Voice dataset and the theory behind fine-tuning, the reader is advised to refer to the blog post.

## Introduction

Whisper is a pre-trained model for automatic speech recognition (ASR)
published in [September 2022](https://openai.com/blog/whisper/) by the authors
Alec Radford et al. from OpenAI. Unlike many of its predecessors, such as
[Wav2Vec 2.0](https://arxiv.org/abs/2006.11477), which are pre-trained
on un-labelled audio data, Whisper is pre-trained on a vast quantity of
**labelled** audio-transcription data, 680,000 hours to be precise.
This is an order of magnitude more data than the un-labelled audio data used
to train Wav2Vec 2.0 (60,000 hours). What is more, 117,000 hours of this
pre-training data is multilingual ASR data. This results in checkpoints
that can be applied to over 96 languages, many of which are considered
_low-resource_.

When scaled to 680,000 hours of labelled pre-training data, Whisper models
demonstrate a strong ability to generalise to many datasets and domains.
The pre-trained checkpoints achieve competitive results to state-of-the-art
ASR systems, with near 3% word error rate (WER) on the test-clean subset of
LibriSpeech ASR and a new state-of-the-art on TED-LIUM with 4.7% WER (_c.f._
Table 8 of the [Whisper paper](https://cdn.openai.com/papers/whisper.pdf)).
The extensive multilingual ASR knowledge acquired by Whisper during pre-training
can be leveraged for other low-resource languages; through fine-tuning, the
pre-trained checkpoints can be adapted for specific datasets and languages
to further improve upon these results. We'll show just how Whisper can be fine-tuned
for low-resource languages in this Colab.

<figure>
<img src="https://raw.githubusercontent.com/sanchit-gandhi/notebooks/main/whisper_architecture.svg" alt="Trulli" style="width:100%">
<figcaption align = "center"><b>Figure 1:</b> Whisper model. The architecture
follows the standard Transformer-based encoder-decoder model. A
log-Mel spectrogram is input to the encoder. The last encoder
hidden states are input to the decoder via cross-attention mechanisms. The
decoder autoregressively predicts text tokens, jointly conditional on the
encoder hidden states and previously predicted tokens. Figure source:
<a href="https://openai.com/blog/whisper/">OpenAI Whisper Blog</a>.</figcaption>
</figure>

The Whisper checkpoints come in five configurations of varying model sizes.
The smallest four are trained on either English-only or multilingual data.
The largest checkpoints are multilingual only. All 11 of the pre-trained checkpoints
are available on the [Hugging Face Hub](https://huggingface.co/models?search=openai/whisper). The
checkpoints are summarised in the following table with links to the models on the Hub:

| Size     | Layers | Width | Heads | Parameters | English-only                                         | Multilingual                                        |
|----------|--------|-------|-------|------------|------------------------------------------------------|-----------------------------------------------------|
| tiny     | 4      | 384   | 6     | 39 M       | [✓](https://huggingface.co/openai/whisper-tiny.en)   | [✓](https://huggingface.co/openai/whisper-tiny.)    |
| base     | 6      | 512   | 8     | 74 M       | [✓](https://huggingface.co/openai/whisper-base.en)   | [✓](https://huggingface.co/openai/whisper-base)     |
| small    | 12     | 768   | 12    | 244 M      | [✓](https://huggingface.co/openai/whisper-small.en)  | [✓](https://huggingface.co/openai/whisper-small)    |
| medium   | 24     | 1024  | 16    | 769 M      | [✓](https://huggingface.co/openai/whisper-medium.en) | [✓](https://huggingface.co/openai/whisper-medium)   |
| large    | 32     | 1280  | 20    | 1550 M     | x                                                    | [✓](https://huggingface.co/openai/whisper-large)    |
| large-v2 | 32     | 1280  | 20    | 1550 M     | x                                                    | [✓](https://huggingface.co/openai/whisper-large-v2) |
| large-v3 | 32     | 1280  | 20    | 1550 M     | x                                                    | [✓](https://huggingface.co/openai/whisper-large-v3) |


For demonstration purposes, we'll fine-tune the multilingual version of the
[`"small"`](https://huggingface.co/openai/whisper-small) checkpoint with 244M params (~= 1GB).
As for our data, we'll train and evaluate our system on a low-resource language
taken from the [Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0)
dataset. We'll show that with as little as 8 hours of fine-tuning data, we can achieve
strong performance in this language.

------------------------------------------------------------------------

\\({}^1\\) The name Whisper follows from the acronym “WSPSR”, which stands for “Web-scale Supervised Pre-training for Speech Recognition”.

## Prepare Environment

First of all, let's try to secure a decent GPU for our Colab! Unfortunately, it's becoming much harder to get access to a good GPU with the free version of Google Colab. However, with Google Colab Pro one should have no issues in being allocated a V100 or P100 GPU.

To get a GPU, click _Runtime_ -> _Change runtime type_, then change _Hardware accelerator_ from _CPU_ to one of the available GPUs, e.g. _T4_ (or better if you have one available). Next, click `Connect T4` in the top right-hand corner of your screen (or `Connect {V100, A100}` if you selected a different GPU).

We can verify that we've been assigned a GPU and view its specifications:

In [2]:
# Colab keep-alive (prevents timeout)
import IPython
from google.colab import drive

# JavaScript to click in the background
js_code = '''
function ClickConnect(){
    console.log("Keeping Colab alive...");
    document.querySelector("colab-toolbar-button#connect").click()
}
setInterval(ClickConnect, 60000)
'''

display(IPython.display.Javascript(js_code))
print("✅ Keep-alive enabled - Colab won't timeout")

# Also mount drive with better error handling
try:
    drive.mount('/content/drive', force_remount=True)
    print("✅ Drive mounted successfully")
except:
    print("⚠️ Drive mounting failed - checkpoints won't be saved!")

<IPython.core.display.Javascript object>

✅ Keep-alive enabled - Colab won't timeout
Mounted at /content/drive
✅ Drive mounted successfully


In [3]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Mon Jun  2 13:10:18 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   34C    P0             44W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

We'll employ several popular Python packages to fine-tune the Whisper model.
We'll use `datasets[audio]` to download and prepare our training data, alongside
`transformers` and `accelerate` to load and train our Whisper model.
We'll also require the `soundfile` package to pre-process audio files,
`evaluate` and `jiwer` to assess the performance of our model, and
`tensorboard` to log our metrics. Finally, we'll use `gradio` to build a
flashy demo of our fine-tuned model.

In [4]:
!pip install --upgrade --quiet pip
!pip install --upgrade --quiet datasets[audio] transformers accelerate evaluate jiwer tensorboard==2.18 gradio
!pip install numba==0.60.0
!pip install pyarabic

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m20.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.5/10.5 MB[0m [31m128.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.2/54.2 MB[0m [31m140.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m129.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.5/11.5 MB[0m [31m168.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m55.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m130.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m123.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

We strongly advise you to upload model checkpoints directly the [Hugging Face Hub](https://huggingface.co/)
whilst training. The Hub provides:
- Integrated version control: you can be sure that no model checkpoint is lost during training.
- Tensorboard logs: track important metrics over the course of training.
- Model cards: document what a model does and its intended use cases.
- Community: an easy way to share and collaborate with the community!

Linking the notebook to the Hub is straightforward - it simply requires entering your
Hub authentication token when prompted. Find your Hub authentication token [here](https://huggingface.co/settings/tokens):

In [5]:
from huggingface_hub import notebook_login
#hf _nUozDOLgoglAeCkKiGQuJlwPzEUMRwBRgv
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Prepare Feature Extractor, Tokenizer and Data

The ASR pipeline can be de-composed into three stages:

1. A feature extractor which pre-processes the raw audio-inputs
2. The model which performs the sequence-to-sequence mapping
3. A tokenizer which post-processes the model outputs to text format

In 🤗 Transformers, the Whisper model has an associated feature extractor and tokenizer,
called [WhisperFeatureExtractor](https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperFeatureExtractor)
and [WhisperTokenizer](https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperTokenizer)
respectively.

We'll go through details for setting-up the feature extractor and tokenizer one-by-one!

### Load WhisperFeatureExtractor

The Whisper feature extractor performs two operations:
1. Pads / truncates the audio inputs to 30s: any audio inputs shorter than 30s are padded to 30s with silence (zeros), and those longer that 30s are truncated to 30s
2. Converts the audio inputs to _log-Mel spectrogram_ input features, a visual representation of the audio and the form of the input expected by the Whisper model

<figure>
<img src="https://raw.githubusercontent.com/sanchit-gandhi/notebooks/main/spectrogram.jpg" alt="Trulli" style="width:100%">
<figcaption align = "center"><b>Figure 2:</b> Conversion of sampled audio array to log-Mel spectrogram.
Left: sampled 1-dimensional audio signal. Right: corresponding log-Mel spectrogram. Figure source:
<a href="https://ai.googleblog.com/2019/04/specaugment-new-data-augmentation.html">Google SpecAugment Blog</a>.
</figcaption>

We'll load the feature extractor from the pre-trained checkpoint with the default values:

In [6]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-large-v3-turbo")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/340 [00:00<?, ?B/s]

### Load WhisperTokenizer

The Whisper model outputs a sequence of _token ids_. The tokenizer maps each of these token ids to their corresponding text string. For Hindi, we can load the pre-trained tokenizer and use it for fine-tuning without any further modifications. We simply have to
specify the target language and the task. These arguments inform the
tokenizer to prefix the language and task tokens to the start of encoded
label sequences:

In [7]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v3-turbo", language="ar", task="transcribe")

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.71M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

### Combine To Create A WhisperProcessor

To simplify using the feature extractor and tokenizer, we can _wrap_
both into a single `WhisperProcessor` class. This processor object
inherits from the `WhisperFeatureExtractor` and `WhisperProcessor`,
and can be used on the audio inputs and model predictions as required.
In doing so, we only need to keep track of two objects during training:
the `processor` and the `model`:

In [8]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3-turbo", language="ar", task="transcribe")

## Full Data Pipeline

In [9]:
from datasets import load_dataset
import requests
from pathlib import Path
from tqdm.auto import tqdm
import os

# Parameters
TARGET_SADA_HOURS = 30  # How many hours of SADA to download
LOCAL_SADA_DIR = "./sada_audio_files/"

print(f"🎯 Goal: Download {TARGET_SADA_HOURS} hours of SADA audio files locally")

# Step 1: Load SADA metadata (fast - no audio)
print("Loading SADA metadata...")
sada_meta = load_dataset("khaledalganem/sada2022", trust_remote_code=True, split="train")

# Step 2: Calculate which files we need
print("Calculating required files...")
total_seconds_needed = TARGET_SADA_HOURS * 3600
current_seconds = 0
files_needed = {}  # filename -> list of segments

for example in tqdm(sada_meta, desc="Scanning segments"):
    if current_seconds >= total_seconds_needed:
        break

    filename = example["FileName"]
    segment_info = {
        'start': example["SegmentStart"],
        'end': example["SegmentEnd"],
        'length': example["SegmentLength"],
        'text': example.get("GroundTruthText", ""),
        'index': example['index'] if 'index' in example else len(files_needed)
    }

    if filename not in files_needed:
        files_needed[filename] = []
    files_needed[filename].append(segment_info)

    current_seconds += example.get("SegmentLength", 0)

print(f"📊 Need {len(files_needed)} unique audio files containing {current_seconds/3600:.1f} hours")

# Step 3: Download files
Path(LOCAL_SADA_DIR).mkdir(exist_ok=True)
SADA_BASE_URL = "https://huggingface.co/datasets/khaledalganem/sada2022/resolve/main/"

downloaded = 0
skipped = 0

for filename in tqdm(files_needed.keys(), desc="Downloading SADA files"):
    local_path = Path(LOCAL_SADA_DIR) / filename

    # Create subdirectories if needed
    local_path.parent.mkdir(parents=True, exist_ok=True)

    # Skip if already exists
    if local_path.exists() and local_path.stat().st_size > 1000:  # Check size to ensure not corrupted
        skipped += 1
        continue

    try:
        # Download
        url = SADA_BASE_URL + filename
        response = requests.get(url, timeout=120, stream=True)
        response.raise_for_status()

        # Save with streaming to handle large files
        with open(local_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)

        downloaded += 1

    except Exception as e:
        print(f"\n❌ Error downloading {filename}: {e}")
        # Remove partial file
        if local_path.exists():
            local_path.unlink()

print(f"\n✅ Download complete! Downloaded: {downloaded}, Skipped: {skipped}")
print(f"📁 Files saved to: {LOCAL_SADA_DIR}")

# Save segment info for faster loading later
import json
segment_info_path = Path(LOCAL_SADA_DIR) / "segment_info.json"
with open(segment_info_path, 'w') as f:
    json.dump(files_needed, f)
print(f"💾 Saved segment info to: {segment_info_path}")

🎯 Goal: Download 30 hours of SADA audio files locally
Loading SADA metadata...


README.md:   0%|          | 0.00/6.28k [00:00<?, ?B/s]

train.csv:   0%|          | 0.00/123M [00:00<?, ?B/s]

valid.csv:   0%|          | 0.00/2.70M [00:00<?, ?B/s]

test.csv:   0%|          | 0.00/3.18M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/241834 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5139 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6193 [00:00<?, ? examples/s]

Calculating required files...


Scanning segments:   0%|          | 0/241834 [00:00<?, ?it/s]

📊 Need 374 unique audio files containing 30.0 hours


Downloading SADA files:   0%|          | 0/374 [00:00<?, ?it/s]


✅ Download complete! Downloaded: 374, Skipped: 0
📁 Files saved to: ./sada_audio_files/
💾 Saved segment info to: sada_audio_files/segment_info.json


#### Loading into RAM

In [64]:
from datasets import load_dataset, DatasetDict, Audio, Dataset, Sequence, concatenate_datasets, Value, Features
import json
from pathlib import Path
import numpy as np
from tqdm.auto import tqdm
import os

# Parameters
TARGET_SADA_HOURS = 30
TARGET_CV_HOURS = 10
LOCAL_SADA_DIR = "./sada_audio_files/"
BATCH_SIZE = 32  # Optimal for A100

print("⚡ Fast dataset loading with direct numpy array storage")

# Fix tokenizer parallelism from the start
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Load SADA segment info
segment_info_path = Path(LOCAL_SADA_DIR) / "segment_info.json"
with open(segment_info_path, 'r') as f:
    sada_files_info = json.load(f)

# Step 1: Create SADA dataset from local files
print(f"\n📂 Creating SADA dataset from {len(sada_files_info)} local files...")

sada_examples = []
total_sada_duration = 0
for filename, segments in sada_files_info.items():
    local_path = Path(LOCAL_SADA_DIR) / filename
    if local_path.exists():
        for seg in segments:
            sada_examples.append({
                'audio': str(local_path),
                'segment_start': float(seg['start']),
                'segment_end': float(seg['end']),
                'segment_length': float(seg['length']),
                'sentence': seg['text'],
                'speaker_dialect': 'SADA'
            })
            total_sada_duration += seg['length']

print(f"Found {len(sada_examples)} SADA segments ({total_sada_duration/3600:.1f} hours)")

# Create SADA dataset
sada_dataset = Dataset.from_list(sada_examples)
sada_dataset = sada_dataset.cast_column("audio", Audio(sampling_rate=16000))

# Step 2: Load Common Voice normally (fast!)
print("\n📥 Loading Common Voice datasets...")

# Calculate samples needed
cv_samples_needed = int(TARGET_CV_HOURS * 3600 / 5)  # ~5s per sample
cv_train_samples = int(cv_samples_needed * 0.95)
cv_test_samples = cv_samples_needed - cv_train_samples

# Load all CV splits
cv_train = load_dataset("mozilla-foundation/common_voice_17_0", "ar", split=f"train[:{cv_train_samples}]", trust_remote_code=True)
cv_val = load_dataset("mozilla-foundation/common_voice_17_0", "ar", split=f"validated[:{cv_train_samples}]", trust_remote_code=True)
cv_test = load_dataset("mozilla-foundation/common_voice_17_0", "ar", split=f"test[:{cv_test_samples*2}]", trust_remote_code=True)

print(f"Loaded CV: {len(cv_train)} train, {len(cv_val)} val, {len(cv_test)} test")

# Step 3: Process Common Voice efficiently
def prepare_cv_for_whisper(dataset, num_samples, split_name):
    """Prepare CV dataset with only needed columns"""
    # Shuffle and select required samples
    dataset = dataset.shuffle(seed=42)
    dataset = dataset.filter(lambda x: bool(x['sentence'].strip()))
    if len(dataset) > num_samples:
        dataset = dataset.select(range(num_samples))

    print(f"Processing {len(dataset)} {split_name} samples...")

    # Remove unnecessary columns first (faster)
    columns_to_keep = ['audio', 'sentence']
    columns_to_remove = [col for col in dataset.column_names if col not in columns_to_keep]
    dataset = dataset.remove_columns(columns_to_remove)

    # Add segment info with FLOAT types to match SADA
    def add_segment_info(example):
        example['segment_start'] = 0.0
        example['segment_end'] = -1.0
        example['segment_length'] = -1.0
        example['speaker_dialect'] = 'Common_Voice'
        return example

    dataset = dataset.map(add_segment_info, desc=f"Adding segment info to {split_name}")

    # Cast audio to 16kHz
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

    return dataset

# Process CV datasets
cv_train_processed = prepare_cv_for_whisper(cv_train, cv_train_samples // 2, "CV train")
cv_val_processed = prepare_cv_for_whisper(cv_val, cv_train_samples // 2, "CV val")
cv_test_processed = prepare_cv_for_whisper(cv_test, cv_test_samples, "CV test")

# Combine train and validation
cv_train_full = concatenate_datasets([cv_train_processed, cv_val_processed])

# Step 4: Split SADA and combine with CV
print("\n🔄 Creating final dataset splits...")

# Split SADA 90/10
sada_split = sada_dataset.train_test_split(test_size=0.1, seed=42)

# Combine datasets
train_dataset = concatenate_datasets([sada_split['train'], cv_train_full])
test_dataset = concatenate_datasets([sada_split['test'], cv_test_processed])

# Shuffle final datasets
train_dataset = train_dataset.shuffle(seed=42)
test_dataset = test_dataset.shuffle(seed=42)

print(f"\n📊 Combined dataset:")
print(f"  Train: {len(train_dataset)} samples ({len(sada_split['train'])} SADA + {len(cv_train_full)} CV)")
print(f"  Test: {len(test_dataset)} samples ({len(sada_split['test'])} SADA + {len(cv_test_processed)} CV)")

# Step 5: OPTIMIZED - Process directly to numpy arrays
def process_to_numpy_arrays(dataset, split_name):
    print(f"\n🚀 Processing {split_name} split to numpy arrays...")

    all_input_features = []
    all_labels = []

    batch_size = 16
    total_samples = len(dataset)

    for idx in tqdm(range(0, total_samples, batch_size), desc=f"Processing {split_name}"):
        end_idx = min(idx + batch_size, total_samples)
        batch = dataset[idx:end_idx]

        for i in range(len(batch['audio'])):
            try:
                audio = batch['audio'][i]
                audio_array = np.array(audio['array'], dtype=np.float32)
                sr = audio['sampling_rate']

                if batch['segment_end'][i] > 0:
                    start_sample = int(batch['segment_start'][i] * sr)
                    end_sample = int(batch['segment_end'][i] * sr)
                    audio_array = audio_array[start_sample:end_sample]

                features = feature_extractor( # Assuming feature_extractor is globally defined
                    audio_array,
                    sampling_rate=sr,
                    return_tensors="np"
                )
                all_input_features.append(features.input_features[0])

                tokens = tokenizer( # Assuming tokenizer is globally defined
                    batch['sentence'][i],
                    truncation=True,
                    max_length=448 # Whisper's max length for labels
                )
                all_labels.append(tokens.input_ids)

            except Exception as e:
                print(f"Error processing sample {idx+i}: {str(e)[:100]}") # More context
                # Append a consistent placeholder if error occurs
                # Ensure the placeholder for input_features has the correct number of mel bins (80)
                all_input_features.append(np.zeros((80, 100), dtype=np.float32)) # Small placeholder
                all_labels.append([tokenizer.pad_token_id if tokenizer else 0])


    print(f"Creating dataset from {len(all_input_features)} processed samples...")

    # --- FIX: Define features with large_list for input_features ---
    # 'input_features': Each item is a 2D numpy array (mel_bins, sequence_length).
    #                   This translates to a Sequence of Sequence of floats for Arrow.
    #                   We mark the outer Sequence (the list of all spectrograms) as "large_list".
    # 'labels': Each item is a 1D list of integers (token ids).
    #           This translates to a Sequence of integers.
    defined_features = Features({
        'input_features': Sequence(Sequence(Value("float32")), id="large_list"),
        'labels': Sequence(Value("int32"))
    })
    # -----------------------------------------------------------------

    processed_dataset = Dataset.from_dict({
        'input_features': all_input_features,
        'labels': all_labels
    }, features=defined_features) # Pass the defined features here

    processed_dataset.set_format(type="numpy", columns=["input_features", "labels"])

    return processed_dataset

# Step 6: Process both splits to numpy
print("\n⚡ Processing audio and creating numpy arrays...")

# Process train and test sets
processed_train = process_to_numpy_arrays(train_dataset, "train")
processed_test = process_to_numpy_arrays(test_dataset, "test")

# Create final dataset
combined_dataset = DatasetDict({
    'train': processed_train,
    'test': processed_test
})

print("\n✅ Dataset processing complete!")
print(combined_dataset)

# Verify numpy storage
sample = combined_dataset['train'][0]
print(f"\n📊 Verification:")
print(f"Input features type: {type(sample['input_features'])}")
print(f"Input features dtype: {sample['input_features'].dtype}")
print(f"Input features shape: {sample['input_features'].shape}")

⚡ Fast dataset loading with direct numpy array storage

📂 Creating SADA dataset from 374 local files...
Found 24662 SADA segments (30.0 hours)

📥 Loading Common Voice datasets...
Loaded CV: 6840 train, 6840 val, 720 test


Filter:   0%|          | 0/6840 [00:00<?, ? examples/s]

Processing 3420 CV train samples...


Adding segment info to CV train:   0%|          | 0/3420 [00:00<?, ? examples/s]

Filter:   0%|          | 0/6840 [00:00<?, ? examples/s]

Processing 3420 CV val samples...


Adding segment info to CV val:   0%|          | 0/3420 [00:00<?, ? examples/s]

Filter:   0%|          | 0/720 [00:00<?, ? examples/s]

Processing 360 CV test samples...


Adding segment info to CV test:   0%|          | 0/360 [00:00<?, ? examples/s]


🔄 Creating final dataset splits...

📊 Combined dataset:
  Train: 29035 samples (22195 SADA + 6840 CV)
  Test: 2827 samples (2467 SADA + 360 CV)

⚡ Processing audio and creating numpy arrays...

🚀 Processing train split to numpy arrays...


Processing train:   0%|          | 0/1815 [00:00<?, ?it/s]

Exception ignored from cffi callback <function SoundFile._init_virtual_io.<locals>.vio_read at 0x79dc3d4305e0>:
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/soundfile.py", line 1290, in vio_read
    @_ffi.callback("sf_vio_read")

KeyboardInterrupt: 


KeyboardInterrupt: 

In [None]:
# Calculate and display total audio hours
print("\n📊 Calculating Audio Duration Statistics...")

# FAST: Use map to get all durations at once
def get_duration(example):
    return {'duration': example['segment_length']}

# Get SADA durations in batch
print("⏱️  Computing SADA durations...")
train_durations = sada_split['train'].map(
    get_duration,
    keep_in_memory=True,
    desc="Getting train durations"
)['duration']

test_durations = sada_split['test'].map(
    get_duration,
    keep_in_memory=True,
    desc="Getting test durations"
)['duration']

# Calculate hours
sada_train_hours = sum(train_durations) / 3600
sada_test_hours = sum(test_durations) / 3600

# CV estimates (5s per sample)
cv_train_hours = len(cv_train_full) * 5 / 3600
cv_test_hours = len(cv_test_processed) * 5 / 3600

# Total hours
total_train_hours = sada_train_hours + cv_train_hours
total_test_hours = sada_test_hours + cv_test_hours

print("\n🎵 Training Set Breakdown:")
print(f"  • SADA: {sada_train_hours:.1f} hours ({len(sada_split['train']):,} segments)")
print(f"  • Common Voice: ~{cv_train_hours:.1f} hours ({len(cv_train_full):,} samples)")
print(f"  • TOTAL: ~{total_train_hours:.1f} hours")

print(f"\n🎵 Test Set Breakdown:")
print(f"  • SADA: {sada_test_hours:.1f} hours ({len(sada_split['test']):,} segments)")
print(f"  • Common Voice: ~{cv_test_hours:.1f} hours ({len(cv_test_processed):,} samples)")
print(f"  • TOTAL: ~{total_test_hours:.1f} hours")

print(f"\n📈 Grand Total: ~{total_train_hours + total_test_hours:.1f} hours of audio data")
print("─" * 50)


📊 Calculating Audio Duration Statistics...
⏱️  Computing SADA durations...


Getting train durations:   0%|          | 0/949 [00:00<?, ? examples/s]

Getting test durations:   0%|          | 0/106 [00:00<?, ? examples/s]


🎵 Training Set Breakdown:
  • SADA: 0.9 hours (949 segments)
  • Common Voice: ~0.9 hours (684 samples)
  • TOTAL: ~1.8 hours

🎵 Test Set Breakdown:
  • SADA: 0.1 hours (106 segments)
  • Common Voice: ~0.1 hours (36 samples)
  • TOTAL: ~0.2 hours

📈 Grand Total: ~2.0 hours of audio data
──────────────────────────────────────────────────


## Training and Evaluation

Now that we've prepared our data, we're ready to dive into the training pipeline.
The [🤗 Trainer](https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer)
will do much of the heavy lifting for us. All we have to do is:

- Load a pre-trained checkpoint: we need to load a pre-trained checkpoint and configure it correctly for training.

- Define a data collator: the data collator takes our pre-processed data and prepares PyTorch tensors ready for the model.

- Evaluation metrics: during evaluation, we want to evaluate the model using the [word error rate (WER)](https://huggingface.co/metrics/wer) metric. We need to define a `compute_metrics` function that handles this computation.

- Define the training configuration: this will be used by the 🤗 Trainer to define the training schedule.

Once we've fine-tuned the model, we will evaluate it on the test data to verify that we have correctly trained it
to transcribe speech in Hindi.

### Load a Pre-Trained Checkpoint

We'll start our fine-tuning run from the pre-trained Whisper `small` checkpoint,
the weights for which we need to load from the Hugging Face Hub. Again, this
is trivial through use of 🤗 Transformers!

In [43]:
from transformers import WhisperForConditionalGeneration
import torch

model = WhisperForConditionalGeneration.from_pretrained(
    "openai/whisper-large-v3-turbo",
    torch_dtype=torch.float16,
)
model.half()

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(128, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 1280)
      (layers): ModuleList(
        (0-31): 32 x WhisperEncoderLayer(
          (self_attn): WhisperSdpaAttention(
            (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
            (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1280, out_features=5120, bias=True)
          (fc2): Linear(in_features=5120, out_features=1280, bia

In [44]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params}")
print(f"Total parameters: {total_params}")
print(f"Percentage trainable: {100 * trainable_params / total_params:.2f}%")

Trainable parameters: 806958080
Total parameters: 808878080
Percentage trainable: 99.76%


We can disable the automatic language detection task performed during inference, and force the model to generate in Hindi. To do so, we set the [langauge](https://huggingface.co/docs/transformers/en/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate.language)
and [task](https://huggingface.co/docs/transformers/en/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate.task)
arguments to the generation config. We'll also set any [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)
to None, since this was the legacy way of setting the language and
task arguments:

In [45]:
model.generation_config.language = "ar"
model.generation_config.task = "transcribe"

model.generation_config.forced_decoder_ids = None

In [46]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model.config.use_cache = False  # Required for gradient checkpointing

# Define LoRA configuration
lora_config = LoraConfig(
    r=64,
    lora_alpha=128,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj",
                    "fc1", "fc2"],
                    #"proj_out",
                    #],
    lora_dropout=0.05,
    bias="none",
)

# Apply LoRA to model
model = get_peft_model(model, lora_config)
model.to('cuda')
model.print_trainable_parameters()

trainable params: 49,152,000 || all params: 858,030,080 || trainable%: 5.7285


### Define a Data Collator

The data collator for a sequence-to-sequence speech model is unique in the sense that it
treats the `input_features` and `labels` independently: the  `input_features` must be
handled by the feature extractor and the `labels` by the tokenizer.

The `input_features` are already padded to 30s and converted to a log-Mel spectrogram
of fixed dimension by action of the feature extractor, so all we have to do is convert the `input_features`
to batched PyTorch tensors. We do this using the feature extractor's `.pad` method with `return_tensors=pt`.

The `labels` on the other hand are un-padded. We first pad the sequences
to the maximum length in the batch using the tokenizer's `.pad` method. The padding tokens
are then replaced by `-100` so that these tokens are **not** taken into account when
computing the loss. We then cut the BOS token from the start of the label sequence as we
append it later during training.

We can leverage the `WhisperProcessor` we defined earlier to perform both the
feature extractor and the tokenizer operations:

In [56]:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int # Or however you pass this

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels
        input_features_list = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features_list, return_tensors="pt")

        # --- MODIFICATION: Explicitly cast input_features to float16 ---
        if "input_features" in batch and batch["input_features"].is_floating_point():
            batch["input_features"] = batch["input_features"].to(torch.float16)
        # --------------------------------------------------------------

        # Get the tokenized label sequences
        label_features_list = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features_list, return_tensors="pt")

        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # If bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        # Ensure decoder_start_token_id is correctly defined and passed
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

Let's initialise the data collator we've just defined:

In [57]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

### Evaluation Metrics

We'll use the word error rate (WER) metric, the 'de-facto' metric for assessing
ASR systems. For more information, refer to the WER [docs](https://huggingface.co/metrics/wer). We'll load the WER metric from 🤗 Evaluate:

In [58]:
import evaluate

metric = evaluate.load("wer")

We then simply have to define a function that takes our model
predictions and returns the WER metric. This function, called
`compute_metrics`, first replaces `-100` with the `pad_token_id`
in the `label_ids` (undoing the step we applied in the
data collator to ignore padded tokens correctly in the loss).
It then decodes the predicted and label ids to strings. Finally,
it computes the WER between the predictions and reference labels:

In [59]:
import unicodedata
import re

def normalize_arabic_for_wer(text):
    """Comprehensive Arabic normalization for WER calculation"""
    if not text:
        return ""

    # Remove diacritics (تشكيل)
    text = unicodedata.normalize('NFD', text)
    text = ''.join(char for char in text if unicodedata.category(char) != 'Mn')

    # Normalize alef variants (أ إ آ → ا)
    text = re.sub(r'[أإآ]', 'ا', text)

    # Normalize alef maksura (ى → ي)
    text = text.replace('ى', 'ي')

    # Normalize teh marbuta (ة → ه)
    text = text.replace('ة', 'ه')

    # Remove tatweel (kashida)
    text = text.replace('ـ', '')

    # Normalize whitespace
    text = ' '.join(text.split())

    # Remove punctuation
    text = re.sub(r'[.,!?؛،؟:\"\'()-]', '', text)

    # Convert to lowercase (if needed for your use case)
    text = text.lower()

    return text.strip()

# Test the function
print("Testing normalization:")
print(f"Original: 'مَرْحَباً بِكَ'")
print(f"Normalized: '{normalize_arabic_for_wer('مَرْحَباً بِكَ')}'")
print(f"Original: 'الطَّائِرة'")
print(f"Normalized: '{normalize_arabic_for_wer('الطَّائِرة')}'")

Testing normalization:
Original: 'مَرْحَباً بِكَ'
Normalized: 'مرحبا بك'
Original: 'الطَّائِرة'
Normalized: 'الطايره'


In [60]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Replace -100 with pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # Decode predictions and labels
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    # Apply Arabic normalization BEFORE filtering
    pred_str = [normalize_arabic_for_wer(p) for p in pred_str]
    label_str = [normalize_arabic_for_wer(l) for l in label_str]

    # Filter out empty references to avoid jiwer error
    filtered_preds = []
    filtered_labels = []

    for pred, label in zip(pred_str, label_str):
        # Skip if reference is empty after stripping
        if label.strip():
            filtered_preds.append(pred)
            filtered_labels.append(label)
        else:
            # Or add a placeholder
            filtered_preds.append(pred if pred.strip() else "[EMPTY]")
            filtered_labels.append("[EMPTY]")

    # Calculate WER only on non-empty samples
    if filtered_labels:
        wer = 100 * metric.compute(predictions=filtered_preds, references=filtered_labels)
    else:
        wer = 100.0  # Default if all empty

    return {"wer": wer}

### Define the Training Configuration

In the final step, we define all the parameters related to training. For more detail on the training arguments, refer to the Seq2SeqTrainingArguments [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments).

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [61]:
from transformers import Seq2SeqTrainingArguments

drive_output_dir = "/content/drive/MyDrive/whisper-large-v3-ar-80h"
# Give your model a specific name on the Hub
hub_repo_name = "Bruno7/whisper-large-v3-turbo-ar-sa-tuned" # <--- CHANGE THIS!

training_args = Seq2SeqTrainingArguments(
    output_dir=drive_output_dir,
    hub_model_id=hub_repo_name,  # The name for your model on the Hub

    # --- Batch Size & Steps ---
    per_device_train_batch_size=8,  # A100 40GB can often handle more, especially with LoRA, but this is a safe start.
    gradient_accumulation_steps=8,  # Effective batch size of 64. Good balance.
    learning_rate=5.5e-5,          #for LoRA.
    warmup_steps=100,
    weight_decay=0.01,
    num_train_epochs=5,

    # --- Performance & Memory ---
    torch_compile=False,             # Excellent for A100 speed.
    gradient_checkpointing=True,    # Essential for saving memory with large models.
    fp16=True,                      # Best for A100 speed & memory.
    fp16_full_eval=True,
    # bf16=False,                   # Correctly False if fp16 is True.
    tf32=True,                      # Leverages A100 Tensor Cores.

    # --- CRITICAL DATA LOADING CHANGES ---
    dataloader_num_workers=8,       # <-- SIGNIFICANT CHANGE: Was 0. Set to a positive number (e.g., 4, 8, 12 based on CPU cores)
                                    # This enables parallel data loading, preventing CPU bottleneck and GPU starvation.
                                    # 0 means data is loaded in the main process, which is very slow.
    dataloader_pin_memory=True,     # <-- SIGNIFICANT CHANGE: Was False. Set to True.
                                    # Speeds up CPU to GPU data transfer by pinning memory.
    dataloader_prefetch_factor=2, # Can uncomment if using num_workers > 0, helps prefetch batches.
    dataloader_persistent_workers=True, # Can uncomment if using num_workers > 0, keeps workers alive.

    group_by_length=False,          # Your setting: Kept False as you had issues previously.
                                    # If stable, True can sometimes improve efficiency by minimizing padding.
    optim="adamw_torch_fused",      # Good, modern, and fast optimizer.

    # --- Evaluation & Saving ---
    do_eval=True,
    save_strategy="steps",
    eval_strategy="steps",
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    generation_max_length=150,
    save_steps=50,
    eval_steps=10,
    logging_steps=5,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,

    # --- Reporting & Hub ---
    report_to=["tensorboard"],
    push_to_hub=True,
    logging_first_step=True,
    save_total_limit=3,
    # ddp_find_unused_parameters=False, # Correctly commented out for single-GPU.
)

**Note**: if one does not want to upload the model checkpoints to the Hub,
set `push_to_hub=False`.

We can forward the training arguments to the 🤗 Trainer along with our model,
dataset, data collator and `compute_metrics` function:

In [62]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=combined_dataset["train"],
    eval_dataset=combined_dataset["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

  trainer = Seq2SeqTrainer(
No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [63]:
import torch
import time
import numpy as np

print("=== A100 PERFORMANCE DIAGNOSTIC ===\n")

# 1. Check GPU
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"CUDA: {torch.version.cuda}")
print(f"PyTorch: {torch.__version__}\n")

# 2. Check current memory usage
print(f"Memory Allocated: {torch.cuda.memory_allocated() / 1e9:.1f} GB")
print(f"Memory Reserved: {torch.cuda.memory_reserved() / 1e9:.1f} GB")
print(f"Free Memory: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1e9:.1f} GB\n")

# 3. Test data loading speed
print("Testing data loading speed...")
start = time.time()
for i, batch in enumerate(trainer.get_train_dataloader()):
    if i >= 10:
        break
data_time = (time.time() - start) / 10
print(f"Average batch load time: {data_time:.3f}s")
print(f"Data loading speed: {1/data_time:.1f} batches/sec\n")

# 4. Test model forward pass
print("Testing model forward pass...")
model.eval()
with torch.no_grad():
    # Get a sample batch
    batch = next(iter(trainer.get_train_dataloader()))
    batch = {k: v.to('cuda') for k, v in batch.items() if isinstance(v, torch.Tensor)}

    # Time forward pass
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(5):
        outputs = model(**batch)
    torch.cuda.synchronize()
    forward_time = (time.time() - start) / 5

print(f"Average forward pass time: {forward_time:.3f}s")
print(f"Forward pass speed: {1/forward_time:.1f} batches/sec\n")

# 5. Check batch sizes
print(f"Train batch size: {trainer.args.per_device_train_batch_size}")
print(f"Gradient accumulation: {trainer.args.gradient_accumulation_steps}")
print(f"Effective batch size: {trainer.args.per_device_train_batch_size * trainer.args.gradient_accumulation_steps}")

# 6. Check for CPU bottlenecks
print(f"\nDataloader workers: {trainer.args.dataloader_num_workers}")
print(f"Pin memory: {trainer.args.dataloader_pin_memory}")

# 7. Test if we're CPU bound
import psutil
print(f"\nCPU count: {psutil.cpu_count()}")
print(f"CPU usage: {psutil.cpu_percent(interval=1)}%")

=== A100 PERFORMANCE DIAGNOSTIC ===

GPU: NVIDIA A100-SXM4-40GB
Memory: 42.5 GB
CUDA: 12.4
PyTorch: 2.6.0+cu124

Memory Allocated: 3.8 GB
Memory Reserved: 3.8 GB
Free Memory: 38.6 GB

Testing data loading speed...


ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/transformers/feature_extraction_utils.py", line 192, in convert_to_tensors
    tensor = as_tensor(value)
             ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/feature_extraction_utils.py", line 141, in as_tensor
    value = np.array(value)
            ^^^^^^^^^^^^^^^
ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (8, 128) + inhomogeneous part.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "<ipython-input-56-9f6622f9954f>", line 13, in __call__
    batch = self.processor.feature_extractor.pad(input_features_list, return_tensors="pt")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/feature_extraction_sequence_utils.py", line 224, in pad
    return BatchFeature(batch_outputs, tensor_type=return_tensors)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/feature_extraction_utils.py", line 78, in __init__
    self.convert_to_tensors(tensor_type=tensor_type)
  File "/usr/local/lib/python3.11/dist-packages/transformers/feature_extraction_utils.py", line 198, in convert_to_tensors
    raise ValueError(
ValueError: Unable to create tensor, you should probably activate padding with 'padding=True' to have batched tensors with the same length.


We'll save the processor object once before starting training. Since the processor is not trainable, it won't change over the course of training:

In [55]:
processor.save_pretrained(training_args.output_dir)

[]

### Training

In [None]:
import time
import torch
import gc
import os
from transformers.trainer_callback import TrainerCallback

# before starting
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()

# Check for existing checkpoints (resume training)
checkpoints = [d for d in os.listdir(drive_output_dir) if d.startswith('checkpoint-')]
checkpoints = False
if checkpoints:
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('-')[1]))
    resume_from = os.path.join(drive_output_dir, latest_checkpoint)
    print(f"🔄 Resuming from checkpoint: {resume_from}")
else:
    resume_from = None
    print("🆕 Starting fresh training")

# Training with error handling
print("\n🚀 Training starting...")
start_time = time.time()

try:
    # Train (will auto-resume if checkpoint exists)
    trainer.train(resume_from_checkpoint=resume_from)

    # Calculate final stats
    total_time = time.time() - start_time
    print(f"\n✅ Training completed successfully!")
    print(f"   Total time: {total_time/3600:.1f} hours ({total_time:.1f} seconds)")

except KeyboardInterrupt:
    print("\n⚠️ Training interrupted by user")
    print("Saving checkpoint...")
    trainer.save_model(os.path.join(drive_output_dir, "checkpoint-interrupted"))
    print("✅ Checkpoint saved. You can resume later.")

finally:
    # Always save the training log
    if hasattr(trainer.state, 'log_history'):
        import json
        log_file = os.path.join(drive_output_dir, "training_log.json")
        with open(log_file, 'w') as f:
            json.dump(trainer.state.log_history, f, indent=2)
        print(f"\n📊 Training log saved to {log_file}")

# Quick evaluation on best model
print("\n🔍 Loading best model for final evaluation...")
trainer.model = trainer.model.from_pretrained(drive_output_dir)
final_metrics = trainer.evaluate()
print(f"📈 Final evaluation WER: {final_metrics.get('eval_wer', 'N/A'):.2f}%")

In [None]:
# Analyze training results
import matplotlib.pyplot as plt
import json

# Load training history
log_file = os.path.join(drive_output_dir, "training_log.json")
with open(log_file, 'r') as f:
    history = json.load(f)

# Extract metrics
steps = [h['step'] for h in history if 'loss' in h]
train_loss = [h['loss'] for h in history if 'loss' in h]
eval_steps = [h['step'] for h in history if 'eval_loss' in h]
eval_loss = [h['eval_loss'] for h in history if 'eval_loss' in h]
eval_wer = [h['eval_wer'] for h in history if 'eval_wer' in h]

# Plot results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Loss plot
ax1.plot(steps, train_loss, label='Train Loss', alpha=0.7)
ax1.plot(eval_steps, eval_loss, label='Eval Loss', marker='o')
ax1.set_xlabel('Steps')
ax1.set_ylabel('Loss')
ax1.set_title('Training Progress')
ax1.legend()
ax1.grid(True, alpha=0.3)

# WER plot
ax2.plot(eval_steps, eval_wer, label='WER', marker='o', color='red')
ax2.set_xlabel('Steps')
ax2.set_ylabel('WER (%)')
ax2.set_title('Word Error Rate Progress')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(drive_output_dir, 'training_curves.png'), dpi=150)
plt.show()

print(f"\n📊 Training Summary:")
print(f"Initial Loss: {train_loss[0]:.4f} → Final Loss: {train_loss[-1]:.4f}")
print(f"Initial WER: {eval_wer[0]:.1f}% → Final WER: {eval_wer[-1]:.1f}%")
print(f"Best WER: {min(eval_wer):.1f}% at step {eval_steps[eval_wer.index(min(eval_wer))]}")

In [None]:
trainer.push_to_hub(**kwargs)