Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions clickhouse-core/src/main/scala/com/clickhouse/spark/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import scala.annotation.tailrec
import scala.collection.mutable
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}
import scala.collection.JavaConverters.asScalaSetConverter

object Utils extends Logging {

Expand Down Expand Up @@ -186,4 +187,153 @@ object Utils extends Logging {
def setTesting(): Unit = System.setProperty(IS_TESTING, "true")

def isTesting: Boolean = System.getProperty(IS_TESTING) == "true"

object RuntimeDetector {

def detectRuntime(): Option[String] =
RuntimeDetector.detectViaStackTrace()
.orElse(RuntimeDetector.detectViaClassLoader())
.orElse(RuntimeDetector.detectViaThreadNames())

/**
* Examines the current stack trace and loaded classes for platform-specific signatures
*/
def detectViaStackTrace(): Option[String] = {
val stackTrace = Thread.currentThread().getStackTrace
val stackClasses = stackTrace.map(_.getClassName.toLowerCase)

// Check for platform-specific classes in stack
if (
stackClasses.exists(c =>
c.contains("com.databricks.logging") ||
c.contains("databricks.spark") ||
c.contains("com.databricks.backend")
)
) {
Some("Databricks")
} else if (
stackClasses.exists { c =>
c.contains("com.amazonaws.services.glue") ||
c.contains("aws.glue") ||
c.contains("awsglue")
}
) {
Some("Glue")
} else if (
stackClasses.exists(c =>
c.contains("com.amazon.emr") ||
c.contains("amazon.emrfs")
)
) {
Some("EMR")
} else if (
stackClasses.exists(c =>
c.contains("com.google.cloud.dataproc") ||
c.contains("dataproc")
)
) {
Some("Dataproc")
} else if (
stackClasses.exists(c =>
c.contains("com.microsoft.azure.synapse") ||
c.contains("synapse.spark")
)
) {
Some("Synapse")
} else {
None
}
}
Comment on lines +201 to +246
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The detectViaStackTrace() method doesn't check for HDInsight, but HDInsight is included in the detectViaClassLoader() method. Consider adding a check for HDInsight in stack traces for consistency.


/**
* More comprehensive check using ClassLoader to find platform-specific classes
*/
def detectViaClassLoader(): Option[String] = {
val classLoader = Thread.currentThread().getContextClassLoader

case class PlatformSignature(name: String, classNames: Seq[String])

val platformSignatures = Seq(
PlatformSignature(
"Databricks",
Seq(
"com.databricks.spark.util.DatabricksLogging",
"com.databricks.backend.daemon.driver.DriverLocal",
"com.databricks.dbutils_v1.DBUtilsHolder",
"com.databricks.spark.util.FrameProfiler"
)
),
PlatformSignature(
"Glue",
Seq(
"com.amazonaws.services.glue.GlueContext",
"com.amazonaws.services.glue.util.GlueArgParser",
"com.amazonaws.services.glue.DynamicFrame"
)
),
PlatformSignature(
"EMR",
Seq(
"com.amazon.ws.emr.hadoop.fs.EmrFileSystem",
"com.amazon.emr.kinesis.client.KinesisConnector",
"com.amazon.emr.cloudwatch.CloudWatchSink"
)
),
PlatformSignature(
"Dataproc",
Seq(
"com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem",
"com.google.cloud.dataproc.DataprocHadoopConfiguration",
"com.google.cloud.spark.bigquery.BigQueryConnector"
)
),
PlatformSignature(
"Synapse",
Seq(
"com.microsoft.azure.synapse.ml.core.env.SynapseEnv",
"com.microsoft.azure.synapse.ml.logging.SynapseMLLogging"
)
),
PlatformSignature(
"HDInsight",
Seq(
"com.microsoft.azure.hdinsight.spark.common.SparkBatchJob",
"com.microsoft.hdinsight.spark.common.HttpFutureCallback"
)
)
)

// Try to load platform-specific classes
def classExists(className: String): Boolean =
try {
Class.forName(className, false, classLoader)
true
} catch {
case _: ClassNotFoundException => false
}

platformSignatures.collectFirst {
case PlatformSignature(name, classes) if classes.exists(classExists) => name
}
}

/**
* Check running threads for platform-specific thread names
*/
def detectViaThreadNames(): Option[String] = {
val threadNames = Thread.getAllStackTraces.keySet().asScala.map(_.getName.toLowerCase)

if (threadNames.exists(_.contains("databricks"))) {
Some("Databricks")
} else if (threadNames.exists(t => t.contains("glue") || t.contains("awsglue"))) {
Some("Glue")
} else if (threadNames.exists(_.contains("emr"))) {
Some("EMR")
} else if (threadNames.exists(_.contains("dataproc"))) {
Some("Dataproc")
} else {
None
}
Comment on lines +323 to +336
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The detectViaThreadNames() method doesn't check for Synapse in thread names, while the other detection methods do. Consider adding a check for Synapse thread names for consistency.

Comment on lines +323 to +336
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The detectViaThreadNames() method doesn't check for HDInsight, but HDInsight is included in the detectViaClassLoader() method. Consider adding a check for HDInsight thread names for consistency.

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ import com.clickhouse.spark.format.{
NamesAndTypes,
SimpleOutput
}
import com.clickhouse.spark.Utils.RuntimeDetector.detectRuntime
import com.clickhouse.spark.spec.NodeSpec
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.node.ObjectNode
import com.clickhouse.spark.format._

import java.io.InputStream
import java.util.UUID
Expand All @@ -42,23 +42,42 @@ class NodeClient(val nodeSpec: NodeSpec) extends AutoCloseable with Logging {
// TODO: add configurable timeout
private val timeout: Int = 30000

private lazy val userAgent = {
private lazy val userAgent: String = {
val title = getClass.getPackage.getImplementationTitle
val version = getClass.getPackage.getImplementationVersion
if (version != null && title != null) {
val versions = version.split("_")
if (versions.length < 3) {
"Spark-ClickHouse-Connector"
} else {
val sparkVersion = versions(0)
val scalaVersion = versions(1)
val connectorVersion = versions(2)
s"${title}/${connectorVersion} (fv:spark/${sparkVersion}, lv:scala/${scalaVersion})"
}
buildUserAgent(title, version)
}

private def buildUserAgent(title: String, version: String): String =
(Option(title), Option(version)) match {
case (Some(t), Some(v)) =>
parseVersionString(v) match {
case Some((spark, scala, connector)) =>
val runtimeSuffix = getRuntimeEnvironmentSuffix()
s"$t/$connector (fv:spark/$spark, lv:scala/$scala$runtimeSuffix)"
case None => "Spark-ClickHouse-Connector"
}
case _ => "Spark-ClickHouse-Connector"
}

private def parseVersionString(version: String): Option[(String, String, String)] =
version.split("_") match {
case Array(spark, scala, connector, _*) => Some((spark, scala, connector))
case _ => None
}

private def getRuntimeEnvironmentSuffix(): String =
if (shouldInferRuntime()) {
detectRuntime()
.filter(_.nonEmpty)
.fold("")(env => s", env:$env")
} else {
"Spark-ClickHouse-Connector"
""
}
}

private def shouldInferRuntime(): Boolean =
nodeSpec.infer_runtime_env.equalsIgnoreCase("true") || nodeSpec.infer_runtime_env == "1"

private val node: ClickHouseNode = ClickHouseNode.builder()
.options(nodeSpec.options)
.host(nodeSpec.host)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ case class NodeSpec(
@JsonProperty("username") username: String = "default",
@JsonProperty("password") password: String = "",
@JsonProperty("database") database: String = "default",
@JsonProperty("infer_runtime_env") infer_runtime_env: String = "true",
@JsonProperty("options") options: util.Map[String, String] = Collections.emptyMap()
) extends Nodes with ToJson with Serializable {
@JsonProperty("host") def host: String = findHost(_host)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ trait ClickHouseHelper extends Logging {
username = options.getOrDefault(CATALOG_PROP_USER, "default"),
password = options.getOrDefault(CATALOG_PROP_PASSWORD, ""),
database = options.getOrDefault(CATALOG_PROP_DATABASE, "default"),
infer_runtime_env = options.getOrDefault(CATALOG_INFER_RUNTIME_ENV, "true"),
options = new JHashMap(clientOpts.asJava)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ object Constants {
final val CATALOG_PROP_PASSWORD = "password"
final val CATALOG_PROP_DATABASE = "database"
final val CATALOG_PROP_TZ = "timezone" // server(default), client, UTC+3, Asia/Shanghai, etc.
final val CATALOG_INFER_RUNTIME_ENV = "infer_runtime_env"
final val CATALOG_PROP_OPTION_PREFIX = "option."
final val CATALOG_PROP_IGNORE_OPTIONS = Seq(
DATABASE.getKey, COMPRESS.getKey, DECOMPRESS.getKey, FORMAT.getKey, RETRY.getKey,
USE_SERVER_TIME_ZONE.getKey, USE_SERVER_TIME_ZONE_FOR_DATES.getKey, SERVER_TIME_ZONE.getKey, USE_TIME_ZONE.getKey)
USE_SERVER_TIME_ZONE.getKey, USE_SERVER_TIME_ZONE_FOR_DATES.getKey, SERVER_TIME_ZONE.getKey, USE_TIME_ZONE.getKey,
CATALOG_INFER_RUNTIME_ENV)

//////////////////////////////////////////////////////////
////////// clickhouse datasource read properties /////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ trait ClickHouseHelper extends Logging {
username = options.getOrDefault(CATALOG_PROP_USER, "default"),
password = options.getOrDefault(CATALOG_PROP_PASSWORD, ""),
database = options.getOrDefault(CATALOG_PROP_DATABASE, "default"),
infer_runtime_env = options.getOrDefault(CATALOG_INFER_RUNTIME_ENV, "true"),
options = new JHashMap(clientOpts.asJava)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ object Constants {
final val CATALOG_PROP_PASSWORD = "password"
final val CATALOG_PROP_DATABASE = "database"
final val CATALOG_PROP_TZ = "timezone" // server(default), client, UTC+3, Asia/Shanghai, etc.
final val CATALOG_INFER_RUNTIME_ENV = "infer_runtime_env"
final val CATALOG_PROP_OPTION_PREFIX = "option."
final val CATALOG_PROP_IGNORE_OPTIONS = Seq(
DATABASE.getKey, COMPRESS.getKey, DECOMPRESS.getKey, FORMAT.getKey, RETRY.getKey,
USE_SERVER_TIME_ZONE.getKey, USE_SERVER_TIME_ZONE_FOR_DATES.getKey, SERVER_TIME_ZONE.getKey, USE_TIME_ZONE.getKey)
USE_SERVER_TIME_ZONE.getKey, USE_SERVER_TIME_ZONE_FOR_DATES.getKey, SERVER_TIME_ZONE.getKey, USE_TIME_ZONE.getKey,
CATALOG_INFER_RUNTIME_ENV)

//////////////////////////////////////////////////////////
////////// clickhouse datasource read properties /////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ trait ClickHouseHelper extends Logging {
username = options.getOrDefault(CATALOG_PROP_USER, "default"),
password = options.getOrDefault(CATALOG_PROP_PASSWORD, ""),
database = options.getOrDefault(CATALOG_PROP_DATABASE, "default"),
infer_runtime_env = options.getOrDefault(CATALOG_INFER_RUNTIME_ENV, "true"),
options = new JHashMap(clientOpts.asJava)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ object Constants {
final val CATALOG_PROP_PASSWORD = "password"
final val CATALOG_PROP_DATABASE = "database"
final val CATALOG_PROP_TZ = "timezone" // server(default), client, UTC+3, Asia/Shanghai, etc.
final val CATALOG_INFER_RUNTIME_ENV = "infer_runtime_env"
final val CATALOG_PROP_OPTION_PREFIX = "option."
final val CATALOG_PROP_IGNORE_OPTIONS = Seq(
DATABASE.getKey, COMPRESS.getKey, DECOMPRESS.getKey, FORMAT.getKey, RETRY.getKey,
USE_SERVER_TIME_ZONE.getKey, USE_SERVER_TIME_ZONE_FOR_DATES.getKey, SERVER_TIME_ZONE.getKey, USE_TIME_ZONE.getKey)
USE_SERVER_TIME_ZONE.getKey, USE_SERVER_TIME_ZONE_FOR_DATES.getKey, SERVER_TIME_ZONE.getKey, USE_TIME_ZONE.getKey,
CATALOG_INFER_RUNTIME_ENV)

//////////////////////////////////////////////////////////
////////// clickhouse datasource read properties /////////
Expand Down
Loading