In [1]:
import os
import torch
import torchaudio
from datasets import Dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import evaluate
import re

# 1. 加载数据
def load_audio_files(audio_dir):
    audio_files = []
    for filename in os.listdir(audio_dir):
        if filename.endswith(".wav"):
            audio_path = os.path.join(audio_dir, filename)
            audio_files.append(audio_path)
    audio_files.sort(key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
    return audio_files

def load_transcripts(transcript_dir):
    transcripts = []
    txt_files = [f for f in os.listdir(transcript_dir) if f.endswith('.txt')]
    txt_files.sort(key=lambda x: int(os.path.splitext(x)[0]))
    for file in txt_files:
        with open(os.path.join(transcript_dir, file), 'r', encoding='utf-8', errors='ignore') as f:
            text = f.read().strip()
            transcripts.append(text)
    return transcripts

def resample_audio(waveform, original_rate, target_rate):
    resampler = torchaudio.transforms.Resample(orig_freq=original_rate, new_freq=target_rate)
    return resampler(waveform)

def create_dataset(audio_files, transcripts):
    data = {"audio": [], "sentence": []}
    for audio_file, transcript in zip(audio_files, transcripts):
        waveform, sample_rate = torchaudio.load(audio_file)
        if sample_rate != 16000:
            waveform = resample_audio(waveform, sample_rate, 16000)
        audio = {"array": waveform.squeeze().numpy(), "sampling_rate": 16000}
        data["audio"].append(audio)
        data["sentence"].append(transcript)
    return Dataset.from_dict(data)

audio_dir = "./training_data"
transcript_dir = "./reference"

audio_files = load_audio_files(audio_dir)
transcripts = load_transcripts(transcript_dir)
dataset = create_dataset(audio_files, transcripts)

# 2. 划分数据集
total_size = len(dataset)
eval_size = int(0.1 * total_size)

eval_dataset = dataset.select(range(eval_size))

# 3. 加载预训练模型和处理器
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="yue", task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model.eval()

# 4. 对评估数据集进行转录
def transcribe_audio(batch):
    input_features = processor.feature_extractor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"], return_tensors="pt").input_features
    input_features = input_features.to(model.device)
    with torch.no_grad():
        predicted_ids = model.generate(input_features)
    transcription = processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    return transcription

eval_dataset = eval_dataset.map(lambda batch: {"prediction": transcribe_audio(batch)}, remove_columns=["audio"])

# 去掉标点符号和空格
def clean_text(text):
    text = re.sub(r'[^\w\s]', '', text)  # 去掉标点符号
    text = re.sub(r'\s+', '', text)  # 去掉多余的空格
    return text

def split_chars(text):
    return " ".join(list(text))

def tokenize_text(text):
    cleaned_text = clean_text(text)
    tokenized_text = split_chars(cleaned_text)
    return tokenized_text

# 5. 计算WER
predictions = eval_dataset["prediction"]
references = eval_dataset["sentence"]

# 对预测和参考句子进行分词和清理
predictions_tokenized = [tokenize_text(pred) for pred in predictions]
references_tokenized = [tokenize_text(ref) for ref in references]

# 打印分词结果以供检查
for pred_tok, ref_tok in zip(predictions_tokenized, references_tokenized):
    print(f"Tokenized Prediction: {pred_tok}")
    print(f"Tokenized Reference: {ref_tok}")
    print("------")

# 使用 evaluate 计算并输出WER
metric = evaluate.load("wer")
wer_result = metric.compute(predictions=predictions_tokenized, references=references_tokenized)
print(f"WER: {wer_result * 100:.2f}%")


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Map:   0%|          | 0/18 [00:00<?, ? examples/s]

Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Tokenized Prediction: 油 烧 要 从 薄 到 好 再 从 厚 度 薄
Tokenized Reference: 讀 書 要 從 薄 到 厚 再 從 厚 到 薄
------
Tokenized Prediction: 走 為 為 讀 書 走 海 奔 住 行 義 回 讀 有 家 子 一 個 書
Tokenized Reference: 所 謂 會 讀 書 就 係 本 住 誠 意 去 讀 有 價 值 嘅 書
------
Tokenized Prediction: 侯 修 喜 人 更 懂 得 行 手 人 生
Tokenized Reference: 好 書 使 人 更 懂 得 享 受 人 生
------
Tokenized Prediction: 打 開 你 們 有 幸 好 笑 我 們 周 侯 已 用 一 多 雲 南 坤 秀 聆 聽 中 外 不 同 舉 作
Tokenized Reference: 打 開 呢 本 有 聲 好 書 我 哋 就 可 以 用 耳 朵 閱 覽 群 書 聆 聽 中 外 不 同 著 作
------
Tokenized Prediction: 香 港 點 台 在 座
Tokenized Reference: 香 港 電 臺 製 作
------
Tokenized Prediction: 有 聲 好 笑
Tokenized Reference: 有 聲 好 書
------
Tokenized Prediction: 所 以 這 疼 人 的 仇 情 留 後 故 事
Tokenized Reference: 最 折 騰 人 的 籌 錢 留 學 故 事
------
Tokenized Prediction: 為 了 掃 請 六 號 以 利 卡 家 人 其 中 最 令 人 燈 無 紮 絲 的 女 子 是 這 時 好 家 李 東 方
Tokenized Reference: 為 咗 索 錢 留 學 而 累 及 家 人 其 中 最 令 人 燈 目 咋 舌 嘅 例 子 係 歷 史 學 家 黎 東 方
------
Tokenized Prediction: 可 能 鴨 家 人 連 龐 姬 曾 佑 都 輪 到 一 團
Tokenized Reference: 佢 令 一 家 人 連 旁 支 親 友 都 亂 作 一 團
------
To

Using the latest cached version of the module from C:\Users\lenovo\.cache\huggingface\modules\evaluate_modules\metrics\evaluate-metric--wer\85bee9e4216a78bb09b2d0d500f6af5c23da58f9210e661add540f5df6630fcd (last modified on Fri Jul 12 21:27:14 2024) since it couldn't be found locally at evaluate-metric--wer, or remotely on the Hugging Face Hub.


WER: 60.92%


In [7]:
import os
import torchaudio
from datasets import load_dataset, Dataset, DatasetDict
from transformers import WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments, GenerationConfig
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch
import re
import evaluate

# Load and preprocess the data
def load_audio_files(audio_dir):
    audio_files = []
    for filename in os.listdir(audio_dir):
        if filename.endswith(".wav"):
            audio_path = os.path.join(audio_dir, filename)
            audio_files.append(audio_path)
    # Sort files by filename
    audio_files.sort(key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
    return audio_files

def load_transcripts(transcript_dir):
    transcripts = []
    txt_files = [f for f in os.listdir(transcript_dir) if f.endswith('.txt')]
    txt_files.sort(key=lambda x: int(os.path.splitext(x)[0]))  # Sort files by filename
    for file in txt_files:
        with open(os.path.join(transcript_dir, file), 'r', encoding='utf-8', errors='ignore') as f:
            text = f.read().strip()
            transcripts.append(text)
    return transcripts

def create_dataset(audio_files, transcripts):
    data = {"audio": [], "sentence": []}
    for audio_file, transcript in zip(audio_files, transcripts):
        audio = {"path": audio_file, "array": torchaudio.load(audio_file)[0].numpy(), "sampling_rate": 16000}
        data["audio"].append(audio)
        data["sentence"].append(transcript)
    return Dataset.from_dict(data)

audio_dir = "./training_data"
transcript_dir = "./reference"

audio_files = load_audio_files(audio_dir)
transcripts = load_transcripts(transcript_dir)
dataset = create_dataset(audio_files, transcripts)

# Calculate the dataset indices
total_size = len(dataset)
eval_size = int(0.1 * total_size)
train_size = total_size - eval_size

# Split the dataset
eval_dataset = dataset.select(range(eval_size))
train_dataset = dataset.select(range(eval_size, total_size))

# Initialize the processor
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="yue", task="transcribe")

# Preprocess the dataset
def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids
    return batch

train_dataset = train_dataset.map(prepare_dataset, remove_columns=train_dataset.column_names)
eval_dataset = eval_dataset.map(prepare_dataset, remove_columns=eval_dataset.column_names)

# Define data collator
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

# Define a function to remove punctuation marks and spaces.
def clean_text(text):
    text = re.sub(r'[^\w\s]', '', text)  # remove punctuation marks
    text = re.sub(r'\s+', '', text)  # remove spaces
    return text

def split_chars(text):
    return " ".join(list(text))

def tokenize_text(text):
    cleaned_text = clean_text(text)
    tokenized_text = split_chars(cleaned_text)
    return tokenized_text

# Define compute metrics function
metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    pred_str_tokenized = [tokenize_text(pred) for pred in pred_str]
    label_str_tokenized = [tokenize_text(label) for label in label_str]

    wer = 100 * metric.compute(predictions=pred_str_tokenized, references=label_str_tokenized)
    return {"wer": wer}

# Load the pre-trained Whisper model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

# Define generation configuration
generation_config = GenerationConfig.from_model_config(model.config)
generation_config.max_length = 448
generation_config.suppress_tokens = []
generation_config.begin_suppress_tokens = [220, 50257]
model.generation_config = generation_config

# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small-finetuned",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    warmup_steps=50,
    max_steps=1000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=2,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=100,
    eval_steps=100,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
)

