In [1]:
!pip install fairseq -q
!pip install g2p_en -q

[33mDEPRECATION: omegaconf 2.0.6 has a non-standard dependency specifier PyYAML>=5.1.*. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of omegaconf or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m[33mDEPRECATION: omegaconf 2.0.6 has a non-standard dependency specifier PyYAML>=5.1.*. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of omegaconf or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m

In [2]:
import torch
import numpy as np

from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface


class TTSModel(torch.nn.Module):
    def __init__(self, dur_factors, p, device):
        """
        phonemes2change desc:
        {
            "a": [4, 31, 18],
            "e": [29, 17],
            "u": [40, 34, 45],
            "i": [33, 15],
            "o": [36, 32],
            ",": [11]
        }
        """
        super().__init__()
        models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
            "facebook/fastspeech2-en-ljspeech",
            arg_overrides={"vocoder": "hifigan", "fp16": False}
        )
        self.model = models[0]
        self.task = task
        
        self.phonemes2change = [4, 31, 18, 29, 17, 40, 34, 45, 33, 15, 36, 32]
        self.dur_factor = dur_factors
        self.p = p
        self.device = device
        
        TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg)
        self.generator = self.task.build_generator(models, cfg) 
        
    def get_sample(self, text):
        return TTSHubInterface.get_model_input(self.task, text)
    
    def get_durations(self, sample):
        for key in sample["net_input"].keys():
            if sample["net_input"][key] is not None:
                sample["net_input"][key] = sample["net_input"][key].to(self.device)
        # encoder.forward args: src_tokens, src_lengths=None, speaker=None, durations=None, pitches=None, energies=None,
        x, x_post, out_lens, log_dur_out, pitch_out, energy_out = self.model.encoder(**sample["net_input"])
        return torch.exp(log_dur_out).cpu()
    
    def pauses_and_phonemes(self, text, 
                            p_add_pause=0.5, 
                            p_change_phoneme_dur=0.5,  
                            p_change_pause_dur=0.5):
        phoneme_change = {}
        
        if np.random.uniform() <= p_add_pause:  
            text = self._add_pauses(text)
            
        with torch.no_grad():
            sample = self.get_sample(text)
            durs = self.get_durations(sample)
        
        if np.random.uniform() <= p_change_phoneme_dur:
            phonemes_to_change = np.random.choice(self.phonemes2change, size=3, replace=False)
            for phoneme in phonemes_to_change:
                phoneme_change[int(phoneme)] = self._change_phoneme_dur(sample, durs, phoneme)
                
        if np.random.uniform() <= p_change_pause_dur:
            phoneme_change[11] = self._change_phoneme_dur(sample, durs, 11)
        
        sample["net_input"]["durations"] = durs.long()
        
        return sample, text, phoneme_change
    
    def _add_pauses(self, text, p=0.3):
        chunks = text.split(' ')
        n = len(chunks)
        inds = np.random.choice(range(n - 1), size=int(p * n), replace=False)
        
        for ind in inds:
            chunks[ind] += ","
            
        new_text = " ".join(chunks)
        new_text = new_text.replace(',,', ',')   
        return new_text
    
    def _change_phoneme_dur(self, sample, durs, phoneme_ind):
        mask = sample["net_input"]["src_tokens"] == phoneme_ind
        mask = mask.cpu()
        factors = np.random.choice(self.dur_factor, size=mask.sum().item(), replace=True, p=self.p).reshape(1, -1)
        durs[mask] *= factors
        return factors.tolist()
        
    def get_wav(self, sample):
        bsz, max_src_len = sample["net_input"]["src_tokens"].size()
        n_frames_per_step = self.model.encoder.n_frames_per_step
        out_dim = self.model.encoder.out_dim
        raw_dim = out_dim // n_frames_per_step
        
        feat, x_post, out_lens, log_dur_out, pitch_out, energy_out = self.model.encoder(**sample["net_input"])

        feat = feat.view(bsz, -1, raw_dim)
        feat = self.generator.gcmvn_denormalize(feat)

        out_lens = out_lens * n_frames_per_step
        finalized = [
            {
                "waveform": self.generator.get_waveform(feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]))
            }
            for b, l in zip(range(bsz), out_lens)
        ]

        return finalized[0]["waveform"], self.task.sr
    
    def full_tts(self, text):
        sample = TTSHubInterface.get_model_input(self.task, text)
        wav, rate = TTSHubInterface.get_prediction(self.task, self.model, self.generator, sample)
        return wav, rate
    

2024-04-15 21:58:36.704277: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-15 21:58:36.704371: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-15 21:58:36.838193: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
durations_factor = np.array([1., 2., 2.5, 3., 3.5]).astype('float32')
p = np.array([0.5, 0.125, 0.125, 0.125, 0.125]).astype('float32')

