From 4be2aaf162821c57925d51f75aeeab433d65490f Mon Sep 17 00:00:00 2001 From: Malthe Borch Date: Mon, 6 Dec 2021 16:10:42 +0100 Subject: [PATCH] Add optional 'procedure' parameter to Oracle operator --- airflow/providers/oracle/hooks/oracle.py | 2 +- airflow/providers/oracle/operators/oracle.py | 40 ++++++++++++++++++- .../providers/oracle/operators/test_oracle.py | 29 +++++++++++++- 3 files changed, 67 insertions(+), 4 deletions(-) diff --git a/airflow/providers/oracle/hooks/oracle.py b/airflow/providers/oracle/hooks/oracle.py index 9231b2ade3876..f07919777cfae 100644 --- a/airflow/providers/oracle/hooks/oracle.py +++ b/airflow/providers/oracle/hooks/oracle.py @@ -17,7 +17,7 @@ # under the License. from datetime import datetime -from typing import Dict, List, Optional, TypeVar, Union +from typing import Dict, List, Optional, TypeVar import cx_Oracle import numpy diff --git a/airflow/providers/oracle/operators/oracle.py b/airflow/providers/oracle/operators/oracle.py index dcc07a207fbaa..b80d570b7dc95 100644 --- a/airflow/providers/oracle/operators/oracle.py +++ b/airflow/providers/oracle/operators/oracle.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Iterable, List, Mapping, Optional, Union +from typing import Dict, Iterable, List, Mapping, Optional, Union from airflow.models import BaseOperator from airflow.providers.oracle.hooks.oracle import OracleHook @@ -62,4 +62,40 @@ def __init__( def execute(self, context) -> None: self.log.info('Executing: %s', self.sql) hook = OracleHook(oracle_conn_id=self.oracle_conn_id) - hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) + if self.sql: + hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) + + +class OracleStoredProcedureOperator(BaseOperator): + """ + Executes stored procedure in a specific Oracle database. + + :param procedure: name of stored procedure to call (templated) + :type procedure: str + :param oracle_conn_id: The :ref:`Oracle connection id ` + reference to a specific Oracle database. + :type oracle_conn_id: str + :param parameters: (optional) the parameters provided in the call + :type parameters: dict or iterable + """ + + template_fields = ('procedure',) + ui_color = '#ededed' + + def __init__( + self, + *, + procedure: str, + oracle_conn_id: str = 'oracle_default', + parameters: Optional[Union[Dict, List]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.oracle_conn_id = oracle_conn_id + self.procedure = procedure + self.parameters = parameters + + def execute(self, context) -> None: + self.log.info('Executing: %s', self.procedure) + hook = OracleHook(oracle_conn_id=self.oracle_conn_id) + return hook.callproc(self.procedure, autocommit=True, parameters=self.parameters) diff --git a/tests/providers/oracle/operators/test_oracle.py b/tests/providers/oracle/operators/test_oracle.py index 8565efe6aea46..6fc77105a66eb 100644 --- a/tests/providers/oracle/operators/test_oracle.py +++ b/tests/providers/oracle/operators/test_oracle.py @@ -19,7 +19,7 @@ from unittest import mock from airflow.providers.oracle.hooks.oracle import OracleHook -from airflow.providers.oracle.operators.oracle import OracleOperator +from airflow.providers.oracle.operators.oracle import OracleOperator, OracleStoredProcedureOperator class TestOracleOperator(unittest.TestCase): @@ -46,3 +46,30 @@ def test_execute(self, mock_run): autocommit=autocommit, parameters=parameters, ) + + +class TestOracleStoredProcedureOperator(unittest.TestCase): + @mock.patch.object(OracleHook, 'run', autospec=OracleHook.run) + def test_execute(self, mock_run): + procedure = 'test' + oracle_conn_id = 'oracle_default' + parameters = {'parameter': 'value'} + autocommit = False + context = "test_context" + task_id = "test_task_id" + + operator = OracleStoredProcedureOperator( + procedure=procedure, + oracle_conn_id=oracle_conn_id, + parameters=parameters, + task_id=task_id, + ) + result = operator.execute(context=context) + assert result is mock_run.return_value + mock_run.assert_called_once_with( + mock.ANY, + 'BEGIN test(:parameter); END;', + autocommit=True, + parameters=parameters, + handler=mock.ANY, + )