Skip to content

Commit

Permalink
整理: TTSEngineManager クラスの新設 (#1234)
Browse files Browse the repository at this point in the history
* refactor: `TTSEngines` クラスの新設

* add: test

* fix: `TTSEngines` → `TTSEngineManager` へリネーム

* fix: リネーム忘れを修正

* fix: TTSEngineManager の property を関数へ変更
  • Loading branch information
tarepan committed May 22, 2024
1 parent 0cfb336 commit 4475b16
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 44 deletions.
6 changes: 4 additions & 2 deletions build_util/make_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from voicevox_engine.dev.tts_engine.mock import MockTTSEngine
from voicevox_engine.preset.PresetManager import PresetManager
from voicevox_engine.setting.SettingLoader import USER_SETTING_PATH, SettingHandler
from voicevox_engine.tts_pipeline.tts_engine import CoreAdapter
from voicevox_engine.tts_pipeline.tts_engine import CoreAdapter, TTSEngineManager
from voicevox_engine.user_dict.user_dict import UserDictionary
from voicevox_engine.utility.path_utility import engine_root

Expand Down Expand Up @@ -36,9 +36,11 @@ def generate_api_docs_html(schema: str) -> str:

if __name__ == "__main__":
mock_core = MockCoreWrapper()
tts_engines = TTSEngineManager()
tts_engines.register_engine(MockTTSEngine(), "mock")
# FastAPI の機能を用いて OpenAPI schema を生成する
app = generate_app(
tts_engines={"mock": MockTTSEngine()},
tts_engines=tts_engines,
cores={"mock": CoreAdapter(mock_core)},
latest_core_version="mock",
setting_loader=SettingHandler(USER_SETTING_PATH),
Expand Down
5 changes: 2 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from voicevox_engine.setting.SettingLoader import USER_SETTING_PATH, SettingHandler
from voicevox_engine.tts_pipeline.tts_engine import make_tts_engines_from_cores
from voicevox_engine.user_dict.user_dict import UserDictionary
from voicevox_engine.utility.core_version_utility import get_latest_version
from voicevox_engine.utility.path_utility import engine_root


Expand Down Expand Up @@ -267,8 +266,8 @@ def main() -> None:
load_all_models=load_all_models,
)
tts_engines = make_tts_engines_from_cores(cores)
assert len(tts_engines) != 0, "音声合成エンジンがありません。"
latest_core_version = get_latest_version(list(tts_engines.keys()))
assert len(tts_engines.versions()) != 0, "音声合成エンジンがありません。"
latest_core_version = tts_engines.latest_version()

# Cancellable Engine
enable_cancellable_synthesis: bool = args.enable_cancellable_synthesis
Expand Down
3 changes: 1 addition & 2 deletions test/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
from voicevox_engine.setting.SettingLoader import SettingHandler
from voicevox_engine.tts_pipeline.tts_engine import make_tts_engines_from_cores
from voicevox_engine.user_dict.user_dict import UserDictionary
from voicevox_engine.utility.core_version_utility import get_latest_version


@pytest.fixture()
def app_params(tmp_path: Path) -> dict[str, Any]:
cores = initialize_cores(use_gpu=False, enable_mock=True)
tts_engines = make_tts_engines_from_cores(cores)
latest_core_version = get_latest_version(list(tts_engines.keys()))
latest_core_version = tts_engines.latest_version()
setting_loader = SettingHandler(Path("./not_exist.yaml"))

# 隔離されたプリセットの生成
Expand Down
124 changes: 124 additions & 0 deletions test/tts_pipeline/test_tts_engines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
""" `TTSEngineManager` クラスのテスト"""

import pytest
from fastapi import HTTPException

from voicevox_engine.dev.tts_engine.mock import MockTTSEngine
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager


def test_tts_engines_register_engine() -> None:
"""TTSEngineManager.register_engine() で TTS エンジンを登録できる。"""
# Inputs
tts_engines = TTSEngineManager()

# Test
tts_engines.register_engine(MockTTSEngine(), "0.0.1")


def test_tts_engines_versions() -> None:
"""TTSEngineManager.versions() でバージョン一覧を取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engines.register_engine(MockTTSEngine(), "0.0.1")
tts_engines.register_engine(MockTTSEngine(), "0.0.2")
# Expects
true_versions = ["0.0.1", "0.0.2"]
# Outputs
versions = tts_engines.versions()

# Test
assert true_versions == versions


def test_tts_engines_latest_version() -> None:
"""TTSEngineManager.latest_version() で最新バージョンを取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engines.register_engine(MockTTSEngine(), "0.0.1")
tts_engines.register_engine(MockTTSEngine(), "0.0.2")
# Expects
true_latest_version = "0.0.2"
# Outputs
latest_version = tts_engines.latest_version()

# Test
assert true_latest_version == latest_version


def test_tts_engines_get_engine_specified() -> None:
"""TTSEngineManager.get_engine() で登録済み TTS エンジンをバージョン指定して取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engine1 = MockTTSEngine()
tts_engine2 = MockTTSEngine()
tts_engines.register_engine(tts_engine1, "0.0.1")
tts_engines.register_engine(tts_engine2, "0.0.2")
# Expects
true_acquired_tts_engine = tts_engine2
# Outputs
acquired_tts_engine = tts_engines.get_engine("0.0.2")

# Test
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_get_engine_latest() -> None:
"""TTSEngineManager.get_engine() で最新版 TTS エンジンをバージョン未指定で取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engine1 = MockTTSEngine()
tts_engine2 = MockTTSEngine()
tts_engines.register_engine(tts_engine1, "0.0.1")
tts_engines.register_engine(tts_engine2, "0.0.2")
# Expects
true_acquired_tts_engine = tts_engine2
# Outputs
acquired_tts_engine = tts_engines.get_engine()

# Test
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_get_engine_missing() -> None:
"""TTSEngineManager.get_engine() で存在しない TTS エンジンを取得しようとするとエラーになる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engine1 = MockTTSEngine()
tts_engine2 = MockTTSEngine()
tts_engines.register_engine(tts_engine1, "0.0.1")
tts_engines.register_engine(tts_engine2, "0.0.2")

# Test
with pytest.raises(HTTPException) as _:
tts_engines.get_engine("0.0.3")


def test_tts_engines_has_engine_true() -> None:
"""TTSEngineManager.has_engine() で TTS エンジンが登録されていることを確認できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engines.register_engine(MockTTSEngine(), "0.0.1")
tts_engines.register_engine(MockTTSEngine(), "0.0.2")
# Expects
true_has = True
# Outputs
has = tts_engines.has_engine("0.0.1")

# Test
assert true_has == has


def test_tts_engines_has_engine_false() -> None:
"""TTSEngineManager.has_engine() で TTS エンジンが登録されていないことを確認できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engines.register_engine(MockTTSEngine(), "0.0.1")
tts_engines.register_engine(MockTTSEngine(), "0.0.2")
# Expects
true_has = False
# Outputs
has = tts_engines.has_engine("0.0.3")

# Test
assert true_has == has
15 changes: 4 additions & 11 deletions voicevox_engine/app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
from voicevox_engine.preset.PresetManager import PresetManager
from voicevox_engine.setting.Setting import CorsPolicyMode
from voicevox_engine.setting.SettingLoader import SettingHandler
from voicevox_engine.tts_pipeline.tts_engine import TTSEngine
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager
from voicevox_engine.user_dict.user_dict import UserDictionary
from voicevox_engine.utility.path_utility import engine_root, get_save_dir


def generate_app(
tts_engines: dict[str, TTSEngine],
tts_engines: TTSEngineManager,
cores: dict[str, CoreAdapter],
latest_core_version: str,
setting_loader: SettingHandler,
Expand Down Expand Up @@ -69,13 +69,6 @@ def generate_app(

metas_store = MetasStore(root_dir / "speaker_info")

def get_engine(core_version: str | None) -> TTSEngine:
if core_version is None:
return tts_engines[latest_core_version]
if core_version in tts_engines:
return tts_engines[core_version]
raise HTTPException(status_code=422, detail="不明なバージョンです")

def get_core(core_version: str | None) -> CoreAdapter:
"""指定したバージョンのコアを取得する"""
if core_version is None:
Expand All @@ -86,10 +79,10 @@ def get_core(core_version: str | None) -> CoreAdapter:

app.include_router(
generate_tts_pipeline_router(
get_engine, get_core, preset_manager, cancellable_engine
tts_engines, get_core, preset_manager, cancellable_engine
)
)
app.include_router(generate_morphing_router(get_engine, get_core, metas_store))
app.include_router(generate_morphing_router(tts_engines, get_core, metas_store))
app.include_router(generate_preset_router(preset_manager))
app.include_router(generate_speaker_router(get_core, metas_store, root_dir))
if engine_manifest_data.supported_features.manage_library:
Expand Down
6 changes: 3 additions & 3 deletions voicevox_engine/app/routers/morphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from voicevox_engine.morphing import (
synthesis_morphing_parameter as _synthesis_morphing_parameter,
)
from voicevox_engine.tts_pipeline.tts_engine import TTSEngine
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager
from voicevox_engine.utility.path_utility import delete_file

# キャッシュを有効化
Expand All @@ -31,7 +31,7 @@


def generate_morphing_router(
get_engine: Callable[[str | None], TTSEngine],
tts_engines: TTSEngineManager,
get_core: Callable[[str | None], CoreAdapter],
metas_store: MetasStore,
) -> APIRouter:
Expand Down Expand Up @@ -94,7 +94,7 @@ def _synthesis_morphing(
指定された2種類のスタイルで音声を合成、指定した割合でモーフィングした音声を得ます。
モーフィングの割合は`morph_rate`で指定でき、0.0でベースのスタイル、1.0でターゲットのスタイルに近づきます。
"""
engine = get_engine(core_version)
engine = tts_engines.get_engine(core_version)
core = get_core(core_version)

try:
Expand Down
26 changes: 13 additions & 13 deletions voicevox_engine/app/routers/tts_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
connect_base64_waves,
)
from voicevox_engine.tts_pipeline.kana_converter import create_kana, parse_kana
from voicevox_engine.tts_pipeline.tts_engine import TTSEngine
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager
from voicevox_engine.utility.path_utility import delete_file


