Skip to content

Commit

Permalink
[SPARK-33045][SQL] Support build-in function like_all and fix StackOv…
Browse files Browse the repository at this point in the history
…erflowError issue

### What changes were proposed in this pull request?
Spark already support `LIKE ALL` syntax, but it will throw `StackOverflowError` if there are many elements(more than 14378 elements). We should implement built-in function for LIKE ALL to fix this issue.

Why the stack overflow can happen in the current approach ?
The current approach uses reduceLeft to connect each `Like(e, p)`, this will lead the the call depth of the thread is too large, causing `StackOverflowError` problems.

Why the fix in this PR can avoid the error?
This PR support built-in function for `LIKE ALL` and avoid this issue.

### Why are the changes needed?
1.Fix the `StackOverflowError` issue.
2.Support built-in function `like_all`.

### Does this PR introduce _any_ user-facing change?
'No'.

### How was this patch tested?
Jenkins test.

Closes #29999 from beliefer/SPARK-33045-like_all.

Lead-authored-by: gengjiaan <gengjiaan@360.cn>
Co-authored-by: beliefer <beliefer@163.com>
Co-authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
2 people authored and cloud-fan committed Nov 19, 2020
1 parent 21b1350 commit 3695e99
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 1 deletion.
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* A collection of implicit conversions that create a DSL for constructing catalyst data structures.
Expand Down Expand Up @@ -102,6 +103,10 @@ package object dsl {
def like(other: Expression, escapeChar: Char = '\\'): Expression =
Like(expr, other, escapeChar)
def rlike(other: Expression): Expression = RLike(expr, other)
def likeAll(others: Expression*): Expression =
LikeAll(expr, others.map(_.eval(EmptyRow).asInstanceOf[UTF8String]))
def notLikeAll(others: Expression*): Expression =
NotLikeAll(expr, others.map(_.eval(EmptyRow).asInstanceOf[UTF8String]))
def contains(other: Expression): Expression = Contains(expr, other)
def startsWith(other: Expression): Expression = StartsWith(expr, other)
def endsWith(other: Expression): Expression = EndsWith(expr, other)
Expand Down
Expand Up @@ -20,10 +20,12 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Locale
import java.util.regex.{Matcher, MatchResult, Pattern}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.commons.text.StringEscapeUtils

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen._
Expand Down Expand Up @@ -178,6 +180,88 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
}
}

