In [1]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="Korean", task="transcribe")

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="Korean", task="transcribe")

  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [2]:
from datasets import load_dataset
dataset = load_dataset("Bingsu/zeroth-korean")

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['audio', 'text'],
        num_rows: 22263
    })
    test: Dataset({
        features: ['audio', 'text'],
        num_rows: 457
    })
})

In [4]:
dataset['train'][0]

{'audio': {'path': None,
  'array': array([-3.05175781e-05,  0.00000000e+00, -3.05175781e-05, ...,
          0.00000000e+00,  0.00000000e+00, -6.10351562e-05]),
  'sampling_rate': 16000},
 'text': '인사를 결정하는 과정에서 당 지도부가 우 원내대표 및 원내지도부와 충분한 상의를 거치지 않은 채 일방적으로 인사를 했다는 불만도 원내지도부를 중심으로 흘러나왔다'}

In [5]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["text"]).input_ids
    return batch

In [6]:
input_dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names["train"], num_proc=4)

In [7]:
input_dataset

DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 22263
    })
    test: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 457
    })
})

In [8]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


In [9]:
import evaluate
metric = evaluate.load('cer')

In [10]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [11]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

In [12]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    cer = 100 * metric.compute(predictions=pred_str, references=label_str)

    # return {"wer": wer}
    return {"cer": cer}


In [24]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-tiny-ko",  # change to a repo name of your choice
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=500,
    num_train_epochs = 2,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="cer", # 한국어의 경우 wer보다 cer이 나음
    greater_is_better=False,
    push_to_hub=False,  # True로 하면 huggingface에 push
)

In [25]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=input_dataset["train"],
    eval_dataset=input_dataset["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)


In [26]:
trainer.train()



Epoch,Training Loss,Validation Loss,Cer
1,0.3503,0.406906,13.46184
2,0.2425,0.346621,11.569567


Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}
There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


TrainOutput(global_step=2784, training_loss=0.3755353070881175, metrics={'train_runtime': 4302.8633, 'train_samples_per_second': 10.348, 'train_steps_per_second': 0.647, 'total_flos': 1.09618047885312e+18, 'train_loss': 0.3755353070881175, 'epoch': 2.0})

In [None]:
# Package                   Version
# ------------------------- ------------
# absl-py                   2.1.0
# accelerate                0.30.1
# aiofiles                  23.2.1
# aiohttp                   3.9.5
# aiosignal                 1.3.1
# altair                    5.3.0
# annotated-types           0.6.0
# anyio                     4.3.0
# asttokens                 2.4.1
# attrs                     23.2.0
# audioread                 3.0.1
# Brotli                    1.0.9
# certifi                   2024.2.2
# cffi                      1.16.0
# charset-normalizer        2.0.4
# click                     8.1.7
# comm                      0.2.2
# contourpy                 1.2.1
# cycler                    0.12.1
# datasets                  2.19.1
# debugpy                   1.6.7
# decorator                 5.1.1
# dill                      0.3.8
# dnspython                 2.6.1
# email_validator           2.1.1
# evaluate                  0.4.2
# exceptiongroup            1.2.0
# executing                 2.0.1
# fastapi                   0.111.0
# fastapi-cli               0.0.3
# ffmpy                     0.3.2
# filelock                  3.13.1
# fonttools                 4.51.0
# frozenlist                1.4.1
# fsspec                    2024.3.1
# gmpy2                     2.1.2
# gradio                    4.31.3
# gradio_client             0.16.3
# grpcio                    1.63.0
# h11                       0.14.0
# httpcore                  1.0.5
# httptools                 0.6.1
# httpx                     0.27.0
# huggingface-hub           0.23.0
# idna                      3.7
# importlib_metadata        7.1.0
# importlib_resources       6.4.0
# ipykernel                 6.29.3
# ipython                   8.24.0
# jedi                      0.19.1
# Jinja2                    3.1.3
# jiwer                     3.0.4
# joblib                    1.4.2
# jsonschema                4.22.0
# jsonschema-specifications 2023.12.1
# jupyter_client            8.6.1
# jupyter_core              5.7.2
# kiwisolver                1.4.5
# lazy_loader               0.4
# librosa                   0.10.2.post1
# llvmlite                  0.42.0
# Markdown                  3.6
# markdown-it-py            3.0.0
# MarkupSafe                2.1.3
# matplotlib                3.9.0
# matplotlib-inline         0.1.7
# mdurl                     0.1.2
# mkl-fft                   1.3.8
# mkl-random                1.2.4
# mkl-service               2.4.0
# mpmath                    1.3.0
# msgpack                   1.0.8
# multidict                 6.0.5
# multiprocess              0.70.16
# nest_asyncio              1.6.0
# networkx                  3.1
# numba                     0.59.1
# numpy                     1.26.4
# orjson                    3.10.3
# packaging                 24.0
# pandas                    2.2.2
# parso                     0.8.4
# pexpect                   4.9.0
# pickleshare               0.7.5
# pillow                    10.3.0
# pip                       24.0
# platformdirs              4.2.2
# pooch                     1.8.1
# prompt-toolkit            3.0.42
# protobuf                  5.26.1
# psutil                    5.9.8
# ptyprocess                0.7.0
# pure-eval                 0.2.2
# pyarrow                   16.1.0
# pyarrow-hotfix            0.6
# pycparser                 2.22
# pydantic                  2.7.1
# pydantic_core             2.18.2
# pydub                     0.25.1
# Pygments                  2.18.0
# pyparsing                 3.1.2
# PySocks                   1.7.1
# python-dateutil           2.9.0
# python-dotenv             1.0.1
# python-multipart          0.0.9
# pytz                      2024.1
# PyYAML                    6.0.1
# pyzmq                     25.1.2
# rapidfuzz                 3.9.0
# referencing               0.35.1
# regex                     2024.5.15
# requests                  2.31.0
# rich                      13.7.1
# rpds-py                   0.18.1
# ruff                      0.4.4
# safetensors               0.4.3
# scikit-learn              1.4.2
# scipy                     1.13.0
# semantic-version          2.10.0
# setuptools                69.5.1
# shellingham               1.5.4
# six                       1.16.0
# sniffio                   1.3.1
# soundfile                 0.12.1
# soxr                      0.3.7
# stack-data                0.6.2
# starlette                 0.37.2
# sympy                     1.12
# tensorboard               2.16.2
# tensorboard-data-server   0.7.2
# threadpoolctl             3.5.0
# tiktoken                  0.7.0
# tokenizers                0.19.1
# tomlkit                   0.12.0
# toolz                     0.12.1
# torch                     2.1.2
# torchaudio                2.1.2
# torchvision               0.16.2
# tornado                   6.4
# tqdm                      4.66.4
# traitlets                 5.14.3
# transformers              4.40.2
# triton                    2.1.0
# typer                     0.12.3
# typing_extensions         4.11.0
# tzdata                    2024.1
# ujson                     5.10.0
# urllib3                   2.2.1
# uvicorn                   0.29.0
# uvloop                    0.19.0
# watchfiles                0.21.0
# wcwidth                   0.2.13
# websockets                11.0.3
# Werkzeug                  3.0.3
# wheel                     0.43.0
# xxhash                    3.4.1
# yarl                      1.9.4
# zipp                      3.17.0