Skip to content

Commit

Permalink
Reflect review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
HeartSaVioR committed Jul 16, 2021
1 parent b9c0357 commit a4fa37b
Show file tree
Hide file tree
Showing 10 changed files with 40 additions and 49 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2336,7 +2336,7 @@ def check_string_field(field, fieldName):
def session_window(timeColumn, gapDuration):
"""
Generates session window given a timestamp specifying column.
Session window is the one of dynamic windows, which means the length of window is vary
Session window is one of dynamic windows, which means the length of window is varying
according to the given inputs. The length of session window is defined as "the timestamp
of latest input of the session + gap duration", so when the new inputs are bound to the
current session window, the end time of session window can be expanded according to the new
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4001,12 +4001,12 @@ object SessionWindowing extends Rule[LogicalPlan] {
case s: SessionWindow => sessionAttr
}

// For backwards compatibility we add a filter to filter out nulls
// As same as tumbling window, we add a filter to filter out nulls.
val filterExpr = IsNotNull(session.timeColumn)

replacedPlan.withNewChildren(
Filter(filterExpr,
Project(sessionStruct +: child.output, child)) :: Nil)
Project(sessionStruct +: child.output,
Filter(filterExpr, child)) :: Nil)
} else if (numWindowExpr > 1) {
throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.util.{DateTimeConstants, IntervalUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Represent the session window.
*
* @param timeColumn the start time of session window
* @param gapDuration the duration of session gap, meaning the session will close if there is
* no new element appeared within "the last element in session + gap".
*/
case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends UnaryExpression
with ImplicitCastInputTypes
with Unevaluable
Expand All @@ -34,7 +38,7 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends Unar
//////////////////////////

def this(timeColumn: Expression, gapDuration: Expression) = {
this(timeColumn, SessionWindow.parseExpression(gapDuration))
this(timeColumn, TimeWindow.parseExpression(gapDuration))
}

override def child: Expression = timeColumn
Expand Down Expand Up @@ -64,40 +68,10 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends Unar
object SessionWindow {
val marker = "spark.sessionWindow"

/**
* Parses the interval string for a valid time duration. CalendarInterval expects interval
* strings to start with the string `interval`. For usability, we prepend `interval` to the string
* if the user omitted it.
*
* @param interval The interval string
* @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond
* precision.
*/
private def getIntervalInMicroSeconds(interval: String): Long = {
val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
if (cal.months != 0) {
throw new IllegalArgumentException(
s"Intervals greater than a month is not supported ($interval).")
}
cal.days * DateTimeConstants.MICROS_PER_DAY + cal.microseconds
}

/**
* Parses the duration expression to generate the long value for the original constructor so
* that we can use `window` in SQL.
*/
private def parseExpression(expr: Expression): Long = expr match {
case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString)
case IntegerLiteral(i) => i.toLong
case NonNullLiteral(l, LongType) => l.toString.toLong
case _ => throw new AnalysisException("The duration and time inputs to window must be " +
"an integer, long or string literal.")
}

def apply(
timeColumn: Expression,
gapDuration: String): SessionWindow = {
SessionWindow(timeColumn,
getIntervalInMicroSeconds(gapDuration))
TimeWindow.getIntervalInMicroSeconds(gapDuration))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ object TimeWindow {
* @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond
* precision.
*/
private def getIntervalInMicroSeconds(interval: String): Long = {
def getIntervalInMicroSeconds(interval: String): Long = {
val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
if (cal.months != 0) {
throw new IllegalArgumentException(
Expand All @@ -122,7 +122,7 @@ object TimeWindow {
* Parses the duration expression to generate the long value for the original constructor so
* that we can use `window` in SQL.
*/
private def parseExpression(expr: Expression): Long = expr match {
def parseExpression(expr: Expression): Long = expr match {
case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString)
case IntegerLiteral(i) => i.toLong
case NonNullLiteral(l, LongType) => l.toString.toLong
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1616,6 +1616,7 @@ object SQLConf {
.doc("When true, streaming session window sorts and merge sessions in local partition " +
"prior to shuffle. This is to reduce the rows to shuffle, but only beneficial when " +
"there're lots of rows in a batch being assigned to same sessions.")
.version("3.2.0")
.booleanConf
.createWithDefault(false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
throw QueryCompilationErrors.groupAggPandasUDFUnsupportedByStreamingAggError()
}

val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION)

val sessionWindowOption = namedGroupingExpressions.find { p =>
p.metadata.contains(SessionWindow.marker)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ object AggUtils {
resultExpressions = partialResultExpressions,
child = child)

// If we have session window expression in aggregation, we add MergingSessionExec to
// merge sessions with calculating aggregation values.
val interExec: SparkPlan = mayAppendMergingSessionExec(groupingExpressions,
aggregateExpressions, partialAggregate)

Expand Down Expand Up @@ -144,6 +146,9 @@ object AggUtils {
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {

// If we have session window expression in aggregation, we add UpdatingSessionsExec to
// calculate sessions for input rows and update rows' session column, so that further
// aggregations can aggregate input rows for the same session.
val maySessionChild = mayAppendUpdatingSessionExec(groupingExpressions, child)

val distinctAttributes = normalizedNamedDistinctExpressions.map(_.toAttribute)
Expand Down Expand Up @@ -394,9 +399,7 @@ object AggUtils {

val groupingAttributes = groupingExpressions.map(_.toAttribute)

// we don't do partial aggregate here, because it requires additional shuffle
// and there will be less rows which have same session start
// here doing partial merge is to have aggregated columns with default value for each row
// Here doing partial merge is to have aggregated columns with default value for each row.
val partialAggregate: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.aggregate.UpdatingSessionsIterator
import org.apache.spark.sql.types.{DataType, StructField, StructType}
Expand Down Expand Up @@ -65,6 +65,14 @@ case class AggregateInPandasExec(
case None => groupingExpressions
}

override def requiredChildDistribution: Seq[Distribution] = {
if (groupingExpressions.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(groupingExpressions) :: Nil
}
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] = sessionWindowOption match {
case Some(sessionExpression) =>
Seq((groupingWithoutSessionExpressions ++ Seq(sessionExpression))
Expand Down Expand Up @@ -116,6 +124,10 @@ case class AggregateInPandasExec(

// Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
// If we have session window expression in aggregation, we wrap iterator with
// UpdatingSessionIterator to calculate sessions for input rows and update
// rows' session column, so that further aggregations can aggregate input rows
// for the same session.
val newIter: Iterator[InternalRow] = mayAppendUpdatingSessionIterator(iter)
val prunedProj = UnsafeProjection.create(allInputs.toSeq, child.output)

Expand Down
5 changes: 3 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3633,13 +3633,14 @@ object functions {
/**
* Generates session window given a timestamp specifying column.
*
* Session window is the one of dynamic windows, which means the length of window is vary
* Session window is one of dynamic windows, which means the length of window is varying
* according to the given inputs. The length of session window is defined as "the timestamp
* of latest input of the session + gap duration", so when the new inputs are bound to the
* current session window, the end time of session window can be expanded according to the new
* inputs.
*
* Windows can support microsecond precision. Windows in the order of months are not supported.
* Windows can support microsecond precision. gapDuration in the order of months are not
* supported.
*
* For a streaming query, you may use the function `current_timestamp` to generate windows on
* processing time.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class StreamingSessionWindowSuite extends StreamTest
"CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",
"numEvents")

sessionUpdates.explain()

testStream(sessionUpdates, OutputMode.Complete())(
AddData(inputData,
("hello world spark streaming", 40L),
Expand Down

0 comments on commit a4fa37b

Please sign in to comment.