# Import MLflow Models into Verta

Whether moving models selectively or in bulk, it is possible to import your MLflow models directly into Verta. The following notebook demonstrates how.

# Environment Setup

In [None]:
!pip install verta mlflow

In [None]:
import os
from typing import Dict, List

import verta
from verta.registry.stage_change import Archived, Staging, Production

import mlflow
from mlflow import MlflowClient
from mlflow.entities.model_registry import RegisteredModel as MLflow_RegisteredModel
from mlflow.store.entities.paged_list import PagedList

## Verta Client
 These values will be pulled from your local environment variables, or you can set them manually by uncommenting the lines below and replacing the values with your own.

In [None]:
# os.environ['VERTA_EMAIL'] = ""
# os.environ['VERTA_DEV_KEY'] = ""
# os.environ['VERTA_HOST'] = ""

verta_client = verta.Client()

## MLflow Client
These values will be pulled from your default MLflow settings, or you can set them manually by uncommenting the lines below and replacing the values with your own.  This example assumes a Databricks MLflow tracking server, and some configurations may require different environment variables.

In [None]:
# os.environ["MLFLOW_TRACKING_URI"] = ""
# os.environ["DATABRICKS_HOST"] = ""
# os.environ["DATABRICKS_TOKEN"] = ""

mlflow_client = MlflowClient()

## Fetch MLflow Models
Get a list of models registered in MLFLow and print them out for inspection.  This entire result set can be passed along to be imported to Verta, or pared down to any desired subset.  This can be achieved by passing search criteria to the `search_registered_models` method (see [MLflow API Docs](https://mlflow.org/docs/latest/python_api/mlflow.client.html?highlight=search_registered_models#mlflow.client.MlflowClient.search_registered_models)), or by filtering the results after they are returned.

In [None]:
mlflow_models: List[MLflow_RegisteredModel] = list()

result: PagedList = mlflow_client.search_registered_models()
mlflow_models += result.to_list()
while result.token:  # handle paginated results
    result: PagedList = mlflow_client.search_registered_models(page_token=result.token)
    mlflow_models += result.to_list()

mlflow_models.sort(key=lambda x: x.name)
for mv in mlflow_models:
    print(f"name={mv.name}; version={mv.version}")

Reduce the list if desired

In [None]:
# Example: filter out models with "test" in the name
models_for_import = [model for model in mlflow_models if "test" not in model.name]

# Import Models to Verta

In [None]:
failed_model_imports: List[MLflow_RegisteredModel] = list()
successful_model_imports: List[MLflow_RegisteredModel] = list()
model_messages: Dict[MLflow_RegisteredModel, List[str]] = dict()

for rm in models_for_import:
    model_messages[rm] = list()

    try:
        verta_rm = verta_client.create_registered_model(
            name=rm.name,
            desc=rm.description,
            labels=[
                "mlflow_import",
                "mlflow_creation_time:" + str(rm.creation_timestamp),
                "mlflow_last_updated_time:" + str(rm.last_updated_timestamp),
                "mlflow_tags:" + ",".join(rm.tags),
            ],
        )
        model_messages[rm].append(
            f"created new registered model in Verta for {rm.name}"
        )
    except ValueError:
        model_messages[rm].append(
            f'a registered model named "{rm.name}" already exists in Verta. Skipping import.'
        )
        failed_model_imports.append(rm)
        continue

    try:
        rm_versions = mlflow_client.search_model_versions(f"name='{rm.name}'")
        if not rm_versions:
            failed_model_imports.append(rm)
            model_messages[rm].append(
                f"unable to find any model versions for {rm.name}.  Skipping import."
            )
            continue
    except Exception as err:
        model_messages[rm].append(
            f'failed to fetch versions for registered model "{rm.name}". Skipping import. Error: {err}'
        )
        failed_model_imports.append(rm)
        continue

    successful_versions = list()
    for version in rm_versions:
        try:
            verta_version = verta_rm.create_version(
                name=str(version.version),
                attrs={
                    "er_id": version.run_id,
                    "mlflow_source": version.source,
                    "mlflow_user_id": version.user_id,
                    "mlflow_run_link": version.run_link,
                    "mlflow_creation_time": version.creation_timestamp,
                    "mlflow_last_updated_time": version.last_updated_timestamp,
                    "mlflow_status": version.status,
                    "mlflow_current_stage": version.current_stage,
                    "mlflow_tags": version.tags,
                },
                labels=["mlflow_import"],
            )
            model_messages[rm].append(f"successfully created version {version.version}")
        except Exception as err:
            model_messages[rm].append(
                f"failed to create model version in Verta for {rm.name} - version: {version.version} due to {err}"
            )
            continue

        # Import artifacts for the model version
        try:
            outpath = mlflow.artifacts.download_artifacts(run_id=version.run_id)
        except Exception as err:
            model_messages[rm].append(
                f"unable to download artifacts from {rm.name} - version run id; {version.run_id} due to {err}"
            )
        else:
            for file_name in os.listdir(outpath):
                try:
                    verta_version.log_artifact(
                        file_name, os.path.join(outpath, file_name)
                    )
                    model_messages[rm].append(f"artifact logged in Verta: {file_name}")
                except ValueError as err:
                    model_messages[rm].append(
                        f"cannot upload artifact {file_name} for {rm.name} due to {err}"
                    )

        # Set model version's current stage
        stage_error_message = f"unable to set stage in Verta for {rm.name} - version: {version}, current_stage: {version.current_stage}"
        try:
            if version.current_stage != "None":
                if version.current_stage == "Staging":
                    verta_version.change_stage(Staging())
                if version.current_stage == "Production":
                    verta_version.change_stage(Production())
                if version.current_stage == "Archived":
                    verta_version.change_stage(Archived())
                else:
                    model_messages[rm].append(stage_error_message)
            else:
                model_messages[rm].append(stage_error_message)
        except Exception as err:
            model_messages[rm].append(f"{stage_error_message}, due to: {str(err)}")

        successful_model_imports.append(rm)

## Print Results

In [None]:
print("IMPORTED SUCCESSFULLY:\n")
for m in successful_model_imports:
    print(f"\n{m.name}\n---------------------")
    for message in model_messages[m]:
        print(f"  - {message}")

print("\nFAILED TO IMPORT:\n")
for m in failed_model_imports:
    print(f"\n{m.name}\n---------------------")
    for message in model_messages[m]:
        print(f"  - {message}")