Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method 'callproc' on Oracle hook #20072

Merged
merged 3 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion airflow/providers/oracle/hooks/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,26 @@
# under the License.

from datetime import datetime
from typing import List, Optional
from typing import Dict, List, Optional, TypeVar

import cx_Oracle
import numpy

from airflow.hooks.dbapi import DbApiHook

PARAM_TYPES = {bool, float, int, str}

ParameterType = TypeVar('ParameterType', Dict, List, None)


def _map_param(value):
if value in PARAM_TYPES:
# In this branch, value is a Python type; calling it produces
# an instance of the type which is understood by the Oracle driver
# in the out parameter mapping mechanism.
value = value()
return value


class OracleHook(DbApiHook):
"""
Expand Down Expand Up @@ -266,3 +279,55 @@ def bulk_insert_rows(
self.log.info('[%s] inserted %s rows', table, row_count)
cursor.close()
conn.close() # type: ignore[attr-defined]

def callproc(
self,
identifier: str,
autocommit: bool = False,
parameters: ParameterType = None,
) -> ParameterType:
"""
Call the stored procedure identified by the provided string.

Any 'OUT parameters' must be provided with a value of either the
expected Python type (e.g., `int`) or an instance of that type.

The return value is a list or mapping that includes parameters in
both directions; the actual return type depends on the type of the
provided `parameters` argument.

See
https://cx-oracle.readthedocs.io/en/latest/api_manual/cursor.html#Cursor.var
for further reference.
"""
if parameters is None:
parameters = ()

args = ",".join(
f":{name}"
for name in (parameters if isinstance(parameters, dict) else range(1, len(parameters) + 1))
)

sql = f"BEGIN {identifier}({args}); END;"

def handler(cursor):
if isinstance(cursor.bindvars, list):
return [v.getvalue() for v in cursor.bindvars]

if isinstance(cursor.bindvars, dict):
return {n: v.getvalue() for (n, v) in cursor.bindvars.items()}

raise TypeError(f"Unexpected bindvars: {cursor.bindvars!r}")

result = self.run(
sql,
autocommit=autocommit,
parameters=(
{name: _map_param(value) for (name, value) in parameters.items()}
if isinstance(parameters, dict)
else [_map_param(value) for value in parameters]
),
handler=handler,
)

return result
40 changes: 38 additions & 2 deletions airflow/providers/oracle/operators/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterable, List, Mapping, Optional, Union
from typing import Dict, Iterable, List, Mapping, Optional, Union

from airflow.models import BaseOperator
from airflow.providers.oracle.hooks.oracle import OracleHook
Expand Down Expand Up @@ -62,4 +62,40 @@ def __init__(
def execute(self, context) -> None:
self.log.info('Executing: %s', self.sql)
hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
if self.sql:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these should be two separate operators- one for running SQL statements and one for calling procedures.

Copy link
Contributor Author

@malthe malthe Dec 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mik-laj it does perhaps make sense considering that autocommit (see also #20085) is disabled in the operator that runs SQL statements while for calling stored procedures – in the context of an Airflow operator – it makes sense to autocommit (given that the stored procedure itself runs inside a subtransaction).

That is, for a stored procedure operator I would have autocommit enabled as the default setting.

(Probably even for the SQL statement operator, but such defaults are not easily changed.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mik-laj I have split the functionality now, introducing a new operator OracleStoredProcedureOperator. This always has autocommit enabled – because that's really what makes sense.

hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)


class OracleStoredProcedureOperator(BaseOperator):
"""
Executes stored procedure in a specific Oracle database.

:param procedure: name of stored procedure to call (templated)
:type procedure: str
:param oracle_conn_id: The :ref:`Oracle connection id <howto/connection:oracle>`
reference to a specific Oracle database.
:type oracle_conn_id: str
:param parameters: (optional) the parameters provided in the call
:type parameters: dict or iterable
"""

template_fields = ('procedure',)
ui_color = '#ededed'

def __init__(
self,
*,
procedure: str,
oracle_conn_id: str = 'oracle_default',
parameters: Optional[Union[Dict, List]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.oracle_conn_id = oracle_conn_id
self.procedure = procedure
self.parameters = parameters

def execute(self, context) -> None:
self.log.info('Executing: %s', self.procedure)
hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
return hook.callproc(self.procedure, autocommit=True, parameters=self.parameters)
38 changes: 38 additions & 0 deletions tests/providers/oracle/hooks/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,41 @@ def test_bulk_insert_rows_no_rows(self):
rows = []
with pytest.raises(ValueError):
self.db_hook.bulk_insert_rows('table', rows)

def test_callproc_dict(self):
parameters = {"a": 1, "b": 2, "c": 3}

class bindvar(int):
def getvalue(self):
return self

self.cur.bindvars = {k: bindvar(v) for k, v in parameters.items()}
result = self.db_hook.callproc('proc', True, parameters)
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END;', parameters)]
assert result == parameters

def test_callproc_list(self):
parameters = [1, 2, 3]

class bindvar(int):
def getvalue(self):
return self

self.cur.bindvars = list(map(bindvar, parameters))
result = self.db_hook.callproc('proc', True, parameters)
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END;', parameters)]
assert result == parameters

def test_callproc_out_param(self):
parameters = [1, int, float, bool, str]

def bindvar(value):
m = mock.Mock()
m.getvalue.return_value = value
return m

self.cur.bindvars = [bindvar(p() if type(p) is type else p) for p in parameters]
result = self.db_hook.callproc('proc', True, parameters)
expected = [1, 0, 0.0, False, '']
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END;', expected)]
assert result == expected
28 changes: 27 additions & 1 deletion tests/providers/oracle/operators/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from unittest import mock

from airflow.providers.oracle.hooks.oracle import OracleHook
from airflow.providers.oracle.operators.oracle import OracleOperator
from airflow.providers.oracle.operators.oracle import OracleOperator, OracleStoredProcedureOperator


class TestOracleOperator(unittest.TestCase):
Expand All @@ -46,3 +46,29 @@ def test_execute(self, mock_run):
autocommit=autocommit,
parameters=parameters,
)


class TestOracleStoredProcedureOperator(unittest.TestCase):
@mock.patch.object(OracleHook, 'run', autospec=OracleHook.run)
def test_execute(self, mock_run):
procedure = 'test'
oracle_conn_id = 'oracle_default'
parameters = {'parameter': 'value'}
context = "test_context"
task_id = "test_task_id"

operator = OracleStoredProcedureOperator(
procedure=procedure,
oracle_conn_id=oracle_conn_id,
parameters=parameters,
task_id=task_id,
)
result = operator.execute(context=context)
assert result is mock_run.return_value
mock_run.assert_called_once_with(
mock.ANY,
'BEGIN test(:parameter); END;',
autocommit=True,
parameters=parameters,
handler=mock.ANY,
)