Skip to content

Commit

Permalink
Expand dynamic supported objects to configs and tokenizers (huggingfa…
Browse files Browse the repository at this point in the history
…ce#14296)

* Dynamic configs

* Add config test

* Better tests

* Add tokenizer and test

* Add to from_config

* With save
  • Loading branch information
sgugger authored and Alberto Bégué committed Jan 27, 2022
1 parent 082ae69 commit dd83b4b
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 10 deletions.
21 changes: 19 additions & 2 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,24 @@ def __init__(self, *args, **kwargs):

@classmethod
def from_config(cls, config, **kwargs):
if type(config) in cls._model_mapping.keys():
trust_remote_code = kwargs.pop("trust_remote_code", False)
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not trust_remote_code:
raise ValueError(
"Loading this model requires you to execute the modeling file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warn(
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
"no malicious code has been contributed in a newer revision."
)
class_ref = config.auto_map[cls.__name__]
module_file, class_name = class_ref.split(".")
model_class = get_class_from_dynamic_module(config.name_or_path, module_file + ".py", class_name, **kwargs)
return model_class._from_config(config, **kwargs)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
return model_class._from_config(config, **kwargs)

Expand All @@ -394,7 +411,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
kwargs["_from_auto"] = True
if not isinstance(config, PretrainedConfig):
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
pretrained_model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **kwargs
)
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not trust_remote_code:
Expand Down
30 changes: 29 additions & 1 deletion src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@

from ...configuration_utils import PretrainedConfig
from ...file_utils import CONFIG_NAME
from ...utils import logging
from .dynamic import get_class_from_dynamic_module


logger = logging.get_logger(__name__)

