Skip to content

Commit

Permalink
[SPARK-31831] Use subclasses for mock in HiveSessionImplSuite
Browse files Browse the repository at this point in the history
  • Loading branch information
Frank Yin committed Jul 12, 2020
1 parent 578b90c commit 1eb2027
Showing 1 changed file with 63 additions and 23 deletions.
Expand Up @@ -16,28 +16,30 @@
*/
package org.apache.spark.sql.hive.thriftserver

import java.lang.reflect.InvocationTargetException
import java.nio.ByteBuffer
import java.util.UUID

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hive.service.cli.OperationHandle
import org.apache.hive.service.cli.operation.{GetCatalogsOperation, OperationManager}
import org.apache.hive.service.cli.session.{HiveSessionImpl, SessionManager}
import org.mockito.Mockito.{mock, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.apache.hive.service.cli.operation.{GetCatalogsOperation, Operation, OperationManager}
import org.apache.hive.service.cli.session.{HiveSession, HiveSessionImpl, SessionManager}
import org.apache.hive.service.rpc.thrift.{THandleIdentifier, TOperationHandle, TOperationType}

import org.apache.spark.SparkFunSuite

class HiveSessionImplSuite extends SparkFunSuite {
private var session: HiveSessionImpl = _
private var operationManager: OperationManager = _
private var operationManager: OperationManagerMock = _

override def beforeAll() {
super.beforeAll()

// mock the instance first - we observed weird classloader issue on creating mock, so
// would like to avoid any cases classloader gets switched
val sessionManager = mock(classOf[SessionManager])
operationManager = mock(classOf[OperationManager])
val sessionManager = new SessionManager(null)
operationManager = new OperationManagerMock()

session = new HiveSessionImpl(
ThriftserverShimUtils.testedProtocolVersions.head,
Expand All @@ -48,13 +50,6 @@ class HiveSessionImplSuite extends SparkFunSuite {
)
session.setSessionManager(sessionManager)
session.setOperationManager(operationManager)
when(operationManager.newGetCatalogsOperation(session)).thenAnswer(
(_: InvocationOnMock) => {
val operation = mock(classOf[GetCatalogsOperation])
when(operation.getHandle).thenReturn(mock(classOf[OperationHandle]))
operation
}
)

session.open(Map.empty[String, String].asJava)
}
Expand All @@ -63,14 +58,59 @@ class HiveSessionImplSuite extends SparkFunSuite {
val operationHandle1 = session.getCatalogs
val operationHandle2 = session.getCatalogs

when(operationManager.closeOperation(operationHandle1))
.thenThrow(classOf[NullPointerException])
when(operationManager.closeOperation(operationHandle2))
.thenThrow(classOf[NullPointerException])

session.close()

verify(operationManager).closeOperation(operationHandle1)
verify(operationManager).closeOperation(operationHandle2)
assert(operationManager.getCalledHandles.contains(operationHandle1))
assert(operationManager.getCalledHandles.contains(operationHandle2))
}
}

class GetCatalogsOperationMock(parentSession: HiveSession)
extends GetCatalogsOperation(parentSession) {

override def runInternal(): Unit = {}

override def getHandle: OperationHandle = {
val uuid: UUID = UUID.randomUUID()
val tHandleIdentifier: THandleIdentifier = new THandleIdentifier()
tHandleIdentifier.setGuid(getByteBufferFromUUID(uuid))
tHandleIdentifier.setSecret(getByteBufferFromUUID(uuid))
val tOperationHandle: TOperationHandle = new TOperationHandle()
tOperationHandle.setOperationId(tHandleIdentifier)
tOperationHandle.setOperationType(TOperationType.GET_TYPE_INFO)
tOperationHandle.setHasResultSetIsSet(false)
new OperationHandle(tOperationHandle)
}

private def getByteBufferFromUUID(uuid: UUID): Array[Byte] = {
val bb: ByteBuffer = ByteBuffer.wrap(new Array[Byte](16))
bb.putLong(uuid.getMostSignificantBits)
bb.putLong(uuid.getLeastSignificantBits)
bb.array
}
}

class OperationManagerMock extends OperationManager {
private val calledHandles: mutable.Set[OperationHandle] = new mutable.HashSet[OperationHandle]()

override def newGetCatalogsOperation(parentSession: HiveSession): GetCatalogsOperation = {
val operation = new GetCatalogsOperationMock(parentSession)
try {
val m = classOf[OperationManager].getDeclaredMethod("addOperation", classOf[Operation])
m.setAccessible(true)
m.invoke(this, operation)
} catch {
case e@(_: NoSuchMethodException | _: IllegalAccessException |
_: InvocationTargetException) =>
throw new RuntimeException(e)
}
operation
}

override def closeOperation(opHandle: OperationHandle): Unit = {
calledHandles.add(opHandle)
throw new RuntimeException
}

def getCalledHandles: mutable.Set[OperationHandle] = calledHandles
}

0 comments on commit 1eb2027

Please sign in to comment.