In [1]:
%%capture
!pip install git+https://github.com/huggingface/datasets.git@3.5.0
!pip install git+https://github.com/huggingface/transformers.git
!pip install soundfile
!pip install jiwer==3.1.0
!pip install evaluate==0.3.0

In [2]:
from google.colab import drive
drive.mount('/content/gdrive/')

Mounted at /content/gdrive/


In [58]:
from datasets import load_dataset

kinya_ = load_dataset("mbazaNLP/kinyarwanda-tts-dataset")

Resolving data files:   0%|          | 0/3993 [00:00<?, ?it/s]

In [59]:
kinya_

DatasetDict({
    train: Dataset({
        features: ['audio'],
        num_rows: 3992
    })
})

In [60]:
import pandas as pd
import os

csv_path = "/content/gdrive/MyDrive/tts-dataset.csv"      # ↔ 你的 csv 路径
df = pd.read_csv(csv_path)

# 把文件名提取成 key，建立 dict 便于快速查找
# 例如 "clips/001.wav" → 只保留 "001.wav"（取决于你 audio.path 的格式）
def basename(p):
    return os.path.basename(p.strip())

text_lookup = {basename(row["file"]): row["text"] for _, row in df.iterrows()}


In [61]:
text_lookup

{'TTS_1_2': 'ntivuga ko nehemiya yakoreye umugabo wa esiteri umwami ahasuwerusi ahubwo yakoreye uwamusimbuye',
 'TTS_1_3': 'iyo nzu ni nto ariko ni nini bihagije kuri twe',
 'TTS_1_4': 'amaze kubona izi mbwa ngo yageze ikigali atekereza icyo yakora ngo nibwo yahise azizana mu rwanda',
 'TTS_1_5': "abana banjye batandatu umugabo wanjye n'abavandimwe banjye babiri bahise bicwa",
 'TTS_1_6': 'seyoboka yavuze ko adahakana ko bariya bantu bishwe ariko ko atari we wabishe',
 'TTS_1_7': 'kugira ngo iyi kipe ikomeze biyisaba kwihagararaho ntizishyurwe ibi bitego cyangwa ngo itsindwe ibindi',
 'TTS_1_8': "kuko habaga umurishyo w'ingoma wavugaga umenyesha yuko kanaka uwo ari igicibwa",
 'TTS_1_9': "ryari ryegeranye n'iryacu ntangira no kumukunda yatumaga numva nkunzwe cyane",
 'TTS_1_10': 'ariko hari indirimbo ijya imfasha duhuriyeho kandi twese tuzi',
 'TTS_1_11': 'nakubwiye kenshi ko ntakunda gushyirwa ku karubanda none unshyize kuri radiyo',
 'TTS_1_12': "uwo mukinnyi yabajijwe icyo yifuza ku

In [63]:
kinya_["train"][20]["audio"]["path"]

'/root/.cache/huggingface/hub/datasets--mbazaNLP--kinyarwanda-tts-dataset/snapshots/0f8b622361419262cb23bd36e31c8e182fbff375/audio/TTS 10_117.wav'

In [69]:
from datasets import Audio

ds = kinya_["train"]                       # 原 Dataset，有 audio 一列

# 给 audio 列显式声明类型（若你还没 cast）
ds = ds.cast_column("audio", Audio(sampling_rate=16000))

def add_text(example):
    fname = os.path.basename(example["audio"]["path"])
    fname_spaced = fname.replace(" ", "_", 1)[:-4]
    example["text"] = text_lookup.get(fname_spaced, "")   # 如遇缺失给空串
    return example

ds = ds.map(add_text, num_proc =4)          # ↑ 多进程加速，可视机器核数调整


Map (num_proc=4):   0%|          | 0/3992 [00:00<?, ? examples/s]

In [70]:
from datasets import DatasetDict

kinya_split = DatasetDict({"train": ds})
print(kinya_split)


DatasetDict({
    train: Dataset({
        features: ['audio', 'text'],
        num_rows: 3992
    })
})


In [71]:
kinya_split["train"][0]

{'audio': {'path': None,
  'array': array([-1.52587891e-04, -1.83105469e-04, -1.22070312e-04, ...,
         -9.15527344e-05, -9.15527344e-05, -9.15527344e-05]),
  'sampling_rate': 16000},
 'text': "iracyahanganye n'ikibazo cy'abantu bayikunda kurenza urugero ku buryo bishobora kugira ingaruka ku buzima bwabo"}

In [72]:
from datasets import load_dataset, DatasetDict

ds = kinya_split["train"].shuffle(seed=42)      # 打乱后再切分
small = ds.train_test_split(train_size=3600, test_size=392, seed=42)

train_small = small["train"]
test_small  = small["test"]

print(len(train_small), len(test_small))    # 3600 392


3600 392


In [73]:
kinya = DatasetDict({
    "train": train_small,
    "test":  test_small
})

In [74]:
kinya

DatasetDict({
    train: Dataset({
        features: ['audio', 'text'],
        num_rows: 3600
    })
    test: Dataset({
        features: ['audio', 'text'],
        num_rows: 392
    })
})

In [75]:
kinya["train"][25]

{'audio': {'path': None,
  'array': array([-0.00039673, -0.00067139, -0.00054932, ..., -0.00027466,
         -0.00024414, -0.00024414]),
  'sampling_rate': 16000},
 'text': "data azatangira ikiruhuko cy'izabukuru nyuma y'imyaka hafi mirongo itatu amaze akora"}

In [76]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

In [77]:
show_random_elements(kinya["train"].remove_columns(["audio"]), num_examples=20)

Unnamed: 0,text
0,abajijwe agaciro k'iyi ndirimbo yavuze ko yamuhenze kuko yayitanzeho asaga miliyoni enye
1,igihe cy'ubukonje bukabije nigeze kubamo ni icyo namaze i sani furansisiko
2,abafasha ba trump bavuga ko ibya uwo madamu nta shingiro bifite
3,kuburyo habonetse amahirwe yo gukina filime zo ku rwego mpuzamahanga nakwitabira amajonjora y'abakinnyi
4,uyu mugoroba ni bwo biteganyijwe ko umukobwa wa mbere asezererwa muri iri rushanwa
5,gushaka inkunga bifasha igihugu mu buryo bw'iterambere ubukungu ndetse igihugu cyikazamuka muruhando mpuzamahanga
6,nashatse abantu banshyirira amabuye ku gikoni kuko igisenge cyari cyagurutse cyagiye
7,iyi korari izaririmba indirimbo zayo ziryoheye amatwi mu mbyino zinogeye ijisho
8,gukoresha abagenzuzi b'imari mu turere aho gukoresha abakozi b'inzego zirebwa n'iyi gahunda binyuranyije n'itegeko
9,intare nayo irye ingwe maze na njye mbarya mwese


In [78]:
import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'

def remove_special_characters(batch):
    batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower()
    return batch

In [79]:
kinya = kinya.map(remove_special_characters)

Map:   0%|          | 0/3600 [00:00<?, ? examples/s]

Map:   0%|          | 0/392 [00:00<?, ? examples/s]

In [80]:
show_random_elements(kinya["train"].remove_columns(["audio"]))

Unnamed: 0,text
0,kubera iki tutatangirira ku bizana cyane uwo mwotsi nk'imodoka zirekura iyo myuka mibi
1,hanagarajwe kandi ibyo nyampinga azaba asabwa gukora mu gihe yamaze kwegukana ikamba
2,abaturage bibumbiye mu makoperative bahawe imbuto y'imyumbati n'ikigo cy'igihugu gishinzwe ubuhinzi
3,si ibanga nta n'ubwo biteye isoni kuburyo icyo gihe cyiza kigeze nabyihererana
4,agira ati ababyeyi bagomba kumenya ko kwigisha abana ubuzima bw'imyororokere ari inshingano zabo
5,kuva mu cyiciro cya mbere ku isonga biraha amahirwe kiyovu yo kukigumamo
6,aha bagaragarizwaga ibitabo by'imishinga ishyirwa mu bikorwa n'abafungiye muri gereza ya nyanza
7,kalisa azi ko mu rubanza hazavamo impozamarira zitubutse akaba ashaka kuzihezaho umugore wa nyakwigendera
8,haravugwa ikibazo cy'ibitabo byanditse mu nyandiko z'abatabona bikiri bike mu gihugu
9,alisiya ni umuririmbyikazi wo muri leta zunze ubumwe za amerika


In [81]:
def extract_all_chars(batch):
  all_text = " ".join(batch["text"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

In [82]:
vocabs = kinya.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=kinya.column_names["train"])

Map:   0%|          | 0/3600 [00:00<?, ? examples/s]

Map:   0%|          | 0/392 [00:00<?, ? examples/s]

In [83]:
vocabs

DatasetDict({
    train: Dataset({
        features: ['vocab', 'all_text'],
        num_rows: 1
    })
    test: Dataset({
        features: ['vocab', 'all_text'],
        num_rows: 1
    })
})

In [84]:
vocabs["train"]["vocab"]

[['c',
  'q',
  'i',
  'j',
  'e',
  'u',
  't',
  'z',
  'm',
  's',
  'o',
  'a',
  'b',
  'w',
  'f',
  'p',
  ' ',
  'h',
  'k',
  'v',
  "'",
  'l',
  'd',
  'g',
  'n',
  'y',
  'r']]

In [85]:
vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))

