Skip to content

Commit

Permalink
SAMZA-2354: Improve UDF discovery in samza-sql. (#1192)
Browse files Browse the repository at this point in the history
* Improve UDF discovery in the samza-sql.

Replace ConfigBasedUDFResolver with the UDF resolver based on the reflections.

* Address review comments.

* Add TODO for the follow-up ticket in the comments.
  • Loading branch information
shanthoosh committed Oct 17, 2019
1 parent 493b275 commit e2928e1
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 27 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ project(":samza-sql_$scalaSuffix") {
compile "org.apache.avro:avro:$avroVersion"
compile "org.apache.calcite:calcite-core:$calciteVersion"
compile "org.slf4j:slf4j-api:$slf4jVersion"
compile "org.reflections:reflections:0.9.10"

testCompile "junit:junit:$junitVersion"
testCompile "org.mockito:mockito-core:$mockitoVersion"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import org.apache.samza.sql.fn.FlattenUdf;
import org.apache.samza.sql.fn.RegexMatchUdf;
import org.apache.samza.sql.impl.ConfigBasedIOResolverFactory;
import org.apache.samza.sql.impl.ConfigBasedUdfResolver;
import org.apache.samza.sql.interfaces.RelSchemaProvider;
import org.apache.samza.sql.interfaces.RelSchemaProviderFactory;
import org.apache.samza.sql.interfaces.SqlIOConfig;
Expand Down Expand Up @@ -328,11 +327,6 @@ Map<String, String> fetchSamzaSqlConfig(int execId) {
ConfigBasedIOResolverFactory.class.getName());

staticConfigs.put(SamzaSqlApplicationConfig.CFG_UDF_RESOLVER, "config");
String configUdfResolverDomain = String.format(SamzaSqlApplicationConfig.CFG_FMT_UDF_RESOLVER_DOMAIN, "config");
staticConfigs.put(configUdfResolverDomain + SamzaSqlApplicationConfig.CFG_FACTORY,
ConfigBasedUdfResolver.class.getName());
staticConfigs.put(configUdfResolverDomain + ConfigBasedUdfResolver.CFG_UDF_CLASSES,
Joiner.on(",").join(RegexMatchUdf.class.getName(), FlattenUdf.class.getName()));

staticConfigs.put("serializers.registry.string.class", StringSerdeFactory.class.getName());
staticConfigs.put("serializers.registry.avro.class", AvroSerDeFactory.class.getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
package org.apache.samza.sql.interfaces;

import java.lang.reflect.Method;

import java.util.List;
import com.google.common.base.Objects;
import org.apache.samza.config.Config;
import org.apache.samza.sql.schema.SamzaSqlFieldType;

Expand Down Expand Up @@ -99,4 +99,19 @@ public boolean isDisableArgCheck() {
return disableArgCheck;
}

@Override
public int hashCode() {
return Objects.hashCode(name, udfMethod, arguments, returnType);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof UdfMetadata)) return false;
UdfMetadata that = (UdfMetadata) o;
return Objects.equal(name, that.name) &&
Objects.equal(udfMethod, that.udfMethod) &&
Objects.equal(arguments, that.arguments) &&
returnType == that.returnType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import org.apache.samza.config.MapConfig;
import org.apache.samza.sql.dsl.SamzaSqlDslConverter;
import org.apache.samza.sql.dsl.SamzaSqlDslConverterFactory;
import org.apache.samza.sql.impl.ConfigBasedUdfResolver;
import org.apache.samza.sql.udf.ReflectionBasedUdfResolver;
import org.apache.samza.sql.interfaces.DslConverter;
import org.apache.samza.sql.interfaces.DslConverterFactory;
import org.apache.samza.sql.interfaces.RelSchemaProvider;
Expand Down Expand Up @@ -214,7 +214,8 @@ private UdfResolver createUdfResolver(Map<String, String> config) {
Properties props = new Properties();
props.putAll(domainConfig);
HashMap<String, String> udfConfig = getDomainProperties(config, CFG_UDF_CONFIG_DOMAIN, false);
return new ConfigBasedUdfResolver(props, new MapConfig(udfConfig));
// TODO: SAMZA-2355: Make the UDFResolver pluggable.
return new ReflectionBasedUdfResolver(new MapConfig(udfConfig));
}

private static HashMap<String, String> getDomainProperties(Map<String, String> props, String prefix,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.samza.sql.udf;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.lang3.reflect.MethodUtils;
import org.apache.samza.SamzaException;
import org.apache.samza.config.Config;
import org.apache.samza.sql.interfaces.UdfMetadata;
import org.apache.samza.sql.interfaces.UdfResolver;
import org.apache.samza.sql.schema.SamzaSqlFieldType;
import org.apache.samza.sql.udfs.SamzaSqlUdf;
import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
import org.reflections.Reflections;
import org.reflections.util.ConfigurationBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* An UDF resolver implementation that uses reflection to discover the subtypes
* of the {@link SamzaSqlUdf} from the classpath. Performs the validation to
* ensure that all subtypes of {@link SamzaSqlUdf} extend and implement the
* method annotated with {@link SamzaSqlUdfMethod}.
*/
public class ReflectionBasedUdfResolver implements UdfResolver {

private static final Logger LOG = LoggerFactory.getLogger(ReflectionBasedUdfResolver.class);

private static final String CONFIG_PACKAGE_PREFIX = "samza.sql.udf.resolver.package.prefix";

private final Set<UdfMetadata> udfs = new HashSet<>();

public ReflectionBasedUdfResolver(Config udfConfig) {
// Searching the entire classpath to discover the subtypes of SamzaSqlUdf is expensive. To reduce the search space,
// the search is limited to the set of package prefixes defined in the configuration.
String samzaSqlUdfPackagePrefix = udfConfig.getOrDefault(CONFIG_PACKAGE_PREFIX, "org.apache.samza");

// 1. Build the reflections instance with appropriate configuration.
ConfigurationBuilder configurationBuilder = new ConfigurationBuilder();
configurationBuilder.forPackages(samzaSqlUdfPackagePrefix.split(","));
configurationBuilder.addClassLoader(Thread.currentThread().getContextClassLoader());
Reflections reflections = new Reflections(configurationBuilder);

// 2. Get all the sub-types of SamzaSqlUdf.
Set<Class<?>> typesAnnotatedWithSamzaSqlUdf = reflections.getTypesAnnotatedWith(SamzaSqlUdf.class);

for (Class<?> udfClass : typesAnnotatedWithSamzaSqlUdf) {
// 3. Get all the methods that are annotated with SamzaSqlUdfMethod
List<Method> methodsAnnotatedWithSamzaSqlMethod = MethodUtils.getMethodsListWithAnnotation(udfClass, SamzaSqlUdfMethod.class);

if (methodsAnnotatedWithSamzaSqlMethod.isEmpty()) {
String msg = String.format("Udf class: %s doesn't have any methods annotated with: %s", udfClass.getName(), SamzaSqlUdfMethod.class.getName());
LOG.error(msg);
throw new SamzaException(msg);
}

SamzaSqlUdf sqlUdf = udfClass.getAnnotation(SamzaSqlUdf.class);
// 4. If the udf is enabled, then add the udf information of the methods to the udfs list.
if (sqlUdf.enabled()) {
String udfName = sqlUdf.name();
methodsAnnotatedWithSamzaSqlMethod.forEach(method -> {
SamzaSqlUdfMethod samzaSqlUdfMethod = method.getAnnotation(SamzaSqlUdfMethod.class);
List<SamzaSqlFieldType> params = Arrays.asList(samzaSqlUdfMethod.params());
udfs.add(new UdfMetadata(udfName, sqlUdf.description(), method, udfConfig.subset(udfName + "."), params,
samzaSqlUdfMethod.returns(), samzaSqlUdfMethod.disableArgumentCheck()));
});
}
}
}

@Override
public Collection<UdfMetadata> getUdfs() {
return udfs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import java.util.stream.Collectors;
import org.apache.samza.SamzaException;
import org.apache.samza.config.MapConfig;
import org.apache.samza.sql.impl.ConfigBasedUdfResolver;
import org.apache.samza.sql.interfaces.SqlIOConfig;
import org.apache.samza.sql.util.JsonUtil;
import org.apache.samza.sql.util.SamzaSqlQueryParser;
Expand All @@ -45,8 +44,6 @@ public class TestSamzaSqlApplicationConfig {
public void testConfigInit() {
Map<String, String> config = SamzaSqlTestConfig.fetchStaticConfigsWithFactories(10);
config.put(SamzaSqlApplicationConfig.CFG_SQL_STMT, "Insert into testavro.COMPLEX1 select * from testavro.SIMPLE1");
String configUdfResolverDomain = String.format(SamzaSqlApplicationConfig.CFG_FMT_UDF_RESOLVER_DOMAIN, "config");
int numUdfs = config.get(configUdfResolverDomain + ConfigBasedUdfResolver.CFG_UDF_CLASSES).split(",").length;

List<String> sqlStmts = fetchSqlFromConfig(config);
List<SamzaSqlQueryParser.QueryInfo> queryInfo = fetchQueryInfo(sqlStmts);
Expand All @@ -55,7 +52,6 @@ public void testConfigInit() {
.collect(Collectors.toList()),
queryInfo.stream().map(SamzaSqlQueryParser.QueryInfo::getSink).collect(Collectors.toList()));

Assert.assertEquals(numUdfs + 1, samzaSqlApplicationConfig.getUdfMetadata().size());
Assert.assertEquals(1, samzaSqlApplicationConfig.getInputSystemStreamConfigBySource().size());
Assert.assertEquals(1, samzaSqlApplicationConfig.getOutputSystemStreamConfigsBySource().size());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.samza.sql.udf.impl;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.samza.config.Config;
import org.apache.samza.config.MapConfig;
import org.apache.samza.sql.interfaces.UdfMetadata;
import org.apache.samza.sql.schema.SamzaSqlFieldType;
import org.apache.samza.sql.udf.ReflectionBasedUdfResolver;
import org.junit.Assert;
import org.junit.Test;

import java.lang.reflect.Method;
import java.util.Collection;

public class TestReflectionBasedUdfResolver {

@Test
public void testShouldReturnNothingWhenNoUDFIsInPackagePrefix() {
Config config = new MapConfig(ImmutableMap.of("samza.sql.udf.resolver.package.prefix", "org.apache.samza.udf.blah.blah"));
ReflectionBasedUdfResolver reflectionBasedUdfResolver = new ReflectionBasedUdfResolver(config);
Collection<UdfMetadata> udfMetadataList = reflectionBasedUdfResolver.getUdfs();

Assert.assertEquals(0, udfMetadataList.size());
}

@Test
public void testUDfResolverShouldReturnAllUDFInClassPath() throws NoSuchMethodException {
Config config = new MapConfig(ImmutableMap.of("samza.sql.udf.resolver.package.prefix", "org.apache.samza.sql.udf.impl"));
ReflectionBasedUdfResolver reflectionBasedUdfResolver = new ReflectionBasedUdfResolver(config);
Collection<UdfMetadata> udfMetadataList = reflectionBasedUdfResolver.getUdfs();

Method method = TestSamzaSqlUdf.class.getMethod("execute", String.class);
UdfMetadata udfMetadata = new UdfMetadata("TESTSAMZASQLUDF",
"Test samza sql udf implementation", method, new MapConfig(), ImmutableList.of(SamzaSqlFieldType.STRING),
SamzaSqlFieldType.STRING, true);

Assert.assertFalse(udfMetadataList.isEmpty());
Assert.assertTrue(udfMetadataList.contains(udfMetadata));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.samza.sql.udf.impl;

import org.apache.samza.config.Config;
import org.apache.samza.context.Context;
import org.apache.samza.sql.schema.SamzaSqlFieldType;
import org.apache.samza.sql.udfs.SamzaSqlUdf;
import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
import org.apache.samza.sql.udfs.ScalarUdf;

@SamzaSqlUdf(name = "TestSamzaSqlUdf", description = "Test samza sql udf implementation")
public class TestSamzaSqlUdf implements ScalarUdf {

@Override
public void init(Config udfConfig, Context context) {

}

@SamzaSqlUdfMethod(params = {SamzaSqlFieldType.STRING}, returns = SamzaSqlFieldType.STRING)
public String execute(String fieldName) {
return "testResponse";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.apache.samza.sql.fn.GetNestedFieldUdf;
import org.apache.samza.sql.fn.RegexMatchUdf;
import org.apache.samza.sql.impl.ConfigBasedIOResolverFactory;
import org.apache.samza.sql.impl.ConfigBasedUdfResolver;
import org.apache.samza.sql.interfaces.SqlIOConfig;
import org.apache.samza.sql.runner.SamzaSqlApplicationConfig;
import org.apache.samza.sql.system.TestAvroSystemFactory;
Expand Down Expand Up @@ -95,13 +94,6 @@ public static Map<String, String> fetchStaticConfigsWithFactories(Map<String, St
RemoteStoreIOResolverTestFactory.class.getName());

staticConfigs.put(SamzaSqlApplicationConfig.CFG_UDF_RESOLVER, "config");
String configUdfResolverDomain = String.format(SamzaSqlApplicationConfig.CFG_FMT_UDF_RESOLVER_DOMAIN, "config");
staticConfigs.put(configUdfResolverDomain + SamzaSqlApplicationConfig.CFG_FACTORY,
ConfigBasedUdfResolver.class.getName());
staticConfigs.put(configUdfResolverDomain + ConfigBasedUdfResolver.CFG_UDF_CLASSES, Joiner.on(",")
.join(MyTestUdf.class.getName(), RegexMatchUdf.class.getName(), FlattenUdf.class.getName(),
MyTestArrayUdf.class.getName(), BuildOutputRecordUdf.class.getName(), MyTestPolyUdf.class.getName(),
MyTestObjUdf.class.getName(), GetNestedFieldUdf.class.getName()));

String avroSystemConfigPrefix =
String.format(ConfigBasedIOResolverFactory.CFG_FMT_SAMZA_PREFIX, SAMZA_SYSTEM_TEST_AVRO);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import org.apache.samza.sql.fn.FlattenUdf;
import org.apache.samza.sql.fn.RegexMatchUdf;
import org.apache.samza.sql.impl.ConfigBasedIOResolverFactory;
import org.apache.samza.sql.impl.ConfigBasedUdfResolver;
import org.apache.samza.sql.interfaces.SqlIOConfig;
import org.apache.samza.sql.runner.SamzaSqlApplicationConfig;
import org.apache.samza.sql.runner.SamzaSqlApplicationRunner;
Expand Down Expand Up @@ -125,11 +124,6 @@ public static Map<String, String> fetchSamzaSqlConfig() {
ConfigBasedIOResolverFactory.class.getName());

staticConfigs.put(SamzaSqlApplicationConfig.CFG_UDF_RESOLVER, "config");
String configUdfResolverDomain = String.format(SamzaSqlApplicationConfig.CFG_FMT_UDF_RESOLVER_DOMAIN, "config");
staticConfigs.put(configUdfResolverDomain + SamzaSqlApplicationConfig.CFG_FACTORY,
ConfigBasedUdfResolver.class.getName());
staticConfigs.put(configUdfResolverDomain + ConfigBasedUdfResolver.CFG_UDF_CLASSES,
Joiner.on(",").join(RegexMatchUdf.class.getName(), FlattenUdf.class.getName()));

staticConfigs.put("serializers.registry.string.class", StringSerdeFactory.class.getName());
staticConfigs.put("serializers.registry.avro.class", AvroSerDeFactory.class.getName());
Expand Down

0 comments on commit e2928e1

Please sign in to comment.