In [1]:
import sys
sys.path.insert(0, "./src")

import f5_tts.model
print("f5_tts.model contents:", dir(f5_tts.model))


  import pkg_resources
Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.635 seconds.
Prefix dict has been built successfully.


Word segmentation module jieba initialized.



  from .autonotebook import tqdm as notebook_tqdm


f5_tts.model contents: ['CFM', 'DiT', 'MMDiT', 'Trainer', 'UNetT', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'backbones', 'cfm', 'dataset', 'modules', 'trainer', 'utils']


In [2]:
import duration_predictor
print("duration_predictor contents:", dir(duration_predictor))


duration_predictor contents: ['PositionalEncoding', 'SpeechLengthPredictor', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', 'calculate_remaining_lengths', 'nn', 'torch']


In [3]:
import guidance_model
print("guidance_model contents:", dir(guidance_model))


guidance_model contents: ['Callable', 'ConformerCTC', 'ConformerDiscirminator', 'DiT', 'ECAPA_TDNN', 'F', 'Guidance', 'NoOpContext', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', '_kl_dist_func', 'annotations', 'default', 'exists', 'lens_to_mask', 'list_str_to_idx', 'list_str_to_tensor', 'mask_from_frac_lengths', 'nn', 'np', 'predict_flow', 'random', 'torch']


In [4]:
import os
print("Current directory explicitly is:", os.getcwd())
#print("Explicit contents of ckpts folder:", os.listdir('ckpts'))

Current directory explicitly is: /home/mike/github/wrightmikea/DMOSpeech2/src


In [None]:
import sys
sys.path.insert(0, "./src")

from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import FileResponse, JSONResponse
import torchaudio
import torch
import uuid
import nest_asyncio
import uvicorn

from transformers import AutoTokenizer
from unimodel import UniModel
from duration_predictor import SpeechLengthPredictor
from guidance_model import Guidance
from f5_tts.model import DiT
from f5_tts.model.utils import get_tokenizer

nest_asyncio.apply()

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Setup tokenizer and vocabulary (matching infer.py pattern)
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")

# Create DiT model with proper configuration (F5TTS_Base config from infer.py)
dit_model = DiT(
    dim=1024, 
    depth=22, 
    heads=16, 
    ff_mult=2, 
    text_dim=512, 
    conv_layers=4,
    text_num_embeds=vocab_size,
    mel_dim=100,
    second_time=True  # num_student_step > 1
)

# Initialize UniModel with the actual DiT model
mel_model = UniModel(
    model=dit_model,
    checkpoint_path="",
    vocab_char_map=vocab_char_map,
    frac_lengths_mask=(0.5, 0.9),
    real_guidance_scale=2.0,
    fake_guidance_scale=0.0,
    gen_cls_loss=False,
    sway_coeff=0
).to(DEVICE)

# Load checkpoint
checkpoint = torch.load("../ckpts/model_85000.pt", map_location='cpu')
mel_model.load_state_dict(checkpoint['model_state_dict'], strict=False)
mel_model.eval()

# Initialize SpeechLengthPredictor with correct parameters (matching infer.py)
duration_model = SpeechLengthPredictor(
    vocab_size=2545,
    n_mel=100,
    hidden_dim=512,
    n_text_layer=4,
    n_cross_layer=4,
    n_head=8,
    output_dim=301
).to(DEVICE)

# Load duration predictor checkpoint
duration_checkpoint = torch.load("../ckpts/model_1500.pt", map_location='cpu')
duration_model.load_state_dict(duration_checkpoint['model_state_dict'])
duration_model.eval()

app = FastAPI()

style_state = {"wav": None, "prompt_ids": None}

@app.post("/init_voice")
async def init_voice(audio_file: UploadFile = File(...), reference_text: str = Form(...)):
    wav, sr = torchaudio.load(audio_file.file)
    if sr != 22050:
        wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=22050)
    style_state["wav"] = wav.unsqueeze(0).to(DEVICE)
    
    # Use proper tokenizer encoding - list_str_to_idx returns a list of indices
    from f5_tts.model.utils import list_str_to_idx
    prompt_ids = list_str_to_idx([reference_text], vocab_char_map)[0]
    style_state["prompt_ids"] = torch.tensor(prompt_ids, dtype=torch.long).unsqueeze(0).to(DEVICE)
    
    return JSONResponse({"status": "Voice style initialized."})

@app.post("/generate_audio")
async def generate_audio(target_text: str = Form(...)):
    if style_state["wav"] is None or style_state["prompt_ids"] is None:
        return JSONResponse({"error": "Initialize voice first."}, status_code=400)

    # Use proper tokenizer encoding
    from f5_tts.model.utils import list_str_to_idx
    tgt_ids = list_str_to_idx([target_text], vocab_char_map)[0] 
    tgt_ids = torch.tensor(tgt_ids, dtype=torch.long).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        duration_pred = duration_model(style_state["wav"], style_state["prompt_ids"])
        generated_mel = mel_model.infer(
            style_state["wav"], style_state["prompt_ids"], tgt_ids,
            duration_pred
        )

    waveform = torchaudio.functional.griffinlim(generated_mel.squeeze().cpu(), n_fft=1024)
    filename = f"/tmp/{uuid.uuid4().hex}.wav"
    torchaudio.save(filename, waveform.unsqueeze(0), 22050)

    return FileResponse(filename, media_type='audio/wav', filename="generated.wav")

uvicorn.run(app, host="0.0.0.0", port=8000)