In [1]:
!pip install fairseq -q
!pip install g2p_en -q

[33mDEPRECATION: omegaconf 2.0.6 has a non-standard dependency specifier PyYAML>=5.1.*. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of omegaconf or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m[33mDEPRECATION: omegaconf 2.0.6 has a non-standard dependency specifier PyYAML>=5.1.*. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of omegaconf or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m

In [2]:
import torch
import numpy as np

from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface


class TTSModel(torch.nn.Module):
    def __init__(self, dur_factors, p, device):
        """
        phonemes2change desc:
        {
            "a": [4, 31, 18],
            "e": [29, 17],
            "u": [40, 34, 45],
            "i": [33, 15],
            "o": [36, 32],
            ",": [11]
        }
        """
        super().__init__()
        models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
            "facebook/fastspeech2-en-ljspeech",
            arg_overrides={"vocoder": "hifigan", "fp16": False}
        )
        self.model = models[0]
        self.task = task
        
        self.phonemes2change = [4, 31, 18, 29, 17, 40, 34, 45, 33, 15, 36, 32]
        self.dur_factor = dur_factors
        self.p = p
        self.device = device
        
        TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg)
        self.generator = self.task.build_generator(models, cfg) 
        
    def get_sample(self, text):
        return TTSHubInterface.get_model_input(self.task, text)
    
    def get_durations(self, sample):
        for key in sample["net_input"].keys():
            if sample["net_input"][key] is not None:
                sample["net_input"][key] = sample["net_input"][key].to(self.device)
        # encoder.forward args: src_tokens, src_lengths=None, speaker=None, durations=None, pitches=None, energies=None,
        x, x_post, out_lens, log_dur_out, pitch_out, energy_out = self.model.encoder(**sample["net_input"])
        return torch.exp(log_dur_out).cpu()
    
    def pauses_and_phonemes(self, text, 
                            p_add_pause=0.5, 
                            p_change_phoneme_dur=0.5,  
                            p_change_pause_dur=0.5):
        phoneme_change = {}
        phoneme_factors = None
        pause_factor = None
        
        if np.random.uniform() <= p_add_pause:  
            text = self._add_pauses(text)
            
        with torch.no_grad():
            sample = self.get_sample(text)
            durs = self.get_durations(sample)
        
        if np.random.uniform() <= p_change_phoneme_dur:
            phonemes_to_change = np.random.choice(self.phonemes2change, size=3, replace=False)
            for phoneme in phonemes_to_change:
                phoneme_change[int(phoneme)] = self._change_phoneme_dur(sample, durs, phoneme)
                
        if np.random.uniform() <= p_change_pause_dur:
            phoneme_change[11] = self._change_phoneme_dur(sample, durs, 11)
        
        sample["net_input"]["durations"] = durs.long()
        
        return sample, text, phoneme_change
    
    def _add_pauses(self, text, p=0.3):
        chunks = text.split(' ')
        n = len(chunks)
        inds = np.random.choice(range(n - 1), size=int(p * n), replace=False)
        
        for ind in inds:
            chunks[ind] += ","
            
        new_text = " ".join(chunks)
        new_text = new_text.replace(',,', ',')   
        return new_text
    
    def _change_phoneme_dur(self, sample, durs, phoneme_ind):
        mask = sample["net_input"]["src_tokens"] == phoneme_ind
        mask = mask.cpu()
        factors = np.random.choice(self.dur_factor, size=mask.sum().item(), replace=True, p=self.p).reshape(1, -1)
        durs[mask] *= factors
        return factors.tolist()
        
    def get_wav(self, sample):
        bsz, max_src_len = sample["net_input"]["src_tokens"].size()
        n_frames_per_step = self.model.encoder.n_frames_per_step
        out_dim = self.model.encoder.out_dim
        raw_dim = out_dim // n_frames_per_step
        
        feat, x_post, out_lens, log_dur_out, pitch_out, energy_out = self.model.encoder(**sample["net_input"])

        feat = feat.view(bsz, -1, raw_dim)
        feat = self.generator.gcmvn_denormalize(feat)

        out_lens = out_lens * n_frames_per_step
        finalized = [
            {
                "waveform": self.generator.get_waveform(feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]))
            }
            for b, l in zip(range(bsz), out_lens)
        ]

        return finalized[0]["waveform"], self.task.sr
    
    def full_tts(self, text):
        sample = TTSHubInterface.get_model_input(self.task, text)
        wav, rate = TTSHubInterface.get_prediction(self.task, self.model, self.generator, sample)
        return wav, rate
    

