Skip to content

Commit

Permalink
Refactor assume_schema_exists and expose it in aql.transform (#1925)
Browse files Browse the repository at this point in the history
This is a follow-up for #1922. In that PR we allowed users to skip
schema check & creation for `aql.load_file`, but we missed the fact that
`aql.transform` and `aql.transform_file` had the same issue. This PR
aims to address this limitation.

Changes included in this PR:
* Rename config `load_table_schema_exists` to `assume_schema_exists`
* Rename (`load_file`) argument `schema_exists` to
`assume_schema_exists`
* Refactor where the check for `assume_schema_exists` happens. Before,
it happened only inside the `load_file_to_table`. Now, it is part of
`create_schema_if_applicable`. This makes this feature available in the
`aql.transform` task as well
* Rename `Database.create_schema_if_needed` to
`Database.create_schema_if_applicable`
* Expose `assume_schema_exists` in `aql.transform`
* Release 1.7.0a2
  • Loading branch information
tatiana authored May 5, 2023
1 parent 829f518 commit 21e8d47
Show file tree
Hide file tree
Showing 20 changed files with 111 additions and 47 deletions.
2 changes: 1 addition & 1 deletion python-sdk/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def database_table_fixture(request):
file = params.get("file")

database.populate_table_metadata(table)
database.create_schema_if_needed(table.metadata.schema)
database.create_schema_if_applicable(table.metadata.schema)

if file:
database.load_file_to_table(file, table)
Expand Down
3 changes: 2 additions & 1 deletion python-sdk/docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Changelog

## 1.7.0a1
## 1.7.0a2

### Feature
- Allow users to disable schema check and creation on `transform` [#1925](https://github.com/astronomer/astro-sdk/pull/1925)
- Allow users to disable schema check and creation on `load_file` [#1922](https://github.com/astronomer/astro-sdk/pull/1922)

## 1.6.0
Expand Down
2 changes: 1 addition & 1 deletion python-sdk/docs/astro/sql/operators/load_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Parameters to use when loading a file to a database table

Note that if you use ``if_exists='replace'``, the existing table will be dropped and the schema of the new data will be used.

#. **schema_exists** (default is False) - By default, the SDK checks if the schema of the target table exists, and if not, it tries to create it. This query can be costly. This argument makes the SDK skip this check, since the user is informing the schema already exists.
#. **assume_schema_exists** (default is False) - By default, the SDK checks if the schema of the target table exists, and if not, it tries to create it. This query can be costly. This argument makes the SDK skip this check, since the user is informing the schema already exists.

#. **output_table** - This parameter defines the output table to load data to, which should be an instance of ``astro.sql.table.Table``. You can specify the schema of the table by providing a list of the instance of ``sqlalchemy.Column <https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column>`` to the ``columns`` parameter. If you don't specify a schema, it will be inferred using Pandas.

Expand Down
10 changes: 5 additions & 5 deletions python-sdk/docs/configurations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,24 @@ or by updating Airflow's configuration
Configuring if schemas existence should be checked and if the SDK should create them
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

By default, during ``aql.load_file``, the SDK checks if the schema of the target table exists, and if not, it tries to create it. This type of check can be costly.
By default, during ``aql.load_file`` and ``aql.transform``, the SDK checks if the schema of the target table exists, and if not, it tries to create it. This type of check can be costly.

The configuration ``AIRFLOW__ASTRO_SDK__LOAD_TABLE_SCHEMA_EXISTS`` allows users to inform the SDK that the schema already exists, skipping this check for all ``load_file`` tasks.
The configuration ``AIRFLOW__ASTRO_SDK__ASSUME_SCHEMA_EXISTS`` allows users to inform the SDK that the schema already exists, skipping this check for all ``load_file`` and ``transform`` tasks.

The user can also have a more granular control, by defining the ``load_file`` argument ``schema_exists`` on a per-task basis :ref:load_file.
The user can also have a more granular control, by defining the ``load_file`` argument ``assume_schema_exists`` on a per-task basis :ref:load_file.

Example of how to disable schema existence check using environment variables:

.. code:: ini
AIRFLOW__ASTRO_SDK__LOAD_TABLE_SCHEMA_EXISTS = True
AIRFLOW__ASTRO_SDK__ASSUME_SCHEMA_EXISTS = True
Or using Airflow's configuration file:

.. code:: ini
[astro_sdk]
load_table_schema_exists = True
assume_schema_exists = True
Configuring the unsafe dataframe storage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from astro.table import Metadata, Table


@aql.transform()
@aql.transform(assume_schema_exists=True)
def combine_data(center_1: Table, center_2: Table):
return """SELECT * FROM {{center_1}}
UNION SELECT * FROM {{center_2}}"""


@aql.transform()
@aql.transform(assume_schema_exists=True)
def clean_data(input_table: Table):
return """SELECT *
FROM {{input_table}} WHERE type NOT LIKE 'Guinea Pig'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def example_snowflake_partial_table_with_append():
schema=os.getenv("SNOWFLAKE_SCHEMA"),
),
),
schema_exists=True, # Skip queries that check if the table schema exist
assume_schema_exists=True, # Skip queries that check if the table schema exist
)

homes_data2 = load_file(
Expand All @@ -97,7 +97,7 @@ def example_snowflake_partial_table_with_append():
schema=os.getenv("SNOWFLAKE_SCHEMA"),
),
),
schema_exists=True,
assume_schema_exists=True,
)

# Define task dependencies
Expand Down
2 changes: 1 addition & 1 deletion python-sdk/src/astro/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""A decorator that allows users to run SQL queries natively in Airflow."""

__version__ = "1.7.0a1"
__version__ = "1.7.0a2"


# This is needed to allow Airflow to pick up specific metadata fields it needs
Expand Down
16 changes: 9 additions & 7 deletions python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
from astro.options import LoadOptions
from astro.query_modifier import QueryModifier
from astro.settings import (
ASSUME_SCHEMA_EXISTS,
LOAD_FILE_ENABLE_NATIVE_FALLBACK,
LOAD_TABLE_AUTODETECT_ROWS_COUNT,
LOAD_TABLE_SCHEMA_EXISTS,
SCHEMA,
)
from astro.table import BaseTable, Metadata
Expand Down Expand Up @@ -453,7 +453,7 @@ def load_file_to_table(
native_support_kwargs: dict | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
enable_native_fallback: bool | None = LOAD_FILE_ENABLE_NATIVE_FALLBACK,
schema_exists: bool = LOAD_TABLE_SCHEMA_EXISTS,
assume_schema_exists: bool = ASSUME_SCHEMA_EXISTS,
**kwargs,
):
"""
Expand All @@ -470,7 +470,7 @@ def load_file_to_table(
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
:param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer
:param schema_exists: Declare the table schema already exists and that load_file should not check if it exists
:param assume_schema_exists: If True, do not check if the output table schema it exists or attempt to create it
"""
normalize_config = normalize_config or {}
if self.check_for_minio_connection(input_file=input_file):
Expand All @@ -480,8 +480,7 @@ def load_file_to_table(
)
use_native_support = False

if not schema_exists:
self.create_schema_if_needed(output_table.metadata.schema)
self.create_schema_if_applicable(output_table.metadata.schema, assume_schema_exists)

self.create_table_if_needed(
file=input_file,
Expand Down Expand Up @@ -745,16 +744,19 @@ def export_table_to_file(
# Schema Management
# ---------------------------------------------------------

def create_schema_if_needed(self, schema: str | None) -> None:
def create_schema_if_applicable(
self, schema: str | None, assume_exists: bool = ASSUME_SCHEMA_EXISTS
) -> None:
"""
This function checks if the expected schema exists in the database. If the schema does not exist,
it will attempt to create it.
:param schema: DB Schema - a namespace that contains named objects like (tables, functions, etc)
:param assume_exists: If assume exists is True, does not check or attempt to create the schema
"""
# We check if the schema exists first because snowflake will fail on a create schema query even if it
# doesn't actually create a schema.
if schema and not self.schema_exists(schema):
if not assume_exists and schema and not self.schema_exists(schema):
statement = self._create_schema_statement.format(schema)
self.run_sql(statement)

Expand Down
10 changes: 6 additions & 4 deletions python-sdk/src/astro/databases/databricks/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from astro.files import File
from astro.options import LoadOptions
from astro.query_modifier import QueryModifier
from astro.settings import LOAD_TABLE_SCHEMA_EXISTS
from astro.settings import ASSUME_SCHEMA_EXISTS
from astro.table import BaseTable, Metadata


Expand Down Expand Up @@ -95,7 +95,9 @@ def schema_exists(self, schema: str) -> bool:
# Schemas do not need to be created for delta, so we can assume this is true
return True

def create_schema_if_needed(self, schema: str | None) -> None: # skipcq: PYL-W0613
def create_schema_if_applicable(
self, schema: str | None, assume_exists: bool = ASSUME_SCHEMA_EXISTS
) -> None: # skipcq: PYL-W0613
# Schemas do not need to be created for delta, so we don't need to do anything here
return None

Expand Down Expand Up @@ -124,7 +126,7 @@ def load_file_to_table(
native_support_kwargs: dict | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
enable_native_fallback: bool | None = None,
schema_exists: bool = LOAD_TABLE_SCHEMA_EXISTS,
assume_schema_exists: bool = ASSUME_SCHEMA_EXISTS,
databricks_job_name: str = "",
**kwargs,
):
Expand All @@ -144,7 +146,7 @@ def load_file_to_table(
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
:param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer
:param schema_exists: Declare the table schema already exists and that load_file should not check if it exists
:param assume_schema_exists: If True, skips check to see if output_table schema exists
"""
load_file_to_delta(
input_file=input_file,
Expand Down
2 changes: 1 addition & 1 deletion python-sdk/src/astro/databases/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def load_pandas_dataframe_to_table(
"""
self._assert_not_empty_df(source_dataframe)

self.create_schema_if_needed(target_table.metadata.schema)
self.create_schema_if_applicable(target_table.metadata.schema)
if not self.table_exists(table=target_table) or if_exists == "replace":
self.create_table(table=target_table, dataframe=source_dataframe)

Expand Down
6 changes: 3 additions & 3 deletions python-sdk/src/astro/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@
# Should Astro SDK automatically add inlets/outlets to take advantage of Airflow 2.4 Data-aware scheduling
AUTO_ADD_INLETS_OUTLETS = conf.getboolean(SECTION_KEY, "auto_add_inlets_outlets", fallback=True)

LOAD_TABLE_SCHEMA_EXISTS = False
ASSUME_SCHEMA_EXISTS = False


def reload():
"""
Reload settings from environment variable during runtime.
"""
global LOAD_TABLE_SCHEMA_EXISTS # skipcq: PYL-W0603
LOAD_TABLE_SCHEMA_EXISTS = conf.getboolean(SECTION_KEY, "load_table_schema_exists", fallback=False)
global ASSUME_SCHEMA_EXISTS # skipcq: PYL-W0603
ASSUME_SCHEMA_EXISTS = conf.getboolean(SECTION_KEY, "assume_schema_exists", fallback=False)


reload()
8 changes: 4 additions & 4 deletions python-sdk/src/astro/sql/operators/load_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class LoadFileOperator(AstroSQLBaseOperator):
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
:param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer
:param schema_exists: Declare the table schema already exists and that load_file should not check if it exists
:param assume_schema_exists: If True, skips check to see if output_table schema exists
:return: If ``output_table`` is passed this operator returns a Table object. If not
passed, returns a dataframe.
Expand All @@ -67,7 +67,7 @@ def __init__(
load_options: LoadOptions | list[LoadOptions] | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
enable_native_fallback: bool | None = settings.LOAD_FILE_ENABLE_NATIVE_FALLBACK,
schema_exists: bool = settings.LOAD_TABLE_SCHEMA_EXISTS,
assume_schema_exists: bool = settings.ASSUME_SCHEMA_EXISTS,
**kwargs,
) -> None:
kwargs.setdefault("task_id", get_unique_task_id("load_file"))
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(
self.native_support_kwargs: dict[str, Any] = native_support_kwargs or {}
self.columns_names_capitalization = columns_names_capitalization
self.enable_native_fallback = enable_native_fallback
self.schema_exists = schema_exists
self.assume_schema_exists = assume_schema_exists
self.load_options_list = LoadOptionsList(load_options)

def execute(self, context: Context) -> BaseTable | File: # skipcq: PYL-W0613
Expand Down Expand Up @@ -162,7 +162,7 @@ def load_data_to_table(self, input_file: File, context: Context) -> BaseTable:
native_support_kwargs=self.native_support_kwargs,
columns_names_capitalization=self.columns_names_capitalization,
enable_native_fallback=self.enable_native_fallback,
schema_exists=self.schema_exists,
assume_schema_exists=self.assume_schema_exists,
databricks_job_name=f"Load data {self.dag_id}_{self.task_id}",
)
self.log.info("Completed loading the data into %s.", self.output_table)
Expand Down
13 changes: 12 additions & 1 deletion python-sdk/src/astro/sql/operators/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from airflow.models.xcom_arg import XComArg
from sqlalchemy.sql.functions import Function

from astro.settings import ASSUME_SCHEMA_EXISTS
from astro.sql.operators.base_decorator import BaseSQLDecoratedOperator
from astro.utils.compat.typing import Context

Expand All @@ -33,9 +34,11 @@ def __init__(
response_size: int = -1,
sql: str = "",
task_id: str = "",
assume_schema_exists: bool = ASSUME_SCHEMA_EXISTS,
**kwargs: Any,
):
task_id = task_id or get_unique_task_id("transform")
self.assume_schema_exists = assume_schema_exists
super().__init__(
conn_id=conn_id,
parameters=parameters,
Expand All @@ -51,7 +54,9 @@ def __init__(

def execute(self, context: Context):
super().execute(context)
self.database_impl.create_schema_if_needed(self.output_table.metadata.schema)
self.database_impl.create_schema_if_applicable(
self.output_table.metadata.schema, self.assume_schema_exists
)
self.database_impl.drop_table(self.output_table)
self.database_impl.create_table_from_select_statement(
statement=self.sql,
Expand All @@ -73,6 +78,7 @@ def transform(
parameters: Mapping | Iterable | None = None,
database: str | None = None,
schema: str | None = None,
assume_schema_exists: bool = ASSUME_SCHEMA_EXISTS,
**kwargs: Any,
) -> TaskDecorator:
"""
Expand Down Expand Up @@ -111,6 +117,7 @@ def my_sql_statement(table1: Table, table2: Table, execution_date) -> Table:
table.metadata.database in the first Table passed to the function (required if there are no table arguments)
:param schema: Schema within the SQL instance you want to access. If left blank we will default to the
table.metadata.schema in the first Table passed to the function (required if there are no table arguments)
:param assume_schema_exists: If True, do not check if the output table schema exists or attempt to create it
:param kwargs: Any keyword arguments supported by the BaseOperator is supported (e.g ``queue``, ``owner``)
:return: Transform functions return a ``Table`` object that can be passed to future tasks.
This table will be either an auto-generated temporary table,
Expand All @@ -124,6 +131,7 @@ def my_sql_statement(table1: Table, table2: Table, execution_date) -> Table:
"database": database,
"schema": schema,
"handler": None,
"assume_schema_exists": assume_schema_exists,
}
)
return task_decorator_factory(
Expand All @@ -140,6 +148,7 @@ def transform_file(
parameters: dict | None = None,
database: str | None = None,
schema: str | None = None,
assume_schema_exists: bool = ASSUME_SCHEMA_EXISTS,
**kwargs: Any,
) -> XComArg:
"""
Expand All @@ -156,6 +165,7 @@ def transform_file(
table.metadata.database in the first Table passed to the function (required if there are no table arguments)
:param schema: Schema within the SQL instance you want to access. If left blank we will default to the
table.metadata.schema in the first Table passed to the function (required if there are no table arguments)
:param assume_schema_exists: If True, do not check if the output table schema exists or attempt to create it
:param kwargs: Any keyword arguments supported by the BaseOperator is supported (e.g ``queue``, ``owner``)
:return: Transform functions return a ``Table`` object that can be passed to future tasks.
This table will be either an auto-generated temporary table,
Expand All @@ -175,6 +185,7 @@ def transform_file(
database=database,
schema=schema,
sql=file_path,
assume_schema_exists=assume_schema_exists,
python_callable=lambda: (file_path, parameters),
**kwargs,
).output
7 changes: 3 additions & 4 deletions python-sdk/tests/databases/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,18 @@

@mock.patch("astro.databases.mssql.MssqlDatabase.schema_exists", return_value=False)
@mock.patch("astro.databases.mssql.MssqlDatabase.run_sql")
def test_create_schema_if_needed(mock_run_sql, mock_schema_exists):
def test_create_schema_if_applicable(mock_run_sql, mock_schema_exists):
"""
Test that run_sql is called with expected arguments when
create_schema_if_needed method is called when the schema is not available
"""
db = MssqlDatabase(conn_id="fake_conn_id")
db.create_schema_if_needed("non-existing-schema")
db.create_schema_if_applicable("non-existing-schema")
mock_run_sql.assert_called_once_with(
"""
IF NOT EXISTS (SELECT 1 FROM sys.schemas WHERE name = 'non-existing-schema')
BEGIN
EXEC( 'CREATE SCHEMA non-existing-schema' );
END
""",
autocommit=True,
"""
)
2 changes: 1 addition & 1 deletion python-sdk/tests/databases/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,5 +236,5 @@ def test_load_file_to_table_skips_schema_check():

file_ = File(path=LOCAL_CSV_FILE)
table = Table(conn_id="fake-conn", metadata=Metadata(schema="abc"))
database.load_file_to_table(input_file=file_, output_table=table, schema_exists=True)
database.load_file_to_table(input_file=file_, output_table=table, assume_schema_exists=True)
assert not database.hook.run.call_count
Loading

0 comments on commit 21e8d47

Please sign in to comment.