Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use handler in run_sql #1773

Merged
merged 6 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from astro.dataframes.pandas import PandasDataframe

if TYPE_CHECKING: # pragma: no cover
from sqlalchemy.engine.cursor import CursorResult
pass

from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.elements import ColumnClause
Expand Down Expand Up @@ -63,8 +63,6 @@ class BaseDatabase(ABC):
# illegal_column_name_chars[0] will be replaced by value in illegal_column_name_chars_replacement[0]
illegal_column_name_chars: list[str] = []
illegal_column_name_chars_replacement: list[str] = []
# In run_raw_sql operator decides if we want to return results directly or process them by handler provided
IGNORE_HANDLER_IN_RUN_RAW_SQL: bool = False
NATIVE_PATHS: dict[Any, Any] = {}
DEFAULT_SCHEMA = SCHEMA
NATIVE_LOAD_EXCEPTIONS: Any = DatabaseCustomError
Expand Down Expand Up @@ -107,8 +105,9 @@ def run_sql(
self,
sql: str | ClauseElement = "",
parameters: dict | None = None,
handler: Callable | None = None,
**kwargs,
) -> CursorResult:
) -> Any:
"""
Return the results to running a SQL statement.

Expand Down Expand Up @@ -139,7 +138,9 @@ def run_sql(
)
else:
result = self.connection.execute(sql, parameters)
return result
if handler:
return handler(result)
return None

def columns_exist(self, table: BaseTable, columns: list[str]) -> bool:
"""
Expand Down Expand Up @@ -419,8 +420,8 @@ def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list:
statement = f"SELECT * FROM {self.get_table_qualified_name(table)}"
if row_limit > -1:
statement = statement + f" LIMIT {row_limit}"
response = self.run_sql(statement)
return response.fetchall() # type: ignore
response = self.run_sql(statement, handler=lambda x: x.fetchall()) # type: ignore
return response

def load_file_to_table(
self,
Expand Down Expand Up @@ -777,8 +778,9 @@ def row_count(self, table: BaseTable):
:return: The number of rows in the table
"""
result = self.run_sql(
f"select count(*) from {self.get_table_qualified_name(table)}" # skipcq: BAN-B608
).scalar()
f"select count(*) from {self.get_table_qualified_name(table)}", # skipcq: BAN-B608
handler=lambda x: x.scalar(),
)
return result

def parameterize_variable(self, variable: str):
Expand Down
8 changes: 3 additions & 5 deletions python-sdk/src/astro/databases/databricks/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid
import warnings
from textwrap import dedent
from typing import Any, Callable

import pandas as pd
from airflow.providers.databricks.hooks.databricks import DatabricksHook
Expand All @@ -25,9 +26,6 @@

class DeltaDatabase(BaseDatabase):
LOAD_OPTIONS_CLASS_NAME = "DeltaLoadOptions"
# In run_raw_sql operator decides if we want to return results directly or process them by handler provided
# For delta tables we ignore the handler
IGNORE_HANDLER_IN_RUN_RAW_SQL: bool = True
_create_table_statement: str = "CREATE TABLE IF NOT EXISTS {} USING DELTA AS {} "

def __init__(self, conn_id: str, table: BaseTable | None = None, load_options: LoadOptions | None = None):
Expand Down Expand Up @@ -197,9 +195,9 @@ def run_sql(
self,
sql: str | ClauseElement = "",
parameters: dict | None = None,
handler=None,
handler: Callable | None = None,
**kwargs,
):
) -> Any:
"""
Run SQL against a delta table using spark SQL.

Expand Down
16 changes: 9 additions & 7 deletions python-sdk/src/astro/databases/mssql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Callable

import pandas as pd
import sqlalchemy
Expand All @@ -18,7 +18,7 @@

DEFAULT_CONN_ID = MsSqlHook.default_conn_name
if TYPE_CHECKING: # pragma: no cover
from sqlalchemy.engine.cursor import CursorResult
pass


class MssqlDatabase(BaseDatabase):
Expand Down Expand Up @@ -145,8 +145,9 @@ def run_sql(
self,
sql: str | ClauseElement = "",
parameters: dict | None = None,
handler: Callable | None = None,
**kwargs,
) -> CursorResult:
) -> Any:
"""
Return the results to running a SQL statement.
Whenever possible, this method should be implemented using Airflow Hooks,
Expand Down Expand Up @@ -177,11 +178,12 @@ def run_sql(
result = self.connection.execute(
sqlalchemy.text(sql).execution_options(autocommit=autocommit), parameters
)
return result
else:
# this is used for append
result = self.connection.execute(sql, parameters)
return result
if handler:
return handler(result)
return None

def create_schema_if_needed(self, schema: str | None) -> None:
"""
Expand Down Expand Up @@ -238,8 +240,8 @@ def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list:
statement = f"SELECT * FROM {self.get_table_qualified_name(table)}" # skipcq: BAN-B608
if row_limit > -1:
statement = f"SELECT TOP {row_limit} * FROM {self.get_table_qualified_name(table)}"
response = self.run_sql(statement)
return response.fetchall() # type: ignore
response = self.run_sql(statement, handler=lambda x: x.fetchall()) # type: ignore
return response

