diff --git a/python/pyspark/errors/utils.py b/python/pyspark/errors/utils.py index cd30463802840..9155bfb54abe8 100644 --- a/python/pyspark/errors/utils.py +++ b/python/pyspark/errors/utils.py @@ -21,6 +21,7 @@ import os import threading from typing import Any, Callable, Dict, Match, TypeVar, Type, Optional, TYPE_CHECKING +import pyspark from pyspark.errors.error_classes import ERROR_CLASSES_MAP if TYPE_CHECKING: @@ -164,9 +165,29 @@ def _capture_call_site(spark_session: "SparkSession", depth: int) -> str: The call site information is used to enhance error messages with the exact location in the user code that led to the error. """ - stack = list(reversed(inspect.stack())) + # Filtering out PySpark code and keeping user code only + pyspark_root = os.path.dirname(pyspark.__file__) + stack = [ + frame_info for frame_info in inspect.stack() if pyspark_root not in frame_info.filename + ] + selected_frames = stack[:depth] - call_sites = [f"{frame.filename}:{frame.lineno}" for frame in selected_frames] + + # We try import here since IPython is not a required dependency + try: + from IPython import get_ipython + + ipython = get_ipython() + except ImportError: + ipython = None + + # Identifying the cell is useful when the error is generated from IPython Notebook + if ipython: + call_sites = [ + f"line {frame.lineno} in cell [{ipython.execution_count}]" for frame in selected_frames + ] + else: + call_sites = [f"{frame.filename}:{frame.lineno}" for frame in selected_frames] call_sites_str = "\n".join(call_sites) return call_sites_str