Skip to content

Commit

Permalink
Support listing Hugging Face model info (#4619)
Browse files Browse the repository at this point in the history
* Support listing Hugging Face model info

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Add documentation about usage

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Add documentation about usage

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Update name of method, support list of model filters

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Improve docstring

Signed-off-by: smajumdar <smajumdar@nvidia.com>
  • Loading branch information
titu1994 authored Jul 27, 2022
1 parent cbf3f66 commit 90ad5af
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 3 deletions.
127 changes: 124 additions & 3 deletions nemo/core/classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from enum import Enum
from functools import total_ordering
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, Iterable, List, Optional, Union

import hydra
import wrapt
from huggingface_hub import hf_hub_download
from huggingface_hub.hf_api import HfFolder
from huggingface_hub import HfApi, HfFolder, ModelFilter, hf_hub_download
from huggingface_hub.hf_api import ModelInfo
from omegaconf import DictConfig, OmegaConf

import nemo
Expand Down Expand Up @@ -667,6 +667,105 @@ def list_available_models(cls) -> Optional[PretrainedModelInfo]:
"""
pass

@classmethod
def search_huggingface_models(
cls, model_filter: Optional[Union[ModelFilter, List[ModelFilter]]] = None
) -> List[ModelInfo]:
"""
Should list all pre-trained models available via Hugging Face Hub.
The following metadata can be passed via the `model_filter` for additional results.
Metadata:
resolve_card_info: Bool flag, if set, returns the model card metadata. Default: False.
limit_results: Optional int, limits the number of results returned.
.. code-block:: python
# You can replace <DomainSubclass> with any subclass of ModelPT.
from nemo.core import ModelPT
# Get default ModelFilter
filt = <DomainSubclass>.get_hf_model_filter()
# Make any modifications to the filter as necessary
filt.language = [...]
filt.task = ...
filt.tags = [...]
# Add any metadata to the filter as needed
filt.limit_results = 5
# Obtain model info
model_infos = <DomainSubclass>.search_huggingface_models(model_filter=filt)
# Browse through cards and select an appropriate one
card = model_infos[0]
# Restore model using `modelId` of the card.
model = ModelPT.from_pretrained(card.modelId)
Args:
model_filter: Optional ModelFilter or List[ModelFilter] (from Hugging Face Hub)
that filters the returned list of compatible model cards, and selects all results from each filter.
Users can then use `model_card.modelId` in `from_pretrained()` to restore a NeMo Model.
If no ModelFilter is provided, uses the classes default filter as defined by `get_hf_model_filter()`.
Returns:
A list of ModelInfo entries.
"""
# Resolve model filter if not provided as argument
if model_filter is None:
model_filter = cls.get_hf_model_filter()

# If single model filter, wrap into list
if not isinstance(model_filter, Iterable):
model_filter = [model_filter]

# Inject `nemo` library filter
for mfilter in model_filter:
if isinstance(mfilter.library, str) and mfilter.library != 'nemo':
logging.warning(f"Model filter's `library` tag updated be `nemo`. Original value: {mfilter.library}")
mfilter.library = "nemo"

elif isinstance(mfilter, Iterable) and 'nemo' not in mfilter.library:
logging.warning(
f"Model filter's `library` list updated to include `nemo`. Original value: {mfilter.library}"
)
mfilter.library = list(mfilter)
mfilter.library.append('nemo')

# Check if api token exists, use if it does
is_token_available = HfFolder.get_token() is not None

# Search for all valid models after filtering
api = HfApi()

# Setup extra arguments for model filtering
all_results = [] # type: List[ModelInfo]

for mfilter in model_filter:
cardData = None
limit = None

if hasattr(mfilter, 'resolve_card_info') and mfilter.resolve_card_info is True:
cardData = True

if hasattr(mfilter, 'limit_results') and mfilter.limit_results is not None:
limit = mfilter.limit_results

results = api.list_models(
filter=mfilter,
use_auth_token=is_token_available,
sort="lastModified",
direction=-1,
cardData=cardData,
limit=limit,
) # type: List[ModelInfo]

all_results.extend(results)

return all_results

@classmethod
def get_available_model_names(cls) -> List[str]:
"""
Expand All @@ -680,6 +779,28 @@ def get_available_model_names(cls) -> List[str]:
model_names = [model.pretrained_model_name for model in cls.list_available_models()]
return model_names

@classmethod
def get_hf_model_filter(cls) -> ModelFilter:
"""
Generates a filter for HuggingFace models.
Additionally includes default values of some metadata about results returned by the Hub.
Metadata:
resolve_card_info: Bool flag, if set, returns the model card metadata. Default: False.
limit_results: Optional int, limits the number of results returned.
Returns:
A Hugging Face Hub ModelFilter object.
"""
model_filter = ModelFilter(library='nemo')

# Attach some additional info
model_filter.resolve_card_info = False
model_filter.limit_results = None

return model_filter

@classmethod
def from_pretrained(
cls,
Expand Down
51 changes: 51 additions & 0 deletions tests/core/test_save_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import pytest
import torch
from huggingface_hub.hf_api import ModelFilter, ModelInfo
from omegaconf import DictConfig, OmegaConf, open_dict

from nemo.collections.asr.models import EncDecCTCModel, EncDecCTCModelBPE
Expand Down Expand Up @@ -605,3 +606,53 @@ class MockModelV2(MockModel):
restored_state_dict = restored_model.state_dict()
for orig, restored in zip(original_state_dict.keys(), restored_state_dict.keys()):
assert (original_state_dict[orig] - restored_state_dict[restored]).abs().mean() < 1e-6

@pytest.mark.unit
def test_hf_model_filter(self):
filt = ModelPT.get_hf_model_filter()
assert isinstance(filt, ModelFilter)
assert filt.library == 'nemo'

@pytest.mark.with_downloads()
@pytest.mark.unit
def test_hf_model_info(self):
filt = ModelPT.get_hf_model_filter()

# check no override results
model_infos = ModelPT.search_huggingface_models(model_filter=None)
assert len(model_infos) > 0

# check with default override results (should match above)
default_model_infos = ModelPT.search_huggingface_models(model_filter=filt)
assert len(model_infos) == len(default_model_infos)

@pytest.mark.with_downloads()
@pytest.mark.unit
def test_hf_model_info_with_card_data(self):
filt = ModelPT.get_hf_model_filter()

# check no override results
model_infos = ModelPT.search_huggingface_models(model_filter=filt)
assert len(model_infos) > 0
assert not hasattr(model_infos[0], 'cardData')

# check overriden defaults
filt.resolve_card_info = True
model_infos = ModelPT.search_huggingface_models(model_filter=filt)
assert len(model_infos) > 0
assert hasattr(model_infos[0], 'cardData') and model_infos[0].cardData is not None

@pytest.mark.with_downloads()
@pytest.mark.unit
def test_hf_model_info_with_limited_results(self):
filt = ModelPT.get_hf_model_filter()

# check no override results
model_infos = ModelPT.search_huggingface_models(model_filter=filt)
assert len(model_infos) > 0

# check overriden defaults
filt.limit_results = 5
new_model_infos = ModelPT.search_huggingface_models(model_filter=filt)
assert len(new_model_infos) <= 5
assert len(new_model_infos) < len(model_infos)

0 comments on commit 90ad5af

Please sign in to comment.