### 训练whisper模型并完成测试

#### 重新加载一个模型，训练一个新的peft模型

In [None]:
from huggingface_hub import notebook_login
from datasets import load_dataset, DatasetDict, Dataset
from pprint import pprint
import os
from datasets import Audio
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
from peft import prepare_model_for_int8_training,LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model,PeftModel, PeftConfig
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl,WhisperForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
import gc
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from tokenizer import Tokenizer
import argparse
import yaml
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
from __init__ import load_model
from audio import (
    FRAMES_PER_SECOND,
    HOP_LENGTH,
    N_FRAMES,
    N_SAMPLES,
    SAMPLE_RATE,
    log_mel_spectrogram,
    pad_or_trim,
)
import torchaudio
from decoding import DecodingOptions, DecodingResult
from timing import add_word_timestamps
from tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from utils import (
    exact_div,
    format_timestamp,
    get_writer,
    make_safe,
    optional_float,
    optional_int,
    str2bool,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = 'large'
model_dir = '~/.cache/whisper'
use_peft = True
input_peft_dir = './model/'
output_peft_dir = './model/'

dataset_dir =  '/home/towardspring/hdd2/dataset/asr/whisper_en'
task = 'translate-chinese'
language = "english"
language_abbr = "en"

In [None]:
# 加载数据

dataset = load_dataset("audiofolder", data_dir="/home/towardspring/hdd2/dataset/asr/whisper_en",streaming=True)
train_data = dataset["train"]
test_data = dataset["test"]



In [None]:
# 加载模型、分词器、特征提取器
model = load_model(model_name, device=device, download_root=model_dir)
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)

pprint([(n, type(m)) for n, m in model.named_modules()])
model = prepare_model_for_int8_training(model)   # 量化模型

def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)

model.encoder.conv1.register_forward_hook(make_inputs_require_grad)

config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05, bias="none")   # 配置Lora

model = get_peft_model(model, config)   # 获取PEFT模型
model.print_trainable_parameters()   # 打印可训练参数

In [None]:
# 数据预处理函数
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:

    def __call__(self, batch: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # 预处理batch数据到16k采样率
        target_sampling_rate = 16000

        # 创建 Resample 转换对象
        resample_transform = torchaudio.transforms.Resample(44100, target_sampling_rate)

        # 对批次数据进行采样率转换
        batch = resample_transform(batch)
        
        # compute log-Mel input features from input audio array
        batch["input_features"] = log_mel_spectrogram(batch["audio"]["array"], sampling_rate=batch["sampling_rate"]).input_features[0]

        # encode target text to label ids
        batch["labels"] = tokenizer(batch["transcription"]).input_ids
        
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in batch]
        batch = pad_or_trim(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in batch]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding()

In [None]:
# 配置训练器参数
metric = evaluate.load("wer")   # 评估指标
training_args = Seq2SeqTrainingArguments(
    output_dir="reach-vb/test",  # change to a repo name of your choice
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-3,
    warmup_steps=50,
    num_train_epochs=100,
    evaluation_strategy="steps",
    save_steps = 1000,
    save_total_limit = 1,
    fp16=True,
    per_device_eval_batch_size=8,
    generation_max_length=128,
    logging_steps=100,
    max_steps=6000, # only for testing purposes, remove this from your final run :)
    remove_unused_columns=False,  # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
    label_names=["labels"],  # same reason as above
)

# This callback helps to save only the adapter weights and remove the base model weights.
class SavePeftModelCallback(TrainerCallback):
    def on_save(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
        return control

In [None]:
# 配置模型训练器
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_data,
    eval_dataset=train_data,
    data_collator=data_collator,
    tokenizer=tokenizer,
    callbacks=[SavePeftModelCallback],
)

model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

In [None]:
# 训练