2024-06-13 14:47:24.775262: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-13 14:47:24.775374: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-13 14:47:24.899454: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
durations_factor = np.array([1., 2., 2.5, 3., 0.05, 0.1, 0.3]).astype('float32')
p = np.array([0.2, 0.2, 0.1, 0.1, 0.2, 0.1, 0.1]).astype('float32')

device = "cuda" if torch.cuda.is_available() else "cpu"

tts = TTSModel(durations_factor, p, device).to(device)

Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

hifigan.bin:   0%|          | 0.00/55.8M [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.22k [00:00<?, ?B/s]

run_fast_speech_2.py:   0%|          | 0.00/306 [00:00<?, ?B/s]

fbank_mfa_gcmvn_stats.npz:   0%|          | 0.00/1.14k [00:00<?, ?B/s]

pytorch_model.pt:   0%|          | 0.00/495M [00:00<?, ?B/s]

hifigan.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

config.yaml:   0%|          | 0.00/612 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/2.13k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/602 [00:00<?, ?B/s]



In [4]:
!wget "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-valid.txt?download=true" -O val_raw.txt

--2024-06-13 14:47:59--  https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-valid.txt?download=true
Resolving huggingface.co (huggingface.co)... 18.239.50.16, 18.239.50.80, 18.239.50.103, ...
Connecting to huggingface.co (huggingface.co)|18.239.50.16|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/42/7f/427f7497b6c6596c18b46d5a72e61364fcad12aa433c60a0dbd4d344477b9d81/94e431816c4cce81ff71e4408ff8d3bda9a42e8d2663986697c3954288cb38b4?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27TinyStories-valid.txt%3B+filename%3D%22TinyStories-valid.txt%22%3B&response-content-type=text%2Fplain&Expires=1718549279&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxODU0OTI3OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy80Mi83Zi80MjdmNzQ5N2I2YzY1OTZjMThiNDZkNWE3MmU2MTM2NGZjYWQxMmFhNDMzYzYwYTBkYmQ0ZDM0NDQ3N2I5ZDgxLzk0ZTQzMTgxNm

In [5]:
with open("val_raw.txt", 'r', encoding="utf-8") as f:
    texts = f.readlines()

In [6]:
# import wandb
# from kaggle_secrets import UserSecretsClient


# secret_label = "wandb_key"
# secret_value = UserSecretsClient().get_secret(secret_label)
# wandb.login(key=secret_value) 
# wandb.init(project="fastspeech_audio")

In [7]:
!mkdir audio 
!mkdir audio/clean
!mkdir audio/aug

In [8]:
import librosa
import torchaudio
from torchaudio import save
import torch.nn.functional as F


class Augmentations:
    def __init__(self, tts_model):
        self.tts = tts_model
        self.aug_num = 0
        self.letter_pairs = {
            "b": "p",
            "p": "b",
            "v": "f",
            "f": "v",
            "t": "d",
            "d": "t",
            "k": "g",
            "g": "k"
        }
        self.path = "/kaggle/input/room-response/RIRS_NOISES/real_rirs_isotropic_noises/RVB2014_type1_rir_mediumroom2_near_anglb.wav"
        
    def process_text(self, text, ind):
        modified_text = t.replace('.', ',')
        
        sample = self.tts.get_sample(modified_text)
        old_durs = self.tts.get_durations(sample).long().detach().tolist()
        
        wav, sr = self.process_sample(sample, ind, part="clean", caption=text)
        
        aug_durations = old_durs
        phoneme_change = {}
        pitch = 0
        new_text = modified_text
        
        # tts aug
        if self.aug_num == 0:
            aug_type = "dur"
            sample, new_text, phoneme_change = self.tts.pauses_and_phonemes(modified_text)
            _, _ = self.process_sample(sample, ind, part="aug", caption=aug_type)
            aug_durations = sample["net_input"]["durations"].detach().cpu().tolist()
        elif self.aug_num == 1:
        # pitch_aug
            aug_type = "pitch"
            aug_wav, pitch = self.pitch_aug(wav, sr)
            self.log_wav(aug_wav, sr, ind, part="aug", caption=f"{aug_type}={pitch}")
        elif self.aug_num == 2:
        # rever_aug
            aug_type = "room_resp"
            aug_wav = self.room_response_aug(wav)
            self.log_wav(aug_wav, sr, ind, part="aug", caption=aug_type)
        elif self.aug_num == 3:
            aug_type = "swap"
            new_text = self.letter_swap(modified_text)
            sample = self.tts.get_sample(new_text)
            aug_durations = self.tts.get_durations(sample).long().detach().tolist()
            _, _ = self.process_sample(sample, ind, part="aug", caption=aug_type)

        self.aug_num = (self.aug_num + 1) % 4
        
        result = {
            "aug_type": aug_type,
            "text": text,
            "modified_text": new_text,
            "clean path": f"clean/{ind}.wav",
            "aug path": f"aug/{ind}.wav",
            "clean durations": old_durs,
            "aug durations": aug_durations,
            "phoneme changes": phoneme_change
        }
        
        return result
        
    def process_sample(self, sample, ind, part="clean", caption="some text"):
        for key in sample["net_input"].keys():
            if sample["net_input"][key] is not None:
                sample["net_input"][key] = sample["net_input"][key].to(device)
        wav, sr = self.tts.get_wav(sample)
        self.log_wav(wav, sr, ind, part, caption)
        return wav, sr
    
    def log_wav(self, wav, sr, ind, part="clean", caption="some text"):
        wav = wav.cpu()
#         wandb.log({f"{part} audio": wandb.Audio(wav.numpy(), caption=caption, sample_rate=sr)})
        save(f"audio/{part}/{ind}.wav", wav.unsqueeze(0), sr)
        
    def pitch_aug(self, wav, sr, n_steps=None):
        if n_steps is None:
            n_steps = np.random.choice([-3, -2, -1, 1, 2, 3])
            
        aug = librosa.effects.pitch_shift(wav.cpu().numpy().squeeze(), sr=sr, n_steps=n_steps)
        aug = torch.from_numpy(aug).to(wav.device)
        return aug, n_steps
    
    def room_response_aug(self, audio):
        rir, rir_sr = torchaudio.load(self.path)
        rir = rir[0].to(audio.device)
        
        left_pad = right_pad = rir.shape[-1] - 1
    
        # Since torch.conv do cross-correlation (not convolution) we need to flip kernel
        flipped_rir = rir.squeeze().flip(0)

        audio = F.pad(audio, [left_pad, right_pad]).view(1, 1, -1)
        convolved_audio = torch.conv1d(audio, flipped_rir.view(1, 1, -1)) \
            .squeeze()

        # peak normalization
        if convolved_audio.abs().max() > 1:
            convolved_audio /= convolved_audio.abs().max()

        return convolved_audio
    
    def letter_swap(self, text):
        letters = np.random.choice(list(self.letter_pairs.keys()), 4, replace=False)
        for letter in letters:
            text2letters = np.array(list(text))

            subset = text2letters[text2letters == letter]
            probs = np.random.rand(len(subset))
            subset[probs >= 0.3] = self.letter_pairs[letter]

            text2letters[text2letters == letter] = subset
            text = ''.join(text2letters)
        return text
        
        
        

In [9]:
import numpy as np
from tqdm import tqdm
from torchaudio import save


results = []

min_len = 100
max_len = 160

limit = 12000

augs = Augmentations(tts)
# wandb.init(project="fastspeech_audio", name="all augs")

for i, t in tqdm(enumerate(texts)):
    if len(results) == limit:
        break
    t = t.strip()
    if len(t) >= min_len:
        if len(t) > max_len:
            t = t[:max_len]
            t = t[:t.rfind(' ')]
        result = augs.process_text(t, i)       
        results.append(result)
    

0it [00:00, ?it/s]

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /usr/share/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package cmudict to /usr/share/nltk_data...
[nltk_data]   Package cmudict is already up-to-date!


19609it [6:49:06,  1.25s/it]


In [10]:
import json


train_ind = int(0.9 * len(results))

with open("audio/train.json", "w") as f:
    json.dump(results[:train_ind], f)
    
with open("audio/test.json", "w") as f:
    json.dump(results[train_ind:], f)

In [11]:
!tar cf audio.tar audio