Skip to content

Commit

Permalink
HubSpot Backport: HBASE-27276 Reduce reflection overhead in Filter de…
Browse files Browse the repository at this point in the history
…serialization (apache#5488)

Signed-off-by: Nick Dimiduk <ndimiduk@apache.org>
Signed-off-by: Duo Zhang <zhangduo@apache.org>
  • Loading branch information
bbeaudreault committed Nov 10, 2023
1 parent d10611e commit 6402f72
Show file tree
Hide file tree
Showing 9 changed files with 505 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
import org.apache.hadoop.hbase.util.DynamicClassLoader;
import org.apache.hadoop.hbase.util.ExceptionUtil;
import org.apache.hadoop.hbase.util.Methods;
import org.apache.hadoop.hbase.util.ReflectedFunctionCache;
import org.apache.hadoop.hbase.util.VersionInfo;
import org.apache.hadoop.ipc.RemoteException;
import org.apache.yetus.audience.InterfaceAudience;
Expand Down Expand Up @@ -296,6 +297,23 @@ public static boolean isClassLoaderLoaded() {
return classLoaderLoaded;
}

private static final String PARSE_FROM = "parseFrom";

// We don't bother using the dynamic CLASS_LOADER above, because currently we can't support
// optimizing dynamically loaded classes. We can do it once we build for java9+, see the todo
// in ReflectedFunctionCache
private static final ReflectedFunctionCache<byte[], Filter> FILTERS =
new ReflectedFunctionCache<>(Filter.class, byte[].class, PARSE_FROM);
private static final ReflectedFunctionCache<byte[], ByteArrayComparable> COMPARATORS =
new ReflectedFunctionCache<>(ByteArrayComparable.class, byte[].class, PARSE_FROM);

private static volatile boolean ALLOW_FAST_REFLECTION_FALLTHROUGH = true;

// Visible for tests
public static void setAllowFastReflectionFallthrough(boolean val) {
ALLOW_FAST_REFLECTION_FALLTHROUGH = val;
}

/**
* Prepend the passed bytes with four bytes of magic, {@link ProtobufMagic#PB_MAGIC}, to flag what
* follows as a protobuf in hbase. Prepend these bytes to all content written to znodes, etc.
Expand Down Expand Up @@ -1496,13 +1514,23 @@ public static ComparatorProtos.Comparator toComparator(ByteArrayComparable compa
public static ByteArrayComparable toComparator(ComparatorProtos.Comparator proto)
throws IOException {
String type = proto.getName();
String funcName = "parseFrom";
byte[] value = proto.getSerializedComparator().toByteArray();

try {
ByteArrayComparable result = COMPARATORS.getAndCallByName(type, value);
if (result != null) {
return result;
}

if (!ALLOW_FAST_REFLECTION_FALLTHROUGH) {
throw new IllegalStateException("Failed to deserialize comparator " + type
+ " because fast reflection returned null and fallthrough is disabled");
}

Class<?> c = Class.forName(type, true, ClassLoaderHolder.CLASS_LOADER);
Method parseFrom = c.getMethod(funcName, byte[].class);
Method parseFrom = c.getMethod(PARSE_FROM, byte[].class);
if (parseFrom == null) {
throw new IOException("Unable to locate function: " + funcName + " in type: " + type);
throw new IOException("Unable to locate function: " + PARSE_FROM + " in type: " + type);
}
return (ByteArrayComparable) parseFrom.invoke(null, value);
} catch (Exception e) {
Expand All @@ -1519,12 +1547,22 @@ public static ByteArrayComparable toComparator(ComparatorProtos.Comparator proto
public static Filter toFilter(FilterProtos.Filter proto) throws IOException {
String type = proto.getName();
final byte[] value = proto.getSerializedFilter().toByteArray();
String funcName = "parseFrom";

try {
Filter result = FILTERS.getAndCallByName(type, value);
if (result != null) {
return result;
}

if (!ALLOW_FAST_REFLECTION_FALLTHROUGH) {
throw new IllegalStateException("Failed to deserialize comparator " + type
+ " because fast reflection returned null and fallthrough is disabled");
}

Class<?> c = Class.forName(type, true, ClassLoaderHolder.CLASS_LOADER);
Method parseFrom = c.getMethod(funcName, byte[].class);
Method parseFrom = c.getMethod(PARSE_FROM, byte[].class);
if (parseFrom == null) {
throw new IOException("Unable to locate function: " + funcName + " in type: " + type);
throw new IOException("Unable to locate function: " + PARSE_FROM + " in type: " + type);
}
return (Filter) parseFrom.invoke(c, value);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.apache.hadoop.hbase.client;

import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
Expand All @@ -25,7 +27,6 @@
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Base64;
Expand All @@ -34,7 +35,6 @@
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hbase.HBaseClassTestRule;
import org.apache.hadoop.hbase.HBaseConfiguration;
import org.apache.hadoop.hbase.exceptions.DeserializationException;
import org.apache.hadoop.hbase.filter.Filter;
import org.apache.hadoop.hbase.filter.FilterList;
import org.apache.hadoop.hbase.filter.KeyOnlyFilter;
Expand All @@ -48,6 +48,8 @@
import org.junit.Test;
import org.junit.experimental.categories.Category;

import org.apache.hbase.thirdparty.com.google.common.base.Throwables;

import org.apache.hadoop.hbase.shaded.protobuf.ProtobufUtil;
import org.apache.hadoop.hbase.shaded.protobuf.generated.ClientProtos;

Expand Down Expand Up @@ -226,9 +228,9 @@ public void testDynamicFilter() throws Exception {
ProtobufUtil.toGet(getProto2);
fail("Should not be able to load the filter class");
} catch (IOException ioe) {
assertTrue(ioe.getCause() instanceof InvocationTargetException);
InvocationTargetException ite = (InvocationTargetException) ioe.getCause();
assertTrue(ite.getTargetException() instanceof DeserializationException);
// This test is deserializing a FilterList, and one of the sub-filters is not found.
// So the actual caused by is buried a few levels deep.
assertThat(Throwables.getRootCause(ioe), instanceOf(ClassNotFoundException.class));
}
FileOutputStream fos = new FileOutputStream(jarFile);
fos.write(Base64.getDecoder().decode(MOCK_FILTER_JAR));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.util;

import edu.umd.cs.findbugs.annotations.Nullable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.apache.yetus.audience.InterfaceAudience;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Cache to hold resolved Functions of a specific signature, generated through reflection. These can
* be (relatively) costly to create, but then are much faster than typical Method.invoke calls when
* executing. The cache is built-up on demand as calls are made to new classes. The functions are
* cached for the lifetime of the process. If a function cannot be created (security reasons, method
* not found, etc), a fallback function is cached which always returns null. Callers to
* {@link #getAndCallByName(String, Object)} should have handling for null return values.
* <p>
* An instance is created for a specified baseClass (i.e. Filter), argClass (i.e. byte[]), and
* static methodName to call. These are used to resolve a Function which delegates to that static
* method, if it is found.
* @param <I> the input argument type for the resolved functions
* @param <R> the return type for the resolved functions
*/
@InterfaceAudience.Private
public final class ReflectedFunctionCache<I, R> {

private static final Logger LOG = LoggerFactory.getLogger(ReflectedFunctionCache.class);

private final ConcurrentMap<String, Function<I, ? extends R>> lambdasByClass =
new ConcurrentHashMap<>();
private final Class<R> baseClass;
private final Class<I> argClass;
private final String methodName;
private final ClassLoader classLoader;

public ReflectedFunctionCache(Class<R> baseClass, Class<I> argClass, String staticMethodName) {
this.classLoader = getClass().getClassLoader();
this.baseClass = baseClass;
this.argClass = argClass;
this.methodName = staticMethodName;
}

/**
* Get and execute the Function for the given className, passing the argument to the function and
* returning the result.
* @param className the full name of the class to lookup
* @param argument the argument to pass to the function, if found.
* @return null if a function is not found for classname, otherwise the result of the function.
*/
@Nullable
public R getAndCallByName(String className, I argument) {
// todo: if we ever make java9+ our lowest supported jdk version, we can
// handle generating these for newly loaded classes from our DynamicClassLoader using
// MethodHandles.privateLookupIn(). For now this is not possible, because we can't easily
// create a privileged lookup in a non-default ClassLoader. So while this cache loads
// over time, it will never load a custom filter from "hbase.dynamic.jars.dir".
Function<I, ? extends R> lambda =
ConcurrentMapUtils.computeIfAbsent(lambdasByClass, className, () -> loadFunction(className));

return lambda.apply(argument);
}

private Function<I, ? extends R> loadFunction(String className) {
long startTime = System.nanoTime();
try {
Class<?> clazz = Class.forName(className, false, classLoader);
if (!baseClass.isAssignableFrom(clazz)) {
LOG.debug("Requested class {} is not assignable to {}, skipping creation of function",
className, baseClass.getName());
return this::notFound;
}
return ReflectionUtils.getOneArgStaticMethodAsFunction(clazz, methodName, argClass,
(Class<? extends R>) clazz);
} catch (Throwable t) {
LOG.debug("Failed to create function for {}", className, t);
return this::notFound;
} finally {
LOG.debug("Populated cache for {} in {}ms", className,
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime));
}
}

/**
* In order to use computeIfAbsent, we can't store nulls in our cache. So we store a lambda which
* resolves to null. The contract is that getAndCallByName returns null in this case.
*/
private R notFound(I argument) {
return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
import java.io.UnsupportedEncodingException;
import java.lang.invoke.CallSite;
import java.lang.invoke.LambdaMetafactory;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadInfo;
import java.lang.management.ThreadMXBean;
Expand All @@ -29,6 +34,7 @@
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.nio.charset.Charset;
import java.util.function.Function;
import org.apache.yetus.audience.InterfaceAudience;
import org.slf4j.Logger;

Expand Down Expand Up @@ -208,6 +214,30 @@ private static String getTaskName(long id, String name) {
return id + " (" + name + ")";
}

/**
* Creates a Function which can be called to performantly execute a reflected static method. The
* creation of the Function itself may not be fast, but executing that method thereafter should be
* much faster than {@link #invokeMethod(Object, String, Object...)}.
* @param lookupClazz the class to find the static method in
* @param methodName the method name
* @param argumentClazz the type of the argument
* @param returnValueClass the type of the return value
* @return a function which when called executes the requested static method.
* @throws Throwable exception types from the underlying reflection
*/
public static <I, R> Function<I, R> getOneArgStaticMethodAsFunction(Class<?> lookupClazz,
String methodName, Class<I> argumentClazz, Class<R> returnValueClass) throws Throwable {
MethodHandles.Lookup lookup = MethodHandles.lookup();
MethodHandle methodHandle = lookup.findStatic(lookupClazz, methodName,
MethodType.methodType(returnValueClass, argumentClazz));
CallSite site =
LambdaMetafactory.metafactory(lookup, "apply", MethodType.methodType(Function.class),
methodHandle.type().generic(), methodHandle, methodHandle.type());

return (Function<I, R>) site.getTarget().invokeExact();

}

/**
* Get and invoke the target method from the given object with given parameters
* @param obj the object to get and invoke method from
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.hadoop.hbase.util;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

Expand Down Expand Up @@ -202,4 +203,20 @@ public static void addJarFilesToJar(File targetJar, String libPrefix, File... sr
public static String localDirPath(Configuration conf) {
return conf.get(ClassLoaderBase.LOCAL_DIR_KEY) + File.separator + "jars" + File.separator;
}

public static void deleteClass(String className, String testDir, Configuration conf)
throws Exception {
String jarFileName = className + ".jar";
File file = new File(testDir, jarFileName);
file.delete();
assertFalse("Should be deleted: " + file.getPath(), file.exists());

file = new File(conf.get("hbase.dynamic.jars.dir"), jarFileName);
file.delete();
assertFalse("Should be deleted: " + file.getPath(), file.exists());

file = new File(ClassLoaderTestHelper.localDirPath(conf), jarFileName);
file.delete();
assertFalse("Should be deleted: " + file.getPath(), file.exists());
}
}
Loading

0 comments on commit 6402f72

Please sign in to comment.