Skip to content

Commit

Permalink
Merge pull request #555 from patrickvonplaten/add_tf_hub
Browse files Browse the repository at this point in the history
Proposal to integrate into 🤗 Hub
  • Loading branch information
dathudeptrai authored May 14, 2021
2 parents ba46b47 + f4efa38 commit f53ecd9
Show file tree
Hide file tree
Showing 13 changed files with 117 additions and 3 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"tensorflow-gpu==2.3.1",
"tensorflow-addons>=0.10.0",
"setuptools>=38.5.1",
"huggingface_hub==0.0.8",
"librosa>=0.7.0",
"soundfile>=0.10.2",
"matplotlib>=3.1.0",
Expand Down
19 changes: 19 additions & 0 deletions tensorflow_tts/inference/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
import yaml
import os
from collections import OrderedDict

from tensorflow_tts.configs import (
Expand All @@ -28,6 +29,10 @@
ParallelWaveGANGeneratorConfig,
)

from tensorflow_tts.utils import CACHE_DIRECTORY, CONFIG_FILE_NAME, LIBRARY_NAME
from tensorflow_tts import __version__ as VERSION
from huggingface_hub import hf_hub_url, cached_download

CONFIG_MAPPING = OrderedDict(
[
("fastspeech", FastSpeechConfig),
Expand All @@ -50,6 +55,20 @@ def __init__(self):

@classmethod
def from_pretrained(cls, pretrained_path, **kwargs):
# load weights from hf hub
if not os.path.isfile(pretrained_path):
# retrieve correct hub url
download_url = hf_hub_url(repo_id=pretrained_path, filename=CONFIG_FILE_NAME)

pretrained_path = str(
cached_download(
url=download_url,
library_name=LIBRARY_NAME,
library_version=VERSION,
cache_dir=CACHE_DIRECTORY,
)
)

with open(pretrained_path) as f:
config = yaml.load(f, Loader=yaml.SafeLoader)

Expand Down
35 changes: 34 additions & 1 deletion tensorflow_tts/inference/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import logging
import warnings
import os

from collections import OrderedDict

from tensorflow_tts.configs import (
Expand All @@ -40,6 +42,9 @@
SavableTFFastSpeech2,
SavableTFTacotron2
)
from tensorflow_tts.utils import CACHE_DIRECTORY, MODEL_FILE_NAME, LIBRARY_NAME
from tensorflow_tts import __version__ as VERSION
from huggingface_hub import hf_hub_url, cached_download


TF_MODEL_MAPPING = OrderedDict(
Expand All @@ -62,8 +67,35 @@ def __init__(self):
raise EnvironmentError("Cannot be instantiated using `__init__()`")

@classmethod
def from_pretrained(cls, config, pretrained_path=None, **kwargs):
def from_pretrained(cls, config=None, pretrained_path=None, **kwargs):
is_build = kwargs.pop("is_build", True)

# load weights from hf hub
if pretrained_path is not None:
if not os.path.isfile(pretrained_path):
# retrieve correct hub url
download_url = hf_hub_url(repo_id=pretrained_path, filename=MODEL_FILE_NAME)

downloaded_file = str(
cached_download(
url=download_url,
library_name=LIBRARY_NAME,
library_version=VERSION,
cache_dir=CACHE_DIRECTORY,
)
)

# load config from repo as well
if config is None:
from tensorflow_tts.inference import AutoConfig

config = AutoConfig.from_pretrained(pretrained_path)

pretraine_path = downloaded_file


assert config is not None, "Please make sure to pass a config along to load a model from a local file"

for config_class, model_class in TF_MODEL_MAPPING.items():
if isinstance(config, config_class) and str(config_class.__name__) in str(
config
Expand All @@ -79,6 +111,7 @@ def from_pretrained(cls, config, pretrained_path=None, **kwargs):
pretrained_path, by_name=True, skip_mismatch=True
)
return model

raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format(
Expand Down
18 changes: 18 additions & 0 deletions tensorflow_tts/inference/auto_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
import json
import os
from collections import OrderedDict

from tensorflow_tts.processor import (
Expand All @@ -26,6 +27,10 @@
ThorstenProcessor,
)

from tensorflow_tts.utils import CACHE_DIRECTORY, PROCESSOR_FILE_NAME, LIBRARY_NAME
from tensorflow_tts import __version__ as VERSION
from huggingface_hub import hf_hub_url, cached_download

CONFIG_MAPPING = OrderedDict(
[
("LJSpeechProcessor", LJSpeechProcessor),
Expand All @@ -46,6 +51,19 @@ def __init__(self):

@classmethod
def from_pretrained(cls, pretrained_path, **kwargs):
# load weights from hf hub
if not os.path.isfile(pretrained_path):
# retrieve correct hub url
download_url = hf_hub_url(repo_id=pretrained_path, filename=PROCESSOR_FILE_NAME)

pretrained_path = str(
cached_download(
url=download_url,
library_name=LIBRARY_NAME,
library_version=VERSION,
cache_dir=CACHE_DIRECTORY,
)
)
with open(pretrained_path, "r") as f:
config = json.load(f)

Expand Down
8 changes: 8 additions & 0 deletions tensorflow_tts/processor/baker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pypinyin.converter import DefaultConverter
from pypinyin.core import Pinyin
from tensorflow_tts.processor import BaseProcessor
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME

_pad = ["pad"]
_eos = ["eos"]
Expand Down Expand Up @@ -552,6 +553,13 @@ def __post_init__(self):
def setup_eos_token(self):
return _eos[0]

def save_pretrained(self, saved_path):
os.makedirs(saved_path, exist_ok=True)
self._save_mapper(
os.path.join(saved_path, PROCESSOR_FILE_NAME),
{"pinyin_dict": self.pinyin_dict},
)

def create_items(self):
items = []
if self.data_dir:
Expand Down
5 changes: 5 additions & 0 deletions tensorflow_tts/processor/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,8 @@ def _save_mapper(self, saved_path: str = None, extra_attrs_to_save: dict = None)
if extra_attrs_to_save:
full_mapper = {**full_mapper, **extra_attrs_to_save}
json.dump(full_mapper, f)

@abc.abstractmethod
def save_pretrained(self, saved_path):
"""Save mappers to file"""
pass
5 changes: 5 additions & 0 deletions tensorflow_tts/processor/kss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tensorflow_tts.processor import BaseProcessor
from tensorflow_tts.utils import cleaners
from tensorflow_tts.utils.korean import symbols as KSS_SYMBOLS
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME

# Regular expression matching text enclosed in curly braces:
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
Expand Down Expand Up @@ -57,6 +58,10 @@ def split_line(self, data_dir, line, split):
def setup_eos_token(self):
return "eos"

def save_pretrained(self, saved_path):
os.makedirs(saved_path, exist_ok=True)
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})

def get_one_sample(self, item):
text, wav_path, speaker_name = item

Expand Down
7 changes: 6 additions & 1 deletion tensorflow_tts/processor/libritts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from g2p_en import g2p as grapheme_to_phonem

from tensorflow_tts.processor.base_processor import BaseProcessor
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME

g2p = grapheme_to_phonem.G2p()

Expand Down Expand Up @@ -84,7 +85,11 @@ def get_one_sample(self, item):
return sample

def setup_eos_token(self):
return None # because we do not use this
return None # because we do not use this

def save_pretrained(self, saved_path):
os.makedirs(saved_path, exist_ok=True)
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})

def text_to_sequence(self, text):
if (
Expand Down
5 changes: 5 additions & 0 deletions tensorflow_tts/processor/ljspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dataclasses import dataclass
from tensorflow_tts.processor import BaseProcessor
from tensorflow_tts.utils import cleaners
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME

valid_symbols = [
"AA",
Expand Down Expand Up @@ -158,6 +159,10 @@ def split_line(self, data_dir, line, split):
def setup_eos_token(self):
return _eos

def save_pretrained(self, saved_path):
os.makedirs(saved_path, exist_ok=True)
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})

def get_one_sample(self, item):
text, wav_path, speaker_name = item

Expand Down
5 changes: 5 additions & 0 deletions tensorflow_tts/processor/thorsten.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dataclasses import dataclass
from tensorflow_tts.processor import BaseProcessor
from tensorflow_tts.utils import cleaners
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME

_pad = "pad"
_eos = "eos"
Expand Down Expand Up @@ -67,6 +68,10 @@ def split_line(self, data_dir, line, split):
def setup_eos_token(self):
return _eos

def save_pretrained(self, saved_path):
os.makedirs(saved_path, exist_ok=True)
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})

