diff --git a/common/utils-java/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/utils-java/src/main/java/org/apache/spark/network/util/JavaUtils.java index b106ad001d93d..2cf4570488ee0 100644 --- a/common/utils-java/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/utils-java/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -31,6 +31,8 @@ import java.nio.file.SimpleFileVisitor; import java.nio.file.StandardCopyOption; import java.nio.file.attribute.BasicFileAttributes; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.util.*; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; @@ -767,4 +769,58 @@ public static void checkState(boolean check, String msg, Object... args) { throw new IllegalStateException(String.format(msg, args)); } } + + private static final HexFormat LOWERCASE_HEX = HexFormat.of(); + + /** + * Computes the digest of the input bytes using the given algorithm + * and returns the result as a lowercase hex string. + */ + public static String digestToHexString(String algorithm, byte[] input) { + try { + return LOWERCASE_HEX.formatHex(MessageDigest.getInstance(algorithm).digest(input)); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } + } + + /** + * Computes the digest of the input string using the given algorithm + * and returns the result as a lowercase hex string. + */ + public static String digestToHexString(String algorithm, String input) { + return digestToHexString(algorithm, input.getBytes(StandardCharsets.UTF_8)); + } + + /** + * Computes the MD5 digest of the input bytes + * and returns the result as a lowercase hex string. + */ + public static String md5Hex(byte[] input) { + return digestToHexString("MD5", input); + } + + /** + * Computes the MD5 digest of the input string + * and returns the result as a lowercase hex string. + */ + public static String md5Hex(String input) { + return digestToHexString("MD5", input); + } + + /** + * Computes the SHA-256 digest of the input bytes + * and returns the result as a lowercase hex string. + */ + public static String sha256Hex(byte[] input) { + return digestToHexString("SHA-256", input); + } + + /** + * Computes the SHA-256 digest of the input string + * and returns the result as a lowercase hex string. + */ + public static String sha256Hex(String input) { + return digestToHexString("SHA-256", input); + } } diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 62992a4f2d679..928be52494776 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -624,6 +624,11 @@ This file is divided into 3 sections: Please use Apache Log4j 2 instead. + + + Please use org.apache.spark.network.util.JavaUtils.digestToHexString instead. + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 88e22a91a64a7..0f27dee9dbc84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -23,9 +23,8 @@ import java.util.zip.CRC32 import scala.annotation.tailrec -import org.apache.commons.codec.digest.DigestUtils -import org.apache.commons.codec.digest.MessageDigestAlgorithms - +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.network.util.JavaUtils.{digestToHexString, md5Hex, sha256Hex} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -72,11 +71,11 @@ case class Md5(child: Expression) override def contextIndependentFoldable: Boolean = child.contextIndependentFoldable protected override def nullSafeEval(input: Any): Any = - UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[Array[Byte]])) + UTF8String.fromString(md5Hex(input.asInstanceOf[Array[Byte]])) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => - s"UTF8String.fromString(${classOf[DigestUtils].getName}.md5Hex($c))") + s"UTF8String.fromString(${classOf[JavaUtils].getName}.md5Hex($c))") } override protected def withNewChildInternal(newChild: Expression): Md5 = copy(child = newChild) @@ -122,35 +121,29 @@ case class Sha2(left: Expression, right: Expression) val input = input1.asInstanceOf[Array[Byte]] bitLength match { case 224 => - UTF8String.fromString( - new DigestUtils(MessageDigestAlgorithms.SHA_224).digestAsHex(input)) + UTF8String.fromString(digestToHexString("SHA-224", input)) case 256 | 0 => - UTF8String.fromString(DigestUtils.sha256Hex(input)) + UTF8String.fromString(sha256Hex(input)) case 384 => - UTF8String.fromString(DigestUtils.sha384Hex(input)) + UTF8String.fromString(digestToHexString("SHA-384", input)) case 512 => - UTF8String.fromString(DigestUtils.sha512Hex(input)) + UTF8String.fromString(digestToHexString("SHA-512", input)) case _ => null } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val digestUtils = classOf[DigestUtils].getName - val messageDigestAlgorithms = classOf[MessageDigestAlgorithms].getName + val javaUtils = classOf[JavaUtils].getName nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" if ($eval2 == 224) { - ${ev.value} = UTF8String.fromString( - new $digestUtils($messageDigestAlgorithms.SHA_224).digestAsHex($eval1)); + ${ev.value} = UTF8String.fromString($javaUtils.digestToHexString("SHA-224", $eval1)); } else if ($eval2 == 256 || $eval2 == 0) { - ${ev.value} = - UTF8String.fromString($digestUtils.sha256Hex($eval1)); + ${ev.value} = UTF8String.fromString($javaUtils.sha256Hex($eval1)); } else if ($eval2 == 384) { - ${ev.value} = - UTF8String.fromString($digestUtils.sha384Hex($eval1)); + ${ev.value} = UTF8String.fromString($javaUtils.digestToHexString("SHA-384", $eval1)); } else if ($eval2 == 512) { - ${ev.value} = - UTF8String.fromString($digestUtils.sha512Hex($eval1)); + ${ev.value} = UTF8String.fromString($javaUtils.digestToHexString("SHA-512", $eval1)); } else { ${ev.isNull} = true; } @@ -186,11 +179,11 @@ case class Sha1(child: Expression) override def contextIndependentFoldable: Boolean = child.contextIndependentFoldable protected override def nullSafeEval(input: Any): Any = - UTF8String.fromString(DigestUtils.sha1Hex(input.asInstanceOf[Array[Byte]])) + UTF8String.fromString(digestToHexString("SHA-1", input.asInstanceOf[Array[Byte]])) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => - s"UTF8String.fromString(${classOf[DigestUtils].getName}.sha1Hex($c))" + s"""UTF8String.fromString(${classOf[JavaUtils].getName}.digestToHexString("SHA-1", $c))""" ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index c084b67d4d57b..c1d44e9458992 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -23,10 +23,10 @@ import java.time.{Duration, LocalTime, Period, ZoneId, ZoneOffset} import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions -import org.apache.commons.codec.digest.DigestUtils import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite +import org.apache.spark.network.util.JavaUtils.{digestToHexString, sha256Hex} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder} @@ -65,13 +65,13 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), Literal(224)), "107c5072b799c4771f328304cfe1ebb375eb6ea7f35a3aa753836fad") checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), Literal(0)), - DigestUtils.sha256Hex("ABC")) + sha256Hex("ABC")) checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), Literal(256)), - DigestUtils.sha256Hex("ABC")) + sha256Hex("ABC")) checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), - DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6))) + digestToHexString("SHA-384", Array[Byte](1, 2, 3, 4, 5, 6))) checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), Literal(512)), - DigestUtils.sha512Hex("ABC")) + digestToHexString("SHA-512", "ABC")) // unsupported bit length checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null) // null input and valid bit length diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala index 8961a49c09ce1..a6ccf39886e22 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala @@ -26,9 +26,9 @@ import java.util.zip.CRC32 import com.google.protobuf.ByteString import io.grpc.{ManagedChannel, Server} import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder} -import org.apache.commons.codec.digest.DigestUtils.sha256Hex import org.apache.spark.connect.proto.AddArtifactsRequest +import org.apache.spark.network.util.JavaUtils.sha256Hex import org.apache.spark.sql.Artifact import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.test.ConnectFunSuite diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala index 44a2a7aa9a2f0..1cfc9f170e896 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala @@ -32,12 +32,12 @@ import scala.util.control.NonFatal import com.google.protobuf.ByteString import io.grpc.StatusRuntimeException import io.grpc.stub.StreamObserver -import org.apache.commons.codec.digest.DigestUtils.sha256Hex import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.AddArtifactsResponse import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary +import org.apache.spark.network.util.JavaUtils.sha256Hex import org.apache.spark.sql.Artifact import org.apache.spark.sql.Artifact.{newCacheArtifact, newIvyArtifacts} import org.apache.spark.util.{SparkFileUtils, SparkStringUtils, SparkThreadUtils} diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala index 275808942d37d..96c7ddaef829a 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala @@ -23,10 +23,10 @@ import scala.concurrent.duration._ import scala.jdk.CollectionConverters._ import io.grpc.stub.StreamObserver -import org.apache.commons.codec.digest.DigestUtils.sha256Hex import org.apache.spark.connect.proto import org.apache.spark.connect.proto.ArtifactStatusesResponse +import org.apache.spark.network.util.JavaUtils.sha256Hex import org.apache.spark.sql.connect.ResourceHelper import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.ThreadUtils