# Finetune the whisper to realize the any2chinese task

Whisper模型是OpenAI推出的语音识别模型，该模型支持多种语言的语音识别，将其转录为对应语言文本。同时，该模型还有翻译功能，能将多种语言的语音转录为英文文本。在Whisper模型诞生之后，很多开源项目对其进行了微调或改进，使得模型能够在小语种上语音识别能力得到了进一步的增强，并且在C语言编译、PEFT等技术的加持下，Whisper模型现在可以在小设备上加速运行。

但是，当前，还鲜有看到有项目进一步的开发Whisper对新语种语音转录功能，或者进一步微调模型使其支持将各种语言的语音直接转录为中文文本。显然，这将进一步的发掘模型的潜力。

本教程借助transformers包收录的Whisper模型，在其基础上对该模型的tokenizer进行了修改。这样使得Whisper模型具备转录新语种以及直接转录为中文文本的能力。本教程将该方法命名为Whisper-Any2Chinese模型。

In [1]:
from huggingface_hub import notebook_login, login
from datasets import load_dataset, DatasetDict
from pprint import pprint
from transformers import WhisperFeatureExtractor
import os
from tokenization_whisper import WhisperTokenizer
from transformers import WhisperProcessor

from datasets import Audio
import torch
from config.whisper_config import MODEL_CONFIG_FILE_ROOT_PATH, DATASET_NAME_OR_PATH, LANGUAGE, LANGUAGE_ABBR, MODEL_NAME_OR_PATH, TASK, NEW_TOKENS, OUTPUT_DIR
from dataclasses import dataclass
from typing import Any, Dict, List, Union

import evaluate
from transformers import WhisperForConditionalGeneration,Seq2SeqTrainingArguments
from peft import prepare_model_for_int8_training,LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model,PeftModel, PeftConfig
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl,WhisperForConditionalGeneration, Seq2SeqTrainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR


import gc
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers.models.whisper.english_normalizer import BasicTextNormalizer




Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues


In [None]:
# login()

