Skip to content

Commit

Permalink
Update the DMS Sample DAG and Docs (#23681)
Browse files Browse the repository at this point in the history
  • Loading branch information
ferruzzi committed May 19, 2022
1 parent a80b2fc commit fb3b980
Show file tree
Hide file tree
Showing 4 changed files with 445 additions and 125 deletions.
347 changes: 347 additions & 0 deletions airflow/providers/amazon/aws/example_dags/example_dms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Note: DMS requires you to configure specific IAM roles/permissions. For more information, see
https://docs.aws.amazon.com/dms/latest/userguide/CHAP_Security.html#CHAP_Security.APIRole
"""

import json
import os
from datetime import datetime

import boto3
from sqlalchemy import Column, MetaData, String, Table, create_engine

from airflow import DAG
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.operators.python import get_current_context
from airflow.providers.amazon.aws.operators.dms import (
DmsCreateTaskOperator,
DmsDeleteTaskOperator,
DmsDescribeTasksOperator,
DmsStartTaskOperator,
DmsStopTaskOperator,
)
from airflow.providers.amazon.aws.sensors.dms import DmsTaskBaseSensor, DmsTaskCompletedSensor

S3_BUCKET = os.getenv('S3_BUCKET', 's3_bucket_name')
ROLE_ARN = os.getenv('ROLE_ARN', 'arn:aws:iam::1234567890:role/s3_target_endpoint_role')

# The project name will be used as a prefix for various entity names.
# Use either PascalCase or camelCase. While some names require kebab-case
# and others require snake_case, they all accept mixedCase strings.
PROJECT_NAME = 'DmsDemo'

# Config values for setting up the "Source" database.
RDS_ENGINE = 'postgres'
RDS_PROTOCOL = 'postgresql'
RDS_USERNAME = 'username'
# NEVER store your production password in plaintext in a DAG like this.
# Use Airflow Secrets or a secret manager for this in production.
RDS_PASSWORD = 'rds_password'

# Config values for RDS.
RDS_INSTANCE_NAME = f'{PROJECT_NAME}-instance'
RDS_DB_NAME = f'{PROJECT_NAME}_source_database'

# Config values for DMS.
DMS_REPLICATION_INSTANCE_NAME = f'{PROJECT_NAME}-replication-instance'
DMS_REPLICATION_TASK_ID = f'{PROJECT_NAME}-replication-task'
SOURCE_ENDPOINT_IDENTIFIER = f'{PROJECT_NAME}-source-endpoint'
TARGET_ENDPOINT_IDENTIFIER = f'{PROJECT_NAME}-target-endpoint'

# Sample data.
TABLE_NAME = f'{PROJECT_NAME}-table'
TABLE_HEADERS = ['apache_project', 'release_year']
SAMPLE_DATA = [
('Airflow', '2015'),
('OpenOffice', '2012'),
('Subversion', '2000'),
('NiFi', '2006'),
]
TABLE_DEFINITION = {
'TableCount': '1',
'Tables': [
{
'TableName': TABLE_NAME,
'TableColumns': [
{
'ColumnName': TABLE_HEADERS[0],
'ColumnType': 'STRING',
'ColumnNullable': 'false',
'ColumnIsPk': 'true',
},
{"ColumnName": TABLE_HEADERS[1], "ColumnType": 'STRING', "ColumnLength": "4"},
],
'TableColumnsTotal': '2',
}
],
}
TABLE_MAPPINGS = {
'rules': [
{
'rule-type': 'selection',
'rule-id': '1',
'rule-name': '1',
'object-locator': {
'schema-name': 'public',
'table-name': TABLE_NAME,
},
'rule-action': 'include',
}
]
}


def _create_rds_instance():
print('Creating RDS Instance.')

rds_client = boto3.client('rds')
rds_client.create_db_instance(
DBName=RDS_DB_NAME,
DBInstanceIdentifier=RDS_INSTANCE_NAME,
AllocatedStorage=20,
DBInstanceClass='db.t3.micro',
Engine=RDS_ENGINE,
MasterUsername=RDS_USERNAME,
MasterUserPassword=RDS_PASSWORD,
)

rds_client.get_waiter('db_instance_available').wait(DBInstanceIdentifier=RDS_INSTANCE_NAME)

response = rds_client.describe_db_instances(DBInstanceIdentifier=RDS_INSTANCE_NAME)
return response['DBInstances'][0]['Endpoint']


def _create_rds_table(rds_endpoint):
print('Creating table.')

hostname = rds_endpoint['Address']
port = rds_endpoint['Port']
rds_url = f'{RDS_PROTOCOL}://{RDS_USERNAME}:{RDS_PASSWORD}@{hostname}:{port}/{RDS_DB_NAME}'
engine = create_engine(rds_url)

table = Table(
TABLE_NAME,
MetaData(engine),
Column(TABLE_HEADERS[0], String, primary_key=True),
Column(TABLE_HEADERS[1], String),
)

with engine.connect() as connection:
# Create the Table.
table.create()
load_data = table.insert().values(SAMPLE_DATA)
connection.execute(load_data)

# Read the data back to verify everything is working.
connection.execute(table.select())


def _create_dms_replication_instance(ti, dms_client):
print('Creating replication instance.')
instance_arn = dms_client.create_replication_instance(
ReplicationInstanceIdentifier=DMS_REPLICATION_INSTANCE_NAME,
ReplicationInstanceClass='dms.t3.micro',
)['ReplicationInstance']['ReplicationInstanceArn']

ti.xcom_push(key='replication_instance_arn', value=instance_arn)
return instance_arn


def _create_dms_endpoints(ti, dms_client, rds_instance_endpoint):
print('Creating DMS source endpoint.')
source_endpoint_arn = dms_client.create_endpoint(
EndpointIdentifier=SOURCE_ENDPOINT_IDENTIFIER,
EndpointType='source',
EngineName=RDS_ENGINE,
Username=RDS_USERNAME,
Password=RDS_PASSWORD,
ServerName=rds_instance_endpoint['Address'],
Port=rds_instance_endpoint['Port'],
DatabaseName=RDS_DB_NAME,
)['Endpoint']['EndpointArn']

print('Creating DMS target endpoint.')
target_endpoint_arn = dms_client.create_endpoint(
EndpointIdentifier=TARGET_ENDPOINT_IDENTIFIER,
EndpointType='target',
EngineName='s3',
S3Settings={
'BucketName': S3_BUCKET,
'BucketFolder': PROJECT_NAME,
'ServiceAccessRoleArn': ROLE_ARN,
'ExternalTableDefinition': json.dumps(TABLE_DEFINITION),
},
)['Endpoint']['EndpointArn']

ti.xcom_push(key='source_endpoint_arn', value=source_endpoint_arn)
ti.xcom_push(key='target_endpoint_arn', value=target_endpoint_arn)


def _await_setup_assets(dms_client, instance_arn):
print("Awaiting asset provisioning.")
dms_client.get_waiter('replication_instance_available').wait(
Filters=[{'Name': 'replication-instance-arn', 'Values': [instance_arn]}]
)


def _delete_rds_instance():
print('Deleting RDS Instance.')

rds_client = boto3.client('rds')
rds_client.delete_db_instance(
DBInstanceIdentifier=RDS_INSTANCE_NAME,
SkipFinalSnapshot=True,
)

rds_client.get_waiter('db_instance_deleted').wait(DBInstanceIdentifier=RDS_INSTANCE_NAME)


def _delete_dms_assets(dms_client):
ti = get_current_context()['ti']
replication_instance_arn = ti.xcom_pull(key='replication_instance_arn')
source_arn = ti.xcom_pull(key='source_endpoint_arn')
target_arn = ti.xcom_pull(key='target_endpoint_arn')

print('Deleting DMS assets.')
dms_client.delete_replication_instance(ReplicationInstanceArn=replication_instance_arn)
dms_client.delete_endpoint(EndpointArn=source_arn)
dms_client.delete_endpoint(EndpointArn=target_arn)


def _await_all_teardowns(dms_client):
print('Awaiting tear-down.')
dms_client.get_waiter('replication_instance_deleted').wait(
Filters=[{'Name': 'replication-instance-id', 'Values': [DMS_REPLICATION_INSTANCE_NAME]}]
)

dms_client.get_waiter('endpoint_deleted').wait(
Filters=[
{
'Name': 'endpoint-id',
'Values': [SOURCE_ENDPOINT_IDENTIFIER, TARGET_ENDPOINT_IDENTIFIER],
}
]
)


@task
def set_up():
ti = get_current_context()['ti']
dms_client = boto3.client('dms')

rds_instance_endpoint = _create_rds_instance()
_create_rds_table(rds_instance_endpoint)
instance_arn = _create_dms_replication_instance(ti, dms_client)
_create_dms_endpoints(ti, dms_client, rds_instance_endpoint)
_await_setup_assets(dms_client, instance_arn)


@task(trigger_rule='all_done')
def clean_up():
dms_client = boto3.client('dms')

_delete_rds_instance()
_delete_dms_assets(dms_client)
_await_all_teardowns(dms_client)


with DAG(
dag_id='example_dms',
schedule_interval=None,
start_date=datetime(2021, 1, 1),
tags=['example'],
catchup=False,
) as dag:

# [START howto_operator_dms_create_task]
create_task = DmsCreateTaskOperator(
task_id='create_task',
replication_task_id=DMS_REPLICATION_TASK_ID,
source_endpoint_arn='{{ ti.xcom_pull(key="source_endpoint_arn") }}',
target_endpoint_arn='{{ ti.xcom_pull(key="target_endpoint_arn") }}',
replication_instance_arn='{{ ti.xcom_pull(key="replication_instance_arn") }}',
table_mappings=TABLE_MAPPINGS,
)
# [END howto_operator_dms_create_task]

# [START howto_operator_dms_start_task]
start_task = DmsStartTaskOperator(
task_id='start_task',
replication_task_arn=create_task.output,
)
# [END howto_operator_dms_start_task]

# [START howto_operator_dms_describe_tasks]
describe_tasks = DmsDescribeTasksOperator(
task_id='describe_tasks',
describe_tasks_kwargs={
'Filters': [
{
'Name': 'replication-instance-arn',
'Values': ['{{ ti.xcom_pull(key="replication_instance_arn") }}'],
}
]
},
do_xcom_push=False,
)
# [END howto_operator_dms_describe_tasks]

await_task_start = DmsTaskBaseSensor(
task_id='await_task_start',
replication_task_arn=create_task.output,
target_statuses=['running'],
termination_statuses=['stopped', 'deleting', 'failed'],
)

# [START howto_operator_dms_stop_task]
stop_task = DmsStopTaskOperator(
task_id='stop_task',
replication_task_arn=create_task.output,
)
# [END howto_operator_dms_stop_task]

# TaskCompletedSensor actually waits until task reaches the "Stopped" state, so it will work here.
# [START howto_operator_dms_task_completed_sensor]
await_task_stop = DmsTaskCompletedSensor(
task_id='await_task_stop',
replication_task_arn=create_task.output,
)
# [END howto_operator_dms_task_completed_sensor]

# [START howto_operator_dms_delete_task]
delete_task = DmsDeleteTaskOperator(
task_id='delete_task',
replication_task_arn=create_task.output,
trigger_rule='all_done',
)
# [END howto_operator_dms_delete_task]

chain(
set_up()
>> create_task
>> start_task
>> describe_tasks
>> await_task_start
>> stop_task
>> await_task_stop
>> delete_task
>> clean_up()
)
Loading

0 comments on commit fb3b980

Please sign in to comment.