def get_one_sample(self, item):
text, wav_path, speaker_name = item

Expand Down
2 changes: 1 addition & 1 deletion tensorflow_tts/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
calculate_3d_loss,
return_strategy,
)
from tensorflow_tts.utils.utils import find_files
from tensorflow_tts.utils.utils import find_files, MODEL_FILE_NAME, CONFIG_FILE_NAME, PROCESSOR_FILE_NAME, CACHE_DIRECTORY, LIBRARY_NAME
from tensorflow_tts.utils.weight_norm import WeightNormalization
7 changes: 7 additions & 0 deletions tensorflow_tts/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@
import os
import re
import tempfile
from pathlib import Path

import tensorflow as tf

MODEL_FILE_NAME = "model.h5"
CONFIG_FILE_NAME = "config.yml"
PROCESSOR_FILE_NAME = "processor.json"
LIBRARY_NAME = "tensorflow_tts"
CACHE_DIRECTORY = os.path.join(Path.home(), ".cache", LIBRARY_NAME)


def find_files(root_dir, query="*.wav", include_root_dir=True):
"""Find files recursively.
Expand Down
3 changes: 3 additions & 0 deletions test/test_base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def text_to_sequence(self, text):
def setup_eos_token(self):
return None

def save_pretrained(self, saved_path):
return super().save_pretrained(saved_path)


@pytest.fixture
def processor(tmpdir):
Expand Down

0 comments on commit f53ecd9

Please sign in to comment.