Skip to content

Commit

Permalink
SQLQueryExtractTask read SQL files
Browse files Browse the repository at this point in the history
Why these changes are being introduced:

Some extract tasks that query the data warehouse will
have complex SQL queries.  Storing them in a dedicated
file will support syntax highlighting and testing of
those files directly, while keeping the task definitions
sipmler.

How this addresses that need:
* Add new property SQLQueryExtractTask.sql_file
* Required either sql_query OR sql_file defined

Side effects of this change:
* None

Relevant ticket(s):
* https://mitlibraries.atlassian.net/browse/HRQB-11
  • Loading branch information
ghukill committed May 12, 2024
1 parent 23f3e0a commit 94b9ef1
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 14 deletions.
22 changes: 18 additions & 4 deletions hrqb/base/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,14 @@ def dwclient(self) -> DWClient:
return DWClient() # pragma: nocover

@property
@abstractmethod
def sql_query(self) -> str:
"""SQL query to run."""
def sql_query(self) -> str | None:
"""SQL query from string to execute."""
return None

@property
def sql_file(self) -> str | None:
"""SQL query loaded from file to execute."""
return None

@property
def sql_query_parameters(self) -> dict:
Expand All @@ -146,8 +151,17 @@ def sql_query_parameters(self) -> dict:

def get_dataframe(self) -> pd.DataFrame:
"""Perform SQL query and return DataFrame for required get_dataframe method."""
if self.sql_query:
query = self.sql_query
elif self.sql_file:
with open(self.sql_file) as f:
query = f.read()
else:
message = "Property sql_query or sql_file must be set."
raise AttributeError(message)
return self.dwclient.execute_query(
self.sql_query, params=self.sql_query_parameters
query,
params=self.sql_query_parameters,
)


Expand Down
4 changes: 4 additions & 0 deletions tests/fixtures/sql/animal_color_query.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
select
animal_id,
name
from animal_name
4 changes: 4 additions & 0 deletions tests/fixtures/sql/animal_name_query.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
select
animal_id,
name
from animal_name
6 changes: 2 additions & 4 deletions tests/fixtures/tasks/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,8 @@ def dwclient(self) -> DWClient:
)

@property
def sql_query(self) -> str:
return """
select animal_id, color from animal_color
"""
def sql_file(self) -> str:
return "tests/fixtures/sql/animal_color_query.sql"


class SQLExtractAnimalNames(SQLQueryExtractTask):
Expand Down
28 changes: 22 additions & 6 deletions tests/test_base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,15 @@ def test_base_pipeline_name(task_pipeline_animals):
assert task_pipeline_animals.pipeline_name == "Animals"


def test_base_sql_task_missing_sql_query_property_error(pipeline_name):
class MissingQueryTask(SQLQueryExtractTask):
# missing required sql_query property
def test_base_sql_task_missing_sql_query_or_sql_file_error(pipeline_name):
class MissingRequiredPropertiesQueryTask(SQLQueryExtractTask):
pass

with pytest.raises(TypeError, match="abstract method sql_query"):
MissingQueryTask(pipeline=pipeline_name, stage="Extract")
task = MissingRequiredPropertiesQueryTask(pipeline=pipeline_name, stage="Extract")
with pytest.raises(
AttributeError, match="Property sql_query or sql_file must be set."
):
task.get_dataframe()


def test_base_sql_task_custom_dwclient(task_sql_extract_animal_names):
Expand All @@ -176,13 +178,27 @@ def test_base_sql_task_sql_query(task_sql_extract_animal_names):
)


def test_base_sql_task_get_dataframe_executes_sql_query_return_dataframe(
def test_base_sql_task_sql_file(task_sql_extract_animal_colors):
assert (
task_sql_extract_animal_colors.sql_file
== "tests/fixtures/sql/animal_color_query.sql"
)


def test_base_sql_task_sql_query_get_dataframe_return_dataframe(
task_sql_extract_animal_names,
):
df = task_sql_extract_animal_names.get_dataframe()
assert isinstance(df, pd.DataFrame)


def test_base_sql_task_sql_file_get_dataframe_return_dataframe(
task_sql_extract_animal_colors,
):
df = task_sql_extract_animal_colors.get_dataframe()
assert isinstance(df, pd.DataFrame)


def test_base_sql_task_sql_query_parameters_used(
pipeline_name, sqlite_dwclient, task_sql_extract_animal_colors
):
Expand Down

0 comments on commit 94b9ef1

Please sign in to comment.