def load_pandas_dataframe_to_table(
self,
Expand Down
17 changes: 2 additions & 15 deletions python-sdk/src/astro/sql/operators/raw_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
def execute(self, context: Context) -> Any:
super().execute(context)

self.handler = self.get_handler()
result = self.database_impl.run_sql(sql=self.sql, parameters=self.parameters, handler=self.handler)
if self.response_size == -1 and not settings.IS_CUSTOM_XCOM_BACKEND:
logging.warning(
Expand All @@ -60,22 +61,8 @@ def execute(self, context: Context) -> Any:
"backend."
)

# ToDo: Currently, the handler param in run_sql() method is only used in databricks all other databases are
# not using it. Which leads to different response types since handler is processed within `run_sql()` for
# databricks and not for other databases. Also the signature of `run_sql()` in databricks deviates from base.
# We need to standardise and when we do, we can remove below check as well.
if self.database_impl.IGNORE_HANDLER_IN_RUN_RAW_SQL:
return result

self.handler = self.get_handler()

if self.handler:
self.handler = self.get_wrapped_handler(
fail_on_empty=self.fail_on_empty, conversion_func=self.handler
)
# otherwise, call the handler and convert the result to a list
response = self.handler(result)
response = self.make_row_serializable(response)
response = self.make_row_serializable(result)
if 0 <= self.response_limit < len(response):
raise IllegalLoadToDatabaseException() # pragma: no cover
if self.response_size >= 0:
Expand Down
16 changes: 8 additions & 8 deletions python-sdk/tests_integration/databases/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def test_bigquery_run_sql():
"""Test run_sql against bigquery database"""
statement = "SELECT 1 + 1;"
database = BigqueryDatabase(conn_id=DEFAULT_CONN_ID)
response = database.run_sql(statement)
assert response.first()[0] == 2
response = database.run_sql(statement, handler=lambda x: x.first())
assert response[0] == 2


@pytest.mark.integration
Expand Down Expand Up @@ -77,12 +77,12 @@ def test_bigquery_create_table_with_columns(database_table_fixture):
f"SELECT TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE "
f"FROM {table.metadata.schema}.INFORMATION_SCHEMA.COLUMNS WHERE table_name='{table.name}'"
)
response = database.run_sql(statement)
assert response.first() is None
response = database.run_sql(statement, handler=lambda x: x.first())
assert response is None