device = "cuda" if torch.cuda.is_available() else "cpu"

tts = TTSModel(durations_factor, p, device).to(device)

Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

config.yaml:   0%|          | 0.00/612 [00:00<?, ?B/s]

run_fast_speech_2.py:   0%|          | 0.00/306 [00:00<?, ?B/s]

fbank_mfa_gcmvn_stats.npz:   0%|          | 0.00/1.14k [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.22k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/2.13k [00:00<?, ?B/s]

hifigan.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/602 [00:00<?, ?B/s]

hifigan.bin:   0%|          | 0.00/55.8M [00:00<?, ?B/s]

pytorch_model.pt:   0%|          | 0.00/495M [00:00<?, ?B/s]



In [4]:
!wget "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-valid.txt?download=true" -O val_raw.txt

--2024-04-15 21:59:36--  https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-valid.txt?download=true
Resolving huggingface.co (huggingface.co)... 13.35.7.5, 13.35.7.57, 13.35.7.81, ...
Connecting to huggingface.co (huggingface.co)|13.35.7.5|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/42/7f/427f7497b6c6596c18b46d5a72e61364fcad12aa433c60a0dbd4d344477b9d81/94e431816c4cce81ff71e4408ff8d3bda9a42e8d2663986697c3954288cb38b4?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27TinyStories-valid.txt%3B+filename%3D%22TinyStories-valid.txt%22%3B&response-content-type=text%2Fplain&Expires=1713477576&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxMzQ3NzU3Nn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy80Mi83Zi80MjdmNzQ5N2I2YzY1OTZjMThiNDZkNWE3MmU2MTM2NGZjYWQxMmFhNDMzYzYwYTBkYmQ0ZDM0NDQ3N2I5ZDgxLzk0ZTQzMTgxNmM0Y2NlODFmZjcxZ

In [5]:
with open("val_raw.txt", 'r', encoding="utf-8") as f:
    texts = f.readlines()

In [6]:
# import wandb
# from kaggle_secrets import UserSecretsClient


# secret_label = "wandb_key"
# secret_value = UserSecretsClient().get_secret(secret_label)
# wandb.login(key=secret_value) 
# wandb.init(project="fastspeech_audio")

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mkvdmitrieva[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
!mkdir audio 
!mkdir audio/clean
!mkdir audio/aug

In [8]:
import numpy as np
from tqdm import tqdm
from torchaudio import save


def process_sample(tts, sample, ind, part="clean", caption="some text"):
    for key in sample["net_input"].keys():
        if sample["net_input"][key] is not None:
            sample["net_input"][key] = sample["net_input"][key].to(device)
    wav, sr = tts.get_wav(sample)
    wav = wav.cpu()

#     wandb.log({f"{part} audio": wandb.Audio(wav.numpy(), caption=caption, sample_rate=sr)})
    save(f"audio/{part}/{ind}.wav", wav.unsqueeze(0), sr)


results = []

min_len = 100
max_len = 160

limit = 5000

for i, t in tqdm(enumerate(texts)):
    if len(results) == limit:
        break
    t = t.strip()
    modified_text = t.replace('.', ',') # этот fastspeech не воспринимает пуктуацию, кроме запятых
    if len(t) >= min_len:
        if len(t) > max_len:
            t = t[:max_len]
            t = t[:t.rfind(' ')]
            
        modified_text = t.replace('.', ',')
        sample = tts.get_sample(modified_text)
        old_durs = tts.get_durations(sample).long().detach().tolist()
        
        process_sample(tts, sample, i, part="clean", caption=modified_text)

        sample, new_text, phoneme_change = tts.pauses_and_phonemes(modified_text)
        process_sample(tts, sample, i, part="aug", caption=modified_text)

        result = {
            "text": t,
            "modified_text": new_text,
            "clean path": f"audio/clean/{i}.wav",
            "aug path": f"audio/aug/{i}.wav",
            "clean durations": old_durs,
            "aug durations": sample["net_input"]["durations"].detach().cpu().tolist(),
            "phoneme changes": phoneme_change
        }
        
        results.append(result)
    

0it [00:00, ?it/s]

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /usr/share/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package cmudict to /usr/share/nltk_data...
[nltk_data]   Package cmudict is already up-to-date!


13it [00:34,  2.62s/it]


In [9]:
import json


train_ind = int(0.9 * len(results))

with open("audio/train.json", "w") as f:
    json.dump(results[:train_ind], f)
    
with open("audio/test.json", "w") as f:
    json.dump(results[train_ind:], f)

In [11]:
!tar cf audio.tar audio