Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] pyfunc.load_model ignores logged model with trust_remote_code set to True #12033

Closed
4 of 23 tasks
mfeller2000 opened this issue May 17, 2024 · 1 comment · Fixed by #12067
Closed
4 of 23 tasks
Labels
area/model-registry Model registry, model registry APIs, and the fluent client calls for model registry area/models MLmodel format, model serialization/deserialization, flavors bug Something isn't working has-closing-pr This issue has a closing PR integrations/databricks Databricks integrations

Comments

@mfeller2000
Copy link

mfeller2000 commented May 17, 2024

Issues Policy acknowledgement

  • I have read and agree to submit bug reports in accordance with the issues policy

Where did you encounter this bug?

Databricks

Willingness to contribute

Yes. I would be willing to contribute a fix for this bug with guidance from the MLflow community.

MLflow version

  • mlflow, version 2.12.2

System information

  • Databricks Runetime Version 14.3 LTS (includes Apache Spark 3.5.0, Scala 2.12)
  • Python 3.10.12

Describe the problem

The load_model() method via pyfunc ignores remote code contained in the model. This will cause the model to load with empty/default weights and be unusable. However, loading the method via sentence_transformers is fine.

Tracking information

No response

Code to reproduce issue

import mlflow
from sentence_transformers import SentenceTransformer
from mlflow.models import infer_signature

example_sentences = ["This is a sentence.", "This is another sentence."]
model = SentenceTransformer('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True) # trust_remote_code is needed to use the encode method
embeddings = model.encode(example_sentences)

# Define the signature
signature = mlflow.models.infer_signature(
    model_input=example_sentences,
    model_output=model.encode(example_sentences),
)

# Logging the model
with mlflow.start_run():
    logged_model = mlflow.sentence_transformers.log_model(
        model=model,
        artifact_path="jina-embeddings-v2-base-en", 
        signature=signature, 
        input_example=example_sentences,
    )

bad_loader = mlflow.pyfunc.load_model(logged_model.model_uri)
bad_results = bad_loader.predict(example_sentences)
print(bad_results)

Stack trace

Some weights of BertModel were not initialized from the model checkpoint at /local_disk0/repl_tmp_data/ReplId-57630-ea81d-38e78-7/tmpumav12m6/jina-embeddings-v2-base-en/model.sentence_transformer and are newly initialized: ['embeddings.position_embeddings.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.output.LayerNorm.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[[-1.3923457  -1.6869806   0.37395296 ... -0.9141787  -1.1804446
   0.4207688 ]
 [-1.0506808  -0.5709704   0.00941753 ... -0.8040266   0.4940365
   1.2246021 ]]

Other info / logs

For completeness, this is done via sentence_transformers, which works as expected.

good_loader = mlflow.sentence_transformers.load_model(logged_model.model_uri)
good_results = good_loader.encode(example_sentences)
print(good_results)

Output:

/local_disk0/.ephemeral_nfs/envs/pythonEnv-d2592314-586c-4092-9f07-e9f1d54bb4a4/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
[[ 0.4217119  -0.7744012   0.83123386 ...  0.24540496 -0.20734143
  -0.87755406]
 [ 0.21493115 -0.5608038   0.95852095 ...  0.08686836 -0.28218246
  -0.8243381 ]]

What component(s) does this bug affect?

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/deployments: MLflow Deployments client APIs, server, and third-party Deployments integrations
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/recipes: Recipes, Recipe APIs, Recipe configs, Recipe Templates
  • area/projects: MLproject format, project running backends
  • area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • area/server-infra: MLflow Tracking server backend
  • area/tracking: Tracking Service, tracking client APIs, autologging

What interface(s) does this bug affect?

  • area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

What language(s) does this bug affect?

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

What integration(s) does this bug affect?

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations
@mfeller2000 mfeller2000 added the bug Something isn't working label May 17, 2024
@github-actions github-actions bot added area/model-registry Model registry, model registry APIs, and the fluent client calls for model registry area/models MLmodel format, model serialization/deserialization, flavors integrations/databricks Databricks integrations labels May 17, 2024
@github-actions github-actions bot added the has-closing-pr This issue has a closing PR label May 20, 2024
@serena-ruan
Copy link
Collaborator

@mfeller2000 Thanks for reporting this! Verified your sample works on PR #12067

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/model-registry Model registry, model registry APIs, and the fluent client calls for model registry area/models MLmodel format, model serialization/deserialization, flavors bug Something isn't working has-closing-pr This issue has a closing PR integrations/databricks Databricks integrations
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants