diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 0b5512b61925c..85f1b3565c696 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -809,7 +809,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"WithField({self._structExpr}, {self._fieldName}, {self._valueExpr})" + return f"update_field({self._structExpr}, {self._fieldName}, {self._valueExpr})" class DropField(Expression): @@ -833,7 +833,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"DropField({self._structExpr}, {self._fieldName})" + return f"drop_field({self._structExpr}, {self._fieldName})" class UnresolvedExtractValue(Expression): @@ -857,7 +857,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"UnresolvedExtractValue({str(self._child)}, {str(self._extraction)})" + return f"{self._child}['{self._extraction}']" class UnresolvedRegex(Expression): diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 1972dd2804d98..5f1991973d27d 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -283,6 +283,77 @@ def test_expr_str_representation(self): when_cond = sf.when(expression, sf.lit(None)) self.assertEqual(str(when_cond), "Column<'CASE WHEN foo THEN NULL END'>") + def test_col_field_ops_representation(self): + # SPARK-49894: Test string representation of columns + c = sf.col("c") + + # getField + self.assertEqual(str(c.x), "Column<'c['x']'>") + self.assertEqual(str(c.x.y), "Column<'c['x']['y']'>") + self.assertEqual(str(c.x.y.z), "Column<'c['x']['y']['z']'>") + + self.assertEqual(str(c["x"]), "Column<'c['x']'>") + self.assertEqual(str(c["x"]["y"]), "Column<'c['x']['y']'>") + self.assertEqual(str(c["x"]["y"]["z"]), "Column<'c['x']['y']['z']'>") + + self.assertEqual(str(c.getField("x")), "Column<'c['x']'>") + self.assertEqual( + str(c.getField("x").getField("y")), + "Column<'c['x']['y']'>", + ) + self.assertEqual( + str(c.getField("x").getField("y").getField("z")), + "Column<'c['x']['y']['z']'>", + ) + + self.assertEqual(str(c.getItem("x")), "Column<'c['x']'>") + self.assertEqual( + str(c.getItem("x").getItem("y")), + "Column<'c['x']['y']'>", + ) + self.assertEqual( + str(c.getItem("x").getItem("y").getItem("z")), + "Column<'c['x']['y']['z']'>", + ) + + self.assertEqual( + str(c.x["y"].getItem("z")), + "Column<'c['x']['y']['z']'>", + ) + self.assertEqual( + str(c["x"].getField("y").getItem("z")), + "Column<'c['x']['y']['z']'>", + ) + self.assertEqual( + str(c.getField("x").getItem("y").z), + "Column<'c['x']['y']['z']'>", + ) + self.assertEqual( + str(c["x"].y.getField("z")), + "Column<'c['x']['y']['z']'>", + ) + + # WithField + self.assertEqual( + str(c.withField("x", sf.col("y"))), + "Column<'update_field(c, x, y)'>", + ) + self.assertEqual( + str(c.withField("x", sf.col("y")).withField("x", sf.col("z"))), + "Column<'update_field(update_field(c, x, y), x, z)'>", + ) + + # DropFields + self.assertEqual(str(c.dropFields("x")), "Column<'drop_field(c, x)'>") + self.assertEqual( + str(c.dropFields("x", "y")), + "Column<'drop_field(drop_field(c, x), y)'>", + ) + self.assertEqual( + str(c.dropFields("x", "y", "z")), + "Column<'drop_field(drop_field(drop_field(c, x), y), z)'>", + ) + def test_lit_time_representation(self): dt = datetime.date(2021, 3, 4) self.assertEqual(str(sf.lit(dt)), "Column<'2021-03-04'>")