Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users to disable schema check and creation on load_file #1922

Merged
merged 8 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python-sdk/docs/astro/sql/operators/load_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ Parameters to use when loading a file to a database table

Note that if you use ``if_exists='replace'``, the existing table will be dropped and the schema of the new data will be used.

#. **schema_exists** (default is False) - By default, the SDK checks if the schema of the target table exists, and if not, it tries to create it. This query can be costly. This argument makes the SDK skip this check, since the user is informing the schema already exists.

#. **output_table** - This parameter defines the output table to load data to, which should be an instance of ``astro.sql.table.Table``. You can specify the schema of the table by providing a list of the instance of ``sqlalchemy.Column <https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column>`` to the ``columns`` parameter. If you don't specify a schema, it will be inferred using Pandas.

.. literalinclude:: ../../../../example_dags/example_load_file.py
Expand Down
23 changes: 23 additions & 0 deletions python-sdk/docs/configurations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,29 @@ or by updating Airflow's configuration
redshift_default_schema = "redshift_tmp"
mssql_default_schema = "mssql_tmp"

Configuring if schemas existence should be checked and if the SDK should create them
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

By default, during ``aql.load_file``, the SDK checks if the schema of the target table exists, and if not, it tries to create it. This type of check can be costly.

The configuration ``AIRFLOW__ASTRO_SDK__LOAD_TABLE_SCHEMA_EXISTS`` allows users to inform the SDK that the schema already exists, skipping this check for all ``load_file`` tasks.

The user can also have a more granular control, by defining the ``load_file`` argument ``schema_exists`` on a per-task basis :ref:load_file.

Example of how to disable schema existence check using environment variables:

.. code:: ini

AIRFLOW__ASTRO_SDK__LOAD_TABLE_SCHEMA_EXISTS = True

Or using Airflow's configuration file:

.. code:: ini

[astro_sdk]
load_table_schema_exists = True


Configuring the unsafe dataframe storage
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The dataframes (generated by ``dataframe`` or ``transform`` operators) are stored in XCom table using pickling in the Airflow metadata database. Since this dataframe is defined by the user and if it is huge, it might potentially break Airflow's metadata DB by using all the available resources. Hence, unsafe dataframe storage should be set to ``True`` once you are aware of this risk and are OK with it. Alternatively, you could use a Custom XCom backend to store the XCom data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def example_snowflake_partial_table_with_append():
schema=os.getenv("SNOWFLAKE_SCHEMA"),
),
),
schema_exists=True, # Skip queries that check if the table schema exist
)

homes_data2 = load_file(
Expand All @@ -96,6 +97,7 @@ def example_snowflake_partial_table_with_append():
schema=os.getenv("SNOWFLAKE_SCHEMA"),
),
),
schema_exists=True,
)

# Define task dependencies
Expand Down
17 changes: 13 additions & 4 deletions python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
from astro.files.types.base import FileType as FileTypeConstants
from astro.options import LoadOptions
from astro.query_modifier import QueryModifier
from astro.settings import LOAD_FILE_ENABLE_NATIVE_FALLBACK, LOAD_TABLE_AUTODETECT_ROWS_COUNT, SCHEMA
from astro.settings import (
LOAD_FILE_ENABLE_NATIVE_FALLBACK,
LOAD_TABLE_AUTODETECT_ROWS_COUNT,
LOAD_TABLE_SCHEMA_EXISTS,
SCHEMA,
)
from astro.table import BaseTable, Metadata
from astro.utils.compat.functools import cached_property

Expand Down Expand Up @@ -359,7 +364,7 @@ def drop_table(self, table: BaseTable) -> None:
# Table load methods
# ---------------------------------------------------------

