diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py index 2d4d74e828f09..c259eb2963e43 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py @@ -17,8 +17,10 @@ from __future__ import annotations import logging +import multiprocessing import os import sys +import threading from concurrent.futures import ProcessPoolExecutor from concurrent.futures.process import BrokenProcessPool from datetime import datetime @@ -819,59 +821,57 @@ def _on_task_instance_manual_state_change( def _execute(self, callable, callable_name: str, use_fork: bool = False): if use_fork: - self._fork_execute(callable, callable_name) + self._thread_execute(callable, callable_name) else: callable() - def _terminate_with_wait(self, process: psutil.Process): - process.terminate() - try: - # Waiting for max 3 seconds to make sure process can clean up before being killed. - process.wait(timeout=3) - except psutil.TimeoutExpired: - # If it's not dead by then, then force kill. - process.kill() - - def _fork_execute(self, callable, callable_name: str): - self.log.debug("Will fork to execute OpenLineage process.") - pid = os.fork() - if pid: - process = psutil.Process(pid) - try: - self.log.debug("Waiting for process %s", pid) - process.wait(conf.execution_timeout()) - except psutil.TimeoutExpired: - self.log.warning( - "OpenLineage process with pid `%s` expired and will be terminated by listener. " - "This has no impact on actual task execution status.", - pid, - ) - self._terminate_with_wait(process) - except BaseException: - # Kill the process directly. - self._terminate_with_wait(process) - self.log.debug("Process with pid %s finished - parent", pid) - else: - setproctitle(getproctitle() + " - OpenLineage - " + callable_name) - if not AIRFLOW_V_3_0_PLUS: - configure_orm(disable_connection_pool=True) - self.log.debug("Executing OpenLineage process - %s - pid %s", callable_name, os.getpid()) + + def _thread_execute(self, callable, callable_name: str): + """Execute callable in a daemon thread with timeout. + + Replaces the previous ``os.fork()`` approach to avoid the + ``DeprecationWarning`` on Python 3.12+ about forking in + multi-threaded processes (which can also lead to deadlocks + when a thread holds a lock at fork time). + + A daemon thread shares the parent's address space, so the + callable's closures (which capture non-picklable ORM models + and extractors) work without serialization. ``join(timeout)`` + provides the same timeout protection as the old fork path. + """ + self.log.debug("Will execute OpenLineage callable in thread.") + + def _target(): + self.log.debug( + "Executing OpenLineage process - %s - thread %s", + callable_name, + threading.current_thread().name, + ) try: callable() - self.log.debug("Process with current pid finishes after %s", callable_name) + self.log.debug("Thread finishes after %s", callable_name) except Exception: self.log.warning( - "OpenLineage %s process failed. This has no impact on actual task execution status.", + "OpenLineage %s thread failed. This has no impact on actual task execution status.", callable_name, exc_info=True, ) - finally: - # os._exit(0) bypasses Python's atexit/stdio flush. Explicitly shut down - # logging so buffered records (including any warnings above) are flushed - # before the process exits. Without this, the final log lines are silently - # dropped, making failures invisible. - logging.shutdown() - os._exit(0) + + thread = threading.Thread( + target=_target, + name=f"OpenLineage-{callable_name}", + daemon=True, + ) + thread.start() + thread.join(timeout=conf.execution_timeout()) + if thread.is_alive(): + self.log.warning( + "OpenLineage thread %r did not finish within %s seconds. " + "Continuing without waiting. " + "This has no impact on actual task execution status.", + callable_name, + conf.execution_timeout(), + ) @property def executor(self) -> ProcessPoolExecutor: @@ -879,6 +879,7 @@ def executor(self) -> ProcessPoolExecutor: self._executor = ProcessPoolExecutor( max_workers=conf.dag_state_change_process_pool_size(), initializer=_executor_initializer, + mp_context=multiprocessing.get_context("forkserver"), ) return self._executor