/**
* Optimized version of LIKE ALL, when all pattern values are literal.
*/
abstract class LikeAllBase extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {

protected def patterns: Seq[UTF8String]

protected def isNotLikeAll: Boolean

override def inputTypes: Seq[DataType] = StringType :: Nil

override def dataType: DataType = BooleanType

override def nullable: Boolean = true

private lazy val hasNull: Boolean = patterns.contains(null)

private lazy val cache = patterns.filterNot(_ == null)
.map(s => Pattern.compile(StringUtils.escapeLikeRegex(s.toString, '\\')))

private lazy val matchFunc = if (isNotLikeAll) {
(p: Pattern, inputValue: String) => !p.matcher(inputValue).matches()
} else {
(p: Pattern, inputValue: String) => p.matcher(inputValue).matches()
}

override def eval(input: InternalRow): Any = {
val exprValue = child.eval(input)
if (exprValue == null) {
null
} else {
if (cache.forall(matchFunc(_, exprValue.toString))) {
if (hasNull) null else true
} else {
false
}
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
val patternClass = classOf[Pattern].getName
val javaDataType = CodeGenerator.javaType(child.dataType)
val pattern = ctx.freshName("pattern")
val valueArg = ctx.freshName("valueArg")
val patternCache = ctx.addReferenceObj("patternCache", cache.asJava)

val checkNotMatchCode = if (isNotLikeAll) {
s"$pattern.matcher($valueArg.toString()).matches()"
} else {
s"!$pattern.matcher($valueArg.toString()).matches()"
}

ev.copy(code =
code"""
|${eval.code}
|boolean ${ev.isNull} = false;
|boolean ${ev.value} = true;
|if (${eval.isNull}) {
| ${ev.isNull} = true;
|} else {
| $javaDataType $valueArg = ${eval.value};
| for ($patternClass $pattern: $patternCache) {
| if ($checkNotMatchCode) {
| ${ev.value} = false;
| break;
| }
| }
| if (${ev.value} && $hasNull) ${ev.isNull} = true;
|}
""".stripMargin)
}
}

case class LikeAll(child: Expression, patterns: Seq[UTF8String]) extends LikeAllBase {
override def isNotLikeAll: Boolean = false
}

case class NotLikeAll(child: Expression, patterns: Seq[UTF8String]) extends LikeAllBase {
override def isNotLikeAll: Boolean = true
}

// scalastyle:off line.contains.tab
@ExpressionDescription(
usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.",
Expand Down
Expand Up @@ -1406,7 +1406,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
case Some(SqlBaseParser.ANY) | Some(SqlBaseParser.SOME) =>
getLikeQuantifierExprs(ctx.expression).reduceLeft(Or)
case Some(SqlBaseParser.ALL) =>
getLikeQuantifierExprs(ctx.expression).reduceLeft(And)
validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx)
val expressions = ctx.expression.asScala.map(expression)
if (expressions.size > SQLConf.get.optimizerLikeAllConversionThreshold &&
expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) {
// If there are many pattern expressions, will throw StackOverflowError.
// So we use LikeAll or NotLikeAll instead.
val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String])
ctx.NOT match {
case null => LikeAll(e, patterns)
case _ => NotLikeAll(e, patterns)
}
} else {
getLikeQuantifierExprs(ctx.expression).reduceLeft(And)
}
case _ =>
val escapeChar = Option(ctx.escapeChar).map(string).map { str =>
if (str.length != 1) {
Expand Down
Expand Up @@ -216,6 +216,18 @@ object SQLConf {
"for using switch statements in InSet must be non-negative and less than or equal to 600")
.createWithDefault(400)

val OPTIMIZER_LIKE_ALL_CONVERSION_THRESHOLD =
buildConf("spark.sql.optimizer.likeAllConversionThreshold")
.internal()
.doc("Configure the maximum size of the pattern sequence in like all. Spark will convert " +
"the logical combination of like to avoid StackOverflowError. 200 is an empirical value " +
"that will not cause StackOverflowError.")
.version("3.1.0")
.intConf
.checkValue(threshold => threshold >= 0, "The maximum size of pattern sequence " +
"in like all must be non-negative")
.createWithDefault(200)

val PLAN_CHANGE_LOG_LEVEL = buildConf("spark.sql.planChangeLog.level")
.internal()
.doc("Configures the log level for logging the change from the original plan to the new " +
Expand Down Expand Up @@ -3037,6 +3049,8 @@ class SQLConf extends Serializable with Logging {

def optimizerInSetSwitchThreshold: Int = getConf(OPTIMIZER_INSET_SWITCH_THRESHOLD)

def optimizerLikeAllConversionThreshold: Int = getConf(OPTIMIZER_LIKE_ALL_CONVERSION_THRESHOLD)

def planChangeLogLevel: String = getConf(PLAN_CHANGE_LOG_LEVEL)

def planChangeRules: Option[String] = getConf(PLAN_CHANGE_LOG_RULES)
Expand Down
Expand Up @@ -48,6 +48,30 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(mkExpr(regex), expected, create_row(input)) // check row input
}

test("LIKE ALL") {
checkEvaluation(Literal.create(null, StringType).likeAll("%foo%", "%oo"), null)
checkEvaluation(Literal.create("foo", StringType).likeAll("%foo%", "%oo"), true)
checkEvaluation(Literal.create("foo", StringType).likeAll("%foo%", "%bar%"), false)
checkEvaluation(Literal.create("foo", StringType)
.likeAll("%foo%", Literal.create(null, StringType)), null)
checkEvaluation(Literal.create("foo", StringType)
.likeAll(Literal.create(null, StringType), "%foo%"), null)
checkEvaluation(Literal.create("foo", StringType)
.likeAll("%feo%", Literal.create(null, StringType)), false)
checkEvaluation(Literal.create("foo", StringType)
.likeAll(Literal.create(null, StringType), "%feo%"), false)
checkEvaluation(Literal.create("foo", StringType).notLikeAll("tee", "%yoo%"), true)
checkEvaluation(Literal.create("foo", StringType).notLikeAll("%oo%", "%yoo%"), false)
checkEvaluation(Literal.create("foo", StringType)
.notLikeAll("%foo%", Literal.create(null, StringType)), false)
checkEvaluation(Literal.create("foo", StringType)
.notLikeAll(Literal.create(null, StringType), "%foo%"), false)
checkEvaluation(Literal.create("foo", StringType)
.notLikeAll("%yoo%", Literal.create(null, StringType)), null)
checkEvaluation(Literal.create("foo", StringType)
.notLikeAll(Literal.create(null, StringType), "%yoo%"), null)
}

test("LIKE Pattern") {

// null handling
Expand Down
4 changes: 4 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/like-all.sql
@@ -1,3 +1,7 @@
-- test cases for like all
--CONFIG_DIM1 spark.sql.optimizer.likeAllConversionThreshold=0
--CONFIG_DIM1 spark.sql.optimizer.likeAllConversionThreshold=200

CREATE OR REPLACE TEMPORARY VIEW like_all_table AS SELECT * FROM (VALUES
('google', '%oo%'),
('facebook', '%oo%'),
Expand Down

0 comments on commit 3695e99

Please sign in to comment.