在微调模型之前，你需要准备好训练数据。训练数据的格式以及数据的元信息文件如本教程的“./data/*”中的两个数据集所示。其中“en_mini”表示英文数据集（包含对应的转录中文），“ug_mini”表示维语数据集（包含对应的转录中文）。你可以根据自己的需求，准备好自己的数据集。本教程的数据来自于commonvoice数据集，你可以在[这里](https://commonvoice.mozilla.org/)下载到该数据集。（只做demo，因此本教程的数据集只包含了很少的数据）

In [2]:
# load dataset
common_voice = DatasetDict()

common_voice["train"] = load_dataset(DATASET_NAME_OR_PATH, split="train")
common_voice["test"] = load_dataset(DATASET_NAME_OR_PATH, split="test")
common_voice["validation"] = load_dataset(DATASET_NAME_OR_PATH, split="validation")

print(common_voice)

Resolving data files:   0%|          | 0/55 [00:00<?, ?it/s]

Downloading and preparing dataset audiofolder/ug_mini to /home/towardspring/.cache/huggingface/datasets/audiofolder/ug_mini-3bff67ba257ec961/0.0.0/6cbdd16f8688354c63b4e2a36e1585d05de285023ee6443ffd71c4182055c0fc...


Downloading data files:   0%|          | 0/56 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Downloading data files:   0%|          | 0/16 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Downloading data files:   0%|          | 0/13 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset audiofolder downloaded and prepared to /home/towardspring/.cache/huggingface/datasets/audiofolder/ug_mini-3bff67ba257ec961/0.0.0/6cbdd16f8688354c63b4e2a36e1585d05de285023ee6443ffd71c4182055c0fc. Subsequent calls will reuse this data.


Resolving data files:   0%|          | 0/55 [00:00<?, ?it/s]

Found cached dataset audiofolder (/home/towardspring/.cache/huggingface/datasets/audiofolder/ug_mini-3bff67ba257ec961/0.0.0/6cbdd16f8688354c63b4e2a36e1585d05de285023ee6443ffd71c4182055c0fc)


Resolving data files:   0%|          | 0/55 [00:00<?, ?it/s]

Found cached dataset audiofolder (/home/towardspring/.cache/huggingface/datasets/audiofolder/ug_mini-3bff67ba257ec961/0.0.0/6cbdd16f8688354c63b4e2a36e1585d05de285023ee6443ffd71c4182055c0fc)


DatasetDict({
    train: Dataset({
        features: ['audio', 'sentence', 'chinese'],
        num_rows: 55
    })
    test: Dataset({
        features: ['audio', 'sentence', 'chinese'],
        num_rows: 12
    })
    validation: Dataset({
        features: ['audio', 'sentence', 'chinese'],
        num_rows: 15
    })
})


In [3]:
# load model
print(f'MOEDL_NAME_OR_PATH:{MODEL_NAME_OR_PATH}')
feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL_NAME_OR_PATH)
tokenizer = WhisperTokenizer.from_pretrained(MODEL_NAME_OR_PATH, language=LANGUAGE_ABBR, task=TASK)
processor = WhisperProcessor.from_pretrained(MODEL_NAME_OR_PATH, language=LANGUAGE_ABBR, task=TASK)
print('Model has been loaded.')

MOEDL_NAME_OR_PATH:/home/towardspring/.cache/huggingface/hub/models--openai--whisper-small/snapshots/e34e8ae444c29815eca53e11383ea13b2e362eb0
Model has been loaded.


Whisper本身是不支持添加新语种或者多语言中文转录功能。因此，我们需要在微调之前，添加新的语种标签以及中文转录翻译提示标签。例如，若想实现对维吾尔语的转录，我们需要向tokenizer模块添加<|ug|>标签；若想实现对多语种的中文转录翻译功能，需要向tokenizer模块添加<|translate-chinese|>标签。本教材只对英语语音转录翻译中文文本做出示范，您可以根据您的需要添加任意语音或翻译标签。

In [4]:
import json

tokenizer.add_tokens(NEW_TOKENS)
tokenizer.save_pretrained(MODEL_CONFIG_FILE_ROOT_PATH)   # Note: you need to change the path (config.whisper_config.py) to your own path

special_tokens_file = os.path.join(MODEL_CONFIG_FILE_ROOT_PATH, 'special_tokens_map.json')
with open(special_tokens_file, 'r') as f:
    special_tokens_map = json.load(f)
print(special_tokens_map)
additional_special_tokens = special_tokens_map['additional_special_tokens']
# print(len(additional_special_tokens))
for new_item in NEW_TOKENS:
    if new_item not in additional_special_tokens:
        additional_special_tokens.append(new_item)

# save to special_tokens_map.json
special_tokens_map['additional_special_tokens'] = additional_special_tokens
json_str = json.dumps(special_tokens_map,ensure_ascii=False)

with open(special_tokens_file, 'w') as f:
    f.write(json_str)

{'additional_special_tokens': ['<|endoftext|>', '<|startoftranscript|>', '<|en|>', '<|zh|>', '<|de|>', '<|es|>', '<|ru|>', '<|ko|>', '<|fr|>', '<|ja|>', '<|pt|>', '<|tr|>', '<|pl|>', '<|ca|>', '<|nl|>', '<|ar|>', '<|sv|>', '<|it|>', '<|id|>', '<|hi|>', '<|fi|>', '<|vi|>', '<|he|>', '<|uk|>', '<|el|>', '<|ms|>', '<|cs|>', '<|ro|>', '<|da|>', '<|hu|>', '<|ta|>', '<|no|>', '<|th|>', '<|ur|>', '<|hr|>', '<|bg|>', '<|lt|>', '<|la|>', '<|mi|>', '<|ml|>', '<|cy|>', '<|sk|>', '<|te|>', '<|fa|>', '<|lv|>', '<|bn|>', '<|sr|>', '<|az|>', '<|sl|>', '<|kn|>', '<|et|>', '<|mk|>', '<|br|>', '<|eu|>', '<|is|>', '<|hy|>', '<|ne|>', '<|mn|>', '<|bs|>', '<|kk|>', '<|sq|>', '<|sw|>', '<|gl|>', '<|mr|>', '<|pa|>', '<|si|>', '<|km|>', '<|sn|>', '<|yo|>', '<|so|>', '<|af|>', '<|oc|>', '<|ka|>', '<|be|>', '<|tg|>', '<|sd|>', '<|gu|>', '<|am|>', '<|yi|>', '<|lo|>', '<|uz|>', '<|fo|>', '<|ht|>', '<|ps|>', '<|tk|>', '<|nn|>', '<|mt|>', '<|sa|>', '<|lb|>', '<|my|>', '<|bo|>', '<|tl|>', '<|mg|>', '<|as|>', '<|tt|>

在开始训练之前，我们需要对tokenizer进行修改，使其支持新加入的提示标签。

在完成必须的修改之后，我们需要利用载入数据并对tokenizer进行测试，检查语言编码与任务编码是否正确。若获得如样例所示的样本，则证明tokenizer修改正确。

In [5]:
input_str = common_voice["train"][0]["chinese"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)


print(f"Input: {input_str}")
print(f"Decoded with special: {decoded_with_special}")
print(f"Decoded without special: {decoded_str}")
print(f"Are equal: {input_str == decoded_str}")

ids = [50258, 51866, 51865, 50363, 913, 250, 25395, 8225, 15868, 98, 9254, 15368, 171, 120, 234, 913, 251, 17015, 36269, 7093, 1543, 50257]
ids = [50258, 51866, 51865, 50363, 913, 250, 25395, 8225, 15868, 98, 9254, 15368, 171, 120, 234, 913, 251, 17015, 36269, 7093, 1543, 50257]
Input: “我不能滥用这个，”女孩想。
Decoded with special: <|startoftranscript|><|ug|><|translate-chinese|><|notimestamps|>“我不能滥用这个，”女孩想。<|endoftext|>
Decoded without special: <|ug|>“我不能滥用这个，”女孩想。
Are equal: False


In [None]:
# 数据处理
pprint(common_voice["train"][1])
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))

pprint(common_voice["train"][1])


def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids
    batch["labels"] = tokenizer(batch["chinese"]).input_ids   # Note: you need to change the language to your own language
    return batch


common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=21)

common_voice["train"]

In [None]:
# Configure training parameters
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        print(f'input_features:{input_features}, size of input_features[0]:{len(input_features[0])}')
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

# metric = evaluate.load("wer")

model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME_OR_PATH, load_in_8bit=True, device_map="auto")

In [None]:
# import evaluate
# metric = evaluate.load("wer")


# def compute_metrics(pred):
#     pred_ids = pred.predictions
#     label_ids = pred.label_ids

#     # replace -100 with the pad_token_id
#     label_ids[label_ids == -100] = tokenizer.pad_token_id

#     # we do not want to group tokens when computing the metrics
#     pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
#     label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
#     wer = 100 * metric.compute(predictions=pred_str, references=label_str)

#     return {"wer": wer}

In [None]:
# model = prepare_model_for_int8_training(model, output_embedding_layer_name="proj_out")
model = prepare_model_for_int8_training(model)

def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)

model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)

config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")

model = get_peft_model(model, config)
model.print_trainable_parameters()


training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,  # change to a repo name of your choice
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-3,
    warmup_steps=50,
    num_train_epochs=500,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit = 6,
    fp16=True,
    per_device_eval_batch_size=8,
    generation_max_length=128,
    logging_steps=100,
    # max_steps=6000, # only for testing purposes, remove this from your final run :)s
    remove_unused_columns=False,  # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
    label_names=["labels"],  # same reason as above
)

# This callback helps to save only the adapter weights and remove the base model weights.
class SavePeftModelCallback(TrainerCallback):
    def on_save(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
        return control

In [None]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=common_voice["train"],
    eval_dataset=common_voice["validation"],
    data_collator=data_collator,
    tokenizer=processor.feature_extractor,
    callbacks=[SavePeftModelCallback],
)

model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

In [None]:
print(torch.cuda.device_count())
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"  # disable logging to Weights & Biases for this run
trainer.train()

In [None]:
trainer.save_model("trained_model/whisper-small-en2chinese")