From 45daec4bfbdefffea791ede37a6e3cea9db743c6 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Fri, 21 Nov 2025 13:46:50 -0800 Subject: [PATCH] Respect session timezone in udf --- .../scala/org/apache/spark/api/python/PythonRunner.scala | 4 ++++ python/pyspark/sql/types.py | 5 ++--- python/pyspark/worker.py | 9 +++++++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 7f1dc7fc86fc..1ed9110a1afb 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -199,6 +199,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( conf.get(PYTHON_DAEMON_KILL_WORKER_ON_FLUSH_FAILURE) protected val hideTraceback: Boolean = false protected val simplifiedTraceback: Boolean = false + protected val sessionLocalTimeZone = conf.getOption("spark.sql.session.timeZone") // All the Python functions should have the same exec, version and envvars. protected val envVars: java.util.Map[String, String] = funcs.head.funcs.head.envVars @@ -282,6 +283,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( if (simplifiedTraceback) { envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") } + if (sessionLocalTimeZone.isDefined) { + envVars.put("SPARK_SESSION_LOCAL_TIMEZONE", sessionLocalTimeZone.get) + } // SPARK-30299 this could be wrong with standalone mode when executor // cores might not be correct because it defaults to all cores on the box. val execCores = execCoresProp.map(_.toInt).getOrElse(conf.get(EXECUTOR_CORES)) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index fc534e48a7ae..cdd7cb495fc2 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -452,9 +452,8 @@ def needConversion(self) -> bool: def toInternal(self, dt: datetime.datetime) -> int: if dt is not None: - seconds = ( - calendar.timegm(dt.utctimetuple()) if dt.tzinfo else time.mktime(dt.timetuple()) - ) + tzinfo = dt.tzinfo if dt.tzinfo else self.tz_info + seconds = calendar.timegm(dt.utctimetuple()) if tzinfo else time.mktime(dt.timetuple()) return int(seconds) * 1000000 + dt.microsecond def fromInternal(self, ts: int) -> datetime.datetime: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 94e3b2728d08..c3852da877a8 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -27,6 +27,7 @@ import inspect import itertools import json +import zoneinfo from typing import Any, Callable, Iterable, Iterator, Optional, Tuple from pyspark.accumulators import ( @@ -3304,8 +3305,12 @@ def main(infile, outfile): sys.exit(-1) start_faulthandler_periodic_traceback() - # Use the local timezone to convert the timestamp - tz = datetime.datetime.now().astimezone().tzinfo + tzname = os.environ.get("SPARK_SESSION_LOCAL_TIMEZONE", None) + if tzname: + tz = zoneinfo.ZoneInfo(tzname) + else: + # Use the local timezone to convert the timestamp + tz = datetime.datetime.now().astimezone().tzinfo TimestampType.tz_info = tz check_python_version(infile)