Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class TableauOperator(BaseOperator):
:param blocking_refresh: By default will be blocking means it will wait until it has finished.
:param check_interval: time in seconds that the job should wait in
between each instance state checks until operation is completed
:param incremental_refresh: Whether to perform an incremental refresh instead of a full refresh.
Only applies to datasource and workbook refresh operations. Defaults to False (full refresh).
:param tableau_conn_id: The :ref:`Tableau Connection id <howto/connection:tableau>`
containing the credentials to authenticate to the Tableau Server.
"""
Expand All @@ -81,6 +83,7 @@ def __init__(
site_id: str | None = None,
blocking_refresh: bool = True,
check_interval: float = 20,
incremental_refresh: bool = False,
tableau_conn_id: str = "tableau_default",
**kwargs,
) -> None:
Expand All @@ -92,6 +95,7 @@ def __init__(
self.check_interval = check_interval
self.site_id = site_id
self.blocking_refresh = blocking_refresh
self.incremental_refresh = incremental_refresh
self.tableau_conn_id = tableau_conn_id

Comment thread
Subham-KRLX marked this conversation as resolved.
def execute(self, context: Context) -> str:
Expand All @@ -111,6 +115,13 @@ def execute(self, context: Context) -> str:
error_message = f"Method not found! Available methods for {self.resource}: {available_methods}"
raise AirflowException(error_message)

if self.incremental_refresh and self.method != "refresh":
self.log.warning(
"incremental_refresh parameter is set to True but method is '%s'. "
"This parameter only applies to 'refresh' operations and will be ignored.",
self.method,
)

with TableauHook(self.site_id, self.tableau_conn_id) as tableau_hook:
resource = getattr(tableau_hook.server, self.resource)
method = getattr(resource, self.method)
Expand All @@ -124,6 +135,10 @@ def execute(self, context: Context) -> str:
if not job_items:
raise ValueError("Tableau tasks.run returned no JobItem in response")
job_id = job_items[0].id
elif self.method == "refresh":
# For refresh operations, pass incremental_refresh parameter
response = method(resource_id, incremental=self.incremental_refresh)
job_id = response.id
else:
response = method(resource_id)
job_id = response.id
Expand Down
149 changes: 145 additions & 4 deletions providers/tableau/tests/unit/tableau/operators/test_tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_execute_workbooks(self, mock_tableau_hook):

job_id = operator.execute(context={})

mock_tableau_hook.server.workbooks.refresh.assert_called_once_with(2)
mock_tableau_hook.server.workbooks.refresh.assert_called_once_with(2, incremental=False)
assert mock_tableau_hook.server.workbooks.refresh.return_value.id == job_id

@patch("airflow.providers.tableau.operators.tableau.TableauHook")
Expand Down Expand Up @@ -106,7 +106,7 @@ def mock_wait_for_state(job_id, target_state, check_interval):

job_id = operator.execute(context={})

mock_tableau_hook.server.workbooks.refresh.assert_called_once_with(2)
mock_tableau_hook.server.workbooks.refresh.assert_called_once_with(2, incremental=False)
assert mock_tableau_hook.server.workbooks.refresh.return_value.id == job_id
mock_tableau_hook.wait_for_state.assert_called_once_with(
job_id=job_id, check_interval=20, target_state=TableauJobFinishCode.SUCCESS
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_execute_datasources(self, mock_tableau_hook):

job_id = operator.execute(context={})

mock_tableau_hook.server.datasources.refresh.assert_called_once_with(2)
mock_tableau_hook.server.datasources.refresh.assert_called_once_with(2, incremental=False)
assert mock_tableau_hook.server.datasources.refresh.return_value.id == job_id

@patch("airflow.providers.tableau.operators.tableau.TableauHook")
Expand Down Expand Up @@ -167,7 +167,7 @@ def mock_wait_for_state(job_id, target_state, check_interval):

job_id = operator.execute(context={})

mock_tableau_hook.server.datasources.refresh.assert_called_once_with(2)
mock_tableau_hook.server.datasources.refresh.assert_called_once_with(2, incremental=False)
assert mock_tableau_hook.server.datasources.refresh.return_value.id == job_id
mock_tableau_hook.wait_for_state.assert_called_once_with(
job_id=job_id, check_interval=20, target_state=TableauJobFinishCode.SUCCESS
Expand Down Expand Up @@ -277,3 +277,144 @@ def test_get_resource_id(self):
resource_id = "res_id"
operator = TableauOperator(resource="tasks", find=resource_id, method="run", task_id="t", dag=None)
assert operator._get_resource_id(Mock()) == resource_id

@patch("airflow.providers.tableau.operators.tableau.TableauHook")
def test_execute_datasources_incremental_refresh(self, mock_tableau_hook):
"""
Test execute datasources with incremental refresh
"""
mock_tableau_hook.get_all = Mock(return_value=self.mock_datasources)
mock_tableau_hook.return_value.__enter__ = Mock(return_value=mock_tableau_hook)
operator = TableauOperator(
blocking_refresh=False,
find="ds_2",
resource="datasources",
incremental_refresh=True,
**self.kwargs,
)

job_id = operator.execute(context={})

mock_tableau_hook.server.datasources.refresh.assert_called_once_with(2, incremental=True)
assert mock_tableau_hook.server.datasources.refresh.return_value.id == job_id

@patch("airflow.providers.tableau.operators.tableau.TableauHook")
def test_execute_datasources_full_refresh(self, mock_tableau_hook):
"""
Test execute datasources with full refresh (default behavior)
"""
mock_tableau_hook.get_all = Mock(return_value=self.mock_datasources)
mock_tableau_hook.return_value.__enter__ = Mock(return_value=mock_tableau_hook)
operator = TableauOperator(
blocking_refresh=False,
find="ds_2",
resource="datasources",
incremental_refresh=False,
**self.kwargs,
)

job_id = operator.execute(context={})

mock_tableau_hook.server.datasources.refresh.assert_called_once_with(2, incremental=False)
assert mock_tableau_hook.server.datasources.refresh.return_value.id == job_id

@patch("airflow.providers.tableau.operators.tableau.TableauHook")
def test_execute_workbooks_incremental_refresh(self, mock_tableau_hook):
"""
Test execute workbooks with incremental refresh
"""
mock_tableau_hook.get_all = Mock(return_value=self.mocked_workbooks)
mock_tableau_hook.return_value.__enter__ = Mock(return_value=mock_tableau_hook)
operator = TableauOperator(
blocking_refresh=False,
find="wb_2",
resource="workbooks",
incremental_refresh=True,
**self.kwargs,
)

job_id = operator.execute(context={})

mock_tableau_hook.server.workbooks.refresh.assert_called_once_with(2, incremental=True)
assert mock_tableau_hook.server.workbooks.refresh.return_value.id == job_id

@patch("airflow.providers.tableau.operators.tableau.TableauHook")
def test_execute_workbooks_full_refresh(self, mock_tableau_hook):
"""
Test execute workbooks with full refresh (default behavior)
"""
mock_tableau_hook.get_all = Mock(return_value=self.mocked_workbooks)
mock_tableau_hook.return_value.__enter__ = Mock(return_value=mock_tableau_hook)
operator = TableauOperator(
blocking_refresh=False,
find="wb_2",
resource="workbooks",
incremental_refresh=False,
**self.kwargs,
)

job_id = operator.execute(context={})

mock_tableau_hook.server.workbooks.refresh.assert_called_once_with(2, incremental=False)
assert mock_tableau_hook.server.workbooks.refresh.return_value.id == job_id

@patch("airflow.providers.tableau.operators.tableau.TableauHook")
def test_execute_datasources_incremental_refresh_blocking(self, mock_tableau_hook):
"""
Test execute datasources with incremental refresh blocking
"""
mock_signed_in = [False]

def mock_hook_enter():
mock_signed_in[0] = True
return mock_tableau_hook

def mock_hook_exit(exc_type, exc_val, exc_tb):
mock_signed_in[0] = False

def mock_wait_for_state(job_id, target_state, check_interval):
if not mock_signed_in[0]:
raise Exception("Not signed in")
return True

mock_tableau_hook.return_value.__enter__ = Mock(side_effect=mock_hook_enter)
mock_tableau_hook.return_value.__exit__ = Mock(side_effect=mock_hook_exit)
mock_tableau_hook.wait_for_state = Mock(side_effect=mock_wait_for_state)
mock_tableau_hook.get_all = Mock(return_value=self.mock_datasources)

operator = TableauOperator(
find="ds_2",
resource="datasources",
incremental_refresh=True,
**self.kwargs,
)

job_id = operator.execute(context={})

mock_tableau_hook.server.datasources.refresh.assert_called_once_with(2, incremental=True)
assert mock_tableau_hook.server.datasources.refresh.return_value.id == job_id
mock_tableau_hook.wait_for_state.assert_called_once_with(
job_id=job_id, check_interval=20, target_state=TableauJobFinishCode.SUCCESS
)

@patch("airflow.providers.tableau.operators.tableau.TableauHook")
def test_incremental_refresh_warning_on_non_refresh_method(self, mock_tableau_hook, caplog):
"""
Test that a warning is logged when incremental_refresh is set but method is not 'refresh'
"""
mock_tableau_hook.return_value.__enter__ = Mock(return_value=mock_tableau_hook)
mock_tableau_hook.get_all = Mock(return_value=self.mock_datasources)

operator = TableauOperator(
find="ds_2",
resource="datasources",
method="delete",
incremental_refresh=True,
dag=None,
task_id="test",
)

operator.execute(context={})

assert "incremental_refresh parameter is set to True but method is 'delete'" in caplog.text
assert "This parameter only applies to 'refresh' operations" in caplog.text
Loading