In [1]:
from f5_tts.model import CFM, DurationPredictor, DiT
from cached_path import cached_path
from f5_tts.model.utils import get_tokenizer
import librosa
from IPython.display import Audio, display
import torch
import torchaudio
from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef, target_rms
import phonemizer

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.514 seconds.
Prefix dict has been built successfully.


Word segmentation module jieba initialized.



In [2]:
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos"

In [3]:
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
tokenizer = "char"
# tokenizer_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt"))
tokenizer_path = "/workspace/F5-TTS/data/LJSpeech_2k"
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
mel_spec_kwargs = dict(
    n_fft=n_fft,
    hop_length=hop_length,
    win_length=win_length,
    n_mel_channels=n_mel_channels,
    target_sample_rate=target_sample_rate,
    mel_spec_type=mel_spec_type,
)

In [4]:
model = CFM(
    transformer=DiT(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
    mel_spec_kwargs=mel_spec_kwargs,
    vocab_char_map=vocab_char_map,
)
duration_predictor = DurationPredictor(vocab_size, 512, 32, 3, 0.5)
setattr(model, 'duration_predictor', duration_predictor)
# model.load_state_dict(torch.load('ckpts/2_with_duration_LJSpeech/model_210000.pt', map_location='cpu')['model_state_dict'])
model.load_state_dict(torch.load('/workspace/F5-TTS/ckpts/F5TTS_Base_LJSpeech_2k/model_50000.pt', map_location='cpu')['model_state_dict'])

<All keys matched successfully>

In [5]:
# model.load_state_dict(torch.load('/workspace/model_80000_slim.pt', map_location='cpu')['model_state_dict'], strict=False)
# model.load_state_dict(torch.load('ckpts/LJSpeech/model_80000.pt', map_location='cpu')['model_state_dict'], strict=False)

In [6]:
model = model.to('cuda:0').eval()
# _ = duration_predictor.to('cuda:0').eval()

In [7]:
vocoder_name = mel_spec_type
vocoder = load_vocoder(
    vocoder_name=vocoder_name, is_local=False, local_path=""
)
_ = vocoder.to('cuda:0').eval()

Download Vocos from huggingface charactr/vocos-mel-24khz


In [8]:
from phonemizer.backend import EspeakBackend
espeak_backend = EspeakBackend(language="en-us", with_stress=True, preserve_punctuation=True)


In [9]:
ref_audio = 'data/LJSpeech-1.1/wavs/LJ050-0273.wav'
ref_text = "The Commission has, however, from its examination of the facts of President Kennedy's assassination"
print (ref_text)
audio, sr = librosa.load(ref_audio)
display(Audio(audio, rate=sr))

audio = torch.Tensor(audio).unsqueeze(0)
if audio.shape[0] > 1:
    audio = torch.mean(audio, dim=0, keepdim=True)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
    audio = audio * target_rms / rms
if sr != target_sample_rate:
    resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
    audio = resampler(audio)
audio = audio.to('cuda:0')
ref_audio_len = audio.shape[-1] // hop_length

The Commission has, however, from its examination of the facts of President Kennedy's assassination


In [10]:
from f5_tts.train.datasets.utils_alignment import load_alignment_model, generate_word_timestamps, word_to_character_alignment, create_attention_matrix, convert_word_timestamps_to_phonemes

alignment_model, alignment_tokenizer = load_alignment_model('cuda:0', dtype=torch.float16)
word_timestamps = generate_word_timestamps(ref_audio, ref_text, alignment_model, alignment_tokenizer, batch_size=16, language='en')
word_timestamps, phonemized_text = convert_word_timestamps_to_phonemes(word_timestamps)
char_alignments = word_to_character_alignment(word_timestamps, phonemized_text)
ref_len_mel = librosa.get_duration(filename=ref_audio) * target_sample_rate // hop_length
attention_matrix = create_attention_matrix(char_alignments, target_sample_rate, hop_length, ref_len_mel)
ref_attn = torch.tensor(attention_matrix).to('cuda:0').unsqueeze(0)
ref_text = phonemized_text

  return F.conv1d(input, weight, bias, self.stride,
	This alias will be removed in version 1.0.
  ref_len_mel = librosa.get_duration(filename=ref_audio) * target_sample_rate // hop_length


In [74]:
from f5_tts.model.utils import list_str_to_idx
gen_text = "Hello, this is a test of ef five tee tee es model. The ey eye said, 'Data is key, data is key, data is the key, the data is key, the data is the key to everything.'"
gen_text = "Hello, this is a test of the F5-TTS model. The AI said, 'Data is key, data is key, data is the key, the data is key, the data is the key to everything.'"
gen_text = "I saw a saw that could out-saw any saw I ever saw. If you ever saw a saw that could out-saw the saw I saw, you’d say it’s the best saw you ever saw."
gen_text = "The repetitive rhythm revealed: 'Tick-tock, tick-tock, tick-tick-tick-tock, tick-tock-tock-tick, tock-tick-tock-tock-tick.'"
gen_text = "Data-driven AI systems said, 'Key data is the key, data is key, and the key to the data is key; the data key is the key to the data that is key to the key'. Can you keep up?"
gen_text = espeak_backend.phonemize([gen_text])[0]
print(gen_text)

# gen_text = "leaving everything to the widow, and constituting William sole executor."
# gen_text = "this is everything"
infer_text = ref_text + ". " + gen_text
text_tokens = list_str_to_idx([infer_text], vocab_char_map).to('cuda:0')

with torch.inference_mode():
    durations = torch.exp(model.duration_predictor(text_tokens, torch.ones_like(text_tokens)).squeeze(1))
LENGTH_SCALE = 1.45
durations = torch.clamp(durations, min=1) * LENGTH_SCALE
gen_durations = durations[..., -len(gen_text)-2:]
gen_durations = torch.round(gen_durations).long()

dˈeɪɾədɹˈɪvən ˌeɪˈaɪ sˈɪstəmz sˈɛd, kˈiː dˈeɪɾə ɪz ðə kˈiː, dˈeɪɾə ɪz kˈiː, ænd ðə kˈiː tə ðə dˈeɪɾə ɪz kˈiː; ðə dˈeɪɾə kˈiː ɪz ðə kˈiː tə ðə dˈeɪɾə ðæt ɪz kˈiː tə ðə kˈiː. kæn juː kˈiːp ˈʌp? 


In [75]:
gen_attn = torch.zeros((1, len(gen_text)+2, torch.sum(gen_durations).item()), dtype=torch.int32)
start = 0
for i in range(len(gen_text)+2):
    duration = gen_durations[0, i]
    gen_attn[0, i, start:start + duration] = 1
    # print(i, start, start + duration)
    start += duration

In [76]:
len(ref_text), len(gen_text)

(114, 192)

In [77]:
ref_attn.shape, gen_attn.shape

(torch.Size([1, 114, 511]), torch.Size([1, 194, 1044]))

In [78]:
# gen_attn[:,-10:,-50:]

In [79]:
val_attn = torch.zeros([1, ref_attn.shape[1] + gen_attn.shape[1], ref_attn.shape[2] + gen_attn.shape[2]]).to(device=ref_attn.device)
val_attn[:, :ref_attn.shape[1], :ref_attn.shape[2]] = ref_attn
val_attn[:, -gen_attn.shape[1]:, -gen_attn.shape[2]:] = gen_attn

In [80]:
val_attn.shape

torch.Size([1, 308, 1555])

In [81]:
with torch.inference_mode():
    generated, _ = model.sample(
        cond=audio,
        text=[infer_text],
        duration=int(ref_len_mel + torch.sum(gen_durations).item()),
        steps=nfe_step,
        cfg_strength=cfg_strength,
        sway_sampling_coef=sway_sampling_coef,
        attn=val_attn,
    )
    generated = generated.to(torch.float32)
    gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to('cuda:0')
    # ref_mel_spec = batch["mel"][0].unsqueeze(0)
    if vocoder_name == "vocos":
        gen_audio = vocoder.decode(gen_mel_spec).cpu()
        # ref_audio = vocoder.decode(ref_mel_spec).cpu()
    elif vocoder_name == "bigvgan":
        gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
        # ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()

In [82]:
display(Audio(gen_audio, rate=target_sample_rate))

In [18]:
740*256/24_000

7.8933333333333335