Skip to content

Commit

Permalink
Merge branch 'main' into databricks/create-example
Browse files Browse the repository at this point in the history
  • Loading branch information
dimberman committed Feb 21, 2023
2 parents d8546b1 + 3610312 commit b1623a3
Show file tree
Hide file tree
Showing 42 changed files with 5,314 additions and 4,059 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/ci-benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ jobs:
ASTRO_DEPLOYMENT_ID: ${{ secrets.ASTRO_DEPLOYMENT_ID }}
ASTRO_KEY_ID: ${{ secrets.ASTRO_KEY_ID }}
ASTRO_KEY_SECRET: ${{ secrets.ASTRO_KEY_SECRET }}
ASTRO_DEPLOYMENT_ID_SINGLE_WORKER: ${{ secrets.ASTRO_DEPLOYMENT_ID_SINGLE_WORKER }}
ASTRO_KEY_ID_SINGLE_WORKER: ${{ secrets.ASTRO_KEY_ID_SINGLE_WORKER }}
ASTRO_KEY_SECRET_SINGLE_WORKER: ${{ secrets.ASTRO_KEY_SECRET_SINGLE_WORKER }}
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- run: cd python-sdk/tests_integration/astro_deploy && sh deploy.sh $ASTRO_DOCKER_REGISTRY $ASTRO_ORGANIZATION_ID $ASTRO_DEPLOYMENT_ID $ASTRO_KEY_ID $ASTRO_KEY_SECRET
- run: cd python-sdk/tests_integration/astro_deploy && sh deploy.sh $ASTRO_DOCKER_REGISTRY $ASTRO_ORGANIZATION_ID $ASTRO_DEPLOYMENT_ID $ASTRO_KEY_ID $ASTRO_KEY_SECRET $ASTRO_DEPLOYMENT_ID_SINGLE_WORKER $ASTRO_KEY_ID_SINGLE_WORKER $ASTRO_KEY_SECRET_SINGLE_WORKER
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ repos:
additional_dependencies: [black>=22.10.0]

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.246'
rev: 'v0.0.249'
hooks:
- id: ruff
args:
Expand All @@ -86,7 +86,7 @@ repos:
types: [text]
exclude: ^mk/.*\.mk$|^python-sdk/docs/Makefile|^python-sdk/Makefile$|^python-sdk/tests/modified_constraint_file.txt$|^python-sdk/tests/benchmark/Makefile$|^sql-cli/poetry.lock$
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.0.0'
rev: 'v1.0.1'
hooks:
- id: mypy
name: mypy-python-sdk
Expand Down
2 changes: 1 addition & 1 deletion python-sdk/dev/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM quay.io/astronomer/astro-runtime:7.2.0-base
FROM quay.io/astronomer/astro-runtime:7.3.0-base

USER root
RUN apt-get update -y && apt-get install -y git
Expand Down
14 changes: 14 additions & 0 deletions python-sdk/docs/astro/sql/operators/load_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,20 @@ Supported native transfers

Reference on how to create such a role is here: https://www.dataliftoff.com/iam-roles-for-loading-data-from-s3-into-redshift/

Loading from MinIO
~~~~~~~~~~~~~~~~~~
MinIO is a High-Performance Object Storage released under GNU Affero General Public License v3.0. It is API-compatible with the Amazon S3 cloud storage service. While loading files from MinIO to the database please make sure the MinIO server is up and running. Also, It's important to pass ``endpoint_url`` in connection to distinguish between S3 and MinIO location, based on which we choose the loading option, for MinIO we don't have native load options so we override the ``use_native_support=False`` to force loading via pandas options. Following is an example of a MinIO connection:

.. code-block::
- conn_id: minio_conn
conn_type: aws
description: null
extra:
aws_access_key_id: "dummy access key"
aws_secret_access_key: "dummy secret key"
endpoint_url: "http://127.0.0.1:9000"
Loading to MS SQL
~~~~~~~~~~~~~~~~~

Expand Down
50 changes: 33 additions & 17 deletions python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@
import logging
import warnings
from abc import ABC
from typing import TYPE_CHECKING, Any, Callable, Mapping
from typing import Any, Callable, Mapping

import pandas as pd
import sqlalchemy
from airflow.hooks.dbapi import DbApiHook
from pandas.io.sql import SQLDatabase
from sqlalchemy import column, insert, select

from astro.dataframes.pandas import PandasDataframe

if TYPE_CHECKING: # pragma: no cover
from sqlalchemy.engine.cursor import CursorResult

from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.schema import Table as SqlaTable
Expand All @@ -29,6 +23,7 @@
LoadExistStrategy,
MergeConflictStrategy,
)
from astro.dataframes.pandas import PandasDataframe
from astro.exceptions import DatabaseCustomError, NonExistentTableException
from astro.files import File, resolve_file_path_pattern
from astro.files.types import create_file_type
Expand Down Expand Up @@ -63,8 +58,6 @@ class BaseDatabase(ABC):
# illegal_column_name_chars[0] will be replaced by value in illegal_column_name_chars_replacement[0]
illegal_column_name_chars: list[str] = []
illegal_column_name_chars_replacement: list[str] = []
# In run_raw_sql operator decides if we want to return results directly or process them by handler provided
IGNORE_HANDLER_IN_RUN_RAW_SQL: bool = False
NATIVE_PATHS: dict[Any, Any] = {}
DEFAULT_SCHEMA = SCHEMA
NATIVE_LOAD_EXCEPTIONS: Any = DatabaseCustomError
Expand Down Expand Up @@ -107,8 +100,9 @@ def run_sql(
self,
sql: str | ClauseElement = "",
parameters: dict | None = None,
handler: Callable | None = None,
**kwargs,
) -> CursorResult:
) -> Any:
"""
Return the results to running a SQL statement.
Expand All @@ -118,6 +112,7 @@ def run_sql(
:param sql: Contains SQL query to be run against database
:param parameters: Optional parameters to be used to render the query
:param autocommit: Optional autocommit flag
:param handler: function that takes in a cursor as an argument.
"""
if parameters is None:
parameters = {}
Expand All @@ -139,7 +134,9 @@ def run_sql(
)
else:
result = self.connection.execute(sql, parameters)
return result
if handler:
return handler(result)
return None

def columns_exist(self, table: BaseTable, columns: list[str]) -> bool:
"""
Expand Down Expand Up @@ -407,7 +404,7 @@ def create_schema_and_table_if_needed(
use_native_support=use_native_support,
)

def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list:
def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> Any:
"""
Fetches all rows for a table and returns as a list. This is needed because some
databases have different cursors that require different methods to fetch rows
Expand All @@ -419,8 +416,21 @@ def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list:
statement = f"SELECT * FROM {self.get_table_qualified_name(table)}"
if row_limit > -1:
statement = statement + f" LIMIT {row_limit}"
response = self.run_sql(statement)
return response.fetchall() # type: ignore
response: list = self.run_sql(statement, handler=lambda x: x.fetchall())
return response

@staticmethod
def check_for_minio_connection(input_file: File) -> bool:
"""Automatically check if the connection is minio or S3"""
is_minio = False
if input_file.location.location_type == FileLocation.S3 and input_file.conn_id:
conn = input_file.location.hook.get_connection(input_file.conn_id)
try:
conn.extra_dejson["endpoint_url"]
is_minio = True
except KeyError:
pass
return is_minio

def load_file_to_table(
self,
Expand Down Expand Up @@ -451,6 +461,12 @@ def load_file_to_table(
:param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer
"""
normalize_config = normalize_config or {}
if self.check_for_minio_connection(input_file=input_file):
logging.info(
"No native support available for the service provided via endpoint_url! Setting use_native_support"
" to False."
)
use_native_support = False

self.create_schema_and_table_if_needed(
file=input_file,
Expand Down Expand Up @@ -551,7 +567,6 @@ def load_file_to_table_natively_with_fallback(
:param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer
:param normalize_config: pandas json_normalize params config
"""

try:
logging.info("Loading file(s) with Native Support...")
self.load_file_to_table_natively(
Expand Down Expand Up @@ -777,8 +792,9 @@ def row_count(self, table: BaseTable):
:return: The number of rows in the table
"""
result = self.run_sql(
f"select count(*) from {self.get_table_qualified_name(table)}" # skipcq: BAN-B608
).scalar()
f"select count(*) from {self.get_table_qualified_name(table)}", # skipcq: BAN-B608
handler=lambda x: x.scalar(),
)
return result

def parameterize_variable(self, variable: str):
Expand Down
8 changes: 3 additions & 5 deletions python-sdk/src/astro/databases/databricks/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid
import warnings
from textwrap import dedent
from typing import Any, Callable

import pandas as pd
from airflow.providers.databricks.hooks.databricks import DatabricksHook
Expand All @@ -25,9 +26,6 @@

class DeltaDatabase(BaseDatabase):
LOAD_OPTIONS_CLASS_NAME = "DeltaLoadOptions"
# In run_raw_sql operator decides if we want to return results directly or process them by handler provided
# For delta tables we ignore the handler
IGNORE_HANDLER_IN_RUN_RAW_SQL: bool = True
_create_table_statement: str = "CREATE TABLE IF NOT EXISTS {} USING DELTA AS {} "

def __init__(self, conn_id: str, table: BaseTable | None = None, load_options: LoadOptions | None = None):
Expand Down Expand Up @@ -197,9 +195,9 @@ def run_sql(
self,
sql: str | ClauseElement = "",
parameters: dict | None = None,
handler=None,
handler: Callable | None = None,
**kwargs,
):
) -> Any:
"""
Run SQL against a delta table using spark SQL.
Expand Down
19 changes: 10 additions & 9 deletions python-sdk/src/astro/databases/mssql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING
from typing import Any, Callable

import pandas as pd
import sqlalchemy
Expand All @@ -17,8 +17,6 @@
from astro.utils.compat.functools import cached_property

DEFAULT_CONN_ID = MsSqlHook.default_conn_name
if TYPE_CHECKING: # pragma: no cover
from sqlalchemy.engine.cursor import CursorResult


class MssqlDatabase(BaseDatabase):
Expand Down Expand Up @@ -145,15 +143,17 @@ def run_sql(
self,
sql: str | ClauseElement = "",
parameters: dict | None = None,
handler: Callable | None = None,
**kwargs,
) -> CursorResult:
) -> Any:
"""
Return the results to running a SQL statement.
Whenever possible, this method should be implemented using Airflow Hooks,
since this will simplify the integration with Async operators.
:param sql: Contains SQL query to be run against database
:param parameters: Optional parameters to be used to render the query
:param handler: function that takes in a cursor as an argument.
"""
if parameters is None:
parameters = {}
Expand All @@ -177,11 +177,12 @@ def run_sql(
result = self.connection.execute(
sqlalchemy.text(sql).execution_options(autocommit=autocommit), parameters
)
return result
else:
# this is used for append
result = self.connection.execute(sql, parameters)
return result
if handler:
return handler(result)
return None

def create_schema_if_needed(self, schema: str | None) -> None:
"""
Expand Down Expand Up @@ -226,7 +227,7 @@ def drop_table(self, table: BaseTable) -> None:
statement = self._drop_table_statement.format(self.get_table_qualified_name(table))
self.run_sql(statement, autocommit=True)

def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list:
def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> Any:
"""
Fetches all rows for a table and returns as a list. This is needed because some
databases have different cursors that require different methods to fetch rows
Expand All @@ -238,8 +239,8 @@ def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list:
statement = f"SELECT * FROM {self.get_table_qualified_name(table)}" # skipcq: BAN-B608
if row_limit > -1:
statement = f"SELECT TOP {row_limit} * FROM {self.get_table_qualified_name(table)}"
response = self.run_sql(statement)
return response.fetchall() # type: ignore
response: list = self.run_sql(statement, handler=lambda x: x.fetchall())
return response

def load_pandas_dataframe_to_table(
self,
Expand Down
17 changes: 2 additions & 15 deletions python-sdk/src/astro/sql/operators/raw_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
def execute(self, context: Context) -> Any:
super().execute(context)

self.handler = self.get_handler()
result = self.database_impl.run_sql(sql=self.sql, parameters=self.parameters, handler=self.handler)
if self.response_size == -1 and not settings.IS_CUSTOM_XCOM_BACKEND:
logging.warning(
Expand All @@ -60,22 +61,8 @@ def execute(self, context: Context) -> Any:
"backend."
)

# ToDo: Currently, the handler param in run_sql() method is only used in databricks all other databases are
# not using it. Which leads to different response types since handler is processed within `run_sql()` for
# databricks and not for other databases. Also the signature of `run_sql()` in databricks deviates from base.
# We need to standardise and when we do, we can remove below check as well.
if self.database_impl.IGNORE_HANDLER_IN_RUN_RAW_SQL:
return result

self.handler = self.get_handler()

if self.handler:
self.handler = self.get_wrapped_handler(
fail_on_empty=self.fail_on_empty, conversion_func=self.handler
)
# otherwise, call the handler and convert the result to a list
response = self.handler(result)
response = self.make_row_serializable(response)
response = self.make_row_serializable(result)
if 0 <= self.response_limit < len(response):
raise IllegalLoadToDatabaseException() # pragma: no cover
if self.response_size >= 0:
Expand Down
36 changes: 36 additions & 0 deletions python-sdk/tests/databases/test_base_database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pathlib
from unittest import mock

import pytest
from airflow.models.connection import Connection
from pandas import DataFrame

from astro.constants import FileType
Expand Down Expand Up @@ -96,3 +98,37 @@ def test_subclass_missing_append_table_raises_exception():
target_table = Table()
with pytest.raises(NotImplementedError):
db.append_table(source_table, target_table, source_to_target_columns_map={})


@mock.patch("astro.files.locations.base.BaseFileLocation.validate_conn")
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
def test_database_for_minio_conn_with_check_for_minio_connection(get_connection, validate_conn):
database = create_database("sqlite_default")
get_connection.return_value = Connection(
conn_id="minio_conn",
conn_type="aws",
extra={
"aws_access_key_id": "",
"aws_secret_access_key": "",
"endpoint_url": "http://127.0.0.1:9000",
},
)
assert (
database.check_for_minio_connection(
input_file=File(path="S3://somebucket/test.csv", conn_id="minio_conn")
)
is True
)


@mock.patch("astro.files.locations.base.BaseFileLocation.validate_conn")
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
def test_database_for_s3_conn_with_check_for_minio_connection(get_connection, validate_conn):
database = create_database("sqlite_default")
get_connection.return_value = Connection(conn_id="aws", conn_type="aws")
assert (
database.check_for_minio_connection(
input_file=File(path="S3://somebucket/test.csv", conn_id="aws_conn")
)
is False
)

0 comments on commit b1623a3

Please sign in to comment.