Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 66 additions & 30 deletions be/src/udf/python/python_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,28 +812,44 @@ def load(self) -> AdaptivePythonUDF:
class ModuleUDFLoader(UDFLoader):
"""Loads a UDF from a Python module file (.py)."""

# Module names that are forbidden for UDFs because they conflict with
# modules already imported by the server process. Loading a user module
# with one of these names would overwrite the entry in sys.modules and
# could break the server itself.
_FORBIDDEN_MODULE_NAMES: frozenset = frozenset({
"argparse", "base64", "gc", "importlib", "inspect", "ipaddress",
"json", "sys", "os", "traceback", "logging", "time", "threading",
"pickle", "abc", "contextlib", "typing", "datetime", "enum",
"pathlib", "pandas", "pd", "pyarrow", "pa", "flight",
"logging.handlers",
})

# Class-level lock dictionary for thread-safe module imports
# Using RLock allows the same thread to acquire the lock multiple times
# Key: (location, module_name) tuple to avoid conflicts between different locations
_import_locks: Dict[Tuple[str, str], threading.RLock] = {}

# Key for _import_locks: module_name only (not location)
# sys.modules is a global dict keyed by module name.
# we need to ensure that imports with the same module name
# do not interfere with each other across different threads,
# even if they come from different file paths.
_import_locks: Dict[str, threading.Lock] = {}
_import_locks_lock = threading.Lock()
_module_cache: Dict[Tuple[str, str], Any] = {}

# Key for _module_cache: location only
# since location already contains a unique function_id
_module_cache: Dict[str, Any] = {}
_module_cache_lock = threading.Lock()

@classmethod
def _get_import_lock(cls, location: str, module_name: str) -> threading.RLock:
def _get_import_lock(cls, module_name: str) -> threading.Lock:
"""
Get or create a reentrant lock for the given location and module name.
Get or create a reentrant lock for the given module name.

Uses double-checked locking pattern for optimal performance:
- Fast path: return existing lock without acquiring global lock
- Slow path: create new lock under global lock protection

Args:
location: The directory path where the module is located
module_name: The full module name to import
"""
cache_key = (location, module_name)
cache_key = module_name

# Fast path: check without lock (read-only, safe for most cases)
if cache_key in cls._import_locks:
Expand All @@ -843,7 +859,7 @@ def _get_import_lock(cls, location: str, module_name: str) -> threading.RLock:
with cls._import_locks_lock:
# Double-check: another thread might have created it while we waited
if cache_key not in cls._import_locks:
cls._import_locks[cache_key] = threading.RLock()
cls._import_locks[cache_key] = threading.Lock()
return cls._import_locks[cache_key]

def load(self) -> AdaptivePythonUDF:
Expand Down Expand Up @@ -911,20 +927,46 @@ def parse_symbol(self, symbol: str):

return package_name, module_name, func_name

@staticmethod
def _clear_modules_from_sys(full_module_name: str) -> None:
"""Remove a module and all its ancestor packages from sys.modules.

To prevent the same module from being polluted by old caches
when loaded from different paths.
e.g., the pkg under path_a affecting the pkg.mdu_a under path_b,
the ancestor chain is cleared after each import.

This ensures that subsequent imports always start from a fresh state.
"""
parts = full_module_name.split(".")
for i in range(len(parts)):
ancestor = ".".join(parts[: i + 1])
sys.modules.pop(ancestor, None)

def _get_or_import_module(self, location: str, full_module_name: str) -> Any:
"""
Get module from cache or import it (thread-safe).

Uses a location-aware cache to prevent conflicts when different locations
have modules with the same name.
The cache is keyed by location alone, which already contains a unique
function_id assigned by the FE catalog.
"""
cache_key = (location, full_module_name)
# Reject module names that would shadow server-critical modules
top_level_name = full_module_name.split(".")[0]
if top_level_name in ModuleUDFLoader._FORBIDDEN_MODULE_NAMES:
raise ImportError(
f"Module name '{full_module_name}' is not allowed for UDFs "
f"because it conflicts with a module used by the server. "
f"Please rename your module to avoid shadowing built-in or "
f"server-critical modules."
)

cache_key = location

# Use a per-(location, module) lock to prevent race conditions during import
import_lock = ModuleUDFLoader._get_import_lock(location, full_module_name)
# Use a per-module lock to prevent race conditions during import
import_lock = ModuleUDFLoader._get_import_lock(full_module_name)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Stale comment. This now uses a per-module lock (not per-(location, module) lock). Should be updated to:

# Use a per-module lock to prevent race conditions during import

