Skip to content

Commit

Permalink
Replace getSparkVersion with spark.version, improve version check (#142)
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Feb 21, 2023
1 parent 562f4a4 commit f525680
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 37 deletions.
39 changes: 4 additions & 35 deletions src/main/scala/uk/co/gresearch/spark/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,46 +42,15 @@ package object spark extends Logging with SparkVersion {
"_" * (existing.map(_.takeWhile(_ == '_').length).reduceOption(_ max _).getOrElse(0) + 1)
}

/**
* Detects the Spark version by inspecting the classpath.
* Falls back to the Spark version that this package is compiled for.
*/
private[spark] lazy val getSparkVersion: String = {
val scalaCompatVersionOpt = Properties.releaseVersion.map(_.split("\\.").take(2).mkString("."))
scalaCompatVersionOpt.flatMap { scalaCompatVersion =>
val propFilePath = s"META-INF/maven/org.apache.spark/spark-sql_$scalaCompatVersion/pom.properties"
Option(ClassLoader.getSystemClassLoader.getResourceAsStream(propFilePath)).flatMap { in =>
val props = try {
val props = new java.util.Properties()
props.load(in)
Some(props)
} catch {
case _: IOException => None
}

props.flatMap { props =>
val ver = Option(props.getProperty("version"))
val group = Option(props.getProperty("groupId"))
val artifact = Option(props.getProperty("artifactId"))

ver.filter(_ =>
group.exists(_.equals("org.apache.spark")) &&
artifact.exists(_.equals(s"spark-sql_$scalaCompatVersion"))
)
}
}
}.getOrElse(s"$SparkCompatVersionString.x")
}

// https://issues.apache.org/jira/browse/SPARK-40588
private[spark] def writePartitionedByRequiresCaching[T](ds: Dataset[T]): Boolean = {
ds.sparkSession.conf.get(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key,
SQLConf.ADAPTIVE_EXECUTION_ENABLED.defaultValue.getOrElse(true).toString
).equalsIgnoreCase("true") && Some(getSparkVersion).exists(ver =>
ver.startsWith("3.0.") || ver.startsWith("3.1.") ||
ver.equals("3.2.x") || ver.startsWith("3.2.0") || ver.startsWith("3.2.1") || ver.startsWith("3.2.2") ||
ver.equals("3.3.x") || ver.startsWith("3.3.0") || ver.startsWith("3.3.1")
).equalsIgnoreCase("true") && Some(ds.sparkSession.version).exists(ver =>
Set("3.0.", "3.1.", "3.2.0", "3.2.1" ,"3.2.2", "3.3.0", "3.3.1").exists(pat =>
if (pat.endsWith(".")) { ver.startsWith(pat) } else { ver.equals(pat) || ver.startsWith(pat + "-") }
)
)
}

Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/uk/co/gresearch/spark/SparkSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class SparkSuite extends AnyFunSuite with SparkTestSession with SparkVersion {
val emptyDataFrame: DataFrame = spark.createDataFrame(Seq.empty[Value])

test("Get Spark version") {
assert(getSparkVersion.startsWith(s"$SparkCompatMajorVersion.$SparkCompatMinorVersion."))
assert(spark.version.startsWith(s"$SparkCompatMajorVersion.$SparkCompatMinorVersion."))
assert(SparkCompatVersion === (SparkCompatMajorVersion, SparkCompatMinorVersion))
assert(SparkCompatVersionString === s"$SparkCompatMajorVersion.$SparkCompatMinorVersion")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession {

test("write partitionedBy requires caching with AQE enabled") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
Some(getSparkVersion)
Some(spark.version)
.map(version => Set("3.0.", "3.1.", "3.2.0", "3.2.1", "3.2.2", "3.3.0", "3.3.1").exists(version.startsWith))
.foreach(expected => assert(writePartitionedByRequiresCaching(values) === expected))
}
Expand Down

0 comments on commit f525680

Please sign in to comment.