Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 30 additions & 45 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -3965,19 +3965,11 @@ setMethod("row_number",
#' yields unresolved \code{a.b.c}
#' @return Column object wrapping JVM UnresolvedNamedLambdaVariable
#' @keywords internal
unresolved_named_lambda_var <- function(...) {
jc <- newJObject(
"org.apache.spark.sql.Column",
newJObject(
"org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable",
lapply(list(...), function(x) {
handledCallJStatic(
"org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable",
"freshVarName",
x)
})
)
)
unresolved_named_lambda_var <- function(name) {
jc <- handledCallJStatic(
"org.apache.spark.sql.api.python.PythonSQLUtils",
"unresolvedNamedLambdaVariable",
name)
column(jc)
}

Expand All @@ -3990,7 +3982,6 @@ unresolved_named_lambda_var <- function(...) {
#' @return JVM \code{LambdaFunction} object
#' @keywords internal
create_lambda <- function(fun) {
as_jexpr <- function(x) callJMethod(x@jc, "expr")

# Process function arguments
parameters <- formals(fun)
Expand All @@ -4011,22 +4002,18 @@ create_lambda <- function(fun) {
stopifnot(class(result) == "Column")

# Convert both Columns to Scala expressions
jexpr <- as_jexpr(result)

jargs <- handledCallJStatic(
"org.apache.spark.api.python.PythonUtils",
"toSeq",
handledCallJStatic(
"java.util.Arrays", "asList", lapply(args, as_jexpr)
)
handledCallJStatic("java.util.Arrays", "asList", lapply(args, function(x) { x@jc }))
)

# Create Scala LambdaFunction
newJObject(
"org.apache.spark.sql.catalyst.expressions.LambdaFunction",
jexpr,
jargs,
FALSE
handledCallJStatic(
"org.apache.spark.sql.api.python.PythonSQLUtils",
"lambdaFunction",
result@jc,
jargs
)
}

Expand All @@ -4039,20 +4026,18 @@ create_lambda <- function(fun) {
#' @return a \code{Column} representing name applied to cols with funs
#' @keywords internal
invoke_higher_order_function <- function(name, cols, funs) {
as_jexpr <- function(x) {
as_col <- function(x) {
if (class(x) == "character") {
x <- column(x)
}
callJMethod(x@jc, "expr")
x@jc
}

jexpr <- do.call(newJObject, c(
paste("org.apache.spark.sql.catalyst.expressions", name, sep = "."),
lapply(cols, as_jexpr),
lapply(funs, create_lambda)
))

column(newJObject("org.apache.spark.sql.Column", jexpr))
jcol <- handledCallJStatic(
"org.apache.spark.sql.api.python.PythonSQLUtils",
"fn",
name,
c(lapply(cols, as_col), lapply(funs, create_lambda))) # check varargs invocation
column(jcol)
}

#' @details
Expand All @@ -4068,7 +4053,7 @@ setMethod("array_aggregate",
signature(x = "characterOrColumn", initialValue = "Column", merge = "function"),
function(x, initialValue, merge, finish = NULL) {
invoke_higher_order_function(
"ArrayAggregate",
"aggregate",
cols = list(x, initialValue),
funs = if (is.null(finish)) {
list(merge)
Expand Down Expand Up @@ -4129,7 +4114,7 @@ setMethod("array_exists",
signature(x = "characterOrColumn", f = "function"),
function(x, f) {
invoke_higher_order_function(
"ArrayExists",
"exists",
cols = list(x),
funs = list(f)
)
Expand All @@ -4145,7 +4130,7 @@ setMethod("array_filter",
signature(x = "characterOrColumn", f = "function"),
function(x, f) {
invoke_higher_order_function(
"ArrayFilter",
"filter",
cols = list(x),
funs = list(f)
)
Expand All @@ -4161,7 +4146,7 @@ setMethod("array_forall",
signature(x = "characterOrColumn", f = "function"),
function(x, f) {
invoke_higher_order_function(
"ArrayForAll",
"forall",
cols = list(x),
funs = list(f)
)
Expand Down Expand Up @@ -4291,7 +4276,7 @@ setMethod("array_sort",
column(callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc))
} else {
invoke_higher_order_function(
"ArraySort",
"array_sort",
cols = list(x),
funs = list(comparator)
)
Expand All @@ -4309,7 +4294,7 @@ setMethod("array_transform",
signature(x = "characterOrColumn", f = "function"),
function(x, f) {
invoke_higher_order_function(
"ArrayTransform",
"transform",
cols = list(x),
funs = list(f)
)
Expand Down Expand Up @@ -4374,7 +4359,7 @@ setMethod("arrays_zip_with",
signature(x = "characterOrColumn", y = "characterOrColumn", f = "function"),
function(x, y, f) {
invoke_higher_order_function(
"ZipWith",
"zip_with",
cols = list(x, y),
funs = list(f)
)
Expand Down Expand Up @@ -4447,7 +4432,7 @@ setMethod("map_filter",
signature(x = "characterOrColumn", f = "function"),
function(x, f) {
invoke_higher_order_function(
"MapFilter",
"map_filter",
cols = list(x),
funs = list(f))
})
Expand Down Expand Up @@ -4504,7 +4489,7 @@ setMethod("transform_keys",
signature(x = "characterOrColumn", f = "function"),
function(x, f) {
invoke_higher_order_function(
"TransformKeys",
"transform_keys",
cols = list(x),
funs = list(f)
)
Expand All @@ -4521,7 +4506,7 @@ setMethod("transform_values",
signature(x = "characterOrColumn", f = "function"),
function(x, f) {
invoke_higher_order_function(
"TransformValues",
"transform_values",
cols = list(x),
funs = list(f)
)
Expand Down Expand Up @@ -4552,7 +4537,7 @@ setMethod("map_zip_with",
signature(x = "characterOrColumn", y = "characterOrColumn", f = "function"),
function(x, y, f) {
invoke_higher_order_function(
"MapZipWith",
"map_zip_with",
cols = list(x, y),
funs = list(f)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object functions {
def from_avro(
data: Column,
jsonFormatSchema: String): Column = {
new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, Map.empty))
Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, Map.empty))
}

/**
Expand All @@ -62,7 +62,7 @@ object functions {
data: Column,
jsonFormatSchema: String,
options: java.util.Map[String, String]): Column = {
new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, options.asScala.toMap))
Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, options.asScala.toMap))
}

/**
Expand All @@ -74,7 +74,7 @@ object functions {
*/
@Experimental
def to_avro(data: Column): Column = {
new Column(CatalystDataToAvro(data.expr, None))
Column(CatalystDataToAvro(data.expr, None))
}

/**
Expand All @@ -87,6 +87,6 @@ object functions {
*/
@Experimental
def to_avro(data: Column, jsonFormatSchema: String): Column = {
new Column(CatalystDataToAvro(data.expr, Some(jsonFormatSchema)))
Column(CatalystDataToAvro(data.expr, Some(jsonFormatSchema)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,12 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.register"),

// Typed Column
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.TypedColumn.*"),
ProblemFilters.exclude[IncompatibleResultTypeProblem](
"org.apache.spark.sql.TypedColumn.expr"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.TypedColumn$"),

// Datasource V2 partition transforms
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform$"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ object functions {
messageName: String,
binaryFileDescriptorSet: Array[Byte],
options: java.util.Map[String, String]): Column = {
new Column(
Column(
ProtobufDataToCatalyst(
data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap
)
Expand All @@ -93,7 +93,7 @@ object functions {
@Experimental
def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = {
val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
new Column(ProtobufDataToCatalyst(data.expr, messageName, Some(fileContent)))
Column(ProtobufDataToCatalyst(data.expr, messageName, Some(fileContent)))
}

/**
Expand All @@ -112,7 +112,7 @@ object functions {
@Experimental
def from_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte])
: Column = {
new Column(ProtobufDataToCatalyst(data.expr, messageName, Some(binaryFileDescriptorSet)))
Column(ProtobufDataToCatalyst(data.expr, messageName, Some(binaryFileDescriptorSet)))
}

/**
Expand All @@ -132,7 +132,7 @@ object functions {
*/
@Experimental
def from_protobuf(data: Column, messageClassName: String): Column = {
new Column(ProtobufDataToCatalyst(data.expr, messageClassName))
Column(ProtobufDataToCatalyst(data.expr, messageClassName))
}

/**
Expand All @@ -156,7 +156,7 @@ object functions {
data: Column,
messageClassName: String,
options: java.util.Map[String, String]): Column = {
new Column(ProtobufDataToCatalyst(data.expr, messageClassName, None, options.asScala.toMap))
Column(ProtobufDataToCatalyst(data.expr, messageClassName, None, options.asScala.toMap))
}

/**
Expand Down Expand Up @@ -194,7 +194,7 @@ object functions {
@Experimental
def to_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte])
: Column = {
new Column(CatalystDataToProtobuf(data.expr, messageName, Some(binaryFileDescriptorSet)))
Column(CatalystDataToProtobuf(data.expr, messageName, Some(binaryFileDescriptorSet)))
}
/**
* Converts a column into binary of protobuf format. The Protobuf definition is provided
Expand All @@ -216,7 +216,7 @@ object functions {
descFilePath: String,
options: java.util.Map[String, String]): Column = {
val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
new Column(
Column(
CatalystDataToProtobuf(data.expr, messageName, Some(fileContent), options.asScala.toMap)
)
}
Expand All @@ -242,7 +242,7 @@ object functions {
binaryFileDescriptorSet: Array[Byte],
options: java.util.Map[String, String]
): Column = {
new Column(
Column(
CatalystDataToProtobuf(
data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap
)
Expand All @@ -266,7 +266,7 @@ object functions {
*/
@Experimental
def to_protobuf(data: Column, messageClassName: String): Column = {
new Column(CatalystDataToProtobuf(data.expr, messageClassName))
Column(CatalystDataToProtobuf(data.expr, messageClassName))
}

/**
Expand All @@ -288,6 +288,6 @@ object functions {
@Experimental
def to_protobuf(data: Column, messageClassName: String, options: java.util.Map[String, String])
: Column = {
new Column(CatalystDataToProtobuf(data.expr, messageClassName, None, options.asScala.toMap))
Column(CatalystDataToProtobuf(data.expr, messageClassName, None, options.asScala.toMap))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, Encoder, Encoders, Row}
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -195,7 +194,7 @@ class StringIndexer @Since("1.4.0") (
} else {
// We don't count for NaN values. Because `StringIndexerAggregator` only processes strings,
// we replace NaNs with null in advance.
new Column(If(col.isNaN.expr, Literal(null), col.expr)).cast(StringType)
when(!isnan(col), col).cast(StringType)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ private[ml] class SummaryBuilderImpl(
mutableAggBufferOffset = 0,
inputAggBufferOffset = 0)

new Column(agg.toAggregateExpression())
Column(agg.toAggregateExpression())
}
}

Expand Down
5 changes: 4 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.JobWaiter.cancel"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.cancel"),
// SPARK-48901: Add clusterBy() to DataStreamWriter.
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.DataStreamWriter.clusterBy")
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.DataStreamWriter.clusterBy"),
// SPARK-49022: Use Column API
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.TypedColumn.this"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.expressions.WindowSpec.this")
)

// Default exclude rules
Expand Down
6 changes: 2 additions & 4 deletions python/pyspark/pandas/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,10 +915,8 @@ def attach_distributed_column(sdf: PySparkDataFrame, column_name: str) -> PySpar
if is_remote():
return sdf.select(F.monotonically_increasing_id().alias(column_name), *scols)
jvm = sdf.sparkSession._jvm
tag = jvm.org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FUNC_ALIAS()
jexpr = F.monotonically_increasing_id()._jc.expr()
jexpr.setTagValue(tag, "distributed_index")
return sdf.select(PySparkColumn(jvm.Column(jexpr)).alias(column_name), *scols)
jcol = jvm.PythonSQLUtils.distributedIndex()
return sdf.select(PySparkColumn(jcol).alias(column_name), *scols)

@staticmethod
def attach_distributed_sequence_column(
Expand Down
Loading