with import_lock:
# Fast path: check location-aware cache first
# Fast path: check cache first
if cache_key in ModuleUDFLoader._module_cache:
cached_module = ModuleUDFLoader._module_cache[cache_key]
if cached_module is not None and (
Expand All @@ -935,25 +977,19 @@ def _get_or_import_module(self, location: str, full_module_name: str) -> Any:
else:
del ModuleUDFLoader._module_cache[cache_key]

# Before importing, clear any existing module with the same name in sys.modules
# that might have been loaded from a different location
if full_module_name in sys.modules:
existing_module = sys.modules[full_module_name]
existing_file = getattr(existing_module, "__file__", None)
# Check if the existing module is from a different location
if existing_file and not existing_file.startswith(location):
del sys.modules[full_module_name]
self._clear_modules_from_sys(full_module_name)

with temporary_sys_path(location):
try:
module = importlib.import_module(full_module_name)
# Store in location-aware cache
ModuleUDFLoader._module_cache[cache_key] = module
# Evict from sys.modules so future imports from a
# different location are not poisoned by this one.
self._clear_modules_from_sys(full_module_name)
return module
except Exception:
# Clean up any partially-imported modules
if full_module_name in sys.modules:
del sys.modules[full_module_name]
self._clear_modules_from_sys(full_module_name)
if cache_key in ModuleUDFLoader._module_cache:
del ModuleUDFLoader._module_cache[cache_key]
raise
Expand Down Expand Up @@ -2540,8 +2576,8 @@ def _clear_modules_from_location(self, location: str) -> list:
# This ensures no concurrent _get_or_import_module is in progress
# for this (location, module_name) pair.
for key in keys_to_remove:
loc, module_name = key
import_lock = ModuleUDFLoader._get_import_lock(loc, module_name)
_, module_name = key
import_lock = ModuleUDFLoader._get_import_lock(module_name)

with import_lock:
with ModuleUDFLoader._module_cache_lock:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !mid_forbidden_udaf_ok --
63

Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !pkg_isolation_1 --
6 2006

-- !pkg_isolation_2 --
6 1006

-- !pkg_isolation_3 --
1006 3006

-- !pkg_isolation_4 --
6 1006 2006 3006

Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !mid_forbidden_ok --
20

13 changes: 13 additions & 0 deletions regression-test/data/pythonudf_p0/test_pythonudf_pkg_isolation.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !pkg_isolation_1 --
15 105

-- !pkg_isolation_2 --
15 25

-- !pkg_isolation_3 --
25 205

-- !pkg_isolation_4 --
20 30 110 210

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !mid_forbidden_udtf_ok --
10 20
20 30
30 40

Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !pkg_isolation_1 --
1 201
2 202

-- !pkg_isolation_2 --
1 101
2 102

-- !pkg_isolation_3 --
101 301
102 302

-- !pkg_isolation_4 --
1 101 201 301
2 102 202 302

Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("test_pythonudaf_forbidden_module") {
// Test that top-level UDAF module names shadowing server-critical modules
// are rejected, while a packaged UDAF with a forbidden middle module name still works.

def pyPath = """${context.file.parent}/udaf_scripts/python_udaf_forbidden_module.zip"""
scp_udf_file_to_all_be(pyPath)
def runtime_version = "3.8.10"
def forbiddenCases = [
[name: "os", function: "py_forbidden_os_udaf", symbol: "os.ForbiddenUDAF"],
[name: "pathlib", function: "py_forbidden_pathlib_udaf", symbol: "pathlib.ForbiddenUDAF"],
[name: "pickle", function: "py_forbidden_pickle_udaf", symbol: "pickle.ForbiddenUDAF"],
[name: "datetime", function: "py_forbidden_datetime_udaf", symbol: "datetime.ForbiddenUDAF"],
]
log.info("Python Zip path: ${pyPath}".toString())

try {
// Create test table
sql """ DROP TABLE IF EXISTS udaf_forbidden_test """
sql """
CREATE TABLE udaf_forbidden_test (
id INT,
val INT
) DISTRIBUTED BY HASH(id) PROPERTIES("replication_num" = "1");
"""

sql """ INSERT INTO udaf_forbidden_test VALUES (1, 10), (2, 20), (3, 30); """

forbiddenCases.each { forbiddenCase ->
sql """ DROP FUNCTION IF EXISTS ${forbiddenCase.function}(INT); """
sql """
CREATE AGGREGATE FUNCTION ${forbiddenCase.function}(INT)
RETURNS BIGINT
PROPERTIES (
"type" = "PYTHON_UDF",
"file" = "file://${pyPath}",
"symbol" = "${forbiddenCase.symbol}",
"runtime_version" = "${runtime_version}"
);
"""

test {
sql """ SELECT ${forbiddenCase.function}(val) FROM udaf_forbidden_test; """
exception "is not allowed for UDFs"
}
}

sql """ DROP FUNCTION IF EXISTS py_mid_forbidden_udaf_ok(INT); """
sql """
CREATE AGGREGATE FUNCTION py_mid_forbidden_udaf_ok(INT)
RETURNS BIGINT
PROPERTIES (
"type" = "PYTHON_UDF",
"file" = "file://${pyPath}",
"symbol" = "safepkg_udaf.pathlib.SafePathlibUDAF",
"runtime_version" = "${runtime_version}"
);
"""

qt_mid_forbidden_udaf_ok """ SELECT py_mid_forbidden_udaf_ok(val) AS result FROM udaf_forbidden_test; """

} finally {
forbiddenCases.each { forbiddenCase ->
try_sql("DROP FUNCTION IF EXISTS ${forbiddenCase.function}(INT);")
}
try_sql("DROP FUNCTION IF EXISTS py_mid_forbidden_udaf_ok(INT);")
try_sql("DROP TABLE IF EXISTS udaf_forbidden_test")
}
}
Loading
Loading