CONFIG_MAPPING_NAMES = OrderedDict(
[
# Add configs here
Expand Down Expand Up @@ -523,6 +527,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs`
is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e.,
the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored.
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
kwargs(additional keyword arguments, `optional`):
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
Expand Down Expand Up @@ -555,8 +563,28 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
{'foo': False}
"""
kwargs["_from_auto"] = True
kwargs["name_or_path"] = pretrained_model_name_or_path
trust_remote_code = kwargs.pop("trust_remote_code", False)
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "model_type" in config_dict:
if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the configuration file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warn(
"Explicitly passing a `revision` is encouraged when loading a configuration with custom code to "
"ensure no malicious code has been contributed in a newer revision."
)
class_ref = config_dict["auto_map"]["AutoConfig"]
module_file, class_name = class_ref.split(".")
config_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "model_type" in config_dict:
config_class = CONFIG_MAPPING[config_dict["model_type"]]
return config_class.from_dict(config_dict, **kwargs)
else:
Expand Down
38 changes: 36 additions & 2 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
model_type_to_module_name,
replace_list_option_in_docstrings,
)
from .dynamic import get_class_from_dynamic_module


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -412,6 +413,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
Whether or not to try to load the fast version of the tokenizer.
tokenizer_type (:obj:`str`, `optional`):
Tokenizer type to be loaded.
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
kwargs (additional keyword arguments, `optional`):
Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
Expand All @@ -436,6 +441,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):

use_fast = kwargs.pop("use_fast", True)
tokenizer_type = kwargs.pop("tokenizer_type", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)

# First, let's see whether the tokenizer_type is passed so that we can leverage it
if tokenizer_type is not None:
Expand Down Expand Up @@ -464,17 +470,45 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
# Next, let's try to use the tokenizer_config file to get the tokenizer class.
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
tokenizer_auto_map = tokenizer_config.get("auto_map")

# If that did not work, let's try to use the config.
if config_tokenizer_class is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
config_tokenizer_class = config.tokenizer_class
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
tokenizer_auto_map = config.auto_map["AutoTokenizer"]

# If we have the tokenizer class from the tokenizer config or the model config we're good!
if config_tokenizer_class is not None:
tokenizer_class = None
if use_fast and not config_tokenizer_class.endswith("Fast"):
if tokenizer_auto_map is not None:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the tokenizer file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warn(
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
"no malicious code has been contributed in a newer revision."
)

if use_fast and tokenizer_auto_map[1] is not None:
class_ref = tokenizer_auto_map[1]
else:
class_ref = tokenizer_auto_map[0]

module_file, class_name = class_ref.split(".")
tokenizer_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)

elif use_fast and not config_tokenizer_class.endswith("Fast"):
tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None:
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,6 +1784,7 @@ def _from_pretrained(
# First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
config_tokenizer_class = init_kwargs.get("tokenizer_class")
init_kwargs.pop("tokenizer_class", None)
init_kwargs.pop("auto_map", None)
saved_init_inputs = init_kwargs.pop("init_inputs", ())
if not init_inputs:
init_inputs = saved_init_inputs
Expand Down Expand Up @@ -2028,6 +2029,8 @@ def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):
if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast":
tokenizer_class = tokenizer_class[:-4]
tokenizer_config["tokenizer_class"] = tokenizer_class
if getattr(self, "_auto_map", None) is not None:
tokenizer_config["auto_map"] = self._auto_map

with open(tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
Expand Down
43 changes: 41 additions & 2 deletions tests/test_configuration_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import tempfile
import unittest

from huggingface_hub import delete_repo, login
from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
from transformers import BertConfig, GPT2Config, is_torch_available
from transformers import AutoConfig, BertConfig, GPT2Config, is_torch_available
from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import PASS, USER, is_staging_test

Expand Down Expand Up @@ -190,6 +190,23 @@ def run_common_tests(self):
self.check_config_arguments_init()


class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)


# Make sure this is synchronized with the config above.
FAKE_CONFIG_CODE = """
from transformers import PretrainedConfig
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
"""


@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
@classmethod
Expand All @@ -208,6 +225,11 @@ def tearDownClass(cls):
except HTTPError:
pass

try:
delete_repo(token=cls._token, name="test-dynamic-config")
except HTTPError:
pass

def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
Expand Down Expand Up @@ -238,6 +260,23 @@ def test_push_to_hub_in_organization(self):
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))

def test_push_to_hub_dynamic_config(self):
config = FakeConfig(attribute=42)
config.auto_map = {"AutoConfig": "configuration.FakeConfig"}

with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-config", use_auth_token=self._token)
config.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "configuration.py"), "w") as f:
f.write(FAKE_CONFIG_CODE)

repo.push_to_hub()

new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True)
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
self.assertEqual(new_config.__class__.__name__, "FakeConfig")
self.assertEqual(new_config.attribute, 42)


class ConfigTestUtils(unittest.TestCase):
def test_config_from_string(self):
Expand Down
74 changes: 72 additions & 2 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@
import transformers
from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging
from transformers import (
AutoConfig,
AutoModel,
AutoModelForSequenceClassification,
PretrainedConfig,
is_torch_available,
logging,
)
from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
Expand Down Expand Up @@ -67,7 +74,6 @@
AdaptiveEmbedding,
BertConfig,
BertModel,
PretrainedConfig,
PreTrainedModel,
T5Config,
T5ForConditionalGeneration,
Expand Down Expand Up @@ -2078,6 +2084,23 @@ def test_model_from_pretrained_torch_dtype(self):
self.assertEqual(model.dtype, torch.float16)


class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)


# Make sure this is synchronized with the config above.
FAKE_CONFIG_CODE = """
from transformers import PretrainedConfig
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
"""


if is_torch_available():

class FakeModel(PreTrainedModel):
Expand Down Expand Up @@ -2140,6 +2163,11 @@ def tearDownClass(cls):
except HTTPError:
pass

try:
delete_repo(token=cls._token, name="test-dynamic-model-config")
except HTTPError:
pass

def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
Expand Down Expand Up @@ -2185,5 +2213,47 @@ def test_push_to_hub_dynamic_model(self):
repo.push_to_hub()

new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
# Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module
self.assertEqual(new_model.__class__.__name__, "FakeModel")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))

config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model")
new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "FakeModel")

def test_push_to_hub_dynamic_model_and_config(self):
config = FakeConfig(
attribute=42,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
)
config.auto_map = {"AutoConfig": "configuration.FakeConfig", "AutoModel": "modeling.FakeModel"}
model = FakeModel(config)

with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-model-config", use_auth_token=self._token)
model.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "configuration.py"), "w") as f:
f.write(FAKE_CONFIG_CODE)
with open(os.path.join(tmp_dir, "modeling.py"), "w") as f:
f.write(FAKE_MODEL_CODE)

repo.push_to_hub()

new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model-config", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(new_model.config.__class__.__name__, "FakeConfig")
self.assertEqual(new_model.config.attribute, 42)

# Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module
self.assertEqual(new_model.__class__.__name__, "FakeModel")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))

config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model")
new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "FakeModel")
Loading

0 comments on commit dd83b4b

Please sign in to comment.