Skip to content

Commit

Permalink
Deprecate AutoMLTrainModelOperator for NL (#34212)
Browse files Browse the repository at this point in the history
  • Loading branch information
VladaZakharova committed Sep 11, 2023
1 parent cad983d commit 25d463c
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 148 deletions.
21 changes: 20 additions & 1 deletion airflow/providers/google/cloud/operators/automl.py
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

import ast
import warnings
from typing import TYPE_CHECKING, Sequence, Tuple

from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
Expand All @@ -31,6 +32,7 @@
TableSpec,
)

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.links.automl import (
AutoMLDatasetLink,
Expand All @@ -53,6 +55,10 @@ class AutoMLTrainModelOperator(GoogleCloudBaseOperator):
"""
Creates Google Cloud AutoML model.
AutoMLTrainModelOperator for text prediction is deprecated. Please use
:class:`airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTextTrainingJobOperator`
instead.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:AutoMLTrainModelOperator`
Expand Down Expand Up @@ -102,7 +108,6 @@ def __init__(
**kwargs,
) -> None:
super().__init__(**kwargs)

self.model = model
self.location = location
self.project_id = project_id
Expand All @@ -113,6 +118,20 @@ def __init__(
self.impersonation_chain = impersonation_chain

def execute(self, context: Context):
# Output warning if running AutoML Natural Language prediction job
automl_nl_model_keys = [
"text_classification_model_metadata",
"text_extraction_model_metadata",
"text_sentiment_dataset_metadata",
]
if any(key in automl_nl_model_keys for key in self.model):
warnings.warn(
"AutoMLTrainModelOperator for text prediction is deprecated. All the functionality of legacy "
"AutoML Natural Language and new features are available on the Vertex AI platform. "
"Please use `CreateAutoMLTextTrainingJobOperator`",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
hook = CloudAutoMLHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down
14 changes: 12 additions & 2 deletions docs/apache-airflow-providers-google/operators/cloud/automl.rst
Expand Up @@ -102,6 +102,16 @@ To create a Google AutoML model you can use
The operator will wait for the operation to complete. Additionally the operator
returns the id of model in :ref:`XCom <concepts:xcom>` under ``model_id`` key.

This Operator is deprecated when running for text prediction and will be removed soon.
All the functionality of legacy AutoML Natural Language and new features are available on the
Vertex AI platform. Please use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTextTrainingJobOperator`.
When running Vertex AI Operator for training dat, please ensure that your data is correctly stored in Vertex AI
datasets. To create and import data to the dataset please use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator`
and
:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator`

.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py
:language: python
:dedent: 4
Expand Down Expand Up @@ -164,7 +174,7 @@ the model must be deployed.
Listing And Deleting Datasets
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

You can get a list of AutoML models using
You can get a list of AutoML datasets using
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLListDatasetOperator`. The operator returns list
of datasets ids in :ref:`XCom <concepts:xcom>` under ``dataset_id_list`` key.

Expand All @@ -174,7 +184,7 @@ of datasets ids in :ref:`XCom <concepts:xcom>` under ``dataset_id_list`` key.
:start-after: [START howto_operator_list_dataset]
:end-before: [END howto_operator_list_dataset]

To delete a model you can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator`.
To delete a dataset you can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator`.
The delete operator allows also to pass list or coma separated string of datasets ids to be deleted.

.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
Expand Down
Expand Up @@ -24,47 +24,54 @@
from datetime import datetime
from typing import cast

from google.cloud.aiplatform import schema
from google.protobuf.struct_pb2 import Value

from airflow import models
from airflow.models.xcom_arg import XComArg
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.operators.automl import (
AutoMLCreateDatasetOperator,
AutoMLDeleteDatasetOperator,
AutoMLDeleteModelOperator,
AutoMLDeployModelOperator,
AutoMLImportDataOperator,
AutoMLTrainModelOperator,
)
from airflow.providers.google.cloud.operators.gcs import (
GCSCreateBucketOperator,
GCSDeleteBucketOperator,
GCSSynchronizeBucketsOperator,
)
from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
CreateAutoMLTextTrainingJobOperator,
DeleteAutoMLTrainingJobOperator,
)
from airflow.providers.google.cloud.operators.vertex_ai.dataset import (
CreateDatasetOperator,
DeleteDatasetOperator,
ImportDataOperator,
)
from airflow.utils.trigger_rule import TriggerRule

ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
DAG_ID = "example_automl_text_cls"
GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
DAG_ID = "example_automl_text_cls"

GCP_AUTOML_LOCATION = "us-central1"
DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"

MODEL_NAME = "text_clss_test_model"
MODEL = {
"display_name": MODEL_NAME,
"text_classification_model_metadata": {},
}
TEXT_CLSS_DISPLAY_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")
AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/classification.csv"

MODEL_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")

DATASET_NAME = f"ds_clss_{ENV_ID}".replace("-", "_")
DATASET = {
"display_name": DATASET_NAME,
"text_classification_dataset_metadata": {"classification_type": "MULTICLASS"},
"metadata_schema_uri": schema.dataset.metadata.text,
"metadata": Value(string_value="clss-dataset"),
}

AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/text_classification.csv"
IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}}

DATA_CONFIG = [
{
"import_schema_uri": schema.dataset.ioformat.text.single_label_classification,
"gcs_source": {"uris": [AUTOML_DATASET_BUCKET]},
},
]
extract_object_id = CloudAutoMLHook.extract_object_id

# Example DAG for AutoML Natural Language Text Classification
Expand All @@ -85,67 +92,77 @@
move_dataset_file = GCSSynchronizeBucketsOperator(
task_id="move_dataset_to_bucket",
source_bucket=RESOURCE_DATA_BUCKET,
source_object="automl/datasets/text",
source_object="vertex-ai/automl/datasets/text",
destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
destination_object="automl",
recursive=True,
)

create_dataset = AutoMLCreateDatasetOperator(
task_id="create_dataset",
create_clss_dataset = CreateDatasetOperator(
task_id="create_clss_dataset",
dataset=DATASET,
location=GCP_AUTOML_LOCATION,
region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
)
clss_dataset_id = create_clss_dataset.output["dataset_id"]

dataset_id = cast(str, XComArg(create_dataset, key="dataset_id"))
MODEL["dataset_id"] = dataset_id
import_dataset = AutoMLImportDataOperator(
task_id="import_dataset",
dataset_id=dataset_id,
location=GCP_AUTOML_LOCATION,
input_config=IMPORT_INPUT_CONFIG,
import_clss_dataset = ImportDataOperator(
task_id="import_clss_data",
dataset_id=clss_dataset_id,
region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
import_configs=DATA_CONFIG,
)
MODEL["dataset_id"] = dataset_id

create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION)
model_id = cast(str, XComArg(create_model, key="model_id"))

deploy_model = AutoMLDeployModelOperator(
task_id="deploy_model",
model_id=model_id,
location=GCP_AUTOML_LOCATION,
# [START howto_operator_automl_create_model]
create_clss_training_job = CreateAutoMLTextTrainingJobOperator(
task_id="create_clss_training_job",
display_name=TEXT_CLSS_DISPLAY_NAME,
prediction_type="classification",
multi_label=False,
dataset_id=clss_dataset_id,
model_display_name=MODEL_NAME,
training_fraction_split=0.7,
validation_fraction_split=0.2,
test_fraction_split=0.1,
sync=True,
region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
)
# [END howto_operator_automl_create_model]
model_id = cast(str, XComArg(create_clss_training_job, key="model_id"))

delete_model = AutoMLDeleteModelOperator(
task_id="delete_model",
model_id=model_id,
location=GCP_AUTOML_LOCATION,
delete_clss_training_job = DeleteAutoMLTrainingJobOperator(
task_id="delete_clss_training_job",
training_pipeline_id=create_clss_training_job.output["training_id"],
region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
trigger_rule=TriggerRule.ALL_DONE,
)

delete_dataset = AutoMLDeleteDatasetOperator(
task_id="delete_dataset",
dataset_id=dataset_id,
location=GCP_AUTOML_LOCATION,
delete_clss_dataset = DeleteDatasetOperator(
task_id="delete_clss_dataset",
dataset_id=clss_dataset_id,
region=GCP_AUTOML_LOCATION,
project_id=GCP_PROJECT_ID,
trigger_rule=TriggerRule.ALL_DONE,
)

delete_bucket = GCSDeleteBucketOperator(
task_id="delete_bucket", bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE
task_id="delete_bucket",
bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
trigger_rule=TriggerRule.ALL_DONE,
)

(
# TEST SETUP
[create_bucket >> move_dataset_file, create_dataset]
[create_bucket >> move_dataset_file, create_clss_dataset]
# TEST BODY
>> import_dataset
>> create_model
>> deploy_model
>> import_clss_dataset
>> create_clss_training_job
# TEST TEARDOWN
>> delete_model
>> delete_dataset
>> delete_clss_training_job
>> delete_clss_dataset
>> delete_bucket
)

Expand Down

0 comments on commit 25d463c

Please sign in to comment.