Skip to content

Commit

Permalink
fix review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Hongbin Ma (Mahone) <mahongbin@apache.org>
  • Loading branch information
binmahone committed May 22, 2024
1 parent e900d50 commit 7237cb6
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,7 @@ spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.suites

import org.apache.spark.sql.catalyst.expressions.JsonExpressionsSuite
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.utils.RapidsTestsTrait
import org.apache.spark.sql.rapids.utils.{RapidsJsonConfTrait, RapidsTestsTrait}

class RapidsJsonExpressionsSuite extends JsonExpressionsSuite with RapidsTestsTrait {
override def beforeAll(): Unit = {
super.beforeAll()
SQLConf.get.setConfString("spark.rapids.sql.expression.JsonTuple", "true")
SQLConf.get.setConfString("spark.rapids.sql.expression.GetJsonObject", "true")
SQLConf.get.setConfString("spark.rapids.sql.expression.JsonToStructs", "true")
SQLConf.get.setConfString("spark.rapids.sql.expression.StructsToJson", "true")
}

override def afterAll(): Unit = {
super.afterAll()
SQLConf.get.unsetConf("spark.rapids.sql.expression.JsonTuple")
SQLConf.get.unsetConf("spark.rapids.sql.expression.GetJsonObject")
SQLConf.get.unsetConf("spark.rapids.sql.expression.JsonToStructs")
SQLConf.get.unsetConf("spark.rapids.sql.expression.StructsToJson")
}
}
class RapidsJsonExpressionsSuite
extends JsonExpressionsSuite with RapidsTestsTrait with RapidsJsonConfTrait {}
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,7 @@ spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.suites

import org.apache.spark.sql.JsonFunctionsSuite
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.utils.RapidsSQLTestsTrait
import org.apache.spark.sql.rapids.utils.{RapidsJsonConfTrait, RapidsSQLTestsTrait}

class RapidsJsonFunctionsSuite extends JsonFunctionsSuite with RapidsSQLTestsTrait {
override def beforeAll(): Unit = {
super.beforeAll()
SQLConf.get.setConfString("spark.rapids.sql.expression.JsonTuple", "true")
SQLConf.get.setConfString("spark.rapids.sql.expression.GetJsonObject", "true")
SQLConf.get.setConfString("spark.rapids.sql.expression.JsonToStructs", "true")
SQLConf.get.setConfString("spark.rapids.sql.expression.StructsToJson", "true")
}

override def afterAll(): Unit = {
super.afterAll()
SQLConf.get.unsetConf("spark.rapids.sql.expression.JsonTuple")
SQLConf.get.unsetConf("spark.rapids.sql.expression.GetJsonObject")
SQLConf.get.unsetConf("spark.rapids.sql.expression.JsonToStructs")
SQLConf.get.unsetConf("spark.rapids.sql.expression.StructsToJson")
}
}
class RapidsJsonFunctionsSuite
extends JsonFunctionsSuite with RapidsSQLTestsTrait with RapidsJsonConfTrait {}
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,13 @@ import org.apache.spark.sql.execution.datasources.{InMemoryFileIndex, NoopCache}
import org.apache.spark.sql.execution.datasources.json.JsonSuite
import org.apache.spark.sql.execution.datasources.v2.json.JsonScanBuilder
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.utils.RapidsSQLTestsBaseTrait
import org.apache.spark.sql.rapids.utils.{RapidsJsonConfTrait, RapidsSQLTestsBaseTrait}
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class RapidsJsonSuite extends JsonSuite with RapidsSQLTestsBaseTrait {

override def beforeAll(): Unit = {
super.beforeAll()
SQLConf.get.setConfString("spark.rapids.sql.expression.JsonTuple", "true")
SQLConf.get.setConfString("spark.rapids.sql.expression.GetJsonObject", "true")
SQLConf.get.setConfString("spark.rapids.sql.expression.JsonToStructs", "true")
SQLConf.get.setConfString("spark.rapids.sql.expression.StructsToJson", "true")
}

override def afterAll(): Unit = {
super.afterAll()
SQLConf.get.unsetConf("spark.rapids.sql.expression.JsonTuple")
SQLConf.get.unsetConf("spark.rapids.sql.expression.GetJsonObject")
SQLConf.get.unsetConf("spark.rapids.sql.expression.JsonToStructs")
SQLConf.get.unsetConf("spark.rapids.sql.expression.StructsToJson")
}

class RapidsJsonSuite
extends JsonSuite with RapidsSQLTestsBaseTrait with RapidsJsonConfTrait {
/** Returns full path to the given file in the resource folder */
override protected def testFile(fileName: String): String = {
getWorkspaceFilePath("sql", "core", "src", "test", "resources").toString + "/" + fileName
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.utils

import org.scalatest.{BeforeAndAfterAll, Suite}

import org.apache.spark.sql.internal.SQLConf

trait RapidsJsonConfTrait extends BeforeAndAfterAll { this: Suite =>
override def beforeAll(): Unit = {
super.beforeAll()
SQLConf.get.setConfString("spark.rapids.sql.expression.JsonTuple", true.toString)
SQLConf.get.setConfString("spark.rapids.sql.expression.GetJsonObject", true.toString)
SQLConf.get.setConfString("spark.rapids.sql.expression.JsonToStructs", true.toString)
SQLConf.get.setConfString("spark.rapids.sql.expression.StructsToJson", true.toString)
}

override def afterAll(): Unit = {
SQLConf.get.unsetConf("spark.rapids.sql.expression.JsonTuple")
SQLConf.get.unsetConf("spark.rapids.sql.expression.GetJsonObject")
SQLConf.get.unsetConf("spark.rapids.sql.expression.JsonToStructs")
SQLConf.get.unsetConf("spark.rapids.sql.expression.StructsToJson")
super.afterAll()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,10 @@ trait RapidsTestsTrait extends RapidsTestsCommonTrait {
}

def rapidsCheckExpression(origExpr: Expression, expected: Any, inputRow: InternalRow): Unit = {
// many of of the expressions in RAPIDS do not support
// vectorized parameters. (e.g. regexp_replace)
// So we downgrade all expression
// evaluation to use scalar parameters.
// In a follow-up issue we'll take care of the expressions
// those already support vectorized paramters.
// many of the expressions in RAPIDS do not support vectorized parameters(e.g. regexp_replace).
// So we downgrade all expression evaluation to use scalar parameters.
// In a follow-up issue (https://github.com/NVIDIA/spark-rapids/issues/10859),
// we'll take care of the expressions those already support vectorized parameters.
val expression = origExpr.transformUp {
case BoundReference(ordinal, dataType, _) =>
Literal(inputRow.asInstanceOf[GenericInternalRow].get(ordinal, dataType), dataType)
Expand Down

0 comments on commit 7237cb6

Please sign in to comment.