Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Render sql code with parameters in BaseSQLDecoratedOperator #897

Merged
merged 7 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 74 additions & 2 deletions python-sdk/src/astro/sql/operators/base_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import inspect
from typing import Any, Callable, Sequence, cast

import jinja2
import pandas as pd
from airflow.decorators.base import DecoratedOperator
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.models.xcom_arg import XComArg
from sqlalchemy.sql.functions import Function

from astro.airflow.datasets import kwargs_with_datasets
Expand All @@ -20,7 +23,8 @@
class BaseSQLDecoratedOperator(UpstreamTaskMixin, DecoratedOperator):
"""Handles all decorator classes that can return a SQL function"""

template_fields: Sequence[str] = ("parameters", "op_args", "op_kwargs")
template_fields: Sequence[str] = ("sql", "parameters", "op_args", "op_kwargs")
feluelle marked this conversation as resolved.
Show resolved Hide resolved
template_ext: Sequence[str] = (".sql",)

database_impl: BaseDatabase

Expand Down Expand Up @@ -59,7 +63,56 @@ def __init__(
**kwargs_with_datasets(kwargs=kwargs, output_datasets=self.output_table),
)

def execute(self, context: Context) -> None:
def _resolve_xcom_op_kwargs(self, context: Context) -> None:
"""
Iterate through self.op_kwargs, resolving any XCom values with the given context.
Replace those values in-place.

:param context: The Airflow Context to be used to resolve the op_kwargs.
"""
# TODO: confirm if it makes sense for us to always replace the op_kwargs or if we should
# only replace those that are within the decorator signature, by using
# inspect.signature(self.python_callable).parameters.values()
kwargs = {}
for kwarg_name, kwarg_value in self.op_kwargs.items():
if isinstance(kwarg_value, XComArg):
kwargs[kwarg_name] = kwarg_value.resolve(context)
else:
kwargs[kwarg_name] = kwarg_value
self.op_kwargs = kwargs

def _resolve_xcom_op_args(self, context: Context) -> None:
"""
Iterates through self.op_args, resolving any XCom values with the given context.
Replace those values in-place.

:param context: The Airflow Context used to resolve the op_args.
"""
args = []
for arg_value in self.op_args:
if isinstance(arg_value, XComArg):
item = arg_value.resolve(context)
else:
item = arg_value
args.append(item)
self.op_args = args # type: ignore

def _enrich_context(self, context: Context) -> Context:
"""
Prepare the sql and context for execution.

Specifically, it will do the following:
1. Preset database settings
2. Load dataframes into tables
3. Render sql as sqlalchemy executable string

:param context: The Airflow Context which will be extended.

:return: the enriched context with astro specific information.
"""
self._resolve_xcom_op_args(context)
self._resolve_xcom_op_kwargs(context)

first_table = find_first_table(
op_args=self.op_args, # type: ignore
op_kwargs=self.op_kwargs,
Expand Down Expand Up @@ -108,6 +161,25 @@ def execute(self, context: Context) -> None:
# if there is no SQL to run we raise an error
if self.sql == "" or not self.sql:
raise AirflowException("There's no SQL to run")
return context

def render_template_fields(
self,
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> BaseOperator | None:
"""Template all attributes listed in template_fields.

This mutates the attributes in-place and is irreversible.

:param context: Dict with values to apply on content
:param jinja_env: Jinja environment
"""
context = self._enrich_context(context)
return super().render_template_fields(context, jinja_env)

def execute(self, context: Context) -> None:
self._enrich_context(context)
tatiana marked this conversation as resolved.
Show resolved Hide resolved

# TODO: remove pushing to XCom once we update the airflow version.
context["ti"].xcom_push(key="base_sql_query", value=str(self.sql))
Expand Down
9 changes: 9 additions & 0 deletions python-sdk/tests/sql/operators/test_base_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,12 @@ def test_load_op_kwarg_dataframes_into_sql():

assert isinstance(results["table"], BaseTable)
assert isinstance(results["some_str"], str)


def test_base_sql_decorated_operator_template_fields_and_template_ext_with_sql():
"""
Test that sql is in BaseSQLDecoratedOperator template_fields and template_ext
as this required for rending the sql in the task instance rendered section.
"""
assert "sql" in BaseSQLDecoratedOperator.template_fields
assert ".sql" in BaseSQLDecoratedOperator.template_ext