From 1642e928478c8c20bae5203ecf2e4d659aca7692 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 24 Jan 2024 00:43:41 -0800 Subject: [PATCH] [SPARK-46823][CONNECT][PYTHON] `LocalDataToArrowConversion` should check the nullability ### What changes were proposed in this pull request? `LocalDataToArrowConversion` should check the nullability ### Why are the changes needed? this check was missing ### Does this PR introduce _any_ user-facing change? yes ``` data = [("asd", None)] schema = StructType( [ StructField("name", StringType(), nullable=True), StructField("age", IntegerType(), nullable=False), ] ) ``` before: ``` In [3]: df = spark.createDataFrame([("asd", None)], schema) In [4]: df Out[4]: 24/01/24 12:08:28 ERROR ErrorUtils: Spark Connect RPC error during: analyze. UserId: ruifeng.zheng. SessionId: cd692bb1-d503-4043-a9db-d29cb5c16517. java.lang.IllegalStateException: Value at index is null at org.apache.arrow.vector.IntVector.get(IntVector.java:107) at org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338) at org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88) at org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source) at scala.collection.Iterator$$anon$9.next(Iterator.scala:584) at scala.collection.Iterator$$anon$9.next(Iterator.scala:584) at scala.collection.Iterator$$anon$9.next(Iterator.scala:584) at scala.collection.immutable.List.prependedAll(List.scala:153) at scala.collection.immutable.List$.from(List.scala:684) at scala.collection.immutable.List$.from(List.scala:681) at scala.collection.SeqFactory$Delegate.from(Factory.scala:306) at scala.collection.immutable.Seq$.from(Seq.scala:42) at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326) at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326) at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289) at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94) at org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80) at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182) at org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79) at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42) at org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95) at org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907) at org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182) at org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351) at org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860) at org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37) at org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133) at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136) at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635) at java.base/java.lang.Thread.run(Thread.java:833) 24/01/24 12:08:28 ERROR ErrorUtils: Spark Connect RPC error during: analyze. UserId: ruifeng.zheng. SessionId: cd692bb1-d503-4043-a9db-d29cb5c16517. java.lang.IllegalStateException: Value at index is null at org.apache.arrow.vector.IntVector.get(IntVector.java:107) at org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338) at org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88) at org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source) at scala.collection.Iterator$$anon$9.next(Iterator.scala:584) at scala.collection.Iterator$$anon$9.next(Iterator.scala:584) at scala.collection.Iterator$$anon$9.next(Iterator.scala:584) at scala.collection.immutable.List.prependedAll(List.scala:153) at scala.collection.immutable.List$.from(List.scala:684) at scala.collection.immutable.List$.from(List.scala:681) at scala.collection.SeqFactory$Delegate.from(Factory.scala:306) at scala.collection.immutable.Seq$.from(Seq.scala:42) at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326) at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326) at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289) at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94) at org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80) at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182) at org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79) at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42) at org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95) at org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907) at org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182) at org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351) at org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860) at org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37) at org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133) at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136) at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635) at java.base/java.lang.Thread.run(Thread.java:833) 24/01/24 12:08:28 ERROR ErrorUtils: Spark Connect RPC error during: analyze. UserId: ruifeng.zheng. SessionId: cd692bb1-d503-4043-a9db-d29cb5c16517. java.lang.IllegalStateException: Value at index is null at org.apache.arrow.vector.IntVector.get(IntVector.java:107) at org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338) at org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88) at org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source) at scala.collection.Iterator$$anon$9.next(Iterator.scala:584) at scala.collection.Iterator$$anon$9.next(Iterator.scala:584) at scala.collection.Iterator$$anon$9.next(Iterator.scala:584) at scala.collection.immutable.List.prependedAll(List.scala:153) at scala.collection.immutable.List$.from(List.scala:684) at scala.collection.immutable.List$.from(List.scala:681) at scala.collection.SeqFactory$Delegate.from(Factory.scala:306) at scala.collection.immutable.Seq$.from(Seq.scala:42) at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326) at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326) at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289) at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94) at org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80) at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182) at org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79) at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42) at org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95) at org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907) at org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182) at org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351) at org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860) at org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37) at org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133) at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136) at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635) at java.base/java.lang.Thread.run(Thread.java:833) --------------------------------------------------------------------------- SparkConnectGrpcException Traceback (most recent call last) File ~/.dev/miniconda3/envs/spark_dev_311/lib/python3.11/site-packages/IPython/core/formatters.py:708, in PlainTextFormatter.__call__(self, obj) 701 stream = StringIO() 702 printer = pretty.RepresentationPrinter(stream, self.verbose, 703 self.max_width, self.newline, 704 max_seq_length=self.max_seq_length, 705 singleton_pprinters=self.singleton_printers, 706 type_pprinters=self.type_printers, 707 deferred_pprinters=self.deferred_printers) --> 708 printer.pretty(obj) 709 printer.flush() 710 return stream.getvalue() File ~/.dev/miniconda3/envs/spark_dev_311/lib/python3.11/site-packages/IPython/lib/pretty.py:410, in RepresentationPrinter.pretty(self, obj) 407 return meth(obj, self, cycle) 408 if cls is not object \ 409 and callable(cls.__dict__.get('__repr__')): --> 410 return _repr_pprint(obj, self, cycle) 412 return _default_pprint(obj, self, cycle) 413 finally: File ~/.dev/miniconda3/envs/spark_dev_311/lib/python3.11/site-packages/IPython/lib/pretty.py:778, in _repr_pprint(obj, p, cycle) 776 """A pprint that just redirects to the normal repr function.""" 777 # Find newlines and replace them with p.break_() --> 778 output = repr(obj) 779 lines = output.splitlines() 780 with p.group(): File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:141, in DataFrame.__repr__(self) 135 if repl_eager_eval_enabled == "true": 136 return self._show_string( 137 n=int(cast(str, repl_eager_eval_max_num_rows)), 138 truncate=int(cast(str, repl_eager_eval_truncate)), 139 vertical=False, 140 ) --> 141 return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:238, in DataFrame.dtypes(self) 236 property 237 def dtypes(self) -> List[Tuple[str, str]]: --> 238 return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:1786, in DataFrame.schema(self) 1783 property 1784 def schema(self) -> StructType: 1785 query = self._plan.to_proto(self._session.client) -> 1786 return self._session.client.schema(query) File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:921, in SparkConnectClient.schema(self, plan) 917 """ 918 Return schema for given plan. 919 """ 920 logger.info(f"Schema for plan: {self._proto_to_string(plan)}") --> 921 schema = self._analyze(method="schema", plan=plan).schema 922 assert schema is not None 923 # Server side should populate the struct field which is the schema. File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1107, in SparkConnectClient._analyze(self, method, **kwargs) 1105 raise SparkConnectException("Invalid state during retry exception handling.") 1106 except Exception as error: -> 1107 self._handle_error(error) File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1525, in SparkConnectClient._handle_error(self, error) 1523 self.thread_local.inside_error_handling = True 1524 if isinstance(error, grpc.RpcError): -> 1525 self._handle_rpc_error(error) 1526 elif isinstance(error, ValueError): 1527 if "Cannot invoke RPC" in str(error) and "closed" in str(error): File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1595, in SparkConnectClient._handle_rpc_error(self, rpc_error) 1592 info = error_details_pb2.ErrorInfo() 1593 d.Unpack(info) -> 1595 raise convert_exception( 1596 info, 1597 status.message, 1598 self._fetch_enriched_error(info), 1599 self._display_server_stack_trace(), 1600 ) from None 1602 raise SparkConnectGrpcException(status.message) from None 1603 else: SparkConnectGrpcException: (java.lang.IllegalStateException) Value at index is null JVM stacktrace: java.lang.IllegalStateException at org.apache.arrow.vector.IntVector.get(IntVector.java:107) at org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338) at org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88) at org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(:-1) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(:-1) at scala.collection.Iterator$$anon$9.next(Iterator.scala:584) at scala.collection.Iterator$$anon$9.next(Iterator.scala:584) at scala.collection.Iterator$$anon$9.next(Iterator.scala:584) at scala.collection.immutable.List.prependedAll(List.scala:153) at scala.collection.immutable.List$.from(List.scala:684) at scala.collection.immutable.List$.from(List.scala:681) at scala.collection.SeqFactory$Delegate.from(Factory.scala:306) at scala.collection.immutable.Seq$.from(Seq.scala:42) at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326) at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326) at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289) at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94) at org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80) at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182) at org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79) at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288) at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42) at org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95) at org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907) at org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182) at org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351) at org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860) at org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37) at org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635) at java.lang.Thread.run(Thread.java:833) ``` after: ``` --------------------------------------------------------------------------- PySparkValueError Traceback (most recent call last) Cell In[3], line 1 ----> 1 df = spark.createDataFrame([("asd", None)], schema) File ~/Dev/spark/python/pyspark/sql/connect/session.py:538, in SparkSession.createDataFrame(self, data, schema) 533 from pyspark.sql.connect.conversion import LocalDataToArrowConversion 535 # Spark Connect will try its best to build the Arrow table with the 536 # inferred schema in the client side, and then rename the columns and 537 # cast the datatypes in the server side. --> 538 _table = LocalDataToArrowConversion.convert(_data, _schema) 540 # TODO: Beside the validation on number of columns, we should also check 541 # whether the Arrow Schema is compatible with the user provided Schema. 542 if _num_cols is not None and _num_cols != _table.shape[1]: File ~/Dev/spark/python/pyspark/sql/connect/conversion.py:351, in LocalDataToArrowConversion.convert(data, schema) 342 raise PySparkValueError( 343 error_class="AXIS_LENGTH_MISMATCH", 344 message_parameters={ (...) 347 }, 348 ) 350 for i in range(len(column_names)): --> 351 pylist[i].append(column_convs[i](item[i])) 353 pa_schema = to_arrow_schema( 354 StructType( 355 [ (...) 361 ) 362 ) 364 return pa.Table.from_arrays(pylist, schema=pa_schema) File ~/Dev/spark/python/pyspark/sql/connect/conversion.py:297, in LocalDataToArrowConversion._create_converter..convert_other(value) 295 def convert_other(value: Any) -> Any: 296 if value is None: --> 297 raise PySparkValueError(f"input for {dataType} must not be None") 298 return value PySparkValueError: input for IntegerType() must not be None ``` ### How was this patch tested? added ut ### Was this patch authored or co-authored using generative AI tooling? no Closes #44861 from zhengruifeng/connect_check_nullable. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- python/pyspark/sql/connect/conversion.py | 78 ++++++++++++++++--- .../sql/tests/connect/test_connect_basic.py | 12 +++ 2 files changed, 80 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index fb5a2d4b17b10..c86ee9c75fec9 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -61,14 +61,23 @@ class LocalDataToArrowConversion: """ @staticmethod - def _need_converter(dataType: DataType) -> bool: - if isinstance(dataType, NullType): + def _need_converter( + dataType: DataType, + nullable: bool = True, + ) -> bool: + if not nullable: + # always check the nullability + return True + elif isinstance(dataType, NullType): + # always check the nullability return True elif isinstance(dataType, StructType): # Struct maybe rows, should convert to dict. return True elif isinstance(dataType, ArrayType): - return LocalDataToArrowConversion._need_converter(dataType.elementType) + return LocalDataToArrowConversion._need_converter( + dataType.elementType, dataType.containsNull + ) elif isinstance(dataType, MapType): # Different from PySpark, here always needs conversion, # since an Arrow Map requires a list of tuples. @@ -90,26 +99,41 @@ def _need_converter(dataType: DataType) -> bool: return False @staticmethod - def _create_converter(dataType: DataType) -> Callable: + def _create_converter( + dataType: DataType, + nullable: bool = True, + ) -> Callable: assert dataType is not None and isinstance(dataType, DataType) + assert isinstance(nullable, bool) - if not LocalDataToArrowConversion._need_converter(dataType): + if not LocalDataToArrowConversion._need_converter(dataType, nullable): return lambda value: value if isinstance(dataType, NullType): - return lambda value: None + + def convert_null(value: Any) -> Any: + if value is not None: + raise PySparkValueError(f"input for {dataType} must be None, but got {value}") + return None + + return convert_null elif isinstance(dataType, StructType): field_names = dataType.fieldNames() dedup_field_names = _dedup_names(dataType.names) field_convs = [ - LocalDataToArrowConversion._create_converter(field.dataType) + LocalDataToArrowConversion._create_converter( + field.dataType, + field.nullable, + ) for field in dataType.fields ] def convert_struct(value: Any) -> Any: if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") return None else: assert isinstance(value, (tuple, dict)) or hasattr( @@ -143,10 +167,15 @@ def convert_struct(value: Any) -> Any: return convert_struct elif isinstance(dataType, ArrayType): - element_conv = LocalDataToArrowConversion._create_converter(dataType.elementType) + element_conv = LocalDataToArrowConversion._create_converter( + dataType.elementType, + dataType.containsNull, + ) def convert_array(value: Any) -> Any: if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") return None else: assert isinstance(value, (list, array.array)) @@ -156,10 +185,15 @@ def convert_array(value: Any) -> Any: elif isinstance(dataType, MapType): key_conv = LocalDataToArrowConversion._create_converter(dataType.keyType) - value_conv = LocalDataToArrowConversion._create_converter(dataType.valueType) + value_conv = LocalDataToArrowConversion._create_converter( + dataType.valueType, + dataType.valueContainsNull, + ) def convert_map(value: Any) -> Any: if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") return None else: assert isinstance(value, dict) @@ -176,6 +210,8 @@ def convert_map(value: Any) -> Any: def convert_binary(value: Any) -> Any: if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") return None else: assert isinstance(value, (bytes, bytearray)) @@ -187,6 +223,8 @@ def convert_binary(value: Any) -> Any: def convert_timestamp(value: Any) -> Any: if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") return None else: assert isinstance(value, datetime.datetime) @@ -198,6 +236,8 @@ def convert_timestamp(value: Any) -> Any: def convert_timestamp_ntz(value: Any) -> Any: if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") return None else: assert isinstance(value, datetime.datetime) and value.tzinfo is None @@ -209,6 +249,8 @@ def convert_timestamp_ntz(value: Any) -> Any: def convert_decimal(value: Any) -> Any: if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") return None else: assert isinstance(value, decimal.Decimal) @@ -220,6 +262,8 @@ def convert_decimal(value: Any) -> Any: def convert_string(value: Any) -> Any: if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") return None else: if isinstance(value, bool): @@ -238,12 +282,22 @@ def convert_string(value: Any) -> Any: def convert_udt(value: Any) -> Any: if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") return None else: return conv(udt.serialize(value)) return convert_udt + elif not nullable: + + def convert_other(value: Any) -> Any: + if value is None: + raise PySparkValueError(f"input for {dataType} must not be None") + return value + + return convert_other else: return lambda value: value @@ -256,7 +310,11 @@ def convert(data: Sequence[Any], schema: StructType) -> "pa.Table": column_names = schema.fieldNames() column_convs = [ - LocalDataToArrowConversion._create_converter(field.dataType) for field in schema.fields + LocalDataToArrowConversion._create_converter( + field.dataType, + field.nullable, + ) + for field in schema.fields ] pylist: List[List] = [[] for _ in range(len(column_names))] diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index fbc1debe75116..08b0a0be2dcf8 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -1125,6 +1125,18 @@ def test_create_df_from_objects(self): self.assertEqual(cdf.schema, sdf.schema) self.assertEqual(cdf.collect(), sdf.collect()) + def test_create_df_nullability(self): + data = [("asd", None)] + schema = StructType( + [ + StructField("name", StringType(), nullable=True), + StructField("age", IntegerType(), nullable=False), + ] + ) + + with self.assertRaises(PySparkValueError): + self.spark.createDataFrame(data, schema) + def test_simple_explain_string(self): df = self.connect.read.table(self.tbl_name).limit(10) result = df._explain_string()