In [1]:
!pip install jiwer datasets transformers accelerate evaluate soundfile librosa

Collecting jiwer
  Downloading jiwer-3.0.3-py3-none-any.whl.metadata (2.6 kB)
Collecting datasets
  Downloading datasets-2.18.0-py3-none-any.whl.metadata (20 kB)
Collecting transformers
  Downloading transformers-4.39.3-py3-none-any.whl.metadata (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting accelerate
  Downloading accelerate-0.28.0-py3-none-any.whl.metadata (18 kB)
Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl.metadata (9.4 kB)
Collecting soundfile
  Downloading soundfile-0.12.1-py2.py3-none-manylinux_2_31_x86_64.whl.metadata (14 kB)
Collecting librosa
  Downloading librosa-0.10.1-py3-none-any.whl.metadata (8.3 kB)
Collecting rapidfuzz<4,>=3 (from jiwer)
  Downloading rapidfuzz-3.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting pyarrow>=12.0.0 (from datasets)
  Downloading pyarrow-15.0.2-cp310-cp310-manylin

In [1]:
from datasets import disable_caching

disable_caching()

In [2]:
from datasets import load_dataset

asr_data = load_dataset("audiofolder", data_dir="./export", keep_in_memory=True)

Resolving data files:   0%|          | 0/31234 [00:00<?, ?it/s]

In [3]:
asr_data = asr_data["train"].train_test_split(test_size=0.03)

In [4]:
asr_data

DatasetDict({
    train: Dataset({
        features: ['audio', 'transcription'],
        num_rows: 30296
    })
    test: Dataset({
        features: ['audio', 'transcription'],
        num_rows: 937
    })
})

In [5]:
model_id = "openai/whisper-large-v3-turbo"

In [6]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained(model_id, task="transcribe")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained(model_id)

In [8]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained(model_id, task="transcribe")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
def prepare_dataset(batch):
    audio = batch["audio"]

    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    batch["labels"] = tokenizer(batch["transcription"]).input_ids
    return batch

In [10]:
asr_data = asr_data.map(prepare_dataset, remove_columns=asr_data.column_names["train"])

Map:   0%|          | 0/30296 [00:00<?, ? examples/s]

Map:   0%|          | 0/937 [00:00<?, ? examples/s]

In [11]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [12]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [13]:
import evaluate

metric = evaluate.load("wer")

In [14]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    label_ids[label_ids == -100] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [15]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained(model_id)

In [22]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.generation_config.language = "cs"

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./out",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=5e-5,
    warmup_steps=500,
    num_train_epochs=3,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=2,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=500,
    eval_steps=500,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
)

In [27]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=asr_data["train"],
    eval_dataset=asr_data["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

In [None]:
trainer.train()

Step,Training Loss,Validation Loss


In [23]:
import torch, gc
gc.collect()
torch.cuda.empty_cache() 