Skip to content

Commit

Permalink
Use handler in run_sql (#1773)
Browse files Browse the repository at this point in the history
**Please describe the feature you'd like to see**
Currently, the handler param in the run_sql() method is only used in
data bricks all other databases are not using it, which leads to
different response types since the handler is processed within
`run_sql()` for data bricks and not for other databases. Also, the
signature of `run_sql()` in data bricks deviates from the base. We need
to standardize and when we do, we can remove the below check as well.

**Describe the solution you'd like**
Ideally, we should have the same signatures for all the run_sql()
database implementations. This will prevent adding explicit conditions
for data bricks in other parts of the code that should ideally be
database agnostic.
example:
https://github.com/astronomer/astro-sdk/blob/2f8c8d0ccbff3cc13af78685cbefa39f473991c3/python-sdk/src/astro/sql/operators/raw_sql.py#L47

**Acceptance Criteria**
- [ ] Where ever we have used run_sql(), we shouldn't be having any
database-specific checks.

closes #1664
  • Loading branch information
rajaths010494 authored and utkarsharma2 committed Feb 21, 2023
1 parent 5df35f7 commit b367003
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 115 deletions.
30 changes: 14 additions & 16 deletions python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@
import logging
import warnings
from abc import ABC
from typing import TYPE_CHECKING, Any, Callable, Mapping
from typing import Any, Callable, Mapping

import pandas as pd
import sqlalchemy
from airflow.hooks.dbapi import DbApiHook
from pandas.io.sql import SQLDatabase
from sqlalchemy import column, insert, select

from astro.dataframes.pandas import PandasDataframe

if TYPE_CHECKING: # pragma: no cover
from sqlalchemy.engine.cursor import CursorResult

from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.schema import Table as SqlaTable
Expand All @@ -29,6 +23,7 @@
LoadExistStrategy,
MergeConflictStrategy,
)
from astro.dataframes.pandas import PandasDataframe
from astro.exceptions import DatabaseCustomError, NonExistentTableException
from astro.files import File, resolve_file_path_pattern
from astro.files.types import create_file_type
Expand Down Expand Up @@ -63,8 +58,6 @@ class BaseDatabase(ABC):
# illegal_column_name_chars[0] will be replaced by value in illegal_column_name_chars_replacement[0]
illegal_column_name_chars: list[str] = []
illegal_column_name_chars_replacement: list[str] = []
# In run_raw_sql operator decides if we want to return results directly or process them by handler provided
IGNORE_HANDLER_IN_RUN_RAW_SQL: bool = False
NATIVE_PATHS: dict[Any, Any] = {}
DEFAULT_SCHEMA = SCHEMA
NATIVE_LOAD_EXCEPTIONS: Any = DatabaseCustomError
Expand Down Expand Up @@ -107,8 +100,9 @@ def run_sql(
self,
sql: str | ClauseElement = "",
parameters: dict | None = None,
handler: Callable | None = None,
**kwargs,
) -> CursorResult:
) -> Any:
"""
Return the results to running a SQL statement.
Expand All @@ -118,6 +112,7 @@ def run_sql(
:param sql: Contains SQL query to be run against database
:param parameters: Optional parameters to be used to render the query
:param autocommit: Optional autocommit flag
:param handler: function that takes in a cursor as an argument.
"""
if parameters is None:
parameters = {}
Expand All @@ -139,7 +134,9 @@ def run_sql(
)
else:
result = self.connection.execute(sql, parameters)
return result
if handler:
return handler(result)
return None

def columns_exist(self, table: BaseTable, columns: list[str]) -> bool:
"""
Expand Down Expand Up @@ -407,7 +404,7 @@ def create_schema_and_table_if_needed(
use_native_support=use_native_support,
)

def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list:
def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> Any:
"""
Fetches all rows for a table and returns as a list. This is needed because some
databases have different cursors that require different methods to fetch rows
Expand All @@ -419,8 +416,8 @@ def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list:
statement = f"SELECT * FROM {self.get_table_qualified_name(table)}"
if row_limit > -1:
statement = statement + f" LIMIT {row_limit}"
response = self.run_sql(statement)
return response.fetchall() # type: ignore
response: list = self.run_sql(statement, handler=lambda x: x.fetchall())
return response