# Initialize the Trainer
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

# Start training
trainer.train()



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Map:   0%|          | 0/162 [00:00<?, ? examples/s]

Map:   0%|          | 0/18 [00:00<?, ? examples/s]

Using the latest cached version of the module from C:\Users\lenovo\.cache\huggingface\modules\evaluate_modules\metrics\evaluate-metric--wer\85bee9e4216a78bb09b2d0d500f6af5c23da58f9210e661add540f5df6630fcd (last modified on Fri Jul 12 21:27:14 2024) since it couldn't be found locally at evaluate-metric--wer, or remotely on the Hugging Face Hub.
max_steps is given, it will override any value given in num_train_epochs


Step,Training Loss,Validation Loss,Wer
100,0.5249,1.225501,47.509579
200,0.041,1.337362,48.659004
300,0.0079,1.37144,42.145594
400,0.0014,1.380215,42.145594
500,0.0009,1.34477,43.678161
600,0.0003,1.389543,42.911877
700,0.0003,1.397231,42.911877
800,0.0002,1.401869,42.911877
900,0.0002,1.405211,42.911877
1000,0.0002,1.406217,42.911877


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618

Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618

TrainOutput(global_step=1000, training_loss=0.2644736334555782, metrics={'train_runtime': 1204.4842, 'train_samples_per_second': 3.321, 'train_steps_per_second': 0.83, 'total_flos': 1.15434160128e+18, 'train_loss': 0.2644736334555782, 'epoch': 24.691358024691358})

