Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
empty
  • Loading branch information
zhengruifeng committed Dec 7, 2023
1 parent 2ddb6be commit 3ea9be1
Show file tree
Hide file tree
Showing 14 changed files with 133 additions and 66 deletions.
Expand Up @@ -2276,7 +2276,13 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame { builder =>
builder.getWithColumnsRenamedBuilder
.setInput(plan.getRoot)
.putAllRenameColumnsMap(colsMap)
.addAllRenames(colsMap.asScala.toSeq.map { case (colName, newColName) =>
proto.WithColumnsRenamed.Rename
.newBuilder()
.setColName(colName)
.setNewColName(newColName)
.build()
}.asJava)
}
}

Expand Down
Expand Up @@ -776,13 +776,23 @@ message WithColumnsRenamed {
Relation input = 1;


// (Required)
// (Optional)
//
// Renaming column names of input relation from A to B where A is the map key
// and B is the map value. This is a no-op if schema doesn't contain any A. It
// does not require that all input relation column names to present as keys.
// duplicated B are not allowed.
map<string, string> rename_columns_map = 2;
map<string, string> rename_columns_map = 2 [deprecated=true];

repeated Rename renames = 3;

message Rename {
// (Required) The existing column name.
string col_name = 1;

// (Required) The new column name.
string new_col_name = 2;
}
}

// Adding columns or replacing the existing columns that have the same names.
Expand Down
Expand Up @@ -11,9 +11,12 @@
"schema": "struct\u003cid:bigint,a:int,b:double\u003e"
}
},
"renameColumnsMap": {
"b": "bravo",
"id": "nid"
}
"renames": [{
"colName": "b",
"newColName": "bravo"
}, {
"colName": "id",
"newColName": "nid"
}]
}
}
Binary file not shown.
Expand Up @@ -11,9 +11,12 @@
"schema": "struct\u003cid:bigint,a:int,b:double\u003e"
}
},
"renameColumnsMap": {
"a": "alpha",
"b": "beta"
}
"renames": [{
"colName": "a",
"newColName": "alpha"
}, {
"colName": "b",
"newColName": "beta"
}]
}
}
Binary file not shown.
Expand Up @@ -11,8 +11,9 @@
"schema": "struct\u003cid:bigint,a:int,b:double\u003e"
}
},
"renameColumnsMap": {
"id": "nid"
}
"renames": [{
"colName": "id",
"newColName": "nid"
}]
}
}
Binary file not shown.
Expand Up @@ -981,10 +981,21 @@ class SparkConnectPlanner(
}

private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed): LogicalPlan = {
Dataset
.ofRows(session, transformRelation(rel.getInput))
.withColumnsRenamed(rel.getRenameColumnsMapMap)
.logicalPlan
if (rel.getRenamesCount > 0) {
val (colNames, newColNames) = rel.getRenamesList.asScala.toSeq.map { rename =>
(rename.getColName, rename.getNewColName)
}.unzip
Dataset
.ofRows(session, transformRelation(rel.getInput))
.withColumnsRenamed(colNames, newColNames)
.logicalPlan
} else {
// for backward compatibility
Dataset
.ofRows(session, transformRelation(rel.getInput))
.withColumnsRenamed(rel.getRenameColumnsMapMap)
.logicalPlan
}
}

private def transformWithColumns(rel: proto.WithColumns): LogicalPlan = {
Expand Down
8 changes: 6 additions & 2 deletions python/pyspark/sql/connect/plan.py
Expand Up @@ -1265,8 +1265,12 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.with_columns_renamed.input.CopyFrom(self._child.plan(session))
for k, v in self._colsMap.items():
plan.with_columns_renamed.rename_columns_map[k] = v
if len(self._colsMap) > 0:
for k, v in self._colsMap.items():
rename = proto.WithColumnsRenamed.Rename()
rename.col_name = k
rename.new_col_name = v
plan.with_columns_renamed.renames.append(rename)
return plan


Expand Down
80 changes: 42 additions & 38 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

34 changes: 32 additions & 2 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Expand Up @@ -2788,35 +2788,65 @@ class WithColumnsRenamed(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
) -> None: ...

class Rename(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

COL_NAME_FIELD_NUMBER: builtins.int
NEW_COL_NAME_FIELD_NUMBER: builtins.int
col_name: builtins.str
"""(Required) The existing column name."""
new_col_name: builtins.str
"""(Required) The new column name."""
def __init__(
self,
*,
col_name: builtins.str = ...,
new_col_name: builtins.str = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"col_name", b"col_name", "new_col_name", b"new_col_name"
],
) -> None: ...

INPUT_FIELD_NUMBER: builtins.int
RENAME_COLUMNS_MAP_FIELD_NUMBER: builtins.int
RENAMES_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) The input relation."""
@property
def rename_columns_map(
self,
) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]:
"""(Required)
"""(Optional)
Renaming column names of input relation from A to B where A is the map key
and B is the map value. This is a no-op if schema doesn't contain any A. It
does not require that all input relation column names to present as keys.
duplicated B are not allowed.
"""
@property
def renames(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
global___WithColumnsRenamed.Rename
]: ...
def __init__(
self,
*,
input: global___Relation | None = ...,
rename_columns_map: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
renames: collections.abc.Iterable[global___WithColumnsRenamed.Rename] | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["input", b"input"]
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"input", b"input", "rename_columns_map", b"rename_columns_map"
"input", b"input", "rename_columns_map", b"rename_columns_map", "renames", b"renames"
],
) -> None: ...

Expand Down
5 changes: 0 additions & 5 deletions python/pyspark/sql/tests/connect/test_parity_dataframe.py
Expand Up @@ -77,11 +77,6 @@ def test_to_pandas_from_mixed_dataframe(self):
def test_toDF_with_string(self):
super().test_toDF_with_string()

# TODO(SPARK-46261): Python Client withColumnsRenamed should respect the dict ordering
@unittest.skip("Fails in Spark Connect, should enable.")
def test_ordering_of_with_columns_renamed(self):
super().test_ordering_of_with_columns_renamed()


if __name__ == "__main__":
import unittest
Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Expand Up @@ -2926,7 +2926,7 @@ class Dataset[T] private[sql](
withColumnsRenamed(colNames, newColNames)
}

private def withColumnsRenamed(
private[spark] def withColumnsRenamed(
colNames: Seq[String],
newColNames: Seq[String]): DataFrame = withOrigin {
require(colNames.size == newColNames.size,
Expand Down

0 comments on commit 3ea9be1

Please sign in to comment.