Skip to content

Commit

Permalink
Merge pull request #5389 from HatuneMiku/master
Browse files Browse the repository at this point in the history
Add cache_dir config for save HFTransformersNLP pre-trained model data.
  • Loading branch information
wochinge committed Mar 18, 2020
2 parents da25fef + 8d90465 commit ad98964
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
1 change: 1 addition & 0 deletions changelog/5389.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add an optional path to a specific directory to download and cache the pre-trained model weights for :ref:`HFTransformersNLP`.
4 changes: 4 additions & 0 deletions docs/nlu/components.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ HFTransformersNLP
# can be found at https://huggingface.co/transformers/pretrained_models.html . If left empty, it uses the
# default model architecture that original transformers library loads
model_weights: "bert-base-uncased"
# An optional path to a specific directory to download and cache the pre-trained model weights.
# The `default` cache_dir is the same as https://huggingface.co/transformers/serialization.html#cache-directory .
cache_dir: null
# +----------------+--------------+-------------------------+
# | Language Model | Parameter | Default value for |
Expand Down
8 changes: 6 additions & 2 deletions rasa/nlu/utils/hugging_face/hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class HFTransformersNLP(Component):
"model_name": "bert",
# Pre-Trained weights to be loaded(string)
"model_weights": None,
# an optional path to a specific directory to download
# and cache the pre-trained model weights.
"cache_dir": None,
}

def __init__(self, component_config: Optional[Dict[Text, Any]] = None) -> None:
Expand Down Expand Up @@ -63,6 +66,7 @@ def _load_model(self) -> None:
)

self.model_weights = self.component_config["model_weights"]
self.cache_dir = self.component_config["cache_dir"]

if not self.model_weights:
logger.info(
Expand All @@ -74,10 +78,10 @@ def _load_model(self) -> None:
logger.debug(f"Loading Tokenizer and Model for {self.model_name}")

self.tokenizer = model_tokenizer_dict[self.model_name].from_pretrained(
self.model_weights
self.model_weights, cache_dir=self.cache_dir
)
self.model = model_class_dict[self.model_name].from_pretrained(
self.model_weights
self.model_weights, cache_dir=self.cache_dir
)

# Use a universal pad token since all transformer architectures do not have a
Expand Down

0 comments on commit ad98964

Please sign in to comment.