Skip to content

Commit

Permalink
DRILL-4372: (continued) Add option to disable/enable function output …
Browse files Browse the repository at this point in the history
…type inference
  • Loading branch information
hsuanyi committed Mar 17, 2016
1 parent 9ecf4a4 commit c9f8621
Show file tree
Hide file tree
Showing 16 changed files with 837 additions and 166 deletions.
Expand Up @@ -43,6 +43,7 @@
import org.apache.drill.exec.expr.fn.impl.hive.ObjectInspectorHelper;
import org.apache.drill.exec.planner.sql.DrillOperatorTable;
import org.apache.drill.exec.planner.sql.HiveUDFOperator;
import org.apache.drill.exec.planner.sql.HiveUDFOperatorNotInfer;
import org.apache.drill.exec.planner.sql.TypeInferenceUtils;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;
Expand Down Expand Up @@ -84,7 +85,8 @@ public HiveFunctionRegistry(DrillConfig config) {
@Override
public void register(DrillOperatorTable operatorTable) {
for (String name : Sets.union(methodsGenericUDF.asMap().keySet(), methodsUDF.asMap().keySet())) {
operatorTable.add(name, new HiveUDFOperator(name.toUpperCase(), new HiveSqlReturnTypeInference()));
operatorTable.addDefault(name, new HiveUDFOperatorNotInfer(name.toUpperCase()));
operatorTable.addInference(name, new HiveUDFOperator(name.toUpperCase(), new HiveSqlReturnTypeInference()));
}
}

Expand Down
@@ -0,0 +1,44 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.planner.sql;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorScope;

public class HiveUDFOperatorNotInfer extends HiveUDFOperator {
public HiveUDFOperatorNotInfer(String name) {
super(name, DynamicReturnType.INSTANCE);
}

@Override
public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) {
RelDataTypeFactory factory = validator.getTypeFactory();
return factory.createTypeWithNullability(factory.createSqlType(SqlTypeName.ANY), true);
}

@Override
public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
RelDataTypeFactory factory = opBinding.getTypeFactory();
return factory.createTypeWithNullability(factory.createSqlType(SqlTypeName.ANY), true);
}
}
Expand Up @@ -23,10 +23,13 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.drill.common.scanner.persistence.AnnotatedClassDescriptor;
Expand All @@ -35,9 +38,11 @@
import org.apache.drill.exec.planner.logical.DrillConstExecutor;
import org.apache.drill.exec.planner.sql.DrillOperatorTable;
import org.apache.drill.exec.planner.sql.DrillSqlAggOperator;
import org.apache.drill.exec.planner.sql.DrillSqlAggOperatorNotInfer;
import org.apache.drill.exec.planner.sql.DrillSqlOperator;

import com.google.common.collect.ArrayListMultimap;
import org.apache.drill.exec.planner.sql.DrillSqlOperatorNotInfer;

