diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6babba0cd6d1..c52b78b27dae 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2079,7 +2079,7 @@ private[spark] object Utils val CONNECT_EXECUTE_THREAD_PREFIX = "SparkConnectExecuteThread" - private val threadInfoOrdering = Ordering.fromLessThan { + private[spark] val threadInfoOrdering = Ordering.fromLessThan { (threadTrace1: ThreadInfo, threadTrace2: ThreadInfo) => { def priority(ti: ThreadInfo): Int = ti.getThreadName match { case name if name.startsWith(TASK_THREAD_NAME_PREFIX) => 100 diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index d600260e9df2..61952c401853 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -38,8 +38,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.audit.CommonAuditContext.currentAuditContext import org.apache.hadoop.ipc.{CallerContext => HadoopCallerContext} import org.apache.logging.log4j.Level -import org.mockito.Mockito.doReturn -import org.scalatest.PrivateMethodTester +import org.mockito.Mockito.when import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.{SparkConf, SparkException, SparkFunSuite, TaskContext} @@ -51,7 +50,7 @@ import org.apache.spark.scheduler.SparkListener import org.apache.spark.util.collection.Utils.createArray import org.apache.spark.util.io.ChunkedByteBufferInputStream -class UtilsSuite extends SparkFunSuite with ResetSystemProperties with PrivateMethodTester { +class UtilsSuite extends SparkFunSuite with ResetSystemProperties { test("timeConversion") { // Test -1 @@ -1132,37 +1131,35 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with PrivateMe test("ThreadInfoOrdering") { val task1T = mock[ThreadInfo] - doReturn(11L).when(task1T).getThreadId - doReturn("Executor task launch worker for task 1.0 in stage 1.0 (TID 11)") - .when(task1T).getThreadName - doReturn("Executor task launch worker for task 1.0 in stage 1.0 (TID 11)") - .when(task1T).toString + when(task1T.getThreadId).thenReturn(11L) + when(task1T.getThreadName) + .thenReturn("Executor task launch worker for task 1.0 in stage 1.0 (TID 11)") + when(task1T.toString) + .thenReturn("Executor task launch worker for task 1.0 in stage 1.0 (TID 11)") val task2T = mock[ThreadInfo] - doReturn(12L).when(task2T).getThreadId - doReturn("Executor task launch worker for task 2.0 in stage 1.0 (TID 22)") - .when(task2T).getThreadName - doReturn("Executor task launch worker for task 2.0 in stage 1.0 (TID 22)") - .when(task2T).toString + when(task2T.getThreadId).thenReturn(12L) + when(task2T.getThreadName) + .thenReturn("Executor task launch worker for task 2.0 in stage 1.0 (TID 22)") + when(task2T.toString) + .thenReturn("Executor task launch worker for task 2.0 in stage 1.0 (TID 22)") val connectExecuteOp1T = mock[ThreadInfo] - doReturn(21L).when(connectExecuteOp1T).getThreadId - doReturn("SparkConnectExecuteThread_opId=16148fb4-4189-43c3-b8d4-8b3b6ddd41c7") - .when(connectExecuteOp1T).getThreadName - doReturn("SparkConnectExecuteThread_opId=16148fb4-4189-43c3-b8d4-8b3b6ddd41c7") - .when(connectExecuteOp1T).toString + when(connectExecuteOp1T.getThreadId).thenReturn(21L) + when(connectExecuteOp1T.getThreadName) + .thenReturn("SparkConnectExecuteThread_opId=16148fb4-4189-43c3-b8d4-8b3b6ddd41c7") + when(connectExecuteOp1T.toString) + .thenReturn("SparkConnectExecuteThread_opId=16148fb4-4189-43c3-b8d4-8b3b6ddd41c7") val connectExecuteOp2T = mock[ThreadInfo] - doReturn(22L).when(connectExecuteOp2T).getThreadId - doReturn("SparkConnectExecuteThread_opId=4e4d1cac-ffde-46c1-b7c2-808b726cb47e") - .when(connectExecuteOp2T).getThreadName - doReturn("SparkConnectExecuteThread_opId=4e4d1cac-ffde-46c1-b7c2-808b726cb47e") - .when(connectExecuteOp2T).toString - - val threadInfoOrderingMethod = - PrivateMethod[Ordering[ThreadInfo]](Symbol("threadInfoOrdering")) + when(connectExecuteOp2T.getThreadId).thenReturn(22L) + when(connectExecuteOp2T.getThreadName) + .thenReturn("SparkConnectExecuteThread_opId=4e4d1cac-ffde-46c1-b7c2-808b726cb47e") + when(connectExecuteOp2T.toString) + .thenReturn("SparkConnectExecuteThread_opId=4e4d1cac-ffde-46c1-b7c2-808b726cb47e") + val sorted = Seq(connectExecuteOp1T, connectExecuteOp2T, task1T, task2T) - .sorted(Utils.invokePrivate(threadInfoOrderingMethod())) + .sorted(Utils.threadInfoOrdering) assert(sorted === Seq(task1T, task2T, connectExecuteOp1T, connectExecuteOp2T)) }