diff --git a/model2vec/distill/distillation.py b/model2vec/distill/distillation.py index 8df47b88..9703bba5 100644 --- a/model2vec/distill/distillation.py +++ b/model2vec/distill/distillation.py @@ -204,6 +204,7 @@ def distill( apply_zipf: bool = True, use_subword: bool = True, token_remove_pattern: str | None = r"\[unused\d+\]", + trust_remote_code: bool = False, ) -> StaticModel: """ Distill a staticmodel from a sentence transformer. @@ -223,11 +224,12 @@ def distill( :param apply_zipf: Whether to apply Zipf weighting to the embeddings. :param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words. :param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary. + :param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components. :return: A StaticModel """ - model: PreTrainedModel = AutoModel.from_pretrained(model_name) - tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(model_name) + model: PreTrainedModel = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code) + tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code) return distill_from_model( model=model,