Skip to content

Commit

Permalink
[SPARK-46823][CONNECT][PYTHON] LocalDataToArrowConversion should ch…
Browse files Browse the repository at this point in the history
…eck the nullability

### What changes were proposed in this pull request?
`LocalDataToArrowConversion` should check the nullability

### Why are the changes needed?
this check was missing

### Does this PR introduce _any_ user-facing change?
yes

```
        data = [("asd", None)]
        schema = StructType(
            [
                StructField("name", StringType(), nullable=True),
                StructField("age", IntegerType(), nullable=False),
            ]
        )
```

before:
```
In [3]: df = spark.createDataFrame([("asd", None)], schema)

In [4]: df
Out[4]: 24/01/24 12:08:28 ERROR ErrorUtils: Spark Connect RPC error during: analyze. UserId: ruifeng.zheng. SessionId: cd692bb1-d503-4043-a9db-d29cb5c16517.
java.lang.IllegalStateException: Value at index is null
        at org.apache.arrow.vector.IntVector.get(IntVector.java:107)
        at org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338)
        at org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88)
        at org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109)
        at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
        at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
        at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
        at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
        at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
        at scala.collection.immutable.List.prependedAll(List.scala:153)
        at scala.collection.immutable.List$.from(List.scala:684)
        at scala.collection.immutable.List$.from(List.scala:681)
        at scala.collection.SeqFactory$Delegate.from(Factory.scala:306)
        at scala.collection.immutable.Seq$.from(Seq.scala:42)
        at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326)
        at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326)
        at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300)
        at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239)
        at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42)
        at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289)
        at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918)
        at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289)
        at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
        at org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80)
        at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182)
        at org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79)
        at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42)
        at org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95)
        at org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907)
        at org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182)
        at org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351)
        at org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860)
        at org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
        at org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
        at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
        at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
        at java.base/java.lang.Thread.run(Thread.java:833)
24/01/24 12:08:28 ERROR ErrorUtils: Spark Connect RPC error during: analyze. UserId: ruifeng.zheng. SessionId: cd692bb1-d503-4043-a9db-d29cb5c16517.
java.lang.IllegalStateException: Value at index is null
        at org.apache.arrow.vector.IntVector.get(IntVector.java:107)
        at org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338)
        at org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88)
        at org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109)
        at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
        at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
        at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
        at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
        at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
        at scala.collection.immutable.List.prependedAll(List.scala:153)
        at scala.collection.immutable.List$.from(List.scala:684)
        at scala.collection.immutable.List$.from(List.scala:681)
        at scala.collection.SeqFactory$Delegate.from(Factory.scala:306)
        at scala.collection.immutable.Seq$.from(Seq.scala:42)
        at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326)
        at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326)
        at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300)
        at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239)
        at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42)
        at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289)
        at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918)
        at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289)
        at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
        at org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80)
        at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182)
        at org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79)
        at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42)
        at org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95)
        at org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907)
        at org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182)
        at org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351)
        at org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860)
        at org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
        at org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
        at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
        at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
        at java.base/java.lang.Thread.run(Thread.java:833)
24/01/24 12:08:28 ERROR ErrorUtils: Spark Connect RPC error during: analyze. UserId: ruifeng.zheng. SessionId: cd692bb1-d503-4043-a9db-d29cb5c16517.
java.lang.IllegalStateException: Value at index is null
        at org.apache.arrow.vector.IntVector.get(IntVector.java:107)
        at org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338)
        at org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88)
        at org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109)
        at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
        at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
        at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
        at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
        at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
        at scala.collection.immutable.List.prependedAll(List.scala:153)
        at scala.collection.immutable.List$.from(List.scala:684)
        at scala.collection.immutable.List$.from(List.scala:681)
        at scala.collection.SeqFactory$Delegate.from(Factory.scala:306)
        at scala.collection.immutable.Seq$.from(Seq.scala:42)
        at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326)
        at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326)
        at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300)
        at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239)
        at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42)
        at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289)
        at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918)
        at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289)
        at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
        at org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80)
        at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182)
        at org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79)
        at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42)
        at org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95)
        at org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907)
        at org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182)
        at org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351)
        at org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860)
        at org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
        at org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
        at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
        at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
        at java.base/java.lang.Thread.run(Thread.java:833)
---------------------------------------------------------------------------
SparkConnectGrpcException                 Traceback (most recent call last)
File ~/.dev/miniconda3/envs/spark_dev_311/lib/python3.11/site-packages/IPython/core/formatters.py:708, in PlainTextFormatter.__call__(self, obj)
    701 stream = StringIO()
    702 printer = pretty.RepresentationPrinter(stream, self.verbose,
    703     self.max_width, self.newline,
    704     max_seq_length=self.max_seq_length,
    705     singleton_pprinters=self.singleton_printers,
    706     type_pprinters=self.type_printers,
    707     deferred_pprinters=self.deferred_printers)
--> 708 printer.pretty(obj)
    709 printer.flush()
    710 return stream.getvalue()

File ~/.dev/miniconda3/envs/spark_dev_311/lib/python3.11/site-packages/IPython/lib/pretty.py:410, in RepresentationPrinter.pretty(self, obj)
    407                         return meth(obj, self, cycle)
    408                 if cls is not object \
    409                         and callable(cls.__dict__.get('__repr__')):
--> 410                     return _repr_pprint(obj, self, cycle)
    412     return _default_pprint(obj, self, cycle)
    413 finally:

File ~/.dev/miniconda3/envs/spark_dev_311/lib/python3.11/site-packages/IPython/lib/pretty.py:778, in _repr_pprint(obj, p, cycle)
    776 """A pprint that just redirects to the normal repr function."""
    777 # Find newlines and replace them with p.break_()
--> 778 output = repr(obj)
    779 lines = output.splitlines()
    780 with p.group():

File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:141, in DataFrame.__repr__(self)
    135     if repl_eager_eval_enabled == "true":
    136         return self._show_string(
    137             n=int(cast(str, repl_eager_eval_max_num_rows)),
    138             truncate=int(cast(str, repl_eager_eval_truncate)),
    139             vertical=False,
    140         )
--> 141 return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))

File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:238, in DataFrame.dtypes(self)
    236 property
    237 def dtypes(self) -> List[Tuple[str, str]]:
--> 238     return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]

File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:1786, in DataFrame.schema(self)
   1783 property
   1784 def schema(self) -> StructType:
   1785     query = self._plan.to_proto(self._session.client)
-> 1786     return self._session.client.schema(query)

File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:921, in SparkConnectClient.schema(self, plan)
    917 """
    918 Return schema for given plan.
    919 """
    920 logger.info(f"Schema for plan: {self._proto_to_string(plan)}")
--> 921 schema = self._analyze(method="schema", plan=plan).schema
    922 assert schema is not None
    923 # Server side should populate the struct field which is the schema.

File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1107, in SparkConnectClient._analyze(self, method, **kwargs)
   1105     raise SparkConnectException("Invalid state during retry exception handling.")
   1106 except Exception as error:
-> 1107     self._handle_error(error)

File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1525, in SparkConnectClient._handle_error(self, error)
   1523 self.thread_local.inside_error_handling = True
   1524 if isinstance(error, grpc.RpcError):
-> 1525     self._handle_rpc_error(error)
   1526 elif isinstance(error, ValueError):
   1527     if "Cannot invoke RPC" in str(error) and "closed" in str(error):

File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1595, in SparkConnectClient._handle_rpc_error(self, rpc_error)
   1592             info = error_details_pb2.ErrorInfo()
   1593             d.Unpack(info)
-> 1595             raise convert_exception(
   1596                 info,
   1597                 status.message,
   1598                 self._fetch_enriched_error(info),
   1599                 self._display_server_stack_trace(),
   1600             ) from None
   1602     raise SparkConnectGrpcException(status.message) from None
   1603 else:

SparkConnectGrpcException: (java.lang.IllegalStateException) Value at index is null

JVM stacktrace:
java.lang.IllegalStateException
        at org.apache.arrow.vector.IntVector.get(IntVector.java:107)
        at org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338)
        at org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88)
        at org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109)
        at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(:-1)
        at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(:-1)
        at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
        at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
        at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
        at scala.collection.immutable.List.prependedAll(List.scala:153)
        at scala.collection.immutable.List$.from(List.scala:684)
        at scala.collection.immutable.List$.from(List.scala:681)
        at scala.collection.SeqFactory$Delegate.from(Factory.scala:306)
        at scala.collection.immutable.Seq$.from(Seq.scala:42)
        at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326)
        at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326)
        at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300)
        at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239)
        at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42)
        at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289)
        at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918)
        at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289)
        at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
        at org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80)
        at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182)
        at org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79)
        at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288)
        at org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42)
        at org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95)
        at org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907)
        at org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182)
        at org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351)
        at org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860)
        at org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
        at org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
        at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
        at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
        at java.lang.Thread.run(Thread.java:833)

```

