diff --git a/setup.py b/setup.py index dcc1fde8..b6c6839c 100755 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ "tensorflow-gpu==2.3.1", "tensorflow-addons>=0.10.0", "setuptools>=38.5.1", + "huggingface_hub==0.0.8", "librosa>=0.7.0", "soundfile>=0.10.2", "matplotlib>=3.1.0", diff --git a/tensorflow_tts/inference/auto_config.py b/tensorflow_tts/inference/auto_config.py index c2318f2c..0b48272c 100644 --- a/tensorflow_tts/inference/auto_config.py +++ b/tensorflow_tts/inference/auto_config.py @@ -16,6 +16,7 @@ import logging import yaml +import os from collections import OrderedDict from tensorflow_tts.configs import ( @@ -28,6 +29,10 @@ ParallelWaveGANGeneratorConfig, ) +from tensorflow_tts.utils import CACHE_DIRECTORY, CONFIG_FILE_NAME, LIBRARY_NAME +from tensorflow_tts import __version__ as VERSION +from huggingface_hub import hf_hub_url, cached_download + CONFIG_MAPPING = OrderedDict( [ ("fastspeech", FastSpeechConfig), @@ -50,6 +55,20 @@ def __init__(self): @classmethod def from_pretrained(cls, pretrained_path, **kwargs): + # load weights from hf hub + if not os.path.isfile(pretrained_path): + # retrieve correct hub url + download_url = hf_hub_url(repo_id=pretrained_path, filename=CONFIG_FILE_NAME) + + pretrained_path = str( + cached_download( + url=download_url, + library_name=LIBRARY_NAME, + library_version=VERSION, + cache_dir=CACHE_DIRECTORY, + ) + ) + with open(pretrained_path) as f: config = yaml.load(f, Loader=yaml.SafeLoader) diff --git a/tensorflow_tts/inference/auto_model.py b/tensorflow_tts/inference/auto_model.py index 38e98d77..03a54cf2 100644 --- a/tensorflow_tts/inference/auto_model.py +++ b/tensorflow_tts/inference/auto_model.py @@ -16,6 +16,8 @@ import logging import warnings +import os + from collections import OrderedDict from tensorflow_tts.configs import ( @@ -40,6 +42,9 @@ SavableTFFastSpeech2, SavableTFTacotron2 ) +from tensorflow_tts.utils import CACHE_DIRECTORY, MODEL_FILE_NAME, LIBRARY_NAME +from tensorflow_tts import __version__ as VERSION +from huggingface_hub import hf_hub_url, cached_download TF_MODEL_MAPPING = OrderedDict( @@ -62,8 +67,35 @@ def __init__(self): raise EnvironmentError("Cannot be instantiated using `__init__()`") @classmethod - def from_pretrained(cls, config, pretrained_path=None, **kwargs): + def from_pretrained(cls, config=None, pretrained_path=None, **kwargs): is_build = kwargs.pop("is_build", True) + + # load weights from hf hub + if pretrained_path is not None: + if not os.path.isfile(pretrained_path): + # retrieve correct hub url + download_url = hf_hub_url(repo_id=pretrained_path, filename=MODEL_FILE_NAME) + + downloaded_file = str( + cached_download( + url=download_url, + library_name=LIBRARY_NAME, + library_version=VERSION, + cache_dir=CACHE_DIRECTORY, + ) + ) + + # load config from repo as well + if config is None: + from tensorflow_tts.inference import AutoConfig + + config = AutoConfig.from_pretrained(pretrained_path) + + pretraine_path = downloaded_file + + + assert config is not None, "Please make sure to pass a config along to load a model from a local file" + for config_class, model_class in TF_MODEL_MAPPING.items(): if isinstance(config, config_class) and str(config_class.__name__) in str( config @@ -79,6 +111,7 @@ def from_pretrained(cls, config, pretrained_path=None, **kwargs): pretrained_path, by_name=True, skip_mismatch=True ) return model + raise ValueError( "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Model type should be one of {}.".format( diff --git a/tensorflow_tts/inference/auto_processor.py b/tensorflow_tts/inference/auto_processor.py index ef82f16f..3a03bf0f 100644 --- a/tensorflow_tts/inference/auto_processor.py +++ b/tensorflow_tts/inference/auto_processor.py @@ -16,6 +16,7 @@ import logging import json +import os from collections import OrderedDict from tensorflow_tts.processor import ( @@ -26,6 +27,10 @@ ThorstenProcessor, ) +from tensorflow_tts.utils import CACHE_DIRECTORY, PROCESSOR_FILE_NAME, LIBRARY_NAME +from tensorflow_tts import __version__ as VERSION +from huggingface_hub import hf_hub_url, cached_download + CONFIG_MAPPING = OrderedDict( [ ("LJSpeechProcessor", LJSpeechProcessor), @@ -46,6 +51,19 @@ def __init__(self): @classmethod def from_pretrained(cls, pretrained_path, **kwargs): + # load weights from hf hub + if not os.path.isfile(pretrained_path): + # retrieve correct hub url + download_url = hf_hub_url(repo_id=pretrained_path, filename=PROCESSOR_FILE_NAME) + + pretrained_path = str( + cached_download( + url=download_url, + library_name=LIBRARY_NAME, + library_version=VERSION, + cache_dir=CACHE_DIRECTORY, + ) + ) with open(pretrained_path, "r") as f: config = json.load(f) diff --git a/tensorflow_tts/processor/baker.py b/tensorflow_tts/processor/baker.py index c17d53ba..465b63ef 100644 --- a/tensorflow_tts/processor/baker.py +++ b/tensorflow_tts/processor/baker.py @@ -27,6 +27,7 @@ from pypinyin.converter import DefaultConverter from pypinyin.core import Pinyin from tensorflow_tts.processor import BaseProcessor +from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME _pad = ["pad"] _eos = ["eos"] @@ -552,6 +553,13 @@ def __post_init__(self): def setup_eos_token(self): return _eos[0] + def save_pretrained(self, saved_path): + os.makedirs(saved_path, exist_ok=True) + self._save_mapper( + os.path.join(saved_path, PROCESSOR_FILE_NAME), + {"pinyin_dict": self.pinyin_dict}, + ) + def create_items(self): items = [] if self.data_dir: diff --git a/tensorflow_tts/processor/base_processor.py b/tensorflow_tts/processor/base_processor.py index b33440ca..ad3a2544 100644 --- a/tensorflow_tts/processor/base_processor.py +++ b/tensorflow_tts/processor/base_processor.py @@ -224,3 +224,8 @@ def _save_mapper(self, saved_path: str = None, extra_attrs_to_save: dict = None) if extra_attrs_to_save: full_mapper = {**full_mapper, **extra_attrs_to_save} json.dump(full_mapper, f) + + @abc.abstractmethod + def save_pretrained(self, saved_path): + """Save mappers to file""" + pass diff --git a/tensorflow_tts/processor/kss.py b/tensorflow_tts/processor/kss.py index 62228629..01fd4833 100644 --- a/tensorflow_tts/processor/kss.py +++ b/tensorflow_tts/processor/kss.py @@ -23,6 +23,7 @@ from tensorflow_tts.processor import BaseProcessor from tensorflow_tts.utils import cleaners from tensorflow_tts.utils.korean import symbols as KSS_SYMBOLS +from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME # Regular expression matching text enclosed in curly braces: _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") @@ -57,6 +58,10 @@ def split_line(self, data_dir, line, split): def setup_eos_token(self): return "eos" + def save_pretrained(self, saved_path): + os.makedirs(saved_path, exist_ok=True) + self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {}) + def get_one_sample(self, item): text, wav_path, speaker_name = item diff --git a/tensorflow_tts/processor/libritts.py b/tensorflow_tts/processor/libritts.py index 22d6b483..27ed8b0a 100644 --- a/tensorflow_tts/processor/libritts.py +++ b/tensorflow_tts/processor/libritts.py @@ -24,6 +24,7 @@ from g2p_en import g2p as grapheme_to_phonem from tensorflow_tts.processor.base_processor import BaseProcessor +from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME g2p = grapheme_to_phonem.G2p() @@ -84,7 +85,11 @@ def get_one_sample(self, item): return sample def setup_eos_token(self): - return None # because we do not use this + return None # because we do not use this + + def save_pretrained(self, saved_path): + os.makedirs(saved_path, exist_ok=True) + self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {}) def text_to_sequence(self, text): if ( diff --git a/tensorflow_tts/processor/ljspeech.py b/tensorflow_tts/processor/ljspeech.py index b44e4536..8ee38d46 100755 --- a/tensorflow_tts/processor/ljspeech.py +++ b/tensorflow_tts/processor/ljspeech.py @@ -22,6 +22,7 @@ from dataclasses import dataclass from tensorflow_tts.processor import BaseProcessor from tensorflow_tts.utils import cleaners +from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME valid_symbols = [ "AA", @@ -158,6 +159,10 @@ def split_line(self, data_dir, line, split): def setup_eos_token(self): return _eos + def save_pretrained(self, saved_path): + os.makedirs(saved_path, exist_ok=True) + self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {}) + def get_one_sample(self, item): text, wav_path, speaker_name = item diff --git a/tensorflow_tts/processor/thorsten.py b/tensorflow_tts/processor/thorsten.py index 16d57abc..437d6bbd 100644 --- a/tensorflow_tts/processor/thorsten.py +++ b/tensorflow_tts/processor/thorsten.py @@ -22,6 +22,7 @@ from dataclasses import dataclass from tensorflow_tts.processor import BaseProcessor from tensorflow_tts.utils import cleaners +from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME _pad = "pad" _eos = "eos" @@ -67,6 +68,10 @@ def split_line(self, data_dir, line, split): def setup_eos_token(self): return _eos + def save_pretrained(self, saved_path): + os.makedirs(saved_path, exist_ok=True) + self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {}) + def get_one_sample(self, item): text, wav_path, speaker_name = item diff --git a/tensorflow_tts/utils/__init__.py b/tensorflow_tts/utils/__init__.py index 7660da1e..21d99e92 100755 --- a/tensorflow_tts/utils/__init__.py +++ b/tensorflow_tts/utils/__init__.py @@ -18,5 +18,5 @@ calculate_3d_loss, return_strategy, ) -from tensorflow_tts.utils.utils import find_files +from tensorflow_tts.utils.utils import find_files, MODEL_FILE_NAME, CONFIG_FILE_NAME, PROCESSOR_FILE_NAME, CACHE_DIRECTORY, LIBRARY_NAME from tensorflow_tts.utils.weight_norm import WeightNormalization diff --git a/tensorflow_tts/utils/utils.py b/tensorflow_tts/utils/utils.py index aa293f79..43aa6376 100755 --- a/tensorflow_tts/utils/utils.py +++ b/tensorflow_tts/utils/utils.py @@ -8,9 +8,16 @@ import os import re import tempfile +from pathlib import Path import tensorflow as tf +MODEL_FILE_NAME = "model.h5" +CONFIG_FILE_NAME = "config.yml" +PROCESSOR_FILE_NAME = "processor.json" +LIBRARY_NAME = "tensorflow_tts" +CACHE_DIRECTORY = os.path.join(Path.home(), ".cache", LIBRARY_NAME) + def find_files(root_dir, query="*.wav", include_root_dir=True): """Find files recursively. diff --git a/test/test_base_processor.py b/test/test_base_processor.py index 35429ee7..1c44d595 100644 --- a/test/test_base_processor.py +++ b/test/test_base_processor.py @@ -24,6 +24,9 @@ def text_to_sequence(self, text): def setup_eos_token(self): return None + def save_pretrained(self, saved_path): + return super().save_pretrained(saved_path) + @pytest.fixture def processor(tmpdir):