diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index b61c4b8d065f2..4cd649b07a5c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * A collection of implicit conversions that create a DSL for constructing catalyst data structures. @@ -102,6 +103,10 @@ package object dsl { def like(other: Expression, escapeChar: Char = '\\'): Expression = Like(expr, other, escapeChar) def rlike(other: Expression): Expression = RLike(expr, other) + def likeAll(others: Expression*): Expression = + LikeAll(expr, others.map(_.eval(EmptyRow).asInstanceOf[UTF8String])) + def notLikeAll(others: Expression*): Expression = + NotLikeAll(expr, others.map(_.eval(EmptyRow).asInstanceOf[UTF8String])) def contains(other: Expression): Expression = Contains(expr, other) def startsWith(other: Expression): Expression = StartsWith(expr, other) def endsWith(other: Expression): Expression = EndsWith(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index c9dd7c7acddde..b4d9921488d5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale import java.util.regex.{Matcher, MatchResult, Pattern} +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.commons.text.StringEscapeUtils +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -178,6 +180,88 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) } } +/** + * Optimized version of LIKE ALL, when all pattern values are literal. + */ +abstract class LikeAllBase extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + + protected def patterns: Seq[UTF8String] + + protected def isNotLikeAll: Boolean + + override def inputTypes: Seq[DataType] = StringType :: Nil + + override def dataType: DataType = BooleanType + + override def nullable: Boolean = true + + private lazy val hasNull: Boolean = patterns.contains(null) + + private lazy val cache = patterns.filterNot(_ == null) + .map(s => Pattern.compile(StringUtils.escapeLikeRegex(s.toString, '\\'))) + + private lazy val matchFunc = if (isNotLikeAll) { + (p: Pattern, inputValue: String) => !p.matcher(inputValue).matches() + } else { + (p: Pattern, inputValue: String) => p.matcher(inputValue).matches() + } + + override def eval(input: InternalRow): Any = { + val exprValue = child.eval(input) + if (exprValue == null) { + null + } else { + if (cache.forall(matchFunc(_, exprValue.toString))) { + if (hasNull) null else true + } else { + false + } + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval = child.genCode(ctx) + val patternClass = classOf[Pattern].getName + val javaDataType = CodeGenerator.javaType(child.dataType) + val pattern = ctx.freshName("pattern") + val valueArg = ctx.freshName("valueArg") + val patternCache = ctx.addReferenceObj("patternCache", cache.asJava) + + val checkNotMatchCode = if (isNotLikeAll) { + s"$pattern.matcher($valueArg.toString()).matches()" + } else { + s"!$pattern.matcher($valueArg.toString()).matches()" + } + + ev.copy(code = + code""" + |${eval.code} + |boolean ${ev.isNull} = false; + |boolean ${ev.value} = true; + |if (${eval.isNull}) { + | ${ev.isNull} = true; + |} else { + | $javaDataType $valueArg = ${eval.value}; + | for ($patternClass $pattern: $patternCache) { + | if ($checkNotMatchCode) { + | ${ev.value} = false; + | break; + | } + | } + | if (${ev.value} && $hasNull) ${ev.isNull} = true; + |} + """.stripMargin) + } +} + +case class LikeAll(child: Expression, patterns: Seq[UTF8String]) extends LikeAllBase { + override def isNotLikeAll: Boolean = false +} + +case class NotLikeAll(child: Expression, patterns: Seq[UTF8String]) extends LikeAllBase { + override def isNotLikeAll: Boolean = true +} + // scalastyle:off line.contains.tab @ExpressionDescription( usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c3855fe088db6..79857a63a69b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1406,7 +1406,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg case Some(SqlBaseParser.ANY) | Some(SqlBaseParser.SOME) => getLikeQuantifierExprs(ctx.expression).reduceLeft(Or) case Some(SqlBaseParser.ALL) => - getLikeQuantifierExprs(ctx.expression).reduceLeft(And) + validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx) + val expressions = ctx.expression.asScala.map(expression) + if (expressions.size > SQLConf.get.optimizerLikeAllConversionThreshold && + expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) { + // If there are many pattern expressions, will throw StackOverflowError. + // So we use LikeAll or NotLikeAll instead. + val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String]) + ctx.NOT match { + case null => LikeAll(e, patterns) + case _ => NotLikeAll(e, patterns) + } + } else { + getLikeQuantifierExprs(ctx.expression).reduceLeft(And) + } case _ => val escapeChar = Option(ctx.escapeChar).map(string).map { str => if (str.length != 1) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 43014feecfd8e..fcf222c8fdab0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -216,6 +216,18 @@ object SQLConf { "for using switch statements in InSet must be non-negative and less than or equal to 600") .createWithDefault(400) + val OPTIMIZER_LIKE_ALL_CONVERSION_THRESHOLD = + buildConf("spark.sql.optimizer.likeAllConversionThreshold") + .internal() + .doc("Configure the maximum size of the pattern sequence in like all. Spark will convert " + + "the logical combination of like to avoid StackOverflowError. 200 is an empirical value " + + "that will not cause StackOverflowError.") + .version("3.1.0") + .intConf + .checkValue(threshold => threshold >= 0, "The maximum size of pattern sequence " + + "in like all must be non-negative") + .createWithDefault(200) + val PLAN_CHANGE_LOG_LEVEL = buildConf("spark.sql.planChangeLog.level") .internal() .doc("Configures the log level for logging the change from the original plan to the new " + @@ -3037,6 +3049,8 @@ class SQLConf extends Serializable with Logging { def optimizerInSetSwitchThreshold: Int = getConf(OPTIMIZER_INSET_SWITCH_THRESHOLD) + def optimizerLikeAllConversionThreshold: Int = getConf(OPTIMIZER_LIKE_ALL_CONVERSION_THRESHOLD) + def planChangeLogLevel: String = getConf(PLAN_CHANGE_LOG_LEVEL) def planChangeRules: Option[String] = getConf(PLAN_CHANGE_LOG_RULES) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 77a32a735f76d..cc5ab5dc7b4e0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -48,6 +48,30 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(mkExpr(regex), expected, create_row(input)) // check row input } + test("LIKE ALL") { + checkEvaluation(Literal.create(null, StringType).likeAll("%foo%", "%oo"), null) + checkEvaluation(Literal.create("foo", StringType).likeAll("%foo%", "%oo"), true) + checkEvaluation(Literal.create("foo", StringType).likeAll("%foo%", "%bar%"), false) + checkEvaluation(Literal.create("foo", StringType) + .likeAll("%foo%", Literal.create(null, StringType)), null) + checkEvaluation(Literal.create("foo", StringType) + .likeAll(Literal.create(null, StringType), "%foo%"), null) + checkEvaluation(Literal.create("foo", StringType) + .likeAll("%feo%", Literal.create(null, StringType)), false) + checkEvaluation(Literal.create("foo", StringType) + .likeAll(Literal.create(null, StringType), "%feo%"), false) + checkEvaluation(Literal.create("foo", StringType).notLikeAll("tee", "%yoo%"), true) + checkEvaluation(Literal.create("foo", StringType).notLikeAll("%oo%", "%yoo%"), false) + checkEvaluation(Literal.create("foo", StringType) + .notLikeAll("%foo%", Literal.create(null, StringType)), false) + checkEvaluation(Literal.create("foo", StringType) + .notLikeAll(Literal.create(null, StringType), "%foo%"), false) + checkEvaluation(Literal.create("foo", StringType) + .notLikeAll("%yoo%", Literal.create(null, StringType)), null) + checkEvaluation(Literal.create("foo", StringType) + .notLikeAll(Literal.create(null, StringType), "%yoo%"), null) + } + test("LIKE Pattern") { // null handling diff --git a/sql/core/src/test/resources/sql-tests/inputs/like-all.sql b/sql/core/src/test/resources/sql-tests/inputs/like-all.sql index a084dbef61a0c..f83277376e680 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/like-all.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/like-all.sql @@ -1,3 +1,7 @@ +-- test cases for like all +--CONFIG_DIM1 spark.sql.optimizer.likeAllConversionThreshold=0 +--CONFIG_DIM1 spark.sql.optimizer.likeAllConversionThreshold=200 + CREATE OR REPLACE TEMPORARY VIEW like_all_table AS SELECT * FROM (VALUES ('google', '%oo%'), ('facebook', '%oo%'),