Skip to content

Commit

Permalink
XCOM push ORA error code in OracleStoredProcedure (#27319)
Browse files Browse the repository at this point in the history
Co-authored-by: Dov Benyomin Sohacheski <b@kloud.email>
Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
  • Loading branch information
3 people committed Dec 13, 2022
1 parent db5995a commit 43530f5
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 1 deletion.
17 changes: 16 additions & 1 deletion airflow/providers/oracle/operators/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
# under the License.
from __future__ import annotations

import re
import warnings
from typing import TYPE_CHECKING, Sequence

import oracledb

from airflow.models import BaseOperator
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.providers.oracle.hooks.oracle import OracleHook
Expand Down Expand Up @@ -69,6 +72,9 @@ class OracleStoredProcedureOperator(BaseOperator):
:param oracle_conn_id: The :ref:`Oracle connection id <howto/connection:oracle>`
reference to a specific Oracle database.
:param parameters: (optional, templated) the parameters provided in the call
If *do_xcom_push* is *True*, the numeric exit code emitted by
the database is pushed to XCom under key ``ORA`` in case of failure.
"""

template_fields: Sequence[str] = (
Expand All @@ -93,4 +99,13 @@ def __init__(
def execute(self, context: Context):
self.log.info("Executing: %s", self.procedure)
hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
return hook.callproc(self.procedure, autocommit=True, parameters=self.parameters)
try:
return hook.callproc(self.procedure, autocommit=True, parameters=self.parameters)
except oracledb.DatabaseError as e:
if not self.do_xcom_push or not context:
raise
ti = context["ti"]
code_match = re.search("^ORA-(\\d+):.+", str(e))
if code_match:
ti.xcom_push(key="ORA", value=code_match.group(1))
raise
60 changes: 60 additions & 0 deletions tests/providers/oracle/operators/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,53 @@
# under the License.
from __future__ import annotations

from random import randrange
from unittest import mock

import oracledb
import pendulum
import pytest

from airflow.models import DAG, DagModel, DagRun, TaskInstance
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.oracle.hooks.oracle import OracleHook
from airflow.providers.oracle.operators.oracle import OracleOperator, OracleStoredProcedureOperator
from airflow.utils.session import create_session
from airflow.utils.timezone import datetime
from airflow.utils.types import DagRunType

DEFAULT_DATE = datetime(2017, 1, 1)


def create_context(task, persist_to_db=False, map_index=None):
if task.has_dag():
dag = task.dag
else:
dag = DAG(dag_id="dag", start_date=pendulum.now())
dag.add_task(task)
dag_run = DagRun(
run_id=DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE),
run_type=DagRunType.MANUAL,
dag_id=dag.dag_id,
)
task_instance = TaskInstance(task=task, run_id=dag_run.run_id)
task_instance.dag_run = dag_run
if map_index is not None:
task_instance.map_index = map_index
if persist_to_db:
with create_session() as session:
session.add(DagModel(dag_id=dag.dag_id))
session.add(dag_run)
session.add(task_instance)
session.commit()
return {
"dag": dag,
"ts": DEFAULT_DATE.isoformat(),
"task": task,
"ti": task_instance,
"task_instance": task_instance,
"run_id": "test",
}


class TestOracleOperator:
Expand Down Expand Up @@ -78,3 +118,23 @@ def test_execute(self, mock_run):
parameters=parameters,
handler=mock.ANY,
)

@mock.patch.object(OracleHook, "callproc", autospec=OracleHook.callproc)
def test_push_oracle_exit_to_xcom(self, mock_callproc):
# Test pulls the value previously pushed to xcom and checks if it's the same
procedure = "test_push"
oracle_conn_id = "oracle_default"
parameters = {"parameter": "value"}
task_id = "test_push"
ora_exit_code = "%05d" % randrange(10**5)
task = OracleStoredProcedureOperator(
procedure=procedure, oracle_conn_id=oracle_conn_id, parameters=parameters, task_id=task_id
)
context = create_context(task, persist_to_db=True)
mock_callproc.side_effect = oracledb.DatabaseError(
"ORA-" + ora_exit_code + ": This is a five-digit ORA error code"
)
try:
task.execute(context=context)
except oracledb.DatabaseError:
assert task.xcom_pull(key="ORA", context=context, task_ids=[task_id])[0] == ora_exit_code

0 comments on commit 43530f5

Please sign in to comment.