def create_schema_and_table_if_needed(
def create_table_if_needed(
self,
table: BaseTable,
file: File,
Expand Down Expand Up @@ -393,7 +398,6 @@ def create_schema_and_table_if_needed(
):
return

self.create_schema_if_needed(table.metadata.schema)
if if_exists == "replace" or not self.table_exists(table):
files = resolve_file_path_pattern(
file.path,
Expand Down Expand Up @@ -449,6 +453,7 @@ def load_file_to_table(
native_support_kwargs: dict | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
enable_native_fallback: bool | None = LOAD_FILE_ENABLE_NATIVE_FALLBACK,
schema_exists: bool = LOAD_TABLE_SCHEMA_EXISTS,
**kwargs,
):
"""
Expand All @@ -465,6 +470,7 @@ def load_file_to_table(
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
:param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer
:param schema_exists: Declare the table schema already exists and that load_file should not check if it exists
"""
normalize_config = normalize_config or {}
if self.check_for_minio_connection(input_file=input_file):
Expand All @@ -474,7 +480,10 @@ def load_file_to_table(
)
use_native_support = False

self.create_schema_and_table_if_needed(
if not schema_exists:
self.create_schema_if_needed(output_table.metadata.schema)

self.create_table_if_needed(
file=input_file,
table=output_table,
columns_names_capitalization=columns_names_capitalization,
Expand Down
4 changes: 3 additions & 1 deletion python-sdk/src/astro/databases/databricks/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from astro.files import File
from astro.options import LoadOptions
from astro.query_modifier import QueryModifier
from astro.settings import LOAD_TABLE_SCHEMA_EXISTS
from astro.table import BaseTable, Metadata


Expand Down Expand Up @@ -123,6 +124,7 @@ def load_file_to_table(
native_support_kwargs: dict | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
enable_native_fallback: bool | None = None,
schema_exists: bool = LOAD_TABLE_SCHEMA_EXISTS,
databricks_job_name: str = "",
**kwargs,
):
Expand All @@ -142,7 +144,7 @@ def load_file_to_table(
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
:param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer

:param schema_exists: Declare the table schema already exists and that load_file should not check if it exists
"""
load_file_to_delta(
input_file=input_file,
Expand Down
14 changes: 13 additions & 1 deletion python-sdk/src/astro/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,22 @@
section=SECTION_KEY, key="load_table_autodetect_rows_count", fallback=1000
)


#: Reduce responses sizes returned by aql.run_raw_sql to avoid trashing the Airflow DB if the BaseXCom is used.
RAW_SQL_MAX_RESPONSE_SIZE = conf.getint(section=SECTION_KEY, key="run_raw_sql_response_size", fallback=-1)

# Temp changes
# Should Astro SDK automatically add inlets/outlets to take advantage of Airflow 2.4 Data-aware scheduling
AUTO_ADD_INLETS_OUTLETS = conf.getboolean(SECTION_KEY, "auto_add_inlets_outlets", fallback=True)

LOAD_TABLE_SCHEMA_EXISTS = False


def reload():
tatiana marked this conversation as resolved.
Show resolved Hide resolved
"""
Reload settings from environment variable during runtime.
"""
global LOAD_TABLE_SCHEMA_EXISTS # skipcq: PYL-W0603
LOAD_TABLE_SCHEMA_EXISTS = conf.getboolean(SECTION_KEY, "load_table_schema_exists", fallback=False)


reload()
8 changes: 6 additions & 2 deletions python-sdk/src/astro/sql/operators/load_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from airflow.hooks.base import BaseHook
from airflow.models.xcom_arg import XComArg

from astro import settings
from astro.airflow.datasets import kwargs_with_datasets
from astro.constants import DEFAULT_CHUNK_SIZE, ColumnCapitalization, LoadExistStrategy
from astro.databases import create_database
Expand All @@ -21,7 +22,6 @@
from astro.dataframes.pandas import PandasDataframe
from astro.files import File, resolve_file_path_pattern
from astro.options import LoadOptions, LoadOptionsList
from astro.settings import LOAD_FILE_ENABLE_NATIVE_FALLBACK
from astro.sql.operators.base_operator import AstroSQLBaseOperator
from astro.table import BaseTable
from astro.utils.compat.typing import Context
Expand All @@ -47,6 +47,7 @@ class LoadFileOperator(AstroSQLBaseOperator):
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
:param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer
:param schema_exists: Declare the table schema already exists and that load_file should not check if it exists

:return: If ``output_table`` is passed this operator returns a Table object. If not
passed, returns a dataframe.
Expand All @@ -65,7 +66,8 @@ def __init__(
native_support_kwargs: dict | None = None,
load_options: LoadOptions | list[LoadOptions] | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
enable_native_fallback: bool | None = LOAD_FILE_ENABLE_NATIVE_FALLBACK,
enable_native_fallback: bool | None = settings.LOAD_FILE_ENABLE_NATIVE_FALLBACK,
schema_exists: bool = settings.LOAD_TABLE_SCHEMA_EXISTS,
**kwargs,
) -> None:
kwargs.setdefault("task_id", get_unique_task_id("load_file"))
Expand Down Expand Up @@ -112,6 +114,7 @@ def __init__(
self.native_support_kwargs: dict[str, Any] = native_support_kwargs or {}
self.columns_names_capitalization = columns_names_capitalization
self.enable_native_fallback = enable_native_fallback
self.schema_exists = schema_exists
self.load_options_list = LoadOptionsList(load_options)

def execute(self, context: Context) -> BaseTable | File: # skipcq: PYL-W0613
Expand Down Expand Up @@ -159,6 +162,7 @@ def load_data_to_table(self, input_file: File, context: Context) -> BaseTable:
native_support_kwargs=self.native_support_kwargs,
columns_names_capitalization=self.columns_names_capitalization,
enable_native_fallback=self.enable_native_fallback,
schema_exists=self.schema_exists,
databricks_job_name=f"Load data {self.dag_id}_{self.task_id}",
)
self.log.info("Completed loading the data into %s.", self.output_table)
Expand Down
39 changes: 33 additions & 6 deletions python-sdk/tests/databases/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
SNOWFLAKE_STORAGE_INTEGRATION_AMAZON = SNOWFLAKE_STORAGE_INTEGRATION_AMAZON or "aws_int_python_sdk"
SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE = SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE or "gcs_int_python_sdk"

LOCAL_CSV_FILE = str(CWD.parent / "data/homes_main.csv")


def test_stage_set_name_after():
stage = SnowflakeStage()
Expand Down Expand Up @@ -111,11 +113,10 @@ def test_load_file_to_table_natively_for_fallback_raises_exception_if_not_enable


def test_snowflake_load_options():
path = str(CWD) + "/../../data/homes_main.csv"
database = SnowflakeDatabase(
conn_id="fake-conn", load_options=SnowflakeLoadOptions(file_options={"foo": "bar"})
)
file = File(path)
file = File(path=LOCAL_CSV_FILE)
with mock.patch(
"astro.databases.snowflake.SnowflakeDatabase.hook", new_callable=PropertyMock
), mock.patch(
Expand All @@ -132,9 +133,8 @@ def test_snowflake_load_options():


def test_snowflake_load_options_default():
path = str(CWD) + "/../../data/homes_main.csv"
database = SnowflakeDatabase(conn_id="fake-conn", load_options=SnowflakeLoadOptions())
file = File(path)
file = File(path=LOCAL_CSV_FILE)
with mock.patch(
"astro.databases.snowflake.SnowflakeDatabase.hook", new_callable=PropertyMock
), mock.patch(
Expand All @@ -151,8 +151,7 @@ def test_snowflake_load_options_default():


def test_snowflake_load_options_wrong_options():
path = str(CWD) + "/../../data/homes_main.csv"
file = File(path)
file = File(path=LOCAL_CSV_FILE)
with pytest.raises(ValueError, match="Error: Requires a SnowflakeLoadOptions"):
database = SnowflakeDatabase(conn_id="fake-conn", load_options=LoadOptions())
database.load_file_to_table_natively(source_file=file, target_table=Table())
Expand Down Expand Up @@ -211,3 +210,31 @@ def test_storage_integrations_params_in_load_options():
database.load_file_to_table_natively(source_file=file, target_table=table)

assert create_stage.call_args.kwargs["storage_integration"] == "some_integrations"


def test_load_file_to_table_by_default_checks_schema():
database = SnowflakeDatabase(conn_id="fake-conn")
database.run_sql = MagicMock()
database.hook = MagicMock()
database.create_table_using_schema_autodetection = MagicMock()

file_ = File(path=LOCAL_CSV_FILE)
table = Table(conn_id="fake-conn", metadata=Metadata(schema="abc"))
database.load_file_to_table(input_file=file_, output_table=table)
expected = (
"SELECT SCHEMA_NAME from information_schema.schemata WHERE LOWER(SCHEMA_NAME) = %(schema_name)s;"
)
assert database.hook.run.call_args_list[0].args[0] == expected
assert database.hook.run.call_args_list[0].kwargs["parameters"]["schema_name"] == "abc"


def test_load_file_to_table_skips_schema_check():
database = SnowflakeDatabase(conn_id="fake-conn")
database.run_sql = MagicMock()
database.hook = MagicMock()
database.create_table_using_schema_autodetection = MagicMock()

file_ = File(path=LOCAL_CSV_FILE)
table = Table(conn_id="fake-conn", metadata=Metadata(schema="abc"))
database.load_file_to_table(input_file=file_, output_table=table, schema_exists=True)
assert not database.hook.run.call_count
22 changes: 22 additions & 0 deletions python-sdk/tests/test_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import importlib
import os
from unittest.mock import patch

import astro
from astro import settings
from astro.files import File


def test_settings_load_table_schema_exists_default():
from astro.sql import LoadFileOperator

load_file = LoadFileOperator(input_file=File("dummy.csv"))
assert not load_file.schema_exists


@patch.dict(os.environ, {"AIRFLOW__ASTRO_SDK__LOAD_TABLE_SCHEMA_EXISTS": "True"})
def test_settings_load_table_schema_exists_override():
settings.reload()
importlib.reload(astro.sql.operators.load_file)
load_file = astro.sql.operators.load_file.LoadFileOperator(input_file=File("dummy.csv"))
assert load_file.schema_exists
Loading