Skip to content

Commit

Permalink
Add method 'callproc' on Oracle hook
Browse files Browse the repository at this point in the history
  • Loading branch information
malthe committed Dec 7, 2021
1 parent 93a6e20 commit aa51274
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 1 deletion.
67 changes: 66 additions & 1 deletion airflow/providers/oracle/hooks/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,26 @@
# under the License.

from datetime import datetime
from typing import List, Optional
from typing import Dict, List, Optional, TypeVar, Union

import cx_Oracle
import numpy

from airflow.hooks.dbapi import DbApiHook

PARAM_TYPES = {bool, float, int, str}

ParameterType = TypeVar('ParameterType', Dict, List, None)


def _map_param(value):
if value in PARAM_TYPES:
# In this branch, value is a Python type; calling it produces
# an instance of the type which is understood by the Oracle driver
# in the out parameter mapping mechanism.
value = value()
return value


class OracleHook(DbApiHook):
"""
Expand Down Expand Up @@ -266,3 +279,55 @@ def bulk_insert_rows(
self.log.info('[%s] inserted %s rows', table, row_count)
cursor.close()
conn.close() # type: ignore[attr-defined]

def callproc(
self,
identifier: str,
autocommit: bool = False,
parameters: ParameterType = None,
) -> ParameterType:
"""
Call the stored procedure identified by the provided string.
Any 'OUT parameters' must be provided with a value of either the
expected Python type (e.g., `int`) or an instance of that type.
The return value is a list or mapping that includes parameters in
both directions; the actual return type depends on the type of the
provided `parameters` argument.
See
https://cx-oracle.readthedocs.io/en/latest/api_manual/cursor.html#Cursor.var
for further reference.
"""
if parameters is None:
parameters = ()

args = ",".join(
f":{name}"
for name in (parameters if isinstance(parameters, dict) else range(1, len(parameters) + 1))
)

sql = f"BEGIN {identifier}({args}); END;"

def handler(cursor):
if isinstance(cursor.bindvars, list):
return [v.getvalue() for v in cursor.bindvars]

if isinstance(cursor.bindvars, dict):
return {n: v.getvalue() for (n, v) in cursor.bindvars.items()}

raise TypeError(f"Unexpected bindvars: {cursor.bindvars!r}")

result = self.run(
sql,
autocommit=autocommit,
parameters=(
{name: _map_param(value) for (name, value) in parameters.items()}
if isinstance(parameters, dict)
else [_map_param(value) for value in parameters]
),
handler=handler,
)

return result
38 changes: 38 additions & 0 deletions tests/providers/oracle/hooks/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,41 @@ def test_bulk_insert_rows_no_rows(self):
rows = []
with pytest.raises(ValueError):
self.db_hook.bulk_insert_rows('table', rows)

def test_callproc_dict(self):
parameters = {"a": 1, "b": 2, "c": 3}

class bindvar(int):
def getvalue(self):
return self

self.cur.bindvars = {k: bindvar(v) for k, v in parameters.items()}
result = self.db_hook.callproc('proc', True, parameters)
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END;', parameters)]
assert result == parameters

def test_callproc_list(self):
parameters = [1, 2, 3]

class bindvar(int):
def getvalue(self):
return self

self.cur.bindvars = list(map(bindvar, parameters))
result = self.db_hook.callproc('proc', True, parameters)
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END;', parameters)]
assert result == parameters

def test_callproc_out_param(self):
parameters = [1, int, float, bool, str]

def bindvar(value):
m = mock.Mock()
m.getvalue.return_value = value
return m

self.cur.bindvars = [bindvar(p() if type(p) is type else p) for p in parameters]
result = self.db_hook.callproc('proc', True, parameters)
expected = [1, 0, 0.0, False, '']
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END;', expected)]
assert result == expected

0 comments on commit aa51274

Please sign in to comment.