In [86]:
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

{'c': 0,
 'q': 1,
 'i': 2,
 'j': 3,
 'e': 4,
 'u': 5,
 't': 6,
 'm': 7,
 's': 8,
 'o': 9,
 'a': 10,
 'b': 11,
 'w': 12,
 'f': 13,
 'p': 14,
 ' ': 15,
 'h': 16,
 'k': 17,
 'v': 18,
 "'": 19,
 'n': 20,
 'l': 21,
 'd': 22,
 'g': 23,
 'z': 24,
 'y': 25,
 'r': 26}

In [87]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

In [88]:
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)

29

In [89]:
import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [90]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

In [91]:
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)

In [92]:
from transformers import Wav2Vec2Processor
processor= Wav2Vec2Processor (feature_extractor=feature_extractor,tokenizer=tokenizer)

In [93]:
processor.save_pretrained("/content/gdrive/MyDrive/wav2vec2-base-kinya-test")

[]

In [94]:
tokenizer.decoder

{0: 'c',
 1: 'q',
 2: 'i',
 3: 'j',
 4: 'e',
 5: 'u',
 6: 't',
 7: 'm',
 8: 's',
 9: 'o',
 10: 'a',
 11: 'b',
 12: 'w',
 13: 'f',
 14: 'p',
 16: 'h',
 17: 'k',
 18: 'v',
 19: "'",
 20: 'n',
 21: 'l',
 22: 'd',
 23: 'g',
 24: 'z',
 25: 'y',
 26: 'r',
 15: '|',
 27: '[UNK]',
 28: '[PAD]'}