In [10]:
# 生成预测结果
def generate_predictions(trainer, eval_dataset):
    predictions = trainer.predict(eval_dataset)
    pred_ids = predictions.predictions
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    return pred_str

# 获取评估数据集的实际标签
def get_eval_labels(eval_dataset):
    label_str = []
    for sample in eval_dataset:
        labels = sample["labels"]
        labels[labels == -100] = processor.tokenizer.pad_token_id
        label_str.append(processor.tokenizer.decode(labels, skip_special_tokens=True))
    return label_str

# 生成预测结果并获取实际标签
predictions = generate_predictions(trainer, eval_dataset)
eval_labels = get_eval_labels(eval_dataset)

# 展示前几个样本的预测结果与实际标签对比
for i in range(18):
    print(f"Sample {i+1}:")
    print(f"Prediction: {predictions[i]}")
    print(f"Reference: {eval_labels[i]}")
    print("-" * 50)


Sample 1:
Prediction: 讀書要從國到學再從學到博
Reference: 讀書要從薄到厚再從厚到薄
--------------------------------------------------
Sample 2:
Prediction: 讀為回讀書就係搬出成績去讀要家子嘅書
Reference: 所謂會讀書就係本住誠意去讀有價值嘅書
--------------------------------------------------
Sample 3:
Prediction: 學生使人更懂得喊愁人心
Reference: 好書使人更懂得享受人生
--------------------------------------------------
Sample 4:
Prediction: 但係呢本有先學校我哋就可以用二多元來管住我
Reference: 打開呢本有聲好書我哋就可以用耳朵閱覽群書聆聽中外不同著作
--------------------------------------------------
Sample 5:
Prediction: 香港電台接住
Reference: 香港電臺製作
--------------------------------------------------
Sample 6:
Prediction: 有生學生
Reference: 有聲好書
--------------------------------------------------
Sample 7:
Prediction: 最積擋人的潮情留學故事
Reference: 最折騰人的籌錢留學故事
--------------------------------------------------
Sample 8:
Prediction: 為咗授清留學而離開家人其中最令人擔心父母
Reference: 為咗索錢留學而累及家人其中最令人{燈}目咋舌嘅例子係歷史學家黎東方
--------------------------------------------------
Sample 9:
Prediction: 距離一家人連同學親友都練做一天
Reference: 佢令一家人連旁支親友都亂作一團
---------------------