Skip to content

Commit

Permalink
Provide missing project id and creds for TabularDataset (#31991)
Browse files Browse the repository at this point in the history
  • Loading branch information
hussein-awala committed Jun 19, 2023
1 parent 10aa704 commit f2ebc29
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
Expand Up @@ -352,11 +352,16 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
credentials, _ = self.hook.get_credentials_and_project_id()
model, training_id = self.hook.create_auto_ml_tabular_training_job(
project_id=self.project_id,
region=self.region,
display_name=self.display_name,
dataset=datasets.TabularDataset(dataset_name=self.dataset_id),
dataset=datasets.TabularDataset(
dataset_name=self.dataset_id,
project=self.project_id,
credentials=credentials,
),
target_column=self.target_column,
optimization_prediction_type=self.optimization_prediction_type,
optimization_objective=self.optimization_objective,
Expand Down
12 changes: 10 additions & 2 deletions tests/providers/google/cloud/operators/test_vertex_ai.py
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

from unittest import mock
from unittest.mock import MagicMock

from google.api_core.gapic_v1.method import DEFAULT
from google.api_core.retry import Retry
Expand Down Expand Up @@ -783,7 +784,12 @@ class TestVertexAICreateAutoMLTabularTrainingJobOperator:
@mock.patch("google.cloud.aiplatform.datasets.TabularDataset")
@mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
def test_execute(self, mock_hook, mock_dataset):
mock_hook.return_value.create_auto_ml_tabular_training_job.return_value = (None, "training_id")
mock_hook.return_value = MagicMock(
**{
"create_auto_ml_tabular_training_job.return_value": (None, "training_id"),
"get_credentials_and_project_id.return_value": ("creds", "project_id"),
}
)
op = CreateAutoMLTabularTrainingJobOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -798,7 +804,9 @@ def test_execute(self, mock_hook, mock_dataset):
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_dataset.assert_called_once_with(dataset_name=TEST_DATASET_ID)
mock_dataset.assert_called_once_with(
dataset_name=TEST_DATASET_ID, project=GCP_PROJECT, credentials="creds"
)
mock_hook.return_value.create_auto_ml_tabular_training_job.assert_called_once_with(
project_id=GCP_PROJECT,
region=GCP_LOCATION,
Expand Down

0 comments on commit f2ebc29

Please sign in to comment.