Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions airflow/sensors/sql_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from builtins import str

from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
Expand All @@ -34,22 +35,33 @@ class SqlSensor(BaseSensorOperator):
:param sql: The sql to run. To pass, it needs to return at least one cell
that contains a non-zero / empty string value.
:type sql: str
:param parameters: The parameters to render the SQL query with (optional).
:type parameters: mapping or iterable
"""
template_fields = ('sql',)
template_ext = ('.hql', '.sql',)
ui_color = '#7c7287'

@apply_defaults
def __init__(self, conn_id, sql, *args, **kwargs):
self.sql = sql
def __init__(self, conn_id, sql, parameters=None, *args, **kwargs):
self.conn_id = conn_id
self.sql = sql
self.parameters = parameters
super(SqlSensor, self).__init__(*args, **kwargs)

def poke(self, context):
hook = BaseHook.get_connection(self.conn_id).get_hook()
conn = BaseHook.get_connection(self.conn_id)

allowed_conn_type = {'google_cloud_platform', 'jdbc', 'mssql',
'mysql', 'oracle', 'postgres',
'presto', 'sqlite', 'vertica'}
if conn.conn_type not in allowed_conn_type:
raise AirflowException("The connection type is not supported by SqlSensor. " +
"Supported connection types: {}".format(list(allowed_conn_type)))
hook = conn.get_hook()

self.log.info('Poking: %s', self.sql)
records = hook.get_records(self.sql)
self.log.info('Poking: %s (with parameters %s)', self.sql, self.parameters)
records = hook.get_records(self.sql, self.parameters)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do all the sql type hooks(druid, mysql, postgres etc) are all inherited from dbapi hook? If that's case, the change LGTM

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not, we should change that hook to inherit from dbapi hook.

Copy link
Member Author

@XD-DENG XD-DENG Feb 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, hooks of druid, mysql, mssql, Postgres, oracle etc are all inherited from dbapi hook.

reference:

For SqlSensor, the get_hook() in BaseHook.get_connection(self.conn_id).get_hook() will decide which exact hook to use based on the connection type (reference: https://github.com/apache/airflow/blob/master/airflow/models/connection.py#L184).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I checkCloudSqlDatabaseHook (

class CloudSqlDatabaseHook(BaseHook):
) which doesn't seem to inherit from dbapi hook.

But the I assume the sqlSensor could use CloudSqlDatabasehook? Given there are many connection hooks defined, a safer approach would be check if the hook is an instance of dbapi hook, use the records = hook.get_records(self.sql, self.parameters), otherwise fall back to records = hook.get_records(self.sql). what you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually SqlSensor can not support CloudSqlDatabaseHook.

Eventually SqlSensor uses get_records method to retrieve the records from DB, while get_records method is not available in CloudSqlDatabaseHook at all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, but would a check on hook type be safer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Get what you mean.

Given the limitation in https://github.com/apache/airflow/blob/master/airflow/models/connection.py#L184 and the implementation of each hook, only the connection types below are supported by SqlSensor:

  • 'google_cloud_platform'
  • 'jdbc'
  • 'mssql'
  • 'mysql'
  • 'oracle'
  • 'postgres'
  • 'presto'
  • 'sqlite'
  • 'vertica'

I will add a check.

if not records:
return False
return str(records[0][0]) not in ('0', '')
39 changes: 35 additions & 4 deletions tests/sensors/test_sql_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from airflow import DAG
from airflow import configuration
from airflow.exceptions import AirflowException
from airflow.sensors.sql_sensor import SqlSensor
from airflow.utils.timezone import datetime

Expand All @@ -40,27 +41,56 @@ def setUp(self):
}
self.dag = DAG(TEST_DAG_ID, default_args=args)

def test_unsupported_conn_type(self):
t = SqlSensor(
task_id='sql_sensor_check',
conn_id='redis_default',
sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
dag=self.dag
)

with self.assertRaises(AirflowException):
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

@unittest.skipUnless(
'mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), "this is a mysql test")
def test_sql_sensor_mysql(self):
t = SqlSensor(
t1 = SqlSensor(
task_id='sql_sensor_check',
conn_id='mysql_default',
sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
dag=self.dag
)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
t1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

t2 = SqlSensor(
task_id='sql_sensor_check',
conn_id='mysql_default',
sql="SELECT count(%s) FROM INFORMATION_SCHEMA.TABLES",
parameters=["table_name"],
dag=self.dag
)
t2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

@unittest.skipUnless(
'postgresql' in configuration.conf.get('core', 'sql_alchemy_conn'), "this is a postgres test")
def test_sql_sensor_postgres(self):
t = SqlSensor(
t1 = SqlSensor(
task_id='sql_sensor_check',
conn_id='postgres_default',
sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
dag=self.dag
)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
t1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

t2 = SqlSensor(
task_id='sql_sensor_check',
conn_id='postgres_default',
sql="SELECT count(%s) FROM INFORMATION_SCHEMA.TABLES",
parameters=["table_name"],
dag=self.dag
)
t2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

@mock.patch('airflow.sensors.sql_sensor.BaseHook')
def test_sql_sensor_postgres_poke(self, mock_hook):
Expand All @@ -70,6 +100,7 @@ def test_sql_sensor_postgres_poke(self, mock_hook):
sql="SELECT 1",
)

mock_hook.get_connection('postgres_default').conn_type = "postgres"
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = []
Expand Down