In [1]:
from f5_tts.model import CFM, DurationPredictor, DiT
from cached_path import cached_path
from f5_tts.model.utils import get_tokenizer
import librosa
from IPython.display import Audio, display
import torch
import torchaudio
from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef, target_rms
import phonemizer

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 1.390 seconds.
Prefix dict has been built successfully.


Word segmentation module jieba initialized.



### Setup models

In [2]:
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos"

In [5]:
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
tokenizer = "custom"
tokenizer_path = str(cached_path("hf://sinhprous/F5TTS-stabilized-LJSpeech/vocab.txt"))
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
mel_spec_kwargs = dict(
    n_fft=n_fft,
    hop_length=hop_length,
    win_length=win_length,
    n_mel_channels=n_mel_channels,
    target_sample_rate=target_sample_rate,
    mel_spec_type=mel_spec_type,
)

In [6]:
model = CFM(
    transformer=DiT(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
    mel_spec_kwargs=mel_spec_kwargs,
    vocab_char_map=vocab_char_map,
)
duration_predictor = DurationPredictor(vocab_size, 512, 32, 3, 0.5)
setattr(model, 'duration_predictor', duration_predictor)

model.load_state_dict(torch.load(cached_path("hf://sinhprous/F5TTS-stabilized-LJSpeech/model_130000.pt"), map_location='cpu')['model_state_dict'], strict=False)



model_130000.pt:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

<All keys matched successfully>

In [7]:
model = model.to('cuda:0').eval()

In [8]:
vocoder_name = mel_spec_type
vocoder = load_vocoder(
    vocoder_name=vocoder_name, is_local=False, local_path=""
)
_ = vocoder.to('cuda:0').eval()

Download Vocos from huggingface charactr/vocos-mel-24khz


In [9]:
from phonemizer.backend import EspeakBackend
espeak_backend = EspeakBackend(language="en-us", with_stress=True, preserve_punctuation=True)


### Ref audio

In [10]:
ref_audio = 'data/LJSpeech-1.1/wavs/LJ050-0273.wav'
ref_text = "The Commission has, however, from its examination of the facts of President Kennedy's assassination"
print (ref_text)
audio, sr = librosa.load(ref_audio)
display(Audio(audio, rate=sr))

audio = torch.Tensor(audio).unsqueeze(0)
if audio.shape[0] > 1:
    audio = torch.mean(audio, dim=0, keepdim=True)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
    audio = audio * target_rms / rms
if sr != target_sample_rate:
    resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
    audio = resampler(audio)
audio = audio.to('cuda:0')
ref_audio_len = audio.shape[-1] // hop_length

The Commission has, however, from its examination of the facts of President Kennedy's assassination


In [11]:
from f5_tts.train.datasets.utils_alignment import load_alignment_model, generate_word_timestamps, word_to_character_alignment, create_attention_matrix, convert_word_timestamps_to_phonemes

alignment_model, alignment_tokenizer = load_alignment_model('cuda:0', dtype=torch.float16)
# alignment_model, alignment_tokenizer = load_alignment_model('cpu', dtype=torch.float16)
word_timestamps = generate_word_timestamps(ref_audio, ref_text, alignment_model, alignment_tokenizer, batch_size=16, language='en')
word_timestamps, phonemized_text = convert_word_timestamps_to_phonemes(word_timestamps)
char_alignments = word_to_character_alignment(word_timestamps, phonemized_text)
ref_len_mel = librosa.get_duration(filename=ref_audio) * target_sample_rate // hop_length
attention_matrix = create_attention_matrix(char_alignments, target_sample_rate, hop_length, ref_len_mel)
ref_attn = torch.tensor(attention_matrix).to('cuda:0').unsqueeze(0)
ref_text = phonemized_text

  return F.conv1d(input, weight, bias, self.stride,
	This alias will be removed in version 1.0.
  ref_len_mel = librosa.get_duration(filename=ref_audio) * target_sample_rate // hop_length


In [12]:
del alignment_model

### Generating audio

In [13]:
from f5_tts.model.utils import list_str_to_idx
gen_text = "The repetitive rhythm revealed: 'Tick tock, tick tock, tick tick tick tock, tick tock tock tick, tock tick tock tock tick.'"
gen_text = espeak_backend.phonemize([gen_text])[0]
print(gen_text)

infer_text = ref_text + ". " + gen_text
text_tokens = list_str_to_idx([infer_text], vocab_char_map).to('cuda:0')

with torch.inference_mode():
    durations = torch.exp(model.duration_predictor(text_tokens, torch.ones_like(text_tokens)).squeeze(1))
LENGTH_SCALE = 1.2
durations = torch.clamp(durations, min=1) * LENGTH_SCALE
gen_durations = durations[..., -len(gen_text)-2:]
gen_durations = torch.round(gen_durations).long()

ðə ɹᵻpˈɛɾᵻtˌɪv ɹˈɪðəm ɹᵻvˈiːld: tˈɪk tˈɑːk, tˈɪk tˈɑːk, tˈɪk tˈɪk tˈɪk tˈɑːk, tˈɪk tˈɑːk tˈɑːk tˈɪk, tˈɑːk tˈɪk tˈɑːk tˈɑːk tˈɪk.


In [14]:
for i in range(len(gen_text)+2):
    if (". " + gen_text)[i] == " ":
        gen_durations[0][i] = int(gen_durations[0][i]*1.3)
    # gen_durations[0][i] = int(4)
gen_attn = torch.zeros((1, len(gen_text)+2, torch.sum(gen_durations).item()), dtype=torch.int32)
start = 0
for i in range(len(gen_text)+2):
    duration = gen_durations[0, i]
    gen_attn[0, i, start:start + duration] = 1
    # print((". " + gen_text)[i], duration)
    start += duration

In [15]:
val_attn = torch.zeros([1, ref_attn.shape[1] + gen_attn.shape[1], ref_attn.shape[2] + gen_attn.shape[2]]).to(device=ref_attn.device)
val_attn[:, :ref_attn.shape[1], :ref_attn.shape[2]] = ref_attn
val_attn[:, -gen_attn.shape[1]:, -gen_attn.shape[2]:] = gen_attn

In [16]:
with torch.inference_mode():
    generated, _ = model.sample(
        cond=audio,
        text=[infer_text],
        duration=int(ref_len_mel + torch.sum(gen_durations).item()),
        steps=int(nfe_step),
        cfg_strength=cfg_strength,
        sway_sampling_coef=sway_sampling_coef,
        attn=val_attn,
    )
    generated = generated.to(torch.float32)
    gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to('cuda:0')
    # ref_mel_spec = batch["mel"][0].unsqueeze(0)
    if vocoder_name == "vocos":
        gen_audio = vocoder.decode(gen_mel_spec).cpu()
        # ref_audio = vocoder.decode(ref_mel_spec).cpu()
    elif vocoder_name == "bigvgan":
        gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
        # ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
    display(Audio(gen_audio, rate=target_sample_rate))