Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support listing Hugging Face model info #4619

Merged
merged 6 commits into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 124 additions & 3 deletions nemo/core/classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from enum import Enum
from functools import total_ordering
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, Iterable, List, Optional, Union

import hydra
import wrapt
from huggingface_hub import hf_hub_download
from huggingface_hub.hf_api import HfFolder
from huggingface_hub import HfApi, HfFolder, ModelFilter, hf_hub_download
from huggingface_hub.hf_api import ModelInfo
from omegaconf import DictConfig, OmegaConf

import nemo
Expand Down Expand Up @@ -667,6 +667,105 @@ def list_available_models(cls) -> Optional[PretrainedModelInfo]:
"""
pass

@classmethod
def search_huggingface_models(
cls, model_filter: Optional[Union[ModelFilter, List[ModelFilter]]] = None
) -> List[ModelInfo]:
"""
Should list all pre-trained models available via Hugging Face Hub.
The following metadata can be passed via the `model_filter` for additional results.
Metadata:
resolve_card_info: Bool flag, if set, returns the model card metadata. Default: False.
limit_results: Optional int, limits the number of results returned.
.. code-block:: python
# You can replace <DomainSubclass> with any subclass of ModelPT.
from nemo.core import ModelPT
# Get default ModelFilter
filt = <DomainSubclass>.get_hf_model_filter()
# Make any modifications to the filter as necessary
filt.language = [...]
filt.task = ...
filt.tags = [...]
# Add any metadata to the filter as needed
filt.limit_results = 5
# Obtain model info
model_infos = <DomainSubclass>.search_huggingface_models(model_filter=filt)
# Browse through cards and select an appropriate one
card = model_infos[0]
# Restore model using `modelId` of the card.
model = ModelPT.from_pretrained(card.modelId)
Args:
model_filter: Optional ModelFilter or List[ModelFilter] (from Hugging Face Hub)
that filters the returned list of compatible model cards, and selects all results from each filter.
Users can then use `model_card.modelId` in `from_pretrained()` to restore a NeMo Model.
If no ModelFilter is provided, uses the classes default filter as defined by `get_hf_model_filter()`.
Returns:
A list of ModelInfo entries.
"""
# Resolve model filter if not provided as argument
if model_filter is None:
model_filter = cls.get_hf_model_filter()

# If single model filter, wrap into list
if not isinstance(model_filter, Iterable):
model_filter = [model_filter]

# Inject `nemo` library filter
for mfilter in model_filter:
if isinstance(mfilter.library, str) and mfilter.library != 'nemo':
logging.warning(f"Model filter's `library` tag updated be `nemo`. Original value: {mfilter.library}")
mfilter.library = "nemo"

elif isinstance(mfilter, Iterable) and 'nemo' not in mfilter.library:
logging.warning(
f"Model filter's `library` list updated to include `nemo`. Original value: {mfilter.library}"
)
mfilter.library = list(mfilter)
mfilter.library.append('nemo')

# Check if api token exists, use if it does
is_token_available = HfFolder.get_token() is not None

# Search for all valid models after filtering
api = HfApi()

# Setup extra arguments for model filtering
all_results = [] # type: List[ModelInfo]

for mfilter in model_filter:
cardData = None
limit = None

if hasattr(mfilter, 'resolve_card_info') and mfilter.resolve_card_info is True:
cardData = True

if hasattr(mfilter, 'limit_results') and mfilter.limit_results is not None:
limit = mfilter.limit_results

results = api.list_models(
filter=mfilter,
use_auth_token=is_token_available,
sort="lastModified",
direction=-1,
cardData=cardData,
limit=limit,
) # type: List[ModelInfo]

all_results.extend(results)

return all_results

@classmethod
def get_available_model_names(cls) -> List[str]:
"""
Expand All @@ -680,6 +779,28 @@ def get_available_model_names(cls) -> List[str]:
model_names = [model.pretrained_model_name for model in cls.list_available_models()]
return model_names

@classmethod
def get_hf_model_filter(cls) -> ModelFilter:
"""
Generates a filter for HuggingFace models.
Additionally includes default values of some metadata about results returned by the Hub.
Metadata:
resolve_card_info: Bool flag, if set, returns the model card metadata. Default: False.
limit_results: Optional int, limits the number of results returned.
Returns:
A Hugging Face Hub ModelFilter object.
"""
model_filter = ModelFilter(library='nemo')

# Attach some additional info
model_filter.resolve_card_info = False
model_filter.limit_results = None

return model_filter

@classmethod
def from_pretrained(
cls,
Expand Down
51 changes: 51 additions & 0 deletions tests/core/test_save_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import pytest
import torch
from huggingface_hub.hf_api import ModelFilter, ModelInfo
from omegaconf import DictConfig, OmegaConf, open_dict

from nemo.collections.asr.models import EncDecCTCModel, EncDecCTCModelBPE
Expand Down Expand Up @@ -605,3 +606,53 @@ class MockModelV2(MockModel):
restored_state_dict = restored_model.state_dict()
for orig, restored in zip(original_state_dict.keys(), restored_state_dict.keys()):
assert (original_state_dict[orig] - restored_state_dict[restored]).abs().mean() < 1e-6

@pytest.mark.unit
def test_hf_model_filter(self):
filt = ModelPT.get_hf_model_filter()
assert isinstance(filt, ModelFilter)
assert filt.library == 'nemo'

@pytest.mark.with_downloads()
@pytest.mark.unit
def test_hf_model_info(self):
filt = ModelPT.get_hf_model_filter()

# check no override results
model_infos = ModelPT.search_huggingface_models(model_filter=None)
assert len(model_infos) > 0

# check with default override results (should match above)
default_model_infos = ModelPT.search_huggingface_models(model_filter=filt)
assert len(model_infos) == len(default_model_infos)

@pytest.mark.with_downloads()
@pytest.mark.unit
def test_hf_model_info_with_card_data(self):
filt = ModelPT.get_hf_model_filter()

# check no override results
model_infos = ModelPT.search_huggingface_models(model_filter=filt)
assert len(model_infos) > 0
assert not hasattr(model_infos[0], 'cardData')

# check overriden defaults
filt.resolve_card_info = True
model_infos = ModelPT.search_huggingface_models(model_filter=filt)
assert len(model_infos) > 0
assert hasattr(model_infos[0], 'cardData') and model_infos[0].cardData is not None

@pytest.mark.with_downloads()
@pytest.mark.unit
def test_hf_model_info_with_limited_results(self):
filt = ModelPT.get_hf_model_filter()

# check no override results
model_infos = ModelPT.search_huggingface_models(model_filter=filt)
assert len(model_infos) > 0

# check overriden defaults
filt.limit_results = 5
new_model_infos = ModelPT.search_huggingface_models(model_filter=filt)
assert len(new_model_infos) <= 5
assert len(new_model_infos) < len(model_infos)