after:
```
---------------------------------------------------------------------------
PySparkValueError                         Traceback (most recent call last)
Cell In[3], line 1
----> 1 df = spark.createDataFrame([("asd", None)], schema)

File ~/Dev/spark/python/pyspark/sql/connect/session.py:538, in SparkSession.createDataFrame(self, data, schema)
    533     from pyspark.sql.connect.conversion import LocalDataToArrowConversion
    535     # Spark Connect will try its best to build the Arrow table with the
    536     # inferred schema in the client side, and then rename the columns and
    537     # cast the datatypes in the server side.
--> 538     _table = LocalDataToArrowConversion.convert(_data, _schema)
    540 # TODO: Beside the validation on number of columns, we should also check
    541 # whether the Arrow Schema is compatible with the user provided Schema.
    542 if _num_cols is not None and _num_cols != _table.shape[1]:

File ~/Dev/spark/python/pyspark/sql/connect/conversion.py:351, in LocalDataToArrowConversion.convert(data, schema)
    342             raise PySparkValueError(
    343                 error_class="AXIS_LENGTH_MISMATCH",
    344                 message_parameters={
   (...)
    347                 },
    348             )
    350         for i in range(len(column_names)):
--> 351             pylist[i].append(column_convs[i](item[i]))
    353 pa_schema = to_arrow_schema(
    354     StructType(
    355         [
   (...)
    361     )
    362 )
    364 return pa.Table.from_arrays(pylist, schema=pa_schema)

File ~/Dev/spark/python/pyspark/sql/connect/conversion.py:297, in LocalDataToArrowConversion._create_converter.<locals>.convert_other(value)
    295 def convert_other(value: Any) -> Any:
    296     if value is None:
--> 297         raise PySparkValueError(f"input for {dataType} must not be None")
    298     return value

PySparkValueError: input for IntegerType() must not be None
```

