In [1]:
# Cell 1 - system

import sys
sys.path.append('src')

from __future__ import annotations
import torch

from typing import Callable

import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torchdiffeq import odeint

import os
import re
from importlib.resources import files
from pathlib import Path

import numpy as np
import soundfile as sf
import tomli
from cached_path import cached_path

import cyrtranslit

from f5_tts.infer.utils_infer import (
    infer_single_process,
    load_model,
    load_vocoder,
    preprocess_ref_audio_text,
    preprocess_ref_audio_text_segment,
    remove_silence_for_generated_wav,
)


from f5_tts.model.utils import (
    default,
    exists,
    lens_to_mask,
    list_str_to_idx,
    list_str_to_tensor,
    mask_from_frac_lengths,
)

from f5_tts.model import DiT, UNetT

# Direct variable assignments instead of command line arguments
model = "F5-TTS"
# ckpt_file = "/home/k4/Python/F5-TTS-Fork/ckpts/russian_dataset_ft_translit_pinyin/model_last.pt"
# ckpt_file = "/home/k4/Python/F5-TTS-Fork/ckpts/russian_coarse_16nov/model_40000.pt"
ckpt_file = "/home/k4/Python/F5-TTS-Fork/ckpts/russian_coarse_16nov/model_0230.pt"

# Default configurations
config_path = os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml")
config = tomli.load(open(config_path, "rb"))

# Additional settings with default values
vocab_file = ""
output_dir = config["output_dir"]
remove_silence = config.get("remove_silence", False)
speed = 1.0
vocoder_name = "vocos"
load_vocoder_from_local = False
wave_path = Path(output_dir) / "infer_cli_out.wav"

# Vocoder settings
if vocoder_name == "vocos":
    vocoder_local_path = "../checkpoints/vocos-mel-24khz"
elif vocoder_name == "bigvgan":
    vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
mel_spec_type = vocoder_name

vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=load_vocoder_from_local, local_path=vocoder_local_path)

# Model configuration
if model == "F5-TTS":
    model_cls = DiT
    model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
    # if not ckpt_file:
    #     if vocoder_name == "vocos":
    #         repo_name = "F5-TTS"
    #         exp_name = "F5TTS_Base"
    #         ckpt_step = 1200000
    #         ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
    #     elif vocoder_name == "bigvgan":
    #         repo_name = "F5-TTS"
    #         exp_name = "F5TTS_Base_bigvgan"
    #         ckpt_step = 1250000
    #         ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))


ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)

device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm
2024-11-17 03:24:10,746 - INFO - PyTorch version 2.5.1+cu121 available.


Download Vocos from huggingface charactr/vocos-mel-24khz


  state_dict = torch.load(model_path, map_location="cpu")



vocab :  /home/k4/Python/F5-TTS-Fork/src/f5_tts/infer/examples/vocab.txt
tokenizer :  custom
model :  /home/k4/Python/F5-TTS-Fork/ckpts/russian_coarse_16nov/model_0230.pt 



In [2]:
# Cell 2 - main process

import librosa
import librosa.display
import matplotlib.pyplot as plt

def save_mel_spectrogram(audio_path, save_path, title):
    y, sr = librosa.load(audio_path, sr=None)  # Load audio
    mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=2048, hop_length=512, n_mels=128, fmax=8000)
    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)  # Convert to dB scale

    # Plot and save the mel-spectrogram
    plt.figure(figsize=(10, 4))
    librosa.display.specshow(mel_spec_db, sr=sr, x_axis="time", y_axis="mel", fmax=8000, cmap="magma")
    plt.colorbar(format="%+2.0f dB")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()


def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed, end_time=None,cfg_strength=2.0, nfe_step=32, start_step=0, end_step=32):
    # Single voice configuration
    main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
    ref_audio, ref_text = preprocess_ref_audio_text_segment(main_voice["ref_audio"], main_voice["ref_text"], end_time=end_time, clip_short=False, bypass_cache=True)
    print("Ref_audio:", ref_audio)
    print("Ref_text:", ref_text)

    # Generate audio
    audio, final_sample_rate, _, trajectory = infer_single_process(
        ref_audio, ref_text, text_gen, model_obj, vocoder, mel_spec_type=mel_spec_type, nfe_step=nfe_step, cfg_strength=cfg_strength, speed=speed, start_step=start_step, end_step=end_step
    )

    # Save the generated wave
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    with open(wave_path, "wb") as f:
        sf.write(f.name, audio, final_sample_rate)
        if remove_silence:
            remove_silence_for_generated_wav(f.name)
        print(f"Saved generated audio to {f.name}")

    # Save mel-spectrograms
    ref_mel_path = Path(output_dir) / "ref_mel.png"
    gen_mel_path = Path(output_dir) / "gen_mel.png"

    save_mel_spectrogram(ref_audio, ref_mel_path, "Reference Audio Mel-Spectrogram")
    save_mel_spectrogram(wave_path, gen_mel_path, "Generated Audio Mel-Spectrogram")

    return trajectory


