diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py b/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py index 5c6d56529d9fb..b1ad7d1a07358 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py @@ -551,6 +551,8 @@ def create_auto_ml_forecasting_training_job( is_default_version: bool | None = None, model_version_aliases: list[str] | None = None, model_version_description: str | None = None, + window_stride_length: int | None = None, + window_max_count: int | None = None, ) -> tuple[models.Model | None, str]: """ Create an AutoML Forecasting Training Job. @@ -703,6 +705,10 @@ def create_auto_ml_forecasting_training_job( :param sync: Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. + :param window_stride_length: Optional. Step length used to generate input examples. Every + ``window_stride_length`` rows will be used to generate a sliding window. + :param window_max_count: Optional. Number of rows that should be used to generate input examples. If the + total row count is larger than this number, the input data will be randomly sampled to hit the count. """ if column_transformations: warnings.warn( @@ -758,6 +764,8 @@ def create_auto_ml_forecasting_training_job( is_default_version=is_default_version, model_version_aliases=model_version_aliases, model_version_description=model_version_description, + window_stride_length=window_stride_length, + window_max_count=window_max_count, ) training_id = self.extract_training_id(self._job.resource_name) if model: diff --git a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py index 14752320121f2..7e3d8bb0834d0 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py @@ -138,6 +138,8 @@ def __init__( region: str, impersonation_chain: str | Sequence[str] | None = None, parent_model: str | None = None, + window_stride_length: int | None = None, + window_max_count: int | None = None, **kwargs, ) -> None: super().__init__( @@ -170,6 +172,8 @@ def __init__( self.quantiles = quantiles self.validation_options = validation_options self.budget_milli_node_hours = budget_milli_node_hours + self.window_stride_length = window_stride_length + self.window_max_count = window_max_count def execute(self, context: Context): self.hook = AutoMLHook( @@ -220,6 +224,8 @@ def execute(self, context: Context): model_display_name=self.model_display_name, model_labels=self.model_labels, sync=self.sync, + window_stride_length=self.window_stride_length, + window_max_count=self.window_max_count, ) if model: diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index 3f8649f588953..4b8264d6157a2 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -1340,6 +1340,8 @@ def test_execute(self, mock_hook, mock_dataset): is_default_version=None, model_version_aliases=None, model_version_description=None, + window_stride_length=None, + window_max_count=None, ) @mock.patch("google.cloud.aiplatform.datasets.TimeSeriesDataset") @@ -1405,6 +1407,8 @@ def test_execute__parent_model_version_index_is_removed(self, mock_hook, mock_da is_default_version=None, model_version_aliases=None, model_version_description=None, + window_stride_length=None, + window_max_count=None, )