/**
* Registry of Drill functions.
Expand Down Expand Up @@ -122,6 +127,13 @@ public List<DrillFuncHolder> getMethods(String name) {
}

public void register(DrillOperatorTable operatorTable) {
registerForInference(operatorTable);
registerForDefault(operatorTable);
}

public void registerForInference(DrillOperatorTable operatorTable) {
final Map<String, DrillSqlOperator.DrillSqlOperatorBuilder> map = Maps.newHashMap();
final Map<String, DrillSqlAggOperator.DrillSqlAggOperatorBuilder> mapAgg = Maps.newHashMap();
for (Entry<String, Collection<DrillFuncHolder>> function : registeredFunctions.asMap().entrySet()) {
final ArrayListMultimap<Pair<Integer, Integer>, DrillFuncHolder> functions = ArrayListMultimap.create();
final ArrayListMultimap<Integer, DrillFuncHolder> aggregateFunctions = ArrayListMultimap.create();
Expand All @@ -146,20 +158,79 @@ public void register(DrillOperatorTable operatorTable) {
}
}
for (Entry<Pair<Integer, Integer>, Collection<DrillFuncHolder>> entry : functions.asMap().entrySet()) {
final DrillSqlOperator drillSqlOperator;
final Pair<Integer, Integer> range = entry.getKey();
final int max = range.getRight();
final int min = range.getLeft();
drillSqlOperator = new DrillSqlOperator(
name,
Lists.newArrayList(entry.getValue()),
min,
max,
isDeterministic);
operatorTable.add(name, drillSqlOperator);
if(map.containsKey(name)) {
final DrillSqlOperator.DrillSqlOperatorBuilder drillSqlOperatorBuilder = map.get(name);
drillSqlOperatorBuilder
.addFunctions(entry.getValue())
.setArgumentCount(min, max)
.setDeterministic(isDeterministic);
} else {
final DrillSqlOperator.DrillSqlOperatorBuilder drillSqlOperatorBuilder = new DrillSqlOperator.DrillSqlOperatorBuilder();
drillSqlOperatorBuilder
.setName(name)
.addFunctions(entry.getValue())
.setArgumentCount(min, max)
.setDeterministic(isDeterministic);

map.put(name, drillSqlOperatorBuilder);
}
}
for (Entry<Integer, Collection<DrillFuncHolder>> entry : aggregateFunctions.asMap().entrySet()) {
operatorTable.add(name, new DrillSqlAggOperator(name, Lists.newArrayList(entry.getValue()), entry.getKey()));
if(mapAgg.containsKey(name)) {
final DrillSqlAggOperator.DrillSqlAggOperatorBuilder drillSqlAggOperatorBuilder = mapAgg.get(name);
drillSqlAggOperatorBuilder
.addFunctions(entry.getValue())
.setArgumentCount(entry.getKey(), entry.getKey());
} else {
final DrillSqlAggOperator.DrillSqlAggOperatorBuilder drillSqlAggOperatorBuilder = new DrillSqlAggOperator.DrillSqlAggOperatorBuilder();
drillSqlAggOperatorBuilder
.setName(name)
.addFunctions(entry.getValue())
.setArgumentCount(entry.getKey(), entry.getKey());

mapAgg.put(name, drillSqlAggOperatorBuilder);
}
}
}

for(final Entry<String, DrillSqlOperator.DrillSqlOperatorBuilder> entry : map.entrySet()) {
operatorTable.addInference(
entry.getKey(),
entry.getValue().build());
}

for(final Entry<String, DrillSqlAggOperator.DrillSqlAggOperatorBuilder> entry : mapAgg.entrySet()) {
operatorTable.addInference(
entry.getKey(),
entry.getValue().build());
}
}

public void registerForDefault(DrillOperatorTable operatorTable) {
SqlOperator op;
for (Entry<String, Collection<DrillFuncHolder>> function : registeredFunctions.asMap().entrySet()) {
Set<Integer> argCounts = Sets.newHashSet();
String name = function.getKey().toUpperCase();
for (DrillFuncHolder func : function.getValue()) {
if (argCounts.add(func.getParamCount())) {
if (func.isAggregating()) {
op = new DrillSqlAggOperatorNotInfer(name, func.getParamCount());
} else {
boolean isDeterministic;
// prevent Drill from folding constant functions with types that cannot be materialized
// into literals
if (DrillConstExecutor.NON_REDUCIBLE_TYPES.contains(func.getReturnType().getMinorType())) {
isDeterministic = false;
} else {
isDeterministic = func.isDeterministic();
}
op = new DrillSqlOperatorNotInfer(name, func.getParamCount(), func.getReturnType(), isDeterministic);
}
operatorTable.addDefault(function.getKey(), op);
}
}
}
}
Expand Down
Expand Up @@ -86,7 +86,7 @@ public QueryContext(final UserSession session, final DrillbitContext drillbitCon
executionControls = new ExecutionControls(queryOptions, drillbitContext.getEndpoint());
plannerSettings = new PlannerSettings(queryOptions, getFunctionRegistry());
plannerSettings.setNumEndPoints(drillbitContext.getBits().size());
table = new DrillOperatorTable(getFunctionRegistry());
table = new DrillOperatorTable(getFunctionRegistry(), drillbitContext.getOptionManager());

queryContextInfo = Utilities.createQueryContextInfo(session.getDefaultSchemaName());
contextInformation = new ContextInformation(session.getCredentials(), queryContextInfo);
Expand Down

0 comments on commit c9f8621

Please sign in to comment.