Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import java.nio.file.SimpleFileVisitor;
import java.nio.file.StandardCopyOption;
import java.nio.file.attribute.BasicFileAttributes;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
Expand Down Expand Up @@ -767,4 +769,58 @@ public static void checkState(boolean check, String msg, Object... args) {
throw new IllegalStateException(String.format(msg, args));
}
}

private static final HexFormat LOWERCASE_HEX = HexFormat.of();

/**
* Computes the digest of the input bytes using the given algorithm
* and returns the result as a lowercase hex string.
*/
public static String digestToHexString(String algorithm, byte[] input) {
try {
return LOWERCASE_HEX.formatHex(MessageDigest.getInstance(algorithm).digest(input));
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
}

/**
* Computes the digest of the input string using the given algorithm
* and returns the result as a lowercase hex string.
*/
public static String digestToHexString(String algorithm, String input) {
return digestToHexString(algorithm, input.getBytes(StandardCharsets.UTF_8));
}

/**
* Computes the MD5 digest of the input bytes
* and returns the result as a lowercase hex string.
*/
public static String md5Hex(byte[] input) {
return digestToHexString("MD5", input);
}

/**
* Computes the MD5 digest of the input string
* and returns the result as a lowercase hex string.
*/
public static String md5Hex(String input) {
return digestToHexString("MD5", input);
}

/**
* Computes the SHA-256 digest of the input bytes
* and returns the result as a lowercase hex string.
*/
public static String sha256Hex(byte[] input) {
return digestToHexString("SHA-256", input);
}

/**
* Computes the SHA-256 digest of the input string
* and returns the result as a lowercase hex string.
*/
public static String sha256Hex(String input) {
return digestToHexString("SHA-256", input);
}
}
5 changes: 5 additions & 0 deletions scalastyle-config.xml
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,11 @@ This file is divided into 3 sections:
<customMessage>Please use Apache Log4j 2 instead.</customMessage>
</check>

<check level="error" class="org.scalastyle.scalariform.IllegalImportsChecker" enabled="true">
<parameters><parameter name="illegalImports"><![CDATA[org.apache.commons.codec.digest]]></parameter></parameters>
<customMessage>Please use org.apache.spark.network.util.JavaUtils.digestToHexString instead.</customMessage>
</check>


<!-- ================================================================================ -->
<!-- rules we'd like to enforce, but haven't cleaned up the codebase yet -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ import java.util.zip.CRC32

import scala.annotation.tailrec

import org.apache.commons.codec.digest.DigestUtils
import org.apache.commons.codec.digest.MessageDigestAlgorithms

import org.apache.spark.network.util.JavaUtils
import org.apache.spark.network.util.JavaUtils.{digestToHexString, md5Hex, sha256Hex}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
Expand Down Expand Up @@ -72,11 +71,11 @@ case class Md5(child: Expression)
override def contextIndependentFoldable: Boolean = child.contextIndependentFoldable

protected override def nullSafeEval(input: Any): Any =
UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[Array[Byte]]))
UTF8String.fromString(md5Hex(input.asInstanceOf[Array[Byte]]))

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c =>
s"UTF8String.fromString(${classOf[DigestUtils].getName}.md5Hex($c))")
s"UTF8String.fromString(${classOf[JavaUtils].getName}.md5Hex($c))")
}

override protected def withNewChildInternal(newChild: Expression): Md5 = copy(child = newChild)
Expand Down Expand Up @@ -122,35 +121,29 @@ case class Sha2(left: Expression, right: Expression)
val input = input1.asInstanceOf[Array[Byte]]
bitLength match {
case 224 =>
UTF8String.fromString(
new DigestUtils(MessageDigestAlgorithms.SHA_224).digestAsHex(input))
UTF8String.fromString(digestToHexString("SHA-224", input))
case 256 | 0 =>
UTF8String.fromString(DigestUtils.sha256Hex(input))
UTF8String.fromString(sha256Hex(input))
case 384 =>
UTF8String.fromString(DigestUtils.sha384Hex(input))
UTF8String.fromString(digestToHexString("SHA-384", input))
case 512 =>
UTF8String.fromString(DigestUtils.sha512Hex(input))
UTF8String.fromString(digestToHexString("SHA-512", input))
case _ => null
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val digestUtils = classOf[DigestUtils].getName
val messageDigestAlgorithms = classOf[MessageDigestAlgorithms].getName
val javaUtils = classOf[JavaUtils].getName
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
if ($eval2 == 224) {
${ev.value} = UTF8String.fromString(
new $digestUtils($messageDigestAlgorithms.SHA_224).digestAsHex($eval1));
${ev.value} = UTF8String.fromString($javaUtils.digestToHexString("SHA-224", $eval1));
} else if ($eval2 == 256 || $eval2 == 0) {
${ev.value} =
UTF8String.fromString($digestUtils.sha256Hex($eval1));
${ev.value} = UTF8String.fromString($javaUtils.sha256Hex($eval1));
} else if ($eval2 == 384) {
${ev.value} =
UTF8String.fromString($digestUtils.sha384Hex($eval1));
${ev.value} = UTF8String.fromString($javaUtils.digestToHexString("SHA-384", $eval1));
} else if ($eval2 == 512) {
${ev.value} =
UTF8String.fromString($digestUtils.sha512Hex($eval1));
${ev.value} = UTF8String.fromString($javaUtils.digestToHexString("SHA-512", $eval1));
} else {
${ev.isNull} = true;
}
Expand Down Expand Up @@ -186,11 +179,11 @@ case class Sha1(child: Expression)
override def contextIndependentFoldable: Boolean = child.contextIndependentFoldable

protected override def nullSafeEval(input: Any): Any =
UTF8String.fromString(DigestUtils.sha1Hex(input.asInstanceOf[Array[Byte]]))
UTF8String.fromString(digestToHexString("SHA-1", input.asInstanceOf[Array[Byte]]))

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c =>
s"UTF8String.fromString(${classOf[DigestUtils].getName}.sha1Hex($c))"
s"""UTF8String.fromString(${classOf[JavaUtils].getName}.digestToHexString("SHA-1", $c))"""
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import java.time.{Duration, LocalTime, Period, ZoneId, ZoneOffset}
import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions

import org.apache.commons.codec.digest.DigestUtils
import org.scalatest.exceptions.TestFailedException

import org.apache.spark.SparkFunSuite
import org.apache.spark.network.util.JavaUtils.{digestToHexString, sha256Hex}
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder}
Expand Down Expand Up @@ -65,13 +65,13 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), Literal(224)),
"107c5072b799c4771f328304cfe1ebb375eb6ea7f35a3aa753836fad")
checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), Literal(0)),
DigestUtils.sha256Hex("ABC"))
sha256Hex("ABC"))
checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), Literal(256)),
DigestUtils.sha256Hex("ABC"))
sha256Hex("ABC"))
checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6)))
digestToHexString("SHA-384", Array[Byte](1, 2, 3, 4, 5, 6)))
checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), Literal(512)),
DigestUtils.sha512Hex("ABC"))
digestToHexString("SHA-512", "ABC"))
// unsupported bit length
checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null)
// null input and valid bit length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ import java.util.zip.CRC32
import com.google.protobuf.ByteString
import io.grpc.{ManagedChannel, Server}
import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder}
import org.apache.commons.codec.digest.DigestUtils.sha256Hex

import org.apache.spark.connect.proto.AddArtifactsRequest
import org.apache.spark.network.util.JavaUtils.sha256Hex
import org.apache.spark.sql.Artifact
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.test.ConnectFunSuite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ import scala.util.control.NonFatal
import com.google.protobuf.ByteString
import io.grpc.StatusRuntimeException
import io.grpc.stub.StreamObserver
import org.apache.commons.codec.digest.DigestUtils.sha256Hex

import org.apache.spark.SparkException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.AddArtifactsResponse
import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
import org.apache.spark.network.util.JavaUtils.sha256Hex
import org.apache.spark.sql.Artifact
import org.apache.spark.sql.Artifact.{newCacheArtifact, newIvyArtifacts}
import org.apache.spark.util.{SparkFileUtils, SparkStringUtils, SparkThreadUtils}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import scala.concurrent.duration._
import scala.jdk.CollectionConverters._

import io.grpc.stub.StreamObserver
import org.apache.commons.codec.digest.DigestUtils.sha256Hex

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ArtifactStatusesResponse
import org.apache.spark.network.util.JavaUtils.sha256Hex
import org.apache.spark.sql.connect.ResourceHelper
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.ThreadUtils
Expand Down