In [1]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
!pip install datasets evaluate jiwer transformers[torch]

Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate
  Downloading evaluate-0.4.2-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jiwer
  Downloading jiwer-3.0.4-py3-none-any.whl (21 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (40.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.8/40.8 MB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
Collecting requests>=2.32.2 (from datasets)
  Downloading requests

In [1]:
import torch
from transformers import WhisperTokenizer, WhisperForConditionalGeneration, WhisperProcessor
from datasets import load_dataset, DatasetDict
from transformers import WhisperFeatureExtractor
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer


In [8]:
checkpoint = "openai/whisper-small"
tokenizer = WhisperTokenizer.from_pretrained(checkpoint, language='en', task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained(checkpoint)
processor = WhisperProcessor.from_pretrained(checkpoint, language='en', task="transcribe")
feature_extractor = WhisperFeatureExtractor.from_pretrained(checkpoint, task="transcribe")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
torgo = DatasetDict()
torgo["train"] = load_dataset("jmaczan/TORGO-very-small", split="train")
split_datasets = torgo["train"].train_test_split(test_size=0.2, seed=42)
torgo["train"] = split_datasets["train"]
torgo["eval"] = split_datasets["test"]

prepare_dataset = lambda batch: {
    "input_features": feature_extractor(
        batch["audio"]["array"],
        sampling_rate=batch["audio"]["sampling_rate"]
    ).input_features[0],
    "labels": tokenizer(batch["transcription"]).input_ids
}

torgo["train"] = torgo["train"].map(prepare_dataset, remove_columns=torgo["train"].column_names)
torgo["eval"] = torgo["eval"].map(prepare_dataset, remove_columns=torgo["eval"].column_names)

print("Train dataset:", torgo["train"])
print("Eval dataset:", torgo["eval"])

Resolving data files:   0%|          | 0/92 [00:00<?, ?it/s]

Map:   0%|          | 0/72 [00:00<?, ? examples/s]

Map:   0%|          | 0/19 [00:00<?, ? examples/s]

Train dataset: Dataset({
    features: ['input_features', 'labels'],
    num_rows: 72
})
Eval dataset: Dataset({
    features: ['input_features', 'labels'],
    num_rows: 19
})


In [10]:
@dataclass
class DataCollator:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels

        return batch

data_collator = DataCollator(processor=processor, decoder_start_token_id=model.config.decoder_start_token_id)
metric = evaluate.load("wer")

In [11]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # We do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [12]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small-hi",  # change to a repo name of your choice
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    remove_unused_columns=False,
    warmup_steps=5,
    max_steps=200,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=20,
    eval_steps=20,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

# Trainer
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=torgo["train"],
    eval_dataset=torgo["eval"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

max_steps is given, it will override any value given in num_train_epochs


In [13]:
trainer.train()



Step,Training Loss,Validation Loss,Wer
20,1.1405,0.781267,62.0
40,0.0281,0.500371,44.0
60,0.0009,0.489232,36.0
80,0.0004,0.490731,38.0
100,0.0003,0.492956,38.0
120,0.0002,0.493935,36.0
140,0.0002,0.494853,36.0
160,0.0002,0.495201,36.0
180,0.0002,0.49522,36.0
200,0.0002,0.495602,36.0


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618

TrainOutput(global_step=200, training_loss=0.18683763453154825, metrics={'train_runtime': 1470.7967, 'train_samples_per_second': 2.176, 'train_steps_per_second': 0.136, 'total_flos': 8.311259529216e+17, 'train_loss': 0.18683763453154825, 'epoch': 40.0})

In [None]:
model.save_pretrained("./whisper-small")
processor.save_pretrained("./whisper-small")

In [None]:
#using the finetuned model
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torch
import soundfile as sf

In [None]:
model = WhisperForConditionalGeneration.from_pretrained("./whisper-small")
processor = WhisperProcessor.from_pretrained("./whisper-small")


In [None]:
audio_input, sample_rate = sf.read("resources\download.wav")

inputs = processor.feature_extractor(audio_input, sampling_rate=sample_rate, return_tensors="pt")

with torch.no_grad():
    predicted_ids = model.generate(inputs["input_features"])

transcription = processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]

print("Transcription:", transcription)