Skip to content

Commit

Permalink
Convert RDS Export Sample DAG to System Test (AIP-47) (#25205)
Browse files Browse the repository at this point in the history
* Convert RDS Export Sample DAG to System Test

* PR Fixes
  • Loading branch information
ferruzzi committed Jul 21, 2022
1 parent 210ad64 commit f6bda38
Show file tree
Hide file tree
Showing 5 changed files with 456 additions and 114 deletions.
71 changes: 0 additions & 71 deletions airflow/providers/amazon/aws/example_dags/example_rds_export.py

This file was deleted.

94 changes: 54 additions & 40 deletions airflow/providers/amazon/aws/operators/rds.py
Expand Up @@ -70,7 +70,7 @@ def _await_status(
error_statuses: Optional[List[str]] = None,
) -> None:
"""
Continuously gets item description from `_describe_item()` and waits until:
Continuously gets item description from `_describe_item()` and waits while:
- status is in `wait_statuses`
- status not in `ok_statuses` and `error_statuses`
"""
Expand Down Expand Up @@ -117,6 +117,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
:param db_snapshot_identifier: The identifier for the DB snapshot
:param tags: A list of tags in format `[{"Key": "something", "Value": "something"},]
`USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
:param wait_for_completion: If True, waits for creation of the DB snapshot to complete. (default: True)
"""

template_fields = ("db_snapshot_identifier", "db_identifier", "tags")
Expand All @@ -128,6 +129,7 @@ def __init__(
db_identifier: str,
db_snapshot_identifier: str,
tags: Optional[Sequence[TagTypeDef]] = None,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_conn_id",
**kwargs,
):
Expand All @@ -136,6 +138,7 @@ def __init__(
self.db_identifier = db_identifier
self.db_snapshot_identifier = db_snapshot_identifier
self.tags = tags or []
self.wait_for_completion = wait_for_completion

def execute(self, context: 'Context') -> str:
self.log.info(
Expand All @@ -152,26 +155,24 @@ def execute(self, context: 'Context') -> str:
Tags=self.tags,
)
create_response = json.dumps(create_instance_snap, default=str)
self._await_status(
'instance_snapshot',
self.db_snapshot_identifier,
wait_statuses=['creating'],
ok_statuses=['available'],
)
item_type = 'instance_snapshot'

else:
create_cluster_snap = self.hook.conn.create_db_cluster_snapshot(
DBClusterIdentifier=self.db_identifier,
DBClusterSnapshotIdentifier=self.db_snapshot_identifier,
Tags=self.tags,
)
create_response = json.dumps(create_cluster_snap, default=str)
item_type = 'cluster_snapshot'

if self.wait_for_completion:
self._await_status(
'cluster_snapshot',
item_type,
self.db_snapshot_identifier,
wait_statuses=['creating'],
ok_statuses=['available'],
)

return create_response


Expand All @@ -196,6 +197,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
:param target_custom_availability_zone: The external custom Availability Zone identifier for the target
Only when db_type='instance'
:param source_region: The ID of the region that contains the snapshot to be copied
:param wait_for_completion: If True, waits for snapshot copy to complete. (default: True)
"""

template_fields = (
Expand All @@ -219,6 +221,7 @@ def __init__(
option_group_name: str = "",
target_custom_availability_zone: str = "",
source_region: str = "",
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
**kwargs,
):
Expand All @@ -234,6 +237,7 @@ def __init__(
self.option_group_name = option_group_name
self.target_custom_availability_zone = target_custom_availability_zone
self.source_region = source_region
self.wait_for_completion = wait_for_completion

def execute(self, context: 'Context') -> str:
self.log.info(
Expand All @@ -255,12 +259,8 @@ def execute(self, context: 'Context') -> str:
SourceRegion=self.source_region,
)
copy_response = json.dumps(copy_instance_snap, default=str)
self._await_status(
'instance_snapshot',
self.target_db_snapshot_identifier,
wait_statuses=['creating'],
ok_statuses=['available'],
)
item_type = 'instance_snapshot'

else:
copy_cluster_snap = self.hook.conn.copy_db_cluster_snapshot(
SourceDBClusterSnapshotIdentifier=self.source_db_snapshot_identifier,
Expand All @@ -272,13 +272,15 @@ def execute(self, context: 'Context') -> str:
SourceRegion=self.source_region,
)
copy_response = json.dumps(copy_cluster_snap, default=str)
item_type = 'cluster_snapshot'

if self.wait_for_completion:
self._await_status(
'cluster_snapshot',
item_type,
self.target_db_snapshot_identifier,
wait_statuses=['copying'],
ok_statuses=['available'],
)

return copy_response


Expand Down Expand Up @@ -341,6 +343,7 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
:param kms_key_id: The ID of the Amazon Web Services KMS key to use to encrypt the snapshot.
:param s3_prefix: The Amazon S3 bucket prefix to use as the file name and path of the exported snapshot.
:param export_only: The data to be exported from the snapshot.
:param wait_for_completion: If True, waits for the DB snapshot export to complete. (default: True)
"""

template_fields = (
Expand All @@ -363,6 +366,7 @@ def __init__(
kms_key_id: str,
s3_prefix: str = '',
export_only: Optional[List[str]] = None,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
**kwargs,
):
Expand All @@ -375,6 +379,7 @@ def __init__(
self.kms_key_id = kms_key_id
self.s3_prefix = s3_prefix
self.export_only = export_only or []
self.wait_for_completion = wait_for_completion

def execute(self, context: 'Context') -> str:
self.log.info("Starting export task %s for snapshot %s", self.export_task_identifier, self.source_arn)
Expand All @@ -389,13 +394,14 @@ def execute(self, context: 'Context') -> str:
ExportOnly=self.export_only,
)

self._await_status(
'export_task',
self.export_task_identifier,
wait_statuses=['starting', 'in_progress'],
ok_statuses=['complete'],
error_statuses=['canceling', 'canceled'],
)
if self.wait_for_completion:
self._await_status(
'export_task',
self.export_task_identifier,
wait_statuses=['starting', 'in_progress'],
ok_statuses=['complete'],
error_statuses=['canceling', 'canceled'],
)

return json.dumps(start_export, default=str)

Expand All @@ -409,6 +415,7 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
:ref:`howto/operator:RdsCancelExportTaskOperator`
:param export_task_identifier: The identifier of the snapshot export task to cancel
:param wait_for_completion: If True, waits for DB snapshot export to cancel. (default: True)
"""

template_fields = ("export_task_identifier",)
Expand All @@ -417,25 +424,29 @@ def __init__(
self,
*,
export_task_identifier: str,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(aws_conn_id=aws_conn_id, **kwargs)

self.export_task_identifier = export_task_identifier
self.wait_for_completion = wait_for_completion

def execute(self, context: 'Context') -> str:
self.log.info("Canceling export task %s", self.export_task_identifier)

cancel_export = self.hook.conn.cancel_export_task(
ExportTaskIdentifier=self.export_task_identifier,
)
self._await_status(
'export_task',
self.export_task_identifier,
wait_statuses=['canceling'],
ok_statuses=['canceled'],
)

if self.wait_for_completion:
self._await_status(
'export_task',
self.export_task_identifier,
wait_statuses=['canceling'],
ok_statuses=['canceled'],
)

return json.dumps(cancel_export, default=str)

Expand All @@ -458,6 +469,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
:param enabled: A value that indicates whether to activate the subscription (default True)l
:param tags: A list of tags in format `[{"Key": "something", "Value": "something"},]
`USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
:param wait_for_completion: If True, waits for creation of the subscription to complete. (default: True)
"""

template_fields = (
Expand All @@ -479,6 +491,7 @@ def __init__(
source_ids: Optional[Sequence[str]] = None,
enabled: bool = True,
tags: Optional[Sequence[TagTypeDef]] = None,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
**kwargs,
):
Expand All @@ -491,6 +504,7 @@ def __init__(
self.source_ids = source_ids or []
self.enabled = enabled
self.tags = tags or []
self.wait_for_completion = wait_for_completion

def execute(self, context: 'Context') -> str:
self.log.info("Creating event subscription '%s' to '%s'", self.subscription_name, self.sns_topic_arn)
Expand All @@ -504,12 +518,14 @@ def execute(self, context: 'Context') -> str:
Enabled=self.enabled,
Tags=self.tags,
)
self._await_status(
'event_subscription',
self.subscription_name,
wait_statuses=['creating'],
ok_statuses=['active'],
)

if self.wait_for_completion:
self._await_status(
'event_subscription',
self.subscription_name,
wait_statuses=['creating'],
ok_statuses=['active'],
)

return json.dumps(create_subscription, default=str)

Expand Down Expand Up @@ -566,8 +582,7 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
:param rds_kwargs: Named arguments to pass to boto3 RDS client function ``create_db_instance``
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_instance
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param wait_for_completion: Whether or not wait for creation of the DB instance to
complete. (default: True)
:param wait_for_completion: If True, waits for creation of the DB instance to complete. (default: True)
"""

def __init__(
Expand Down Expand Up @@ -619,8 +634,7 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
:param rds_kwargs: Named arguments to pass to boto3 RDS client function ``delete_db_instance``
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.delete_db_instance
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param wait_for_completion: Whether or not wait for deletion of the DB instance to
complete. (default: True)
:param wait_for_completion: If True, waits for deletion of the DB instance to complete. (default: True)
"""

def __init__(
Expand Down
6 changes: 3 additions & 3 deletions docs/apache-airflow-providers-amazon/operators/rds.rst
Expand Up @@ -86,7 +86,7 @@ To export an Amazon RDS snapshot to Amazon S3 you can use
:class:`~airflow.providers.amazon.aws.operators.rds.RDSStartExportTaskOperator`.
The provided IAM role must have access to the S3 bucket.

.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_rds_export.py
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_rds_export.py
:language: python
:dedent: 4
:start-after: [START howto_operator_rds_start_export_task]
Expand All @@ -101,7 +101,7 @@ To cancel an Amazon RDS export task to S3 you can use
:class:`~airflow.providers.amazon.aws.operators.rds.RDSCancelExportTaskOperator`.
Any data that has already been written to the S3 bucket isn't removed.

.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_rds_export.py
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_rds_export.py
:language: python
:dedent: 4
:start-after: [START howto_operator_rds_cancel_export]
Expand Down Expand Up @@ -194,7 +194,7 @@ To wait a for an Amazon RDS snapshot export task with specific statuses you can
:class:`~airflow.providers.amazon.aws.sensors.rds.RdsExportTaskExistenceSensor`.
By default, the sensor waits for the existence of a snapshot with status ``available``.

.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_rds_export.py
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_rds_export.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_rds_export_task_existence]
Expand Down

0 comments on commit f6bda38

Please sign in to comment.