### How was this patch tested?
added ut

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #44861 from zhengruifeng/connect_check_nullable.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
zhengruifeng authored and dongjoon-hyun committed Jan 24, 2024
1 parent 7004dd9 commit 1642e92
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 10 deletions.
78 changes: 68 additions & 10 deletions python/pyspark/sql/connect/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,23 @@ class LocalDataToArrowConversion:
"""

@staticmethod
def _need_converter(dataType: DataType) -> bool:
if isinstance(dataType, NullType):
def _need_converter(
dataType: DataType,
nullable: bool = True,
) -> bool:
if not nullable:
# always check the nullability
return True
elif isinstance(dataType, NullType):
# always check the nullability
return True
elif isinstance(dataType, StructType):
# Struct maybe rows, should convert to dict.
return True
elif isinstance(dataType, ArrayType):
return LocalDataToArrowConversion._need_converter(dataType.elementType)
return LocalDataToArrowConversion._need_converter(
dataType.elementType, dataType.containsNull
)
elif isinstance(dataType, MapType):
# Different from PySpark, here always needs conversion,
# since an Arrow Map requires a list of tuples.
Expand All @@ -90,26 +99,41 @@ def _need_converter(dataType: DataType) -> bool:
return False

@staticmethod
def _create_converter(dataType: DataType) -> Callable:
def _create_converter(
dataType: DataType,
nullable: bool = True,
) -> Callable:
assert dataType is not None and isinstance(dataType, DataType)
assert isinstance(nullable, bool)

if not LocalDataToArrowConversion._need_converter(dataType):
if not LocalDataToArrowConversion._need_converter(dataType, nullable):
return lambda value: value

if isinstance(dataType, NullType):
return lambda value: None

def convert_null(value: Any) -> Any:
if value is not None:
raise PySparkValueError(f"input for {dataType} must be None, but got {value}")
return None

return convert_null

elif isinstance(dataType, StructType):
field_names = dataType.fieldNames()
dedup_field_names = _dedup_names(dataType.names)

field_convs = [
LocalDataToArrowConversion._create_converter(field.dataType)
LocalDataToArrowConversion._create_converter(
field.dataType,
field.nullable,
)
for field in dataType.fields
]

def convert_struct(value: Any) -> Any:
if value is None:
if not nullable:
raise PySparkValueError(f"input for {dataType} must not be None")
return None
else:
assert isinstance(value, (tuple, dict)) or hasattr(
Expand Down Expand Up @@ -143,10 +167,15 @@ def convert_struct(value: Any) -> Any:
return convert_struct

elif isinstance(dataType, ArrayType):
element_conv = LocalDataToArrowConversion._create_converter(dataType.elementType)
element_conv = LocalDataToArrowConversion._create_converter(
dataType.elementType,
dataType.containsNull,
)

def convert_array(value: Any) -> Any:
if value is None:
if not nullable:
raise PySparkValueError(f"input for {dataType} must not be None")
return None
else:
assert isinstance(value, (list, array.array))
Expand All @@ -156,10 +185,15 @@ def convert_array(value: Any) -> Any:

elif isinstance(dataType, MapType):
key_conv = LocalDataToArrowConversion._create_converter(dataType.keyType)
value_conv = LocalDataToArrowConversion._create_converter(dataType.valueType)
value_conv = LocalDataToArrowConversion._create_converter(
dataType.valueType,
dataType.valueContainsNull,
)

def convert_map(value: Any) -> Any:
if value is None:
if not nullable:
raise PySparkValueError(f"input for {dataType} must not be None")
return None
else:
assert isinstance(value, dict)
Expand All @@ -176,6 +210,8 @@ def convert_map(value: Any) -> Any:

def convert_binary(value: Any) -> Any:
if value is None:
if not nullable:
raise PySparkValueError(f"input for {dataType} must not be None")
return None
else:
assert isinstance(value, (bytes, bytearray))
Expand All @@ -187,6 +223,8 @@ def convert_binary(value: Any) -> Any:

def convert_timestamp(value: Any) -> Any:
if value is None:
if not nullable:
raise PySparkValueError(f"input for {dataType} must not be None")
return None
else:
assert isinstance(value, datetime.datetime)
Expand All @@ -198,6 +236,8 @@ def convert_timestamp(value: Any) -> Any:

def convert_timestamp_ntz(value: Any) -> Any:
if value is None:
if not nullable:
raise PySparkValueError(f"input for {dataType} must not be None")
return None
else:
assert isinstance(value, datetime.datetime) and value.tzinfo is None
Expand All @@ -209,6 +249,8 @@ def convert_timestamp_ntz(value: Any) -> Any:

def convert_decimal(value: Any) -> Any:
if value is None:
if not nullable:
raise PySparkValueError(f"input for {dataType} must not be None")
return None
else:
assert isinstance(value, decimal.Decimal)
Expand All @@ -220,6 +262,8 @@ def convert_decimal(value: Any) -> Any:

def convert_string(value: Any) -> Any:
if value is None:
if not nullable:
raise PySparkValueError(f"input for {dataType} must not be None")
return None
else:
if isinstance(value, bool):
Expand All @@ -238,12 +282,22 @@ def convert_string(value: Any) -> Any:

def convert_udt(value: Any) -> Any:
if value is None:
if not nullable:
raise PySparkValueError(f"input for {dataType} must not be None")
return None
else:
return conv(udt.serialize(value))

return convert_udt

elif not nullable:

def convert_other(value: Any) -> Any:
if value is None:
raise PySparkValueError(f"input for {dataType} must not be None")
return value

return convert_other
else:
return lambda value: value

Expand All @@ -256,7 +310,11 @@ def convert(data: Sequence[Any], schema: StructType) -> "pa.Table":
column_names = schema.fieldNames()

column_convs = [
LocalDataToArrowConversion._create_converter(field.dataType) for field in schema.fields
LocalDataToArrowConversion._create_converter(
field.dataType,
field.nullable,
)
for field in schema.fields
]

pylist: List[List] = [[] for _ in range(len(column_names))]
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,18 @@ def test_create_df_from_objects(self):
self.assertEqual(cdf.schema, sdf.schema)
self.assertEqual(cdf.collect(), sdf.collect())

def test_create_df_nullability(self):
data = [("asd", None)]
schema = StructType(
[
StructField("name", StringType(), nullable=True),
StructField("age", IntegerType(), nullable=False),
]
)

with self.assertRaises(PySparkValueError):
self.spark.createDataFrame(data, schema)

def test_simple_explain_string(self):
df = self.connect.read.table(self.tbl_name).limit(10)
result = df._explain_string()
Expand Down

0 comments on commit 1642e92

Please sign in to comment.