def load_file_to_table(
self,
Expand Down Expand Up @@ -777,8 +774,9 @@ def row_count(self, table: BaseTable):
:return: The number of rows in the table
"""
result = self.run_sql(
f"select count(*) from {self.get_table_qualified_name(table)}" # skipcq: BAN-B608
).scalar()
f"select count(*) from {self.get_table_qualified_name(table)}", # skipcq: BAN-B608
handler=lambda x: x.scalar(),
)
return result

def parameterize_variable(self, variable: str):
Expand Down
8 changes: 3 additions & 5 deletions python-sdk/src/astro/databases/databricks/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid
import warnings
from textwrap import dedent
from typing import Any, Callable

import pandas as pd
from airflow.providers.databricks.hooks.databricks import DatabricksHook
Expand All @@ -25,9 +26,6 @@

class DeltaDatabase(BaseDatabase):
LOAD_OPTIONS_CLASS_NAME = "DeltaLoadOptions"
# In run_raw_sql operator decides if we want to return results directly or process them by handler provided
# For delta tables we ignore the handler
IGNORE_HANDLER_IN_RUN_RAW_SQL: bool = True
_create_table_statement: str = "CREATE TABLE IF NOT EXISTS {} USING DELTA AS {} "

def __init__(self, conn_id: str, table: BaseTable | None = None, load_options: LoadOptions | None = None):
Expand Down Expand Up @@ -197,9 +195,9 @@ def run_sql(
self,
sql: str | ClauseElement = "",
parameters: dict | None = None,
handler=None,
handler: Callable | None = None,
**kwargs,
):
) -> Any:
"""
Run SQL against a delta table using spark SQL.
Expand Down
19 changes: 10 additions & 9 deletions python-sdk/src/astro/databases/mssql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING
from typing import Any, Callable

import pandas as pd
import sqlalchemy
Expand All @@ -17,8 +17,6 @@
from astro.utils.compat.functools import cached_property

DEFAULT_CONN_ID = MsSqlHook.default_conn_name
if TYPE_CHECKING: # pragma: no cover
from sqlalchemy.engine.cursor import CursorResult


class MssqlDatabase(BaseDatabase):
Expand Down Expand Up @@ -145,15 +143,17 @@ def run_sql(
self,
sql: str | ClauseElement = "",
parameters: dict | None = None,
handler: Callable | None = None,
**kwargs,
) -> CursorResult:
) -> Any:
"""
Return the results to running a SQL statement.
Whenever possible, this method should be implemented using Airflow Hooks,
since this will simplify the integration with Async operators.
:param sql: Contains SQL query to be run against database
:param parameters: Optional parameters to be used to render the query
:param handler: function that takes in a cursor as an argument.
"""
if parameters is None:
parameters = {}
Expand All @@ -177,11 +177,12 @@ def run_sql(
result = self.connection.execute(
sqlalchemy.text(sql).execution_options(autocommit=autocommit), parameters
)
return result
else:
# this is used for append
result = self.connection.execute(sql, parameters)
return result
if handler:
return handler(result)
return None

def create_schema_if_needed(self, schema: str | None) -> None:
"""
Expand Down Expand Up @@ -226,7 +227,7 @@ def drop_table(self, table: BaseTable) -> None:
statement = self._drop_table_statement.format(self.get_table_qualified_name(table))
self.run_sql(statement, autocommit=True)

def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list:
def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> Any:
"""
Fetches all rows for a table and returns as a list. This is needed because some
databases have different cursors that require different methods to fetch rows
Expand All @@ -238,8 +239,8 @@ def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list:
statement = f"SELECT * FROM {self.get_table_qualified_name(table)}" # skipcq: BAN-B608
if row_limit > -1:
statement = f"SELECT TOP {row_limit} * FROM {self.get_table_qualified_name(table)}"
response = self.run_sql(statement)
return response.fetchall() # type: ignore
response: list = self.run_sql(statement, handler=lambda x: x.fetchall())
return response

def load_pandas_dataframe_to_table(
self,
Expand Down
17 changes: 2 additions & 15 deletions python-sdk/src/astro/sql/operators/raw_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
def execute(self, context: Context) -> Any:
super().execute(context)

self.handler = self.get_handler()
result = self.database_impl.run_sql(sql=self.sql, parameters=self.parameters, handler=self.handler)
if self.response_size == -1 and not settings.IS_CUSTOM_XCOM_BACKEND:
logging.warning(
Expand All @@ -60,22 +61,8 @@ def execute(self, context: Context) -> Any:
"backend."
)

# ToDo: Currently, the handler param in run_sql() method is only used in databricks all other databases are
# not using it. Which leads to different response types since handler is processed within `run_sql()` for
# databricks and not for other databases. Also the signature of `run_sql()` in databricks deviates from base.
# We need to standardise and when we do, we can remove below check as well.
if self.database_impl.IGNORE_HANDLER_IN_RUN_RAW_SQL:
return result

self.handler = self.get_handler()

if self.handler:
self.handler = self.get_wrapped_handler(
fail_on_empty=self.fail_on_empty, conversion_func=self.handler
)
# otherwise, call the handler and convert the result to a list
response = self.handler(result)
response = self.make_row_serializable(response)
response = self.make_row_serializable(result)
if 0 <= self.response_limit < len(response):
raise IllegalLoadToDatabaseException() # pragma: no cover
if self.response_size >= 0:
Expand Down
19 changes: 8 additions & 11 deletions python-sdk/tests/sql/operators/test_run_raw_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def test_make_row_serializable(rows):


@mock.patch("astro.sql.operators.raw_sql.RawSQLOperator.results_as_list")
@mock.patch("astro.databases.base.BaseDatabase.run_sql")
@mock.patch("astro.databases.base.BaseDatabase.connection")
def test_run_sql_calls_list_handler(run_sql, results_as_list, sample_dag):
results_as_list.return_value = []
run_sql.return_value = []
run_sql.execute.return_value = []
with sample_dag:

@aql.run_raw_sql(results_format="list", conn_id="sqlite_default")
Expand All @@ -40,10 +40,10 @@ def dummy_method():


@mock.patch("astro.sql.operators.raw_sql.RawSQLOperator.results_as_pandas_dataframe")
@mock.patch("astro.databases.base.BaseDatabase.run_sql")
@mock.patch("astro.databases.base.BaseDatabase.connection")
def test_run_sql_calls_pandas_dataframe_handler(run_sql, results_as_pandas_dataframe, sample_dag):
results_as_pandas_dataframe.return_value = []
run_sql.return_value = []
run_sql.execute.return_value = []
with sample_dag:

@aql.run_raw_sql(results_format="pandas_dataframe", conn_id="sqlite_default")
Expand All @@ -57,13 +57,13 @@ def dummy_method():


@mock.patch("astro.sql.operators.raw_sql.RawSQLOperator.results_as_pandas_dataframe")
@mock.patch("astro.databases.base.BaseDatabase.run_sql")
@mock.patch("astro.databases.base.BaseDatabase.connection")
def test_run_sql_gives_priority_to_pandas_dataframe_handler(run_sql, results_as_pandas_dataframe, sample_dag):
"""
Test that run_sql calls `results_format` specified handler over handler passed in decorator.
"""
results_as_pandas_dataframe.return_value = []
run_sql.return_value = []
run_sql.execute.return_value = []
with sample_dag:

@aql.run_raw_sql(
Expand Down Expand Up @@ -103,14 +103,11 @@ def dummy_method():


@mock.patch("astro.sql.operators.raw_sql.RawSQLOperator.results_as_pandas_dataframe")
@mock.patch("astro.databases.base.BaseDatabase.run_sql")
def test_run_sql_should_raise_exception(run_sql, results_as_pandas_dataframe, sample_dag):
def test_run_sql_should_raise_exception(results_as_pandas_dataframe, sample_dag):
"""
Test that run_sql should raise an exception when fail_on_empty=False
"""
results_as_pandas_dataframe.return_value = []
return_value = [1, 2, 3]
run_sql.return_value = return_value

def raise_exception(result):
raise ValueError("dummy exception")
Expand Down Expand Up @@ -156,7 +153,7 @@ def test_handlers():

class Val:
def __init__(self, val):
self.value = [val]
self.value: list = [val]

def values(self) -> list:
return self.value
Expand Down
16 changes: 8 additions & 8 deletions python-sdk/tests_integration/databases/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def test_bigquery_run_sql():
"""Test run_sql against bigquery database"""
statement = "SELECT 1 + 1;"
database = BigqueryDatabase(conn_id=DEFAULT_CONN_ID)
response = database.run_sql(statement)
assert response.first()[0] == 2
response = database.run_sql(statement, handler=lambda x: x.first())
assert response[0] == 2


@pytest.mark.integration
Expand Down Expand Up @@ -77,12 +77,12 @@ def test_bigquery_create_table_with_columns(database_table_fixture):
f"SELECT TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE "
f"FROM {table.metadata.schema}.INFORMATION_SCHEMA.COLUMNS WHERE table_name='{table.name}'"
)
response = database.run_sql(statement)
assert response.first() is None
response = database.run_sql(statement, handler=lambda x: x.first())
assert response is None

database.create_table(table)
response = database.run_sql(statement)
rows = response.fetchall()
response = database.run_sql(statement, handler=lambda x: x.fetchall())
rows = response
assert len(rows) == 2
assert rows[0] == (
"astronomer-dag-authoring",
Expand Down Expand Up @@ -121,9 +121,9 @@ def test_load_pandas_dataframe_to_table(database_table_fixture):
database.load_pandas_dataframe_to_table(pandas_dataframe, table)

statement = f"SELECT * FROM {database.get_table_qualified_name(table)};"
response = database.run_sql(statement)
response = database.run_sql(statement, handler=lambda x: x.fetchall())

rows = response.fetchall()
rows = response
assert len(rows) == 2
assert rows[0] == (1,)
assert rows[1] == (2,)
Expand Down
16 changes: 8 additions & 8 deletions python-sdk/tests_integration/databases/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def test_mssql_run_sql():
"""Test run_sql against mssql database"""
statement = "SELECT 1 + 1;"
database = MssqlDatabase(conn_id=CUSTOM_CONN_ID)
response = database.run_sql(statement)
assert response.first()[0] == 2
response = database.run_sql(statement, handler=lambda x: x.first())
assert response[0] == 2


@pytest.mark.integration
Expand Down Expand Up @@ -88,12 +88,12 @@ def test_mssql_create_table_with_columns(database_table_fixture):
database, table = database_table_fixture

statement = f"SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name='{table.name}'"
response = database.run_sql(statement)
assert response.first() is None
response = database.run_sql(statement, handler=lambda x: x.first())
assert response is None

database.create_table(table)
response = database.run_sql(statement)
rows = response.fetchall()
response = database.run_sql(statement, handler=lambda x: x.fetchall())
rows = response
assert len(rows) == 2
assert rows[0][0:4] == (
"astrodb",
Expand Down Expand Up @@ -126,9 +126,9 @@ def test_load_pandas_dataframe_to_table(database_table_fixture):
database.load_pandas_dataframe_to_table(pandas_dataframe, table)

statement = f"SELECT * FROM {database.get_table_qualified_name(table)};"
response = database.run_sql(statement)
response = database.run_sql(statement, handler=lambda x: x.fetchall())

rows = response.fetchall()
rows = response
assert len(rows) == 2
assert rows[0] == (1,)
assert rows[1] == (2,)
Loading

0 comments on commit b367003

Please sign in to comment.