### Настройки генерации

In [3]:
from russian_songs_dataset_utils import process_music_ref_dataset

base_ref_tracks = process_music_ref_dataset("/media/k4_nas/Datasets/Music_RU/Vocal_Dereverb", cache_filename="ref_data_cache.json")
print(f"found {len(base_ref_tracks)} base tracks with lyrics")

Loading cached data from ref_data_cache.json
found 766 base tracks with lyrics


In [4]:
ref_track_name = list(base_ref_tracks.keys())[515]
ref_track = base_ref_tracks[ref_track_name]
print(ref_track_name)
print(ref_track['mir'])
print(ref_track['caption'])
print(ref_track['mir_data']['bpm'])
print(f"how many sections: {len(ref_track['sections'])}")
print(ref_track['sections'][0]['words'])
print(ref_track['sections'][0]['mp3_path'])

03.Encelt - Дикая охота
/media/k4_nas/Datasets/Music_RU/Vocal_Dereverb/Encelt/Истории у костра [2017]/03.Encelt - Дикая охота.mir.json
﻿Женский голос, отчётливый, легенды, фолк, рок, мюзикл, эпичный
140
how many sections: 7
Между небом и землёй в дикой ясоте Слышишь ветер вой на протяжной ноте Мчатся вихрем в ночи,
/media/k4_nas/Datasets/Music_RU/Vocal_Dereverb/Encelt/Истории у костра [2017]/03.Encelt - Дикая охота_vocals_stretched_120bpm_section2.mp3


In [12]:
# # ref_audio = "src/f5_tts/infer/examples/basic/basic_ref_en.wav"
# # ref_audio = "/media/k4_nas/Datasets/Music_RU/Vocal_Dereverb/Tesla Boy/Андропов [2020]/05.Tesla Boy - Ватикан_vocals_stretched_120bpm_section5.mp3"
# # ref_text = "Prosti moi drug, prosti moi drug, ne opravdal, ne opravdal, instinkti vrut, instinkti vrut, ya vidishny, kreatariseksual"
# ref_audio = "/media/k4_nas/Datasets/Music_RU/Vocal_Dereverb/Tesla Boy/Андропов [2020]/05.Tesla Boy - Ватикан_vocals_stretched_120bpm_section5.mp3"
# ref_text = "Prosti moi drug, prosti moi drug, ne opravdal, ne opravdal, instinkti vrut, instinkti vrut, ya vidishny, kreatariseksual"
# # gen_text = "nichego na svete luchshe netu, chem brodit' druz'jam po belu svetu. tem, kto druzhen, ne strashny trevogi. nam ljubye dorogi dorogi"
# # gen_text = "nichego na svete luchshe netu, chem brodit' druz'jam po belu svetu."

def get_text_audio_and_end_time_path_for_ref(ref_id, section_num, base_ref_tracks):
    ref_track = base_ref_tracks[list(base_ref_tracks.keys())[ref_id]]
    words = ref_track['sections'][section_num]['words']
    mp3_path =ref_track['sections'][section_num]['mp3_path']
    return cyrtranslit.to_latin(words, "ru").lower(), mp3_path, ref_track['sections'][section_num]['end_time']

def to_translit(text):
    return cyrtranslit.to_latin(text, "ru").lower()

# ref_text, ref_audio = get_text_audio_and_end_time_path_for_ref(80, 1, base_ref_tracks) # 80 плохой
# ref_text, ref_audio = get_text_audio_and_end_time_path_for_ref(105, 2, base_ref_tracks) # тут нужна скорость 0.1
# ref_text, ref_audio = get_text_audio_and_end_time_path_for_ref(500, 2, base_ref_tracks) # тут нужна скорость 0.3
ref_text, ref_audio, ref_end_time = get_text_audio_and_end_time_path_for_ref(600, 0, base_ref_tracks) # монеточка,  скорость 0.3
# ref_text, ref_audio = get_text_audio_and_end_time_path_for_ref(630, 0, base_ref_tracks) # какой-то женский шансон,  скорость 0.3
# ref_text, ref_audio = get_text_audio_and_end_time_path_for_ref(638, 0, base_ref_tracks) #  скорость 0.6
# ref_text, ref_audio = get_text_audio_and_end_time_path_for_ref(640, 0, base_ref_tracks) # Этот всё выговаривает, скорость 0.5
# ref_text, ref_audio = get_text_audio_and_end_time_path_for_ref(645, 0, base_ref_tracks) #  скорость 0.7
# ref_text, ref_audio = get_text_audio_and_end_time_path_for_ref(648, 0, base_ref_tracks) #  скорость 0.7
# ref_text, ref_audio = get_text_audio_and_end_time_path_for_ref(680, 0, base_ref_tracks) # рэп, скорость 0.9
# ref_text, ref_audio, ref_end_time = get_text_audio_and_end_time_path_for_ref(360, 0, base_ref_tracks) # 

