Skip to content

Commit

Permalink
Add optional 'procedure' parameter to Oracle operator
Browse files Browse the repository at this point in the history
  • Loading branch information
malthe committed Dec 7, 2021
1 parent aa51274 commit 4be2aaf
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 4 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/oracle/hooks/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.

from datetime import datetime
from typing import Dict, List, Optional, TypeVar, Union
from typing import Dict, List, Optional, TypeVar

import cx_Oracle
import numpy
Expand Down
40 changes: 38 additions & 2 deletions airflow/providers/oracle/operators/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterable, List, Mapping, Optional, Union
from typing import Dict, Iterable, List, Mapping, Optional, Union

from airflow.models import BaseOperator
from airflow.providers.oracle.hooks.oracle import OracleHook
Expand Down Expand Up @@ -62,4 +62,40 @@ def __init__(
def execute(self, context) -> None:
self.log.info('Executing: %s', self.sql)
hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
if self.sql:
hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)


class OracleStoredProcedureOperator(BaseOperator):
"""
Executes stored procedure in a specific Oracle database.
:param procedure: name of stored procedure to call (templated)
:type procedure: str
:param oracle_conn_id: The :ref:`Oracle connection id <howto/connection:oracle>`
reference to a specific Oracle database.
:type oracle_conn_id: str
:param parameters: (optional) the parameters provided in the call
:type parameters: dict or iterable
"""

template_fields = ('procedure',)
ui_color = '#ededed'

def __init__(
self,
*,
procedure: str,
oracle_conn_id: str = 'oracle_default',
parameters: Optional[Union[Dict, List]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.oracle_conn_id = oracle_conn_id
self.procedure = procedure
self.parameters = parameters

def execute(self, context) -> None:
self.log.info('Executing: %s', self.procedure)
hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
return hook.callproc(self.procedure, autocommit=True, parameters=self.parameters)
29 changes: 28 additions & 1 deletion tests/providers/oracle/operators/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from unittest import mock

from airflow.providers.oracle.hooks.oracle import OracleHook
from airflow.providers.oracle.operators.oracle import OracleOperator
from airflow.providers.oracle.operators.oracle import OracleOperator, OracleStoredProcedureOperator


class TestOracleOperator(unittest.TestCase):
Expand All @@ -46,3 +46,30 @@ def test_execute(self, mock_run):
autocommit=autocommit,
parameters=parameters,
)


class TestOracleStoredProcedureOperator(unittest.TestCase):
@mock.patch.object(OracleHook, 'run', autospec=OracleHook.run)
def test_execute(self, mock_run):
procedure = 'test'
oracle_conn_id = 'oracle_default'
parameters = {'parameter': 'value'}
autocommit = False
context = "test_context"
task_id = "test_task_id"

operator = OracleStoredProcedureOperator(
procedure=procedure,
oracle_conn_id=oracle_conn_id,
parameters=parameters,
task_id=task_id,
)
result = operator.execute(context=context)
assert result is mock_run.return_value
mock_run.assert_called_once_with(
mock.ANY,
'BEGIN test(:parameter); END;',
autocommit=True,
parameters=parameters,
handler=mock.ANY,
)

0 comments on commit 4be2aaf

Please sign in to comment.