In [95]:
kinya["train"][0]

{'audio': {'path': None,
  'array': array([-3.66210938e-04, -6.10351562e-04, -5.79833984e-04, ...,
          3.05175781e-05,  6.10351562e-05,  0.00000000e+00]),
  'sampling_rate': 16000},
 'text': "kayumba yavuze ko kuba abanyamakuru bakennye biterwa n'uko ubukungu bwifashe"}

In [96]:
import soundfile as sf

def speech_file_to_array_fn(batch):
    speech_array = batch["audio"]["array"]
    sampling_rate = batch["audio"]["sampling_rate"]
    batch["speech"] = speech_array
    batch["sampling_rate"] = sampling_rate
    batch["target_text"] = batch["text"]
    return batch

In [97]:
kinya = kinya.map(speech_file_to_array_fn, remove_columns=kinya.column_names["train"], num_proc=4)

Map (num_proc=4):   0%|          | 0/3600 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/392 [00:00<?, ? examples/s]

In [98]:
import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(kinya["train"]))

ipd.Audio(data=np.asarray(kinya["train"][rand_int]["speech"]), autoplay=True, rate=16000)

In [99]:
rand_int = random.randint(0, len(kinya["train"]))

print("Target text:", kinya["train"][rand_int]["target_text"])
print("Input array shape:", np.asarray(kinya["train"][rand_int]["speech"]).shape)
print("Sampling rate:", kinya["train"][rand_int]["sampling_rate"])

Target text: imfungwa za politiki zifatwa mu buryo butandukanye n'ubw'izindi zifatwamo
Input array shape: (72000,)
Sampling rate: 16000


In [100]:
def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    assert (
        len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values

    with processor.as_target_processor():
        batch["labels"] = processor(batch["target_text"]).input_ids
    return batch

In [101]:
kinya_prepared = kinya.map(prepare_dataset, remove_columns=kinya.column_names["train"], batch_size=8, num_proc=4, batched=True)

Map (num_proc=4):   0%|          | 0/3600 [00:00<?, ? examples/s]



Map (num_proc=4):   0%|          | 0/392 [00:00<?, ? examples/s]



In [102]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from transformers import Wav2Vec2Processor
@dataclass
class DataCollatorCTCWithPadding:

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [103]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [104]:
from evaluate import load

wer_metric = load("wer")

Downloading builder script: 0.00B [00:00, ?B/s]

In [105]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [106]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base",
    gradient_checkpointing=True,
    vocab_size=len(processor.tokenizer),
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
)


config.json: 0.00B [00:00, ?B/s]



pytorch_model.bin:   0%|          | 0.00/380M [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [107]:
model.config.update({"vocab_size":len(tokenizer)})

In [108]:
model.freeze_feature_extractor()



In [109]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  # output_dir="/content/gdrive/MyDrive/wav2vec2-base-wolof-test",
  output_dir="/content/gdrive/MyDrive/project7_1_kinya",
  group_by_length=True,
  per_device_train_batch_size=4,
  eval_strategy="steps",
  num_train_epochs=30,
  fp16=False,
  save_steps=500,
  eval_steps=500,
  logging_steps=500,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=2,
)

In [111]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=kinya_prepared["train"],
    eval_dataset=kinya_prepared["test"],
    tokenizer=processor.feature_extractor,
)

  trainer = Trainer(


In [112]:
trainer.train()



<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mvickyliuqy[0m ([33mvickyliuqy-t-bingen[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Step,Training Loss,Validation Loss,Wer
500,3.5221,1.65546,0.992499
1000,0.9343,0.55761,0.820189
1500,0.5826,0.394237,0.661595
2000,0.4642,0.355213,0.600729
2500,0.4041,0.326972,0.565581
3000,0.3471,0.290958,0.507072
3500,0.3222,0.276342,0.473853
4000,0.2809,0.277218,0.454565
4500,0.271,0.26017,0.44985
5000,0.2332,0.255855,0.422417




TrainOutput(global_step=27000, training_loss=0.21907731049149123, metrics={'train_runtime': 9714.6591, 'train_samples_per_second': 11.117, 'train_steps_per_second': 2.779, 'total_flos': 5.71011090290622e+18, 'train_loss': 0.21907731049149123, 'epoch': 30.0})

In [113]:
trainer.save_model("/content/gdrive/MyDrive/project7_1_kinya/wav2vec2-base-kinya-final")
processor.save_pretrained("/content/gdrive/MyDrive/project7_1_kinya/processor/wav2vec2-base-kinya-final")

[]

In [114]:
processor = Wav2Vec2Processor.from_pretrained("/content/gdrive/MyDrive/project7_1_kinya/processor/wav2vec2-base-kinya-final")

In [115]:
model = Wav2Vec2ForCTC.from_pretrained("/content/gdrive/MyDrive/project7_1_kinya/wav2vec2-base-kinya-final")

In [116]:
def map_to_result(batch):
  with torch.no_grad():
    input_values = torch.tensor(batch["input_values"]).unsqueeze(0)
    logits = model(input_values).logits

  pred_ids = torch.argmax(logits, dim=-1)
  batch["pred_str"] = processor.batch_decode(pred_ids)[0]
  batch["text"] = processor.decode(batch["labels"], group_tokens=False)

  return batch

results = kinya_prepared["test"].map(map_to_result, remove_columns=kinya_prepared["test"].column_names)




Map:   0%|          | 0/392 [00:00<?, ? examples/s]

In [117]:
print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))


Test WER: 0.241


In [118]:
show_random_elements(results)

Unnamed: 0,pred_str,text
0,agakomeza avuga ko ubwo yari atashye avuye kubwiriza yasanze wa mugabo yinjiye urugo rwe,agakomeza avuga ko ubwo yari atashye avuye kubwiriza yasanze wa mugabo yinjiye urugo rwe
1,ambasaderi yavuze ko yitabira iyo myigaragambyo mu rwego rwo kugaragaza ibirimo bibera muri icyo kigo,ambasaderi yavuze ko yitabiriye iyo myigaragambyo mu rwego rwo kugaragaza ibirimo bibera muri icyo gihugu
2,toma ahompe e rukira ku mwumva acuranga nabonye ari umuhanga mu gucuranga inanga,tomu aho mperukira kumwumva acuranga nabonye ari umuhanga mu gucuranga inanga
3,banakomoje kurindi tsinda ry'abanyarwanda batawe muri yombi mu minsi ishize,banakomoje ku rindi tsinda ry'abanyarwanda batawe muri yombi mu minsi ishize
4,iyo nzu yaguzwe niyo banki ntabwo yigeze igurwa n'undi muntu uwo ari we wese,iyo nzu yaguzwe n'iyo banki ntabwo yigeze igurwa n'undi muntu uwo ari we wese
5,amatora yatumye abadashaka manda ya gatatu ya perezida nkurunziza bigaragambya abandi barahunga,amatora yatumye abadashaka manda ya gatatu ya perezida nkurunziza bigaragambya abandi barahunga
6,bakaba bafite intego yo kujya ku masoko hirya no hino mu gihugu,bakaba bafite intego yo kujya ku masoko hirya no hino mu gihugu
7,bwa mbere i serukiramuco rya sinema rya kane mu bufaransa ryaratangijwe,bwa mbere iserukiramuco rya sinema rya kane mu bufaransa ryaratangijwe
8,izindi mpamvu zagaragajwe na njyanama harimo kudakorera hamwe kw'abayobozi no kuba buri wese atiha agahunda,izindi mpamvu zagaragajwe na njyanama harimo kudakorera hamwe kw'abayobozi no kuba buri wese atiha gahunda
9,polisi yagereranyije ububwicanyi n'iyicarubozo ryahitanye benshi kuva intambara yatangira mu buye apani,polisi yagereranyije ubu bwicanyi n'iyicarubozo ryahitanye benshi kuva intambara yatangira mu buyapani
