# Training Whisper for Bengali ASR 🎵👂➡️🇧🇩📝

The [Whisper models, developed by OpenAI](https://openai.com/research/whisper), are the best open source multilingual Automatic Speech Recognition models available.

## Architecture 🏗️


The architecture is relatively straightforward: it's a sequence-to-sequence model containing an audio encoder and a text decoder. The feature extractor turns the 1d audio signal (amplitude over time) to a log-mel spectrogram. The encoder creates hidden states which are then passed to the decoder to generate text. It's basically BART with a few convolution layers at the input.

![](https://huggingface.co/blog/assets/111_fine_tune_whisper/whisper_architecture.svg)


## Training Data 📊

While the architecture is nothing novel, the OpenAI team created an enormous dataset on nearly 700k labeled audio data. Of that data, 117k hours were on multilingual ASR. Sadly, it looks like there was barely any Bengali data (Less than 2 hours!) in that training set, but that's why we're here! This competition provides 1200 hours of data which can be used to fine-tune the existing Whisper models. Even though Whisper wasn't trained on much Bengali data, it will be able to learn very quickly.

Image from [figure 11 on page 27 of whisper paper](https://arxiv.org/pdf/2212.04356.pdf)
![](https://raw.githubusercontent.com/nbroad1881/kaggle-images/main/bengali-ai-asr/whisper-bn.png)




## In this notebook 📓


### Preprocessing 🎵➡️🔢

I show how to preprocess the data and how to train. The preprocessing can be a bit slow, so it is recommended to use a CPU notebook, or better yet, a bulkier CPU VM in your favorite cloud. I used another instance to preprocess 10k training samples and 1k eval samples.

### Training 🏋️

I show how to train on 2x T4 GPUs using Hugging Face transformers in pytorch. These GPUs have tensor cores which makes them go fast while using mixed precision. The code doesn't do anything fancy, but it can serve as a starting point for understanding ASR training. [@mbmmurad](https://kaggle.com/mbmmurad) pointed out that `bangla-speech-processing/BanglaASR` is a whisper model that has already been trained on Bangla. This notebook will do more fine-tuning  on 10k out of 960k training samples. 


### Validation 🕵️

I take a random split of files for train and validation, but since the domain is unknown, it is hard to get a good sense of how well the model will do on out-of-domain data. It would be nice if train.csv had domains as a column.


Notebook Version | Model | WER 
- | - | -
1 | openai/whisper-base | 0.69
2 | bangla-speech-processing/BanglaASR | 0.529

---

## How to improve the model 💪

1. Train on more data. 
  - Like I mentioned above, there over 900k files and I only used 10k of them. 
2. Use a larger model. 
  - I'm only using the small-sized model which has 244M params. The largest model [1550M params](https://huggingface.co/openai/whisper-large-v2), which is probably too big, but there is also a [769M](https://huggingface.co/openai/whisper-medium) model.
3. Data augmentations.
  - I use [spec augment](https://arxiv.org/abs/1904.08779)
  - You could also consider [BPE dropout](https://arxiv.org/abs/1910.13267) 
4. Better hyperparameters.
  - learning rate is usually the most important

In [1]:
# Necessary packages
!pip install -U evaluate datasets transformers jiwer -q

# Preprocessing

This should be done on CPU because this saves GPU time and the CPUs in CPU notebooks are faster than the CPUs in GPU notebooks. Better yet, use an even better CPU in a cloud VM.

In [2]:
%%writefile preprocess.py

import logging
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import datasets
from datasets import DatasetDict, load_dataset

from transformers import (
    AutoConfig,
    AutoFeatureExtractor,
    AutoTokenizer,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    set_seed,
)


warnings.simplefilter("ignore")


@dataclass
class Config:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={
            "help": "Path to pretrained model or model identifier from huggingface.co/models"
        }
    )

    apply_spec_augment: bool = field(
        default=False,
        metadata={
            "help": "Whether to apply *SpecAugment* data augmentation to the input features. This is currently only relevant for Wav2Vec2, HuBERT, WavLM and Whisper models."
        },
    )
    overwrite_cache: bool = field(
        default=False,
        metadata={"help": "Overwrite the cached training and evaluation sets"},
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    forced_decoder_ids: List[List[int]] = field(
        default=None,
        metadata={
            "help": (
                "A list of pairs of integers which indicates a mapping from generation indices to token indices "
                "that will be forced before sampling. For example, [[0, 123]] means the first generated token "
                "will always be a token of index 123."
            )
        },
    )
    suppress_tokens: List[int] = field(
        default=None,
        metadata={"help": "A list of tokens that will be suppressed at generation."},
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    audio_column_name: str = field(
        default="audio",
        metadata={
            "help": "The name of the dataset column containing the audio data. Defaults to 'audio'"
        },
    )
    text_column_name: str = field(
        default="text",
        metadata={
            "help": "The name of the dataset column containing the text data. Defaults to 'text'"
        },
    )
    max_duration_in_seconds: float = field(
        default=20.0,
        metadata={
            "help": (
                "Truncate audio files that are longer than `max_duration_in_seconds` seconds to"
                " 'max_duration_in_seconds`"
            )
        },
    )
    min_duration_in_seconds: float = field(
        default=0.0,
        metadata={
            "help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"
        },
    )
    preprocessing_only: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether to only do data preprocessing and skip training. This is especially useful when data"
                " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
                " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
                " can consequently be loaded in distributed training"
            )
        },
    )
    language: str = field(
        default=None,
        metadata={
            "help": (
                "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
                "only. For English speech recognition, it should be set to `None`."
            )
        },
    )

    data_dir: str = field(
        default="/kaggle/input/bengaliai-speech",
        metadata={
            "help": (
                "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
                "only. For English speech recognition, it should be set to `None`."
            )
        },
    )


logger = logging.getLogger(__name__)


def main():
    parser = HfArgumentParser((Config, Seq2SeqTrainingArguments))

    cfg, training_args = parser.parse_args_into_dataclasses()

    # Set seed before initializing model.
    set_seed(training_args.seed)

    config = AutoConfig.from_pretrained(cfg.model_name_or_path)

    config.update(
        {
            "forced_decoder_ids": cfg.forced_decoder_ids,
            "suppress_tokens": cfg.suppress_tokens,
        }
    )

    # SpecAugment for whisper models
    if getattr(config, "model_type", None) == "whisper":
        config.update({"apply_spec_augment": cfg.apply_spec_augment})

    feature_extractor = AutoFeatureExtractor.from_pretrained(cfg.model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path)

    raw_datasets = DatasetDict()

    data_dir = Path(cfg.data_dir)

    raw_ds = load_dataset("csv", data_files=str(data_dir / "train.csv"), split="train")

    def add_mp3_path(examples):
        return {
            "audio": [str(data_dir / f"train_mp3s/{id_}.mp3") for id_ in examples["id"]]
        }

    raw_ds = raw_ds.map(add_mp3_path, batched=True, num_proc=cfg.preprocessing_num_workers)
    raw_ds = raw_ds.train_test_split(
        test_size=0.2, seed=training_args.seed, shuffle=True
    )

    raw_datasets["train"] = raw_ds["train"]
    raw_datasets["validation"] = raw_ds["test"]

    if cfg.max_train_samples:
        raw_datasets["train"] = raw_datasets["train"].select(
            range(min(cfg.max_train_samples, len(raw_datasets["train"])))
        )

    if cfg.max_eval_samples:
        raw_datasets["validation"] = raw_datasets["validation"].select(
            range(min(cfg.max_eval_samples, len(raw_datasets["validation"])))
        )

    # cast to audio
    raw_datasets = raw_datasets.cast_column(
        cfg.audio_column_name,
        datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
    )

    if cfg.language is not None:
        # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
        tokenizer.set_prefix_tokens(language=cfg.language, task="transcribe")

    # Preprocessing the datasets.
    # We need to read the audio files as arrays and tokenize the targets.
    max_input_length = cfg.max_duration_in_seconds * feature_extractor.sampling_rate
    min_input_length = cfg.min_duration_in_seconds * feature_extractor.sampling_rate
    audio_column_name = cfg.audio_column_name
    num_workers = cfg.preprocessing_num_workers
    text_column_name = cfg.text_column_name
    model_input_name = feature_extractor.model_input_names[0]
    # if SpecAugment is used for whisper models, return attention_mask to guide the mask along time axis
    forward_attention_mask = (
        getattr(config, "model_type", None) == "whisper"
        and getattr(config, "apply_spec_augment", False)
        and getattr(config, "mask_time_prob", 0) > 0
    )

    def prepare_dataset(batch):
        # process audio
        sample = batch[audio_column_name]
        inputs = feature_extractor(
            sample["array"],
            sampling_rate=sample["sampling_rate"],
            return_attention_mask=forward_attention_mask,
        )
        # process audio length
        batch[model_input_name] = inputs.get(model_input_name)[0]
        batch["input_length"] = len(sample["array"])
        if forward_attention_mask:
            batch["attention_mask"] = inputs.get("attention_mask")[0]

        # process targets
        input_str = batch[text_column_name]
        batch["labels"] = tokenizer(input_str).input_ids
        return batch

    with training_args.main_process_first(desc="dataset map pre-processing"):
        vectorized_datasets = raw_datasets.map(
            prepare_dataset,
            remove_columns=next(iter(raw_datasets.values())).column_names,
            num_proc=cfg.preprocessing_num_workers,
            desc="preprocess train dataset",
        )

    # filter data that is shorter than min_input_length or longer than
    # max_input_length
    def is_audio_in_length_range(length):
        return length > min_input_length and length < max_input_length

    vectorized_datasets = vectorized_datasets.filter(
        is_audio_in_length_range,
        num_proc=num_workers,
        input_columns=["input_length"],
    )

    def save_chunks(ds, chunk_size, prefix):
        for i in range(0, len(ds), chunk_size):
            ii = min(i + chunk_size, len(ds))

            ds.select(range(i, ii)).to_parquet(f"{prefix}_{i}_to_{ii}.parquet")

    # for large datasets it is advised to run the preprocessing on a
    # single machine first with `args.preprocessing_only` since there will mostly likely
    # be a timeout when running the script in distributed mode.
    # In a second step `args.preprocessing_only` can then be set to `False` to load the
    # cached dataset
    if cfg.preprocessing_only:
        cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
        logger.info(f"Data preprocessing finished. Files cached at {cache}.")

        save_chunks(
            vectorized_datasets["train"], 1000, f"train_{training_args.output_dir}"
        )
        save_chunks(
            vectorized_datasets["validation"], 1000, f"eval_{training_args.output_dir}"
        )
        return


if __name__ == "__main__":
    main()

Writing preprocess.py


# I've already uploaded a preprocessed dataset with 10k train samples and 1k eval samples so I won't run this

In [3]:
# !python preprocess.py \
#  --model_name_or_path "openai/whisper-small" \
#  --language "Bengali" \
#  --output_dir "75k-samples" \
#  --preprocessing_num_workers 90 \
#  --preprocessing_only \
#  --text_column_name "sentence" \
#  --data_dir "data" \
#  --min_duration_in_seconds 2 \
#  --max_duration_in_seconds 30 \
#  --max_train_samples 75000 \
#  --max_eval_samples 5000 \
#  --apply_spec_augment

# Training script

Adapted from [here](https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py)

In [4]:
%%writefile train.py

#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence speech recognition.
"""
# You can also adapt this script on your own sequence to sequence speech
# recognition task. Pointers for this are left as comments.

import logging
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import datasets
import evaluate
import torch
from datasets import DatasetDict, load_dataset

import transformers
from transformers import (
    AutoConfig,
    AutoFeatureExtractor,
    AutoModelForSpeechSeq2Seq,
    AutoProcessor,
    AutoTokenizer,
    HfArgumentParser,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process

logger = logging.getLogger(__name__)


@dataclass
class Config:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={
            "help": "Path to pretrained model or model identifier from huggingface.co/models"
        }
    )
    freeze_encoder: bool = field(
        default=False,
        metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."},
    )
    train_data_dir: str = field(default=None, metadata={"help": "Path to train files"})
    validation_data_dir: str = field(
        default=None, metadata={"help": "Path to eval files"}
    )
    forced_decoder_ids: List[List[int]] = field(
        default=None,
        metadata={
            "help": (
                "A list of pairs of integers which indicates a mapping from generation indices to token indices "
                "that will be forced before sampling. For example, [[0, 123]] means the first generated token "
                "will always be a token of index 123."
            )
        },
    )
    suppress_tokens: List[int] = field(
        default=None,
        metadata={"help": "A list of tokens that will be suppressed at generation."},
    )
    apply_spec_augment: bool = field(
        default=False,
        metadata={
            "help": "Whether to apply *SpecAugment* data augmentation to the input features. This is currently only relevant for Wav2Vec2, HuBERT, WavLM and Whisper models."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    audio_column_name: str = field(
        default="audio",
        metadata={
            "help": "The name of the dataset column containing the audio data. Defaults to 'audio'"
        },
    )
    text_column_name: str = field(
        default="text",
        metadata={
            "help": "The name of the dataset column containing the text data. Defaults to 'text'"
        },
    )

    language: str = field(
        default=None,
        metadata={
            "help": (
                "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
                "only. For English speech recognition, it should be set to `None`."
            )
        },
    )


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor ([`WhisperProcessor`])
            The processor used for processing the data.
        decoder_start_token_id (`int`)
            The begin-of-sentence of the decoder.
        forward_attention_mask (`bool`)
            Whether to return attention_mask.
    """

    processor: Any
    decoder_start_token_id: int
    forward_attention_mask: bool

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        model_input_name = self.processor.model_input_names[0]
        input_features = [
            {model_input_name: feature[model_input_name]} for feature in features
        ]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.feature_extractor.pad(
            input_features, return_tensors="pt"
        )

        if self.forward_attention_mask:
            batch["attention_mask"] = torch.LongTensor(
                [feature["attention_mask"] for feature in features]
            )

        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


def main():
    # 1. Parse input arguments
    parser = HfArgumentParser((Config, Seq2SeqTrainingArguments))

    cfg, training_args = parser.parse_args_into_dataclasses()

    # 2. Detecting last checkpoint and eventually continue from last checkpoint
    last_checkpoint = None
    if (
        os.path.isdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif (
            last_checkpoint is not None and training_args.resume_from_checkpoint is None
        ):
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # 3. Load dataset
    vectorized_datasets = DatasetDict()

    if training_args.do_train:
        train_files = list(map(str, Path(cfg.train_data_dir).glob("train*.parquet")))
        vectorized_datasets["train"] = load_dataset(
            "parquet", data_files=train_files, split="train"
        )

    if training_args.do_eval:
        eval_files = list(map(str, Path(cfg.validation_data_dir).glob("eval*.parquet")))
        vectorized_datasets["eval"] = load_dataset(
            "parquet", data_files=eval_files, split="train"
        )

    # 4. Load pretrained model, tokenizer, and feature extractor
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    config = AutoConfig.from_pretrained(
        cfg.model_name_or_path,
    )

    config.update(
        {
            "forced_decoder_ids": cfg.forced_decoder_ids,
            "suppress_tokens": cfg.suppress_tokens,
        }
    )

    # SpecAugment for whisper models
    if getattr(config, "model_type", None) == "whisper":
        config.update({"apply_spec_augment": cfg.apply_spec_augment})

    feature_extractor = AutoFeatureExtractor.from_pretrained(
        cfg.model_name_or_path,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        cfg.model_name_or_path,
    )
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        cfg.model_name_or_path,
        config=config,
    )

    if model.config.decoder_start_token_id is None:
        raise ValueError(
            "Make sure that `config.decoder_start_token_id` is correctly defined"
        )

    if cfg.freeze_encoder:
        model.freeze_encoder()
        model.model.encoder.gradient_checkpointing = False

    if cfg.language is not None:
        # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
        tokenizer.set_prefix_tokens(language=cfg.language, task="transcribe")

    if cfg.max_train_samples is not None:
        vectorized_datasets["train"] = vectorized_datasets["train"].select(
            range(cfg.max_train_samples)
        )

    if cfg.max_eval_samples is not None:
        vectorized_datasets["eval"] = vectorized_datasets["eval"].select(
            range(cfg.max_eval_samples)
        )

    # 5. Load Metric
    metric = evaluate.load("wer")

    def compute_metrics(pred):
        pred_ids = pred.predictions

        pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id

        pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        # we do not want to group tokens when computing the metrics
        label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)

        wer = metric.compute(predictions=pred_str, references=label_str)

        return {"wer": wer}

    # 6. Create a single speech processor
    # make sure all processes wait until data is saved
    with training_args.main_process_first():
        # only the main process saves them
        if is_main_process(training_args.local_rank):
            # save feature extractor, tokenizer and config
            feature_extractor.save_pretrained(training_args.output_dir)
            tokenizer.save_pretrained(training_args.output_dir)
            config.save_pretrained(training_args.output_dir)

    processor = AutoProcessor.from_pretrained(training_args.output_dir)

    # if SpecAugment is used for whisper models, return attention_mask to guide the mask along time axis
    forward_attention_mask = (
        getattr(config, "model_type", None) == "whisper"
        and getattr(config, "apply_spec_augment", False)
        and getattr(config, "mask_time_prob", 0) > 0
    )

    # 7. Define data collator
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(
        processor=processor,
        decoder_start_token_id=model.config.decoder_start_token_id,
        forward_attention_mask=forward_attention_mask,
    )

    # 8. Initialize Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
        eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
        tokenizer=feature_extractor,
        data_collator=data_collator,
        compute_metrics=compute_metrics
        if training_args.predict_with_generate
        else None,
    )

    # 9. Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        trainer.save_model()  # Saves the feature extractor too for easy upload

        metrics = train_result.metrics
        max_train_samples = (
            cfg.max_train_samples
            if cfg.max_train_samples is not None
            else len(vectorized_datasets["train"])
        )
        metrics["train_samples"] = min(
            max_train_samples, len(vectorized_datasets["train"])
        )
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # 10. Evaluation
    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate(
            metric_key_prefix="eval",
            max_length=training_args.generation_max_length,
            num_beams=training_args.generation_num_beams,
        )
        max_eval_samples = (
            cfg.max_eval_samples
            if cfg.max_eval_samples is not None
            else len(vectorized_datasets["eval"])
        )
        metrics["eval_samples"] = min(
            max_eval_samples, len(vectorized_datasets["eval"])
        )

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)


if __name__ == "__main__":
    main()

Writing train.py


In [5]:
data_dir = "/kaggle/input/bengali-ai-asr-10k"

!torchrun --nproc_per_node 2 train.py \
 --model_name_or_path "bangla-speech-processing/BanglaASR" \
 --train_data_dir $data_dir \
 --validation_data_dir $data_dir \
 --language "Bengali" \
 --output_dir "whisper-base-bn" \
 --do_train \
 --do_eval \
 --fp16 \
 --group_by_length \
 --predict_with_generate \
 --dataloader_num_workers 1 \
 --overwrite_output_dir \
 --per_device_train_batch_size 4 \
 --length_column_name "input_length" \
 --report_to "none" \
 --metric_for_best_model "wer" \
 --greater_is_better False \
 --evaluation_strategy "epoch" \
 --save_strategy "epoch" \
 --save_total_limit 1 \
 --logging_steps 10 \
 --gradient_checkpointing \
 --warmup_steps 50 \
 --apply_spec_augment True \
 --num_train_epochs 3 \
 --learning_rate "1e-5"

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']
Downloading and preparing dataset parquet/default to /root/.cache/huggingface/datasets/parquet/default-81385bd261e94688/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7...
Downloading data files: 100%|███████