database.create_table(table)
response = database.run_sql(statement)
rows = response.fetchall()
response = database.run_sql(statement, handler=lambda x: x.fetchall())
rows = response
assert len(rows) == 2
assert rows[0] == (
"astronomer-dag-authoring",
Expand Down Expand Up @@ -121,9 +121,9 @@ def test_load_pandas_dataframe_to_table(database_table_fixture):
database.load_pandas_dataframe_to_table(pandas_dataframe, table)

statement = f"SELECT * FROM {database.get_table_qualified_name(table)};"
response = database.run_sql(statement)
response = database.run_sql(statement, handler=lambda x: x.fetchall())

rows = response.fetchall()
rows = response
assert len(rows) == 2
assert rows[0] == (1,)
assert rows[1] == (2,)
Expand Down
16 changes: 8 additions & 8 deletions python-sdk/tests_integration/databases/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def test_mssql_run_sql():
"""Test run_sql against mssql database"""
statement = "SELECT 1 + 1;"
database = MssqlDatabase(conn_id=CUSTOM_CONN_ID)
response = database.run_sql(statement)
assert response.first()[0] == 2
response = database.run_sql(statement, handler=lambda x: x.first())
assert response[0] == 2


@pytest.mark.integration
Expand Down Expand Up @@ -88,12 +88,12 @@ def test_mssql_create_table_with_columns(database_table_fixture):
database, table = database_table_fixture

statement = f"SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name='{table.name}'"
response = database.run_sql(statement)
assert response.first() is None
response = database.run_sql(statement, handler=lambda x: x.first())
assert response is None

database.create_table(table)
response = database.run_sql(statement)
rows = response.fetchall()
response = database.run_sql(statement, handler=lambda x: x.fetchall())
rows = response
assert len(rows) == 2
assert rows[0][0:4] == (
"astrodb",
Expand Down Expand Up @@ -126,9 +126,9 @@ def test_load_pandas_dataframe_to_table(database_table_fixture):
database.load_pandas_dataframe_to_table(pandas_dataframe, table)

statement = f"SELECT * FROM {database.get_table_qualified_name(table)};"
response = database.run_sql(statement)
response = database.run_sql(statement, handler=lambda x: x.fetchall())

rows = response.fetchall()
rows = response
assert len(rows) == 2
assert rows[0] == (1,)
assert rows[1] == (2,)
16 changes: 8 additions & 8 deletions python-sdk/tests_integration/databases/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def test_postgres_run_sql():
"""Test run_sql against postgres database"""
statement = "SELECT 1 + 1;"
database = PostgresDatabase(conn_id=CUSTOM_CONN_ID)
response = database.run_sql(statement)
assert response.first()[0] == 2
response = database.run_sql(statement, handler=lambda x: x.first())
assert response[0] == 2


@pytest.mark.integration
Expand Down Expand Up @@ -90,12 +90,12 @@ def test_postgres_create_table_with_columns(database_table_fixture):
database, table = database_table_fixture

statement = f"SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name='{table.name}'"
response = database.run_sql(statement)
assert response.first() is None
response = database.run_sql(statement, handler=lambda x: x.first())
assert response is None

database.create_table(table)
response = database.run_sql(statement)
rows = response.fetchall()
response = database.run_sql(statement, handler=lambda x: x.fetchall())
rows = response
assert len(rows) == 2
assert rows[0][0:4] == (
"postgres",
Expand Down Expand Up @@ -128,9 +128,9 @@ def test_load_pandas_dataframe_to_table(database_table_fixture):
database.load_pandas_dataframe_to_table(pandas_dataframe, table)

statement = f"SELECT * FROM {database.get_table_qualified_name(table)};"
response = database.run_sql(statement)
response = database.run_sql(statement, handler=lambda x: x.fetchall())

rows = response.fetchall()
rows = response
assert len(rows) == 2
assert rows[0] == (1,)
assert rows[1] == (2,)
Expand Down
16 changes: 8 additions & 8 deletions python-sdk/tests_integration/databases/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def test_redshift_run_sql():
"""Test run_sql against redshift database"""
statement = "SELECT 1 + 1;"
database = RedshiftDatabase(conn_id=CUSTOM_CONN_ID)
response = database.run_sql(statement)
assert response.first()[0] == 2
response = database.run_sql(statement, handler=lambda x: x.first())
assert response[0] == 2


@pytest.mark.integration
Expand Down Expand Up @@ -73,12 +73,12 @@ def test_redshift_create_table_with_columns(database_table_fixture):
database, table = database_table_fixture

statement = f"SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name='{table.name}'"
response = database.run_sql(statement)
assert response.first() is None
response = database.run_sql(statement, handler=lambda x: x.first())
assert response is None

database.create_table(table)
response = database.run_sql(statement)
rows = response.fetchall()
response = database.run_sql(statement, handler=lambda x: x.fetchall())
rows = response
assert len(rows) == 2
assert rows[0][0:4] == (
"dev",
Expand Down Expand Up @@ -111,9 +111,9 @@ def test_load_pandas_dataframe_to_table(database_table_fixture):
database.load_pandas_dataframe_to_table(pandas_dataframe, table)

statement = f"SELECT * FROM {database.get_table_qualified_name(table)};"
response = database.run_sql(statement)
response = database.run_sql(statement, handler=lambda x: x.fetchall())

rows = response.fetchall()
rows = response
assert len(rows) == 2
assert rows[0] == (1,)
assert rows[1] == (2,)
Expand Down
18 changes: 9 additions & 9 deletions python-sdk/tests_integration/databases/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def test_snowflake_run_sql():
"""Test run_sql against snowflake database"""
statement = "SELECT 1 + 1;"
database = SnowflakeDatabase(conn_id=CUSTOM_CONN_ID)
response = database.run_sql(statement)
assert response.first()[0] == 2
response = database.run_sql(statement, handler=lambda x: x.first())
assert response[0] == 2


@pytest.mark.integration
Expand Down Expand Up @@ -79,8 +79,8 @@ def test_snowflake_create_table_with_columns(database_table_fixture):
assert e.match("does not exist or not authorized")

database.create_table(table)
response = database.run_sql(statement)
rows = response.fetchall()
response = database.run_sql(statement, handler=lambda x: x.fetchall())
rows = response
assert len(rows) == 2
assert rows[0] == (
"ID",
Expand Down Expand Up @@ -137,11 +137,11 @@ def test_snowflake_create_table_using_native_schema_autodetection(

file = File("s3://astro-sdk/sample.parquet", conn_id="aws_conn")
database.create_table(table, file)
response = database.run_sql(statement)
rows = response.fetchall()
response = database.run_sql(statement, handler=lambda x: x.fetchall())
rows = response
assert len(rows) == 2
statement = f"SELECT COUNT(*) FROM {database.get_table_qualified_name(table)}"
count = database.run_sql(statement).scalar()
count = database.run_sql(statement, handler=lambda x: x.scalar())
assert count == 0


Expand All @@ -165,9 +165,9 @@ def test_load_pandas_dataframe_to_table(database_table_fixture):
database.load_pandas_dataframe_to_table(pandas_dataframe, table)

statement = f"SELECT * FROM {database.get_table_qualified_name(table)}"
response = database.run_sql(statement)
response = database.run_sql(statement, handler=lambda x: x.fetchall())

rows = response.fetchall()
rows = response
assert len(rows) == 2
assert rows[0] == (1,)
assert rows[1] == (2,)
Expand Down
Loading