gen_text = to_translit("Эх, полным полна коробочка! Есть в ней ситец и парча. Пожалей, душа-зазнобушка молодецкого плеча")
# gen_text = to_translit("О-о-о-о, зеленоеглазое такси. О-о-о-о, притормози, притормози! О-о-о-о, ты отвези меня туда! О-о-о-о, где будут рады мне всегда! ")
# gen_text = to_translit("Ты сидишь на стуле с голым торсом, я смотрю на пиццу, у меня вопросы.")
# gen_text = to_translit("Эх, полным полна коробочка!")

### Простая генерация

In [6]:
# # Cell 3 - main process call & visualization
# speed = 0.8
# print(f"Symbols in ref text: {len(ref_text)}")
# print(f"Symbols in gen text: {len(gen_text)}")
# main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed, end_time=ref_end_time, nfe_step=32, cfg_strength=3.0)
# print(len(ref_text))

# # Display mel-spectrograms
# from IPython.display import display, Image, Audio

# ref_mel_path = Path(output_dir) / "ref_mel.png"
# gen_mel_path = Path(output_dir) / "gen_mel.png"

# display(Image(filename=ref_mel_path))
# display(Image(filename=gen_mel_path))

# # Display audio
# display(Audio(ref_audio))
# display(Audio(wave_path))

### Частичная генерация с промежуточным растяжением

In [None]:
from IPython.display import display, Image, Audio
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchaudio

# Parameters
N = 5  # Step at which to intervene
stretch_factor = 1.2  # How much to stretch along time axis
total_steps = 32
device = "cuda:0"

# First stage - generate trajectory up to step N
ref_audio_preprocessed, ref_text_preprocessed = preprocess_ref_audio_text(ref_audio, ref_text) 
first_audio, sr, _, first_trajectory = infer_single_process(
    ref_audio_preprocessed, ref_text_preprocessed, gen_text, ema_model, vocoder, 
    mel_spec_type=mel_spec_type, 
    nfe_step=total_steps,
    speed=speed,
    end_step=N
)

# Get intermediate state and remove noise according to schedule
noise_scale = 1 - (N/total_steps) ** 2
intermediate_state = first_trajectory[-1] - first_trajectory[0] * noise_scale

# Draw mel spectrum after noise removal
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(intermediate_state.squeeze().cpu().numpy().T, aspect='auto', origin='lower', cmap='magma')
plt.colorbar()
plt.title('After noise removal at step N')

# # Get the audio and process it
# audio, sr = torchaudio.load(ref_audio_preprocessed)
# if audio.shape[0] > 1:
#     audio = torch.mean(audio, dim=0, keepdim=True)
# audio = audio.to(device=device, dtype=torch.float32)

# if sr != target_sample_rate:
#     resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
#     resampler = resampler.to(device)
#     audio = resampler(audio)

# # Calculate mel conditioning
# mel = vocoder.mel_spec(audio) if mel_spec_type == "bigvgan" else model_obj.mel_spec(audio)
# cond = mel.permute(0, 2, 1)
# cond_seq_len = cond.shape[1]

# Stretch intermediate state
orig_len = intermediate_state.shape[1]
new_len = int(orig_len * stretch_factor)
stretched_state = F.interpolate(
    intermediate_state.permute(0, 2, 1), 
    size=new_len, 
    mode='linear', 
    align_corners=False
).permute(0, 2, 1).to(device)

# Handle conditioning - only stretch the non-padded part
# real_cond = cond[:, :cond_seq_len, :]
# stretched_cond = F.interpolate(
#     real_cond.permute(0, 2, 1),
#     size=int(cond_seq_len * stretch_factor),
#     mode='linear',
#     align_corners=False
# ).permute(0, 2, 1)

# Pad the stretched conditioning to match new length
# stretched_cond = F.pad(
#     stretched_cond,
#     (0, 0, 0, new_len - stretched_cond.shape[1]),
#     value=0.0
# )

# Add new noise scaled to the same amount we removed
new_noise = torch.randn(
    stretched_state.shape[0], 
    new_len,
    stretched_state.shape[2], 
    device=device, 
    dtype=stretched_state.dtype
) * noise_scale

stretched_state_with_noise = stretched_state + new_noise

# Continue flow from step N+1 with stretched state
audio, sr, mel_spec, final_trajectory = infer_single_process(
    ref_audio_preprocessed, ref_text_preprocessed, gen_text, ema_model, vocoder, 
    mel_spec_type=mel_spec_type, 
    nfe_step=total_steps,
    speed=speed * stretch_factor,  # Adjust speed to match new length
    start_step=N+1,
    initial_state=stretched_state_with_noise
)

# Draw final mel spectrum
plt.subplot(1, 2, 2)
plt.imshow(mel_spec.T, aspect='auto', origin='lower', cmap='magma')
plt.colorbar()
plt.title('Final result after full denoising')

plt.tight_layout()
plt.show()

# Save and play final audio
sf.write("stretch_test.mp3", audio, sr)
display(Audio("stretch_test.mp3"))

Converting audio...
Audio is over 15s, clipping short. (2)
Using cached reference text...
