Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@

if TYPE_CHECKING:
from pyspark._typing import SupportsIAdd
import socketserver.BaseRequestHandler # type: ignore[import-not-found]
from socketserver import BaseRequestHandler


__all__ = ["Accumulator", "AccumulatorParam"]

T = TypeVar("T")
U = TypeVar("U", bound="SupportsIAdd")
U = TypeVar("U", bound=Union["SupportsIAdd", int, float, complex])

pickleSer = CPickleSerializer()

Expand Down Expand Up @@ -240,14 +240,14 @@ def zero(self, value: U) -> U:
return self.zero_value

def addInPlace(self, value1: U, value2: U) -> U:
value1 += value2 # type: ignore[operator]
value1 += value2 # type: ignore[operator, assignment]
return value1


# Singleton accumulator params for some standard types
INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) # type: ignore[type-var]
FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) # type: ignore[type-var]
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) # type: ignore[type-var]
INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0)
FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)


class UpdateRequestHandler(socketserver.StreamRequestHandler):
Expand All @@ -256,10 +256,12 @@ class UpdateRequestHandler(socketserver.StreamRequestHandler):
server is shutdown.
"""

server: Union["AccumulatorTCPServer", "AccumulatorUnixServer"]

def handle(self) -> None:
from pyspark.accumulators import _accumulatorRegistry

auth_token = self.server.auth_token # type: ignore[attr-defined]
auth_token = self.server.auth_token

def poll(func: Callable[[], bool]) -> None:
poller = None
Expand All @@ -269,7 +271,7 @@ def poll(func: Callable[[], bool]) -> None:
poller = select.poll()
poller.register(self.rfile, select.POLLIN)

while not self.server.server_shutdown: # type: ignore[attr-defined]
while not self.server.server_shutdown:
# Poll every 1 second for new data -- don't block in case of shutdown.
if poller is not None:
r = []
Expand Down Expand Up @@ -302,6 +304,7 @@ def accum_updates() -> bool:
return False

def authenticate_and_accum_updates() -> bool:
assert auth_token is not None
received_token: Union[bytes, str] = self.rfile.read(len(auth_token))
if isinstance(received_token, bytes):
received_token = received_token.decode("utf-8")
Expand Down Expand Up @@ -329,7 +332,7 @@ class AccumulatorTCPServer(socketserver.TCPServer):
def __init__(
self,
server_address: Tuple[str, int],
RequestHandlerClass: Type["socketserver.BaseRequestHandler"],
RequestHandlerClass: Type["BaseRequestHandler"],
auth_token: str,
):
super().__init__(server_address, RequestHandlerClass)
Expand All @@ -348,25 +351,22 @@ def shutdown(self) -> None:
class AccumulatorUnixServer(socketserver.UnixStreamServer):
server_shutdown = False

def __init__(
self, socket_path: str, RequestHandlerClass: Type[socketserver.BaseRequestHandler]
):
def __init__(self, socket_path: str, RequestHandlerClass: Type["BaseRequestHandler"]):
super().__init__(socket_path, RequestHandlerClass)
self.auth_token = None

def shutdown(self) -> None:
self.server_shutdown = True
super().shutdown()
self.server_close()
if os.path.exists(self.server_address): # type: ignore[arg-type]
os.remove(self.server_address) # type: ignore[arg-type]
assert isinstance(self.server_address, str)
if os.path.exists(self.server_address):
os.remove(self.server_address)

else:

class AccumulatorUnixServer(socketserver.TCPServer): # type: ignore[no-redef]
def __init__(
self, socket_path: str, RequestHandlerClass: Type[socketserver.BaseRequestHandler]
):
def __init__(self, socket_path: str, RequestHandlerClass: Type["BaseRequestHandler"]):
raise NotImplementedError(
"Unix Domain Sockets are not supported on this platform. "
"Please disable it by setting spark.python.unix.domain.socket.enabled to false."
Expand All @@ -377,13 +377,14 @@ def _start_update_server(
auth_token: str, is_unix_domain_sock: bool, socket_path: Optional[str] = None
) -> Union[AccumulatorTCPServer, AccumulatorUnixServer]:
"""Start a TCP or Unix Domain Socket server for accumulator updates."""
server: Union[AccumulatorTCPServer, AccumulatorUnixServer]
if is_unix_domain_sock:
assert socket_path is not None
if os.path.exists(socket_path):
os.remove(socket_path)
server = AccumulatorUnixServer(socket_path, UpdateRequestHandler)
else:
server = AccumulatorTCPServer(("localhost", 0), UpdateRequestHandler, auth_token) # type: ignore[assignment]
server = AccumulatorTCPServer(("localhost", 0), UpdateRequestHandler, auth_token)

thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
Expand Down