In [None]:
!pip install fairseq -q
!pip install g2p_en -q

In [None]:
import torch

from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface


class TTSModel:
    def __init__(self):
        models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
            "facebook/fastspeech2-en-ljspeech",
            arg_overrides={"vocoder": "hifigan", "fp16": False}
        )
        self.model = models[0]
        self.task = task
        
        TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg)
        self.generator = self.task.build_generator(models, cfg) 
        
    def get_sample(self, text):
        return TTSHubInterface.get_model_input(self.task, text)
    
    def get_durations(self, sample):
        # encoder.forward args: src_tokens, src_lengths=None, speaker=None, durations=None, pitches=None, energies=None,
        x, x_post, out_lens, log_dur_out, pitch_out, energy_out = self.model.encoder(**sample["net_input"])
        return torch.exp(log_dur_out)
    
    def simple_change(self, text, dur_factor=1.):
        sample = self.get_sample(text)
        durs = self.get_durations(sample)
        
        durs[sample["net_input"]["src_tokens"] == 11] *= dur_factor # 11 == ','
        sample["net_input"]["durations"] = durs.long()
        
        return sample
    
    def get_wav(self, sample):
        bsz, max_src_len = sample["net_input"]["src_tokens"].size()
        n_frames_per_step = self.model.encoder.n_frames_per_step
        out_dim = self.model.encoder.out_dim
        raw_dim = out_dim // n_frames_per_step
        
        feat, x_post, out_lens, log_dur_out, pitch_out, energy_out = self.model.encoder(**sample["net_input"])

        feat = feat.view(bsz, -1, raw_dim)
        feat = self.generator.gcmvn_denormalize(feat)

        out_lens = out_lens * n_frames_per_step
        finalized = [
            {
                "waveform": self.generator.get_waveform(feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]))
            }
            for b, l in zip(range(bsz), out_lens)
        ]

        return finalized[0]["waveform"], self.task.sr
    
    def full_tts(self, text):
        sample = TTSHubInterface.get_model_input(self.task, text)
        wav, rate = TTSHubInterface.get_prediction(self.task, self.model, self.generator, sample)
        return wav, rate
    

In [None]:
tts = TTSModel()

In [None]:
!wget "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-valid.txt?download=true" -O val_raw.txt

In [None]:
with open("val_raw.txt", 'r', encoding="utf-8") as f:
    texts = f.readlines()

In [None]:
import IPython.display as ipd

print(texts[0])
t = texts[0].replace('.', ',')
wav, sr = tts.get_wav(tts.simple_change(t, dur_factor=1.))
ipd.Audio(wav, rate=sr)

In [None]:
import wandb
from kaggle_secrets import UserSecretsClient


secret_label = "wandb_key"
secret_value = UserSecretsClient().get_secret(secret_label)
wandb.login(key=secret_value) 
wandb.init(project="fastspeech_audio")

In [None]:
!mkdir audio
!mkdir audio/clean
!mkdir audio/aug

In [None]:
import numpy as np
from tqdm import tqdm
from torchaudio import save


durations_factor = np.linspace(2, 5, 13)
limit = 10

for i, t in tqdm(enumerate(texts[:limit])):
    modified_text = t.replace('.', ',') # этот fastspeech не воспринимает пуктуацию, кроме запятых
    if ',' in modified_text:
        wav, sr = tts.get_wav(tts.simple_change(modified_text, dur_factor=1.))
        wandb.log({"test audio": wandb.Audio(wav.numpy(), caption=modified_text, sample_rate=sr)})
        save(f"audio/clean/{i}.wav", wav.unsqueeze(0), sr)
        
        dur_factor = np.random.choice(durations_factor)
        
        wav, sr = tts.get_wav(tts.simple_change(modified_text, dur_factor=dur_factor))
        wandb.log({"aug audio": wandb.Audio(wav.numpy(), caption=modified_text, sample_rate=sr)})
        save(f"audio/aug/{i}.wav", wav.unsqueeze(0), sr)
    

In [None]:
wav.shape

In [None]:
!tar cf audio.tar audio