diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d7eec086bf9bf6..76ff2db343843d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -89,7 +89,7 @@ replace_return_docstrings, strtobool, ) -from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files +from .utils.hub import convert_file_size_to_int, create_and_tag_model_card, get_checkpoint_shard_files from .utils.import_utils import ( ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled, @@ -1172,6 +1172,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix config_class = None base_model_prefix = "" main_input_name = "input_ids" + model_tags = None + _auto_class = None _no_split_modules = None _skip_keys_device_placement = None @@ -1252,6 +1254,38 @@ def _backward_compatibility_gradient_checkpointing(self): # Remove the attribute now that is has been consumed, so it's no saved in the config. delattr(self.config, "gradient_checkpointing") + def add_model_tags(self, tags: Union[List[str], str]) -> None: + r""" + Add custom tags into the model that gets pushed to the Hugging Face Hub. Will + not overwrite existing tags in the model. + + Args: + tags (`Union[List[str], str]`): + The desired tags to inject in the model + + Examples: + + ```python + from transformers import AutoModel + + model = AutoModel.from_pretrained("bert-base-cased") + + model.add_model_tags(["custom", "custom-bert"]) + + # Push the model to your namespace with the name "my-custom-bert". + model.push_to_hub("my-custom-bert") + ``` + """ + if isinstance(tags, str): + tags = [tags] + + if self.model_tags is None: + self.model_tags = [] + + for tag in tags: + if tag not in self.model_tags: + self.model_tags.append(tag) + @classmethod def _from_config(cls, config, **kwargs): """ @@ -2212,6 +2246,7 @@ def save_pretrained( Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ use_auth_token = kwargs.pop("use_auth_token", None) + ignore_metadata_errors = kwargs.pop("ignore_metadata_errors", False) if use_auth_token is not None: warnings.warn( @@ -2438,6 +2473,14 @@ def save_pretrained( ) if push_to_hub: + # Eventually create an empty model card + model_card = create_and_tag_model_card( + repo_id, self.model_tags, token=token, ignore_metadata_errors=ignore_metadata_errors + ) + + # Update model card if needed: + model_card.save(os.path.join(save_directory, "README.md")) + self._upload_modified_files( save_directory, repo_id, @@ -2446,6 +2489,22 @@ def save_pretrained( token=token, ) + @wraps(PushToHubMixin.push_to_hub) + def push_to_hub(self, *args, **kwargs): + tags = self.model_tags if self.model_tags is not None else [] + + tags_kwargs = kwargs.get("tags", []) + if isinstance(tags_kwargs, str): + tags_kwargs = [tags_kwargs] + + for tag in tags_kwargs: + if tag not in tags: + tags.append(tag) + + if tags: + kwargs["tags"] = tags + return super().push_to_hub(*args, **kwargs) + def get_memory_footprint(self, return_buffers=True): r""" Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c3fa757157cbee..6850f4dca067ea 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3581,6 +3581,15 @@ def create_model_card( library_name = ModelCard.load(model_card_filepath).data.get("library_name") is_peft_library = library_name == "peft" + # Append existing tags in `tags` + existing_tags = ModelCard.load(model_card_filepath).data.tags + if tags is not None and existing_tags is not None: + if isinstance(tags, str): + tags = [tags] + for tag in existing_tags: + if tag not in tags: + tags.append(tag) + training_summary = TrainingSummary.from_trainer( self, language=language, @@ -3699,6 +3708,18 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin if not self.is_world_process_zero(): return + # Add additional tags in the case the model has already some tags and users pass + # "tags" argument to `push_to_hub` so that trainer automatically handles internal tags + # from all models since Trainer does not call `model.push_to_hub`. + if "tags" in kwargs and getattr(self.model, "model_tags", None) is not None: + # If it is a string, convert it to a list + if isinstance(kwargs["tags"], str): + kwargs["tags"] = [kwargs["tags"]] + + for model_tag in self.model.model_tags: + if model_tag not in kwargs["tags"]: + kwargs["tags"].append(model_tag) + self.create_model_card(model_name=model_name, **kwargs) # Wait for the current upload to be finished. diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 83ef69b5f37213..6b427ed4df0af0 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -33,6 +33,8 @@ from huggingface_hub import ( _CACHED_NO_EXIST, CommitOperationAdd, + ModelCard, + ModelCardData, constants, create_branch, create_commit, @@ -762,6 +764,7 @@ def push_to_hub( safe_serialization: bool = True, revision: str = None, commit_description: str = None, + tags: Optional[List[str]] = None, **deprecated_kwargs, ) -> str: """ @@ -795,6 +798,8 @@ def push_to_hub( Branch to push the uploaded files to. commit_description (`str`, *optional*): The description of the commit that will be created + tags (`List[str]`, *optional*): + List of tags to push on the Hub. Examples: @@ -811,6 +816,7 @@ def push_to_hub( ``` """ use_auth_token = deprecated_kwargs.pop("use_auth_token", None) + ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False) if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", @@ -855,6 +861,11 @@ def push_to_hub( repo_id, private=private, token=token, repo_url=repo_url, organization=organization ) + # Create a new empty model card and eventually tag it + model_card = create_and_tag_model_card( + repo_id, tags, token=token, ignore_metadata_errors=ignore_metadata_errors + ) + if use_temp_dir is None: use_temp_dir = not os.path.isdir(working_dir) @@ -864,6 +875,9 @@ def push_to_hub( # Save all files. self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) + # Update model card if needed: + model_card.save(os.path.join(work_dir, "README.md")) + return self._upload_modified_files( work_dir, repo_id, @@ -1081,6 +1095,43 @@ def extract_info_from_url(url): return {"repo": cache_repo, "revision": revision, "filename": filename} +def create_and_tag_model_card( + repo_id: str, + tags: Optional[List[str]] = None, + token: Optional[str] = None, + ignore_metadata_errors: bool = False, +): + """ + Creates or loads an existing model card and tags it. + + Args: + repo_id (`str`): + The repo_id where to look for the model card. + tags (`List[str]`, *optional*): + The list of tags to add in the model card + token (`str`, *optional*): + Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token. + ignore_metadata_errors (`str`): + If True, errors while parsing the metadata section will be ignored. Some information might be lost during + the process. Use it at your own risk. + """ + try: + # Check if the model card is present on the remote repo + model_card = ModelCard.load(repo_id, token=token, ignore_metadata_errors=ignore_metadata_errors) + except EntryNotFoundError: + # Otherwise create a simple model card from template + model_description = "This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated." + card_data = ModelCardData(tags=[] if tags is None else tags, library_name="transformers") + model_card = ModelCard.from_template(card_data, model_description=model_description) + + if tags is not None: + for model_tag in tags: + if model_tag not in model_card.data.tags: + model_card.data.tags.append(model_tag) + + return model_card + + def clean_files_for(file): """ Remove, if they exist, file, file.json and file.lock diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 398c0bd0949345..ec72cdab82b900 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1435,6 +1435,11 @@ def tearDownClass(cls): except HTTPError: pass + try: + delete_repo(token=cls._token, repo_id="test-dynamic-model-with-tags") + except HTTPError: + pass + @unittest.skip("This test is flaky") def test_push_to_hub(self): config = BertConfig( @@ -1522,6 +1527,28 @@ def test_push_to_hub_dynamic_model(self): new_model = AutoModel.from_config(config, trust_remote_code=True) self.assertEqual(new_model.__class__.__name__, "CustomModel") + def test_push_to_hub_with_tags(self): + from huggingface_hub import ModelCard + + new_tags = ["tag-1", "tag-2"] + + CustomConfig.register_for_auto_class() + CustomModel.register_for_auto_class() + + config = CustomConfig(hidden_size=32) + model = CustomModel(config) + + self.assertTrue(model.model_tags is None) + + model.add_model_tags(new_tags) + + self.assertTrue(model.model_tags == new_tags) + + model.push_to_hub("test-dynamic-model-with-tags", token=self._token) + + loaded_model_card = ModelCard.load(f"{USER}/test-dynamic-model-with-tags") + self.assertEqual(loaded_model_card.data.tags, new_tags) + @require_torch class AttentionMaskTester(unittest.TestCase):