<a href="https://colab.research.google.com/github/YeralyK/TTS_Learning/blob/main/RussianTTS(SpeechT5_FineTuning).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers datasets soundfile accelerate speechbrain==0.5.16

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
from datasets import load_dataset, Audio
#https://huggingface.co/datasets/0x7o/klara-voice
dataset = load_dataset("0x7o/klara-voice", split="train")
dataset

In [None]:
len(dataset)

In [None]:
print(dataset)

In [None]:
half_size = len(dataset) // 3

dataset = dataset.select(range(half_size))

print(dataset)

In [None]:
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

In [None]:
from transformers import SpeechT5Processor

checkpoint = "microsoft/speecht5_tts"
processor = SpeechT5Processor.from_pretrained(checkpoint)

In [None]:
tokenizer = processor.tokenizer

In [None]:
dataset[2:5]

In [None]:
def extract_all_chars(batch):
    all_text = " ".join(batch["text"])
    vocab = list(set(all_text))
    return {"audio": [vocab], "all_text": [all_text]}


vocabs = dataset.map(
    extract_all_chars,
    batched=True,
    batch_size=-1,
    keep_in_memory=True,
    remove_columns=dataset.column_names,
)

dataset_vocab = set(vocabs["audio"][0])
tokenizer_vocab = {k for k, _ in tokenizer.get_vocab().items()}

In [None]:
dataset_vocab - tokenizer_vocab

In [None]:
import re

def normalize_text(text):
    text = text.lower()

    text = re.sub(r'[^\w\s\']', '', text)

    text = ' '.join(text.split())

    return text

def add_normalized_text(example):
    example['normalized_text'] = normalize_text(example['text'])
    return example

dataset = dataset.map(add_normalized_text)

print(dataset[2:5])

In [None]:
def extract_all_chars(batch):
    all_text = " ".join(batch["normalized_text"])
    vocab = list(set(all_text))
    return {"audio": [vocab], "all_text": [all_text]}


vocabs = dataset.map(
    extract_all_chars,
    batched=True,
    batch_size=-1,
    keep_in_memory=True,
    remove_columns=dataset.column_names,
)

dataset_vocab = set(vocabs["audio"][0])
tokenizer_vocab = {k for k, _ in tokenizer.get_vocab().items()}

In [None]:
dataset_vocab - tokenizer_vocab

In [None]:
replacements = [
    ("а", "a"),
    ("б", "b"),
    ("в", "v"),
    ("г", "g"),
    ("д", "d"),
    ("е", "e"),
    ("ё", "yo"),
    ("ж", "zh"),
    ("з", "z"),
    ("и", "i"),
    ("й", "y"),
    ("к", "k"),
    ("л", "l"),
    ("м", "m"),
    ("н", "n"),
    ("о", "o"),
    ("п", "p"),
    ("р", "r"),
    ("с", "s"),
    ("т", "t"),
    ("у", "u"),
    ("ф", "f"),
    ("х", "kh"),
    ("ц", "ts"),
    ("ч", "ch"),
    ("ш", "sh"),
    ("щ", "shch"),
    ("ъ", ""),      # hard sign removed
    ("ы", "y"),
    ("ь", ""),      # soft sign removed
    ("э", "e"),
    ("ю", "yu"),
    ("я", "ya"),
]
def cleanup_text(inputs):
    text = inputs["normalized_text"]
    for src, dst in replacements:
        text = text.replace(src, dst)
    inputs["normalized_text"] = text
    return inputs

dataset = dataset.map(cleanup_text)

In [None]:
import os
import torch

import huggingface_hub

if not hasattr(huggingface_hub, "_orig_hf_hub_download"):
    huggingface_hub._orig_hf_hub_download = huggingface_hub.hf_hub_download

def _hf_hub_download_compat(repo_id, filename, *args, **kwargs):
    if "use_auth_token" in kwargs and "token" not in kwargs:
        kwargs["token"] = kwargs.pop("use_auth_token")
    else:
        kwargs.pop("use_auth_token", None)

    if repo_id == "speechbrain/spkrec-xvect-voxceleb" and filename == "custom.py":
        stub_dir = os.path.join("/tmp", "speechbrain_hf_stubs", "speechbrain_spkrec_xvect")
        os.makedirs(stub_dir, exist_ok=True)
        stub_path = os.path.join(stub_dir, "custom.py")
        if not os.path.exists(stub_path):
            with open(stub_path, "w", encoding="utf-8") as f:
                f.write("# Stub file auto-created to satisfy SpeechBrain downloader.\n")
        return stub_path

    return huggingface_hub._orig_hf_hub_download(repo_id, filename, *args, **kwargs)

huggingface_hub.hf_hub_download = _hf_hub_download_compat

from speechbrain.pretrained import EncoderClassifier

spk_model_name = "speechbrain/spkrec-xvect-voxceleb"
device = "cuda" if torch.cuda.is_available() else "cpu"

speaker_model = EncoderClassifier.from_hparams(
    source=spk_model_name,
    run_opts={"device": device},
    savedir=os.path.join("/tmp", "speechbrain_spkrec_xvect"),
)

def create_speaker_embedding(waveform):
    with torch.no_grad():
        wav = waveform if torch.is_tensor(waveform) else torch.tensor(waveform)
        wav = wav.float()

        # expected shape: [batch, time]
        if wav.ndim == 1:
            wav = wav.unsqueeze(0)
        elif wav.ndim != 2:
            raise ValueError(f"Expected 1D or 2D waveform, got shape {tuple(wav.shape)}")

        wav = wav.to(device)

        emb = speaker_model.encode_batch(wav)                 # usually [B, 1, D]
        emb = torch.nn.functional.normalize(emb, dim=2)
        return emb.squeeze().cpu().numpy()


In [None]:
def prepare_dataset(example):
    audio = example["audio"]

    example = processor(
        text=example["normalized_text"],
        audio_target=audio["array"],
        sampling_rate=audio["sampling_rate"],
        return_attention_mask=False,
    )

    example["labels"] = example["labels"][0]

    example["speaker_embeddings"] = create_speaker_embedding(audio["array"])

    return example

In [None]:
processed_example = prepare_dataset(dataset[0])
list(processed_example.keys())

In [None]:
processed_example["speaker_embeddings"].shape

In [None]:
dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names)

In [None]:
def is_not_too_long(input_ids):
    input_length = len(input_ids)
    return input_length < 200

dataset = dataset.filter(is_not_too_long, input_columns=["input_ids"])
len(dataset)

In [None]:
dataset = dataset.train_test_split(test_size=0.1)

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union


@dataclass
class TTSDataCollatorWithPadding:
    processor: Any

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
        label_features = [{"input_values": feature["labels"]} for feature in features]
        speaker_features = [feature["speaker_embeddings"] for feature in features]

        batch = processor.pad(
            input_ids=input_ids, labels=label_features, return_tensors="pt"
        )

        batch["labels"] = batch["labels"].masked_fill(
            batch.decoder_attention_mask.unsqueeze(-1).ne(1), -100
        )

        del batch["decoder_attention_mask"]

        if model.config.reduction_factor > 1:
            target_lengths = torch.tensor(
                [len(feature["input_values"]) for feature in label_features]
            )
            target_lengths = target_lengths.new(
                [
                    length - length % model.config.reduction_factor
                    for length in target_lengths
                ]
            )
            max_length = max(target_lengths)
            batch["labels"] = batch["labels"][:, :max_length]

        batch["speaker_embeddings"] = torch.tensor(speaker_features)

        return batch

In [None]:
data_collator = TTSDataCollatorWithPadding(processor=processor)

In [None]:
from transformers import SpeechT5ForTextToSpeech

model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)

In [None]:
from functools import partial

model.config.use_cache = False

model.generate = partial(model.generate, use_cache=True)

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="speecht5_finetuned_russian_speech",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=1e-4,
    warmup_steps=100,
    max_steps=500,
    gradient_checkpointing=True,
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=2,
    save_steps=100,
    eval_steps=100,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    greater_is_better=False,
    label_names=["labels"],
    push_to_hub=True,
)

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=data_collator,
    processing_class=processor,
)


In [None]:
trainer.train()

In [None]:
trainer.push_to_hub()

In [None]:
model = SpeechT5ForTextToSpeech.from_pretrained(
    "speecht5_finetuned_russian_speech"
)

In [None]:
example = dataset["test"][304]
speaker_embeddings = torch.tensor(example["speaker_embeddings"]).unsqueeze(0)

In [None]:
text = "Привет, меня зовут Ералы, я студент университета"

In [None]:
import re

RU_DIGITS = {
    "0": "ноль", "1": "один", "2": "два", "3": "три", "4": "четыре",
    "5": "пять", "6": "шесть", "7": "семь", "8": "восемь", "9": "девять",
}

def replace_numbers_with_words(text: str) -> str:
    # replaces single digits only (simple but safe)
    return re.sub(r"\d", lambda m: RU_DIGITS[m.group(0)], text)


In [None]:
def cleanup_text(text):
    for src, dst in replacements:
        text = text.replace(src, dst)
    return text

In [None]:
converted_text = replace_numbers_with_words(text)
cleaned_text = cleanup_text(converted_text)
final_text = normalize_text(cleaned_text)
final_text

In [None]:
inputs = processor(text=final_text, return_tensors="pt")

In [None]:
from transformers import SpeechT5HifiGan

vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)

In [None]:
from IPython.display import Audio
import soundfile as sf

Audio(speech.numpy(), rate=16000)
sf.write('output.wav', speech.numpy(), 16000)