### 训练whisper模型并完成测试

#### 重新加载一个模型，训练一个新的peft模型

In [2]:
from huggingface_hub import notebook_login
from datasets import load_dataset, DatasetDict, Dataset
from pprint import pprint
import os
from datasets import Audio
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
from peft import prepare_model_for_int8_training,LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model,PeftModel, PeftConfig
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl,WhisperForConditionalGeneration, Seq2SeqTrainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
import gc
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from tokenizer import Tokenizer
import argparse
import yaml
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
from __init__ import load_model
from audio import (
    FRAMES_PER_SECOND,
    HOP_LENGTH,
    N_FRAMES,
    N_SAMPLES,
    SAMPLE_RATE,
    log_mel_spectrogram,
    pad_or_trim,
)

from decoding import DecodingOptions, DecodingResult
from timing import add_word_timestamps
from tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from utils import (
    exact_div,
    format_timestamp,
    get_writer,
    make_safe,
    optional_float,
    optional_int,
    str2bool,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = 'large'
model_dir = '~/.cache/whisper'
use_peft = True
input_peft_dir = './model/'
output_peft_dir = './model/'

dataset_dir =  '/home/towardspring/hdd2/dataset/asr/whisper_en'
task = 'translate-chinese'
language = "english"
language_abbr = "en"

In [5]:
# 加载数据

dataset = load_dataset("audiofolder", data_dir="/home/towardspring/hdd2/dataset/asr/whisper_en",streaming=True)
train_data = dataset["train"]
test_data = dataset["test"]



Resolving data files:   0%|          | 0/80511 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/34506 [00:00<?, ?it/s]

{'audio': {'path': '/home/towardspring/hdd2/dataset/asr/whisper_en/data/train/common_voice_en_34925857.mp3', 'array': array([0.00000000e+00, 6.67174345e-12, 8.06623171e-12, ...,
       6.04606248e-05, 1.68784260e-04, 9.82716447e-05]), 'sampling_rate': 32000}, 'transcription': 'It is consumed domestically and exported to other countries.', 'chinese': '它在国内消费并出口到其他国家。'}


In [None]:
# 加载模型、分词器、特征提取器
model = load_model(model_name, device=device, download_root=model_dir)
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)

In [6]:
# 数据预处理函数
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:

    def __call__(self, batch: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # 预处理batch数据到16k采样率
        batch = batch.cast_column("audio", Audio(sampling_rate=16000))

        # load and resample audio data from 48 to 16kHz
        audio = batch["audio"]

        # compute log-Mel input features from input audio array
        batch["input_features"] = log_mel_spectrogram(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

        # encode target text to label ids
        batch["labels"] = tokenizer(batch["transcription"]).input_ids
        
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in batch]
        batch = pad_or_trim(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in batch]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding()

{'audio': {'path': '/home/towardspring/hdd2/dataset/asr/whisper_en/data/train/common_voice_en_34925857.mp3', 'array': array([0.00000000e+00, 6.67174345e-12, 8.06623171e-12, ...,
       6.04606248e-05, 1.68784260e-04, 9.82716447e-05]), 'sampling_rate': 32000}, 'transcription': 'It is consumed domestically and exported to other countries.', 'chinese': '它在国内消费并出口到其他国家。'}
{'audio': {'path': '/home/towardspring/hdd2/dataset/asr/whisper_en/data/train/common_voice_en_34925857.mp3', 'array': array([ 2.91038305e-11,  0.00000000e+00,  4.72937245e-11, ...,
       -1.29326945e-05,  2.53636681e-05,  1.27412495e-04]), 'sampling_rate': 16000}, 'transcription': 'It is consumed domestically and exported to other countries.', 'chinese': '它在国内消费并出口到其他国家。'}