def generate_tts_pipeline_router(
get_engine: Callable[[str | None], TTSEngine],
tts_engines: TTSEngineManager,
get_core: Callable[[str | None], CoreAdapter],
preset_manager: PresetManager,
cancellable_engine: CancellableEngine | None,
Expand All @@ -53,7 +53,7 @@ def audio_query(
"""
音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
engine = get_engine(core_version)
engine = tts_engines.get_engine(core_version)
core = get_core(core_version)
accent_phrases = engine.create_accent_phrases(text, style_id)
return AudioQuery(
Expand Down Expand Up @@ -82,7 +82,7 @@ def audio_query_from_preset(
"""
音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
engine = get_engine(core_version)
engine = tts_engines.get_engine(core_version)
core = get_core(core_version)
try:
presets = preset_manager.load_presets()
Expand Down Expand Up @@ -139,7 +139,7 @@ def accent_phrases(
* アクセント位置を`'`で指定する。全てのアクセント句にはアクセント位置を1つ指定する必要がある。
* アクセント句末に`?`(全角)を入れることにより疑問文の発音ができる。
"""
engine = get_engine(core_version)
engine = tts_engines.get_engine(core_version)
if is_kana:
try:
return engine.create_accent_phrases_from_kana(text, style_id)
Expand All @@ -160,7 +160,7 @@ def mora_data(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | None = None,
) -> list[AccentPhrase]:
engine = get_engine(core_version)
engine = tts_engines.get_engine(core_version)
return engine.update_length_and_pitch(accent_phrases, style_id)

@router.post(
Expand All @@ -173,7 +173,7 @@ def mora_length(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | None = None,
) -> list[AccentPhrase]:
engine = get_engine(core_version)
engine = tts_engines.get_engine(core_version)
return engine.update_length(accent_phrases, style_id)

@router.post(
Expand All @@ -186,7 +186,7 @@ def mora_pitch(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | None = None,
) -> list[AccentPhrase]:
engine = get_engine(core_version)
engine = tts_engines.get_engine(core_version)
return engine.update_pitch(accent_phrases, style_id)

@router.post(
Expand All @@ -213,7 +213,7 @@ def synthesis(
] = True,
core_version: str | None = None,
) -> FileResponse:
engine = get_engine(core_version)
engine = tts_engines.get_engine(core_version)
wave = engine.synthesize_wave(
query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak
)
Expand Down Expand Up @@ -285,7 +285,7 @@ def multi_synthesis(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | None = None,
) -> FileResponse:
engine = get_engine(core_version)
engine = tts_engines.get_engine(core_version)
sampling_rate = queries[0].outputSamplingRate

with NamedTemporaryFile(delete=False) as f:
Expand Down Expand Up @@ -327,7 +327,7 @@ def sing_frame_audio_query(
"""
歌唱音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま歌唱音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
engine = get_engine(core_version)
engine = tts_engines.get_engine(core_version)
core = get_core(core_version)
phonemes, f0, volume = engine.create_sing_phoneme_and_f0_and_volume(
score, style_id
Expand All @@ -353,7 +353,7 @@ def sing_frame_volume(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | None = None,
) -> list[float]:
engine = get_engine(core_version)
engine = tts_engines.get_engine(core_version)
return engine.create_sing_volume_from_phoneme_and_f0(
score, frame_audio_query.phonemes, frame_audio_query.f0, style_id
)
Expand All @@ -378,7 +378,7 @@ def frame_synthesis(
"""
歌唱音声合成を行います。
"""
engine = get_engine(core_version)
engine = tts_engines.get_engine(core_version)
wave = engine.frame_synthsize_wave(query, style_id)

with NamedTemporaryFile(delete=False) as f:
Expand Down
10 changes: 4 additions & 6 deletions voicevox_engine/cancellable_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from .metas.Metas import StyleId
from .model import AudioQuery
from .tts_pipeline.tts_engine import make_tts_engines_from_cores
from .utility.core_version_utility import get_latest_version


class CancellableEngine:
Expand Down Expand Up @@ -238,15 +237,14 @@ def start_synthesis_subprocess(
)
tts_engines = make_tts_engines_from_cores(cores)

assert len(tts_engines) != 0, "音声合成エンジンがありません。"
latest_core_version = get_latest_version(list(tts_engines.keys()))
assert len(tts_engines.versions()) != 0, "音声合成エンジンがありません。"
while True:
try:
query, style_id, core_version = sub_proc_con.recv()
if core_version is None:
_engine = tts_engines[latest_core_version]
elif core_version in tts_engines:
_engine = tts_engines[core_version]
_engine = tts_engines.get_engine()
elif tts_engines.has_engine(core_version):
_engine = tts_engines.get_engine(core_version)
else:
# バージョンが見つからないエラー
sub_proc_con.send("")
Expand Down
Loading

0 comments on commit 4475b16

Please sign in to comment.