Skip to content

Commit

Permalink
[SEDONA-568] Refactor TestBaseScala to allow customizable Spark confi…
Browse files Browse the repository at this point in the history
…gurations (#1455)

- Added default Spark configurations in `TestBaseScala` trait.
- Provided a method `defaultSparkConfig` for default configurations.
- Allowed subclasses to override `sparkConfig` to add or modify Spark configurations.
- Updated initialization of Spark session and Spark context to use the provided configurations.
- Included comments to explain how to override configurations in subclasses.
- Ensured that default configurations are preserved if not overridden by subclasses.
  • Loading branch information
zhangfengcdt committed Jun 7, 2024
1 parent 9f4be60 commit 94cd82b
Showing 1 changed file with 23 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import org.apache.sedona.common.Functions.{frechetDistance, hausdorffDistance}
import org.apache.sedona.common.Predicates.dWithin
import org.apache.sedona.common.sphere.{Haversine, Spheroid}
import org.apache.sedona.spark.SedonaContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.locationtech.jts.geom._
import org.scalatest.{BeforeAndAfterAll, FunSpec}

Expand All @@ -39,16 +40,27 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
Logger.getLogger("akka").setLevel(Level.WARN)
Logger.getLogger("org.apache.sedona.core").setLevel(Level.WARN)

val warehouseLocation = System.getProperty("user.dir") + "/target/"
val sparkSession = SedonaContext.builder().
master("local[*]").appName("sedonasqlScalaTest")
.config("spark.sql.warehouse.dir", warehouseLocation)
// We need to be explicit about broadcasting in tests.
.config("sedona.join.autoBroadcastJoinThreshold", "-1")
.config("spark.kryoserializer.buffer.max", "64m")
.getOrCreate()
// Default Spark configurations
def defaultSparkConfig: Map[String, String] = Map(
"spark.sql.warehouse.dir" -> (System.getProperty("user.dir") + "/target/"),
"sedona.join.autoBroadcastJoinThreshold" -> "-1",
"spark.kryoserializer.buffer.max" -> "64m"
)

val sc = sparkSession.sparkContext
// Method to be overridden by subclasses to provide additional configurations
def sparkConfig: Map[String, String] = defaultSparkConfig

// Lazy initialization of Spark session using configurations
lazy val sparkSession: SparkSession = {
val builder = SedonaContext.builder()
.master("local[*]")
.appName("sedonasqlScalaTest")
sparkConfig.foreach { case (key, value) => builder.config(key, value) }
builder.getOrCreate()
}

// Lazy initialization of Spark context from the Spark session
lazy val sc: SparkContext = sparkSession.sparkContext

val resourceFolder = System.getProperty("user.dir") + "/src/test/resources/"
val mixedWkbGeometryInputLocation = resourceFolder + "county_small_wkb.tsv"
Expand Down Expand Up @@ -80,6 +92,7 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
private val factory = new GeometryFactory()

override def beforeAll(): Unit = {
super.beforeAll()
SedonaContext.create(sparkSession)
}

Expand Down

0 comments on commit 94cd82b

Please sign in to comment.