Skip to content

Commit

Permalink
[SPARK-31831] Use subclasses for mock in HiveSessionImplSuite
Browse files Browse the repository at this point in the history
  • Loading branch information
Frank Yin committed Jul 12, 2020
1 parent 578b90c commit 9067f21
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hive.service.cli.operation;

import org.apache.hive.service.cli.HiveSQLException;
import org.apache.hive.service.cli.OperationHandle;
import org.apache.hive.service.cli.session.HiveSession;
import org.apache.hive.service.rpc.thrift.THandleIdentifier;
import org.apache.hive.service.rpc.thrift.TOperationHandle;
import org.apache.hive.service.rpc.thrift.TOperationType;

import java.nio.ByteBuffer;
import java.util.UUID;

public class GetCatalogsOperationMock extends GetCatalogsOperation {
protected GetCatalogsOperationMock(HiveSession parentSession) {
super(parentSession);
}

@Override
public void runInternal() throws HiveSQLException {

}

public OperationHandle getHandle() {
UUID uuid = UUID.randomUUID();
THandleIdentifier tHandleIdentifier = new THandleIdentifier();
tHandleIdentifier.setGuid(getByteBufferFromUUID(uuid));
tHandleIdentifier.setSecret(getByteBufferFromUUID(uuid));
return new OperationHandle(new TOperationHandle(tHandleIdentifier, TOperationType.GET_TYPE_INFO, false));
}

private byte[] getByteBufferFromUUID(UUID uuid) {
ByteBuffer bb = ByteBuffer.wrap(new byte[16]);
bb.putLong(uuid.getMostSignificantBits());
bb.putLong(uuid.getLeastSignificantBits());
return bb.array();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hive.service.cli.operation;

import org.apache.hive.service.cli.HiveSQLException;
import org.apache.hive.service.cli.OperationHandle;
import org.apache.hive.service.cli.session.HiveSession;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.HashSet;
import java.util.Set;

public class OperationManagerMock extends OperationManager {
private Set<OperationHandle> calledHandles = new HashSet<>();

@Override
public GetCatalogsOperation newGetCatalogsOperation(HiveSession parentSession) {
GetCatalogsOperationMock operation = new GetCatalogsOperationMock(parentSession);
try {
Method m = OperationManager.class.getDeclaredMethod("addOperation", Operation.class);
m.setAccessible(true);
m.invoke(this, operation);
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
return operation;
}

@Override
public void closeOperation(OperationHandle opHandle) throws HiveSQLException {
calledHandles.add(opHandle);
throw new RuntimeException();
}

public Set<OperationHandle> getCalledHandles() {
return calledHandles;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,20 @@ package org.apache.spark.sql.hive.thriftserver
import scala.collection.JavaConverters._

import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hive.service.cli.OperationHandle
import org.apache.hive.service.cli.operation.{GetCatalogsOperation, OperationManager}
import org.apache.hive.service.cli.operation.OperationManagerMock
import org.apache.hive.service.cli.session.{HiveSessionImpl, SessionManager}
import org.mockito.Mockito.{mock, verify, when}
import org.mockito.invocation.InvocationOnMock

import org.apache.spark.SparkFunSuite

class HiveSessionImplSuite extends SparkFunSuite {
private var session: HiveSessionImpl = _
private var operationManager: OperationManager = _
private var operationManager: OperationManagerMock = _

override def beforeAll() {
super.beforeAll()

// mock the instance first - we observed weird classloader issue on creating mock, so
// would like to avoid any cases classloader gets switched
val sessionManager = mock(classOf[SessionManager])
operationManager = mock(classOf[OperationManager])
val sessionManager = new SessionManager(null)
operationManager = new OperationManagerMock()

session = new HiveSessionImpl(
ThriftserverShimUtils.testedProtocolVersions.head,
Expand All @@ -48,13 +43,6 @@ class HiveSessionImplSuite extends SparkFunSuite {
)
session.setSessionManager(sessionManager)
session.setOperationManager(operationManager)
when(operationManager.newGetCatalogsOperation(session)).thenAnswer(
(_: InvocationOnMock) => {
val operation = mock(classOf[GetCatalogsOperation])
when(operation.getHandle).thenReturn(mock(classOf[OperationHandle]))
operation
}
)

session.open(Map.empty[String, String].asJava)
}
Expand All @@ -63,14 +51,9 @@ class HiveSessionImplSuite extends SparkFunSuite {
val operationHandle1 = session.getCatalogs
val operationHandle2 = session.getCatalogs

when(operationManager.closeOperation(operationHandle1))
.thenThrow(classOf[NullPointerException])
when(operationManager.closeOperation(operationHandle2))
.thenThrow(classOf[NullPointerException])

session.close()

verify(operationManager).closeOperation(operationHandle1)
verify(operationManager).closeOperation(operationHandle2)
assert(operationManager.getCalledHandles.contains(operationHandle1))
assert(operationManager.getCalledHandles.contains(operationHandle2))
}
}

0 comments on commit 9067f21

Please sign in to comment.