Skip to content

Commit

Permalink
Code review updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ghukill committed May 13, 2024
1 parent 94b9ef1 commit 82fe684
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 47 deletions.
19 changes: 5 additions & 14 deletions hrqb/utils/data_warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,11 @@ class DWClient:
factory=lambda: Config().DATA_WAREHOUSE_CONNECTION_STRING,
repr=False,
)
engine_parameters: dict | None = field(default=None)
engine_parameters: dict = field(factory=lambda: {"thick_mode": True})
engine: Engine = field(default=None)

@staticmethod
def default_engine_parameters() -> dict:
return {"thick_mode": True}

def validate_data_warehouse_connection_string(self) -> None:
"""Validates that a proper connection is configured."""
def verify_connection_string_set(self) -> None:
"""Verify that a connection string is set explicitly or by env var default."""
if not self.connection_string:
message = (
"Data Warehouse connection string not found. Please pass explicitly to "
Expand All @@ -41,14 +37,9 @@ def init_engine(self) -> None:
User provided engine parameters will override self.default_engine_parameters.
"""
self.validate_data_warehouse_connection_string()
self.verify_connection_string_set()
if not self.engine:
engine_parameters = (
self.engine_parameters
if self.engine_parameters is not None
else self.default_engine_parameters()
)
self.engine = create_engine(self.connection_string, **engine_parameters)
self.engine = create_engine(self.connection_string, **self.engine_parameters)

def execute_query(self, query: str, params: dict | None = None) -> pd.DataFrame:
"""Execute SQL query, with optional parameters, returning a pandas Dataframe.
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ExtractAnimalNames,
SQLExtractAnimalColors,
SQLExtractAnimalNames,
SQLQueryWithParameters,
)
from tests.fixtures.tasks.load import LoadAnimals
from tests.fixtures.tasks.pipelines import Animals, AnimalsDebug
Expand Down Expand Up @@ -150,6 +151,11 @@ def task_pipeline_animals_debug(pipeline_name):
return AnimalsDebug()


@pytest.fixture
def task_extract_sql_query_with_parameters(pipeline_name):
return SQLQueryWithParameters(pipeline=pipeline_name)


@pytest.fixture
def task_extract_animal_names_target(targets_directory, task_extract_animal_names):
shutil.copy(
Expand Down
29 changes: 27 additions & 2 deletions tests/fixtures/tasks/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from hrqb.base import PandasPickleTask, SQLQueryExtractTask
from hrqb.utils.data_warehouse import DWClient

SQLITE_CONNECTION_STRING = "sqlite:///tests/fixtures/sql_extract_task_test_data.sqlite"


class ExtractAnimalColors(PandasPickleTask):
pipeline = luigi.Parameter()
Expand Down Expand Up @@ -40,7 +42,7 @@ class SQLExtractAnimalColors(SQLQueryExtractTask):
@property
def dwclient(self) -> DWClient:
return DWClient(
connection_string="sqlite:///tests/fixtures/sql_extract_task_test_data.sqlite",
connection_string=SQLITE_CONNECTION_STRING,
engine_parameters={},
)

Expand All @@ -56,7 +58,7 @@ class SQLExtractAnimalNames(SQLQueryExtractTask):
@property
def dwclient(self) -> DWClient:
return DWClient(
connection_string="sqlite:///tests/fixtures/sql_extract_task_test_data.sqlite",
connection_string=SQLITE_CONNECTION_STRING,
engine_parameters={},
)

Expand All @@ -65,3 +67,26 @@ def sql_query(self) -> str:
return """
select animal_id, name from animal_name
"""


class SQLQueryWithParameters(SQLQueryExtractTask):
stage = luigi.Parameter("Extract")

@property
def dwclient(self) -> DWClient:
return DWClient(
connection_string=SQLITE_CONNECTION_STRING,
engine_parameters={},
)

@property
def sql_query(self) -> str:
return """
select
:foo_val as foo,
:bar_val as bar
"""

@property
def sql_query_parameters(self) -> dict:
return {"foo_val": 42, "bar_val": "apple"}
30 changes: 3 additions & 27 deletions tests/test_base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,35 +199,11 @@ def test_base_sql_task_sql_file_get_dataframe_return_dataframe(
assert isinstance(df, pd.DataFrame)


def test_base_sql_task_sql_query_parameters_used(
pipeline_name, sqlite_dwclient, task_sql_extract_animal_colors
):
foo_val, bar_val = 42, "apple"

class SQLQueryWithParameters(SQLQueryExtractTask):
stage = luigi.Parameter("Extract")

@property
def dwclient(self) -> DWClient:
return sqlite_dwclient

@property
def sql_query(self) -> str:
return """
select
:foo_val as foo,
:bar_val as bar
"""

@property
def sql_query_parameters(self) -> dict:
return {"foo_val": foo_val, "bar_val": bar_val}

task = SQLQueryWithParameters(pipeline=pipeline_name)
df = task.get_dataframe()
def test_base_sql_task_sql_query_parameters_used(task_extract_sql_query_with_parameters):
df = task_extract_sql_query_with_parameters.get_dataframe()
assert isinstance(df, pd.DataFrame)
row = df.iloc[0]
assert (row.foo, row.bar) == (foo_val, bar_val)
assert (row.foo, row.bar) == (42, "apple")


def test_base_sql_task_run_writes_pickled_dataframe(task_sql_extract_animal_names):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_data_warehouse_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def test_dwclient_default_engine_parameters():
assert DWClient.default_engine_parameters() == {"thick_mode": True}
assert DWClient().engine_parameters == {"thick_mode": True}


def test_dwclient_default_connection_string_from_env_var_success(
Expand All @@ -27,14 +27,14 @@ def test_dwclient_validate_connection_string_explicit_success(
):
monkeypatch.delenv("DATA_WAREHOUSE_CONNECTION_STRING")
dwclient = DWClient(connection_string=data_warehouse_connection_string)
assert dwclient.validate_data_warehouse_connection_string() is None
assert dwclient.verify_connection_string_set() is None


def test_dwclient_validate_connection_string_env_var_success(
monkeypatch, data_warehouse_connection_string
):
dwclient = DWClient(connection_string=data_warehouse_connection_string)
assert dwclient.validate_data_warehouse_connection_string() is None
assert dwclient.verify_connection_string_set() is None


def test_dwclient_validate_connection_string_missing_error(
Expand All @@ -43,7 +43,7 @@ def test_dwclient_validate_connection_string_missing_error(
monkeypatch.delenv("DATA_WAREHOUSE_CONNECTION_STRING")
dwclient = DWClient()
with pytest.raises(AttributeError, match="connection string not found"):
dwclient.validate_data_warehouse_connection_string()
dwclient.verify_connection_string_set()


def test_dwclient_sqlite_connection_string_and_engine_success(sqlite_dwclient):
Expand Down

0 comments on commit 82fe684

Please sign in to comment.