Skip to content

Commit

Permalink
DRILL-4735: ConvertCountToDirectScan rule enhancements
Browse files Browse the repository at this point in the history
1. ConvertCountToDirectScan rule will be applicable for 2 or more COUNT aggregates.
To achieve this DynamicPojoRecordReader was added which accepts any number of columns,
on the contrary with PojoRecordReader which depends on class fields.
AbstractPojoRecordReader class was added to factor out common logic for these two readers.

2. ConvertCountToDirectScan will distinguish between missing, directory and implicit columns.
For missing columns count will be set 0, for implicit to the total records count
since implicit columns are based on files and there is no data without a file.
If directory column will be encountered, rule won't be applied.
CountsCollector class was introduced to encapsulate counts collection logic.

3. MetadataDirectGroupScan class was introduced to indicate to the user when metadata was used
during calculation and for which files it was applied.

DRILL-4735: Changes after code review.

close #900
  • Loading branch information
arina-ielchiieva authored and jinfengni committed Aug 15, 2017
1 parent 5c57b50 commit 8b56423
Show file tree
Hide file tree
Showing 21 changed files with 1,026 additions and 561 deletions.
Expand Up @@ -58,7 +58,7 @@
import org.apache.drill.exec.record.TypedFieldId;
import org.apache.drill.exec.record.VectorContainer;
import org.apache.drill.exec.record.VectorWrapper;
import org.apache.drill.exec.store.ImplicitColumnExplorer;
import org.apache.drill.exec.store.ColumnExplorer;
import org.apache.drill.exec.vector.AllocationHelper;
import org.apache.drill.exec.vector.FixedWidthVector;
import org.apache.drill.exec.vector.ValueVector;
Expand Down Expand Up @@ -500,7 +500,7 @@ protected boolean setupNewSchema() throws SchemaChangeException {
}

private boolean isImplicitFileColumn(ValueVector vvIn) {
return ImplicitColumnExplorer.initImplicitFileColumns(context.getOptions()).get(vvIn.getField().getName()) != null;
return ColumnExplorer.initImplicitFileColumns(context.getOptions()).get(vvIn.getField().getName()) != null;
}

private List<NamedExpression> getExpressionList() {
Expand Down
@@ -1,4 +1,4 @@
/**
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
Expand All @@ -18,15 +18,21 @@

package org.apache.drill.exec.planner.physical;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.google.common.collect.ImmutableMap;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelDataTypeFieldImpl;
import org.apache.calcite.rel.type.RelRecordType;
Expand All @@ -35,37 +41,41 @@
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.drill.common.expression.SchemaPath;
import org.apache.drill.exec.physical.base.GroupScan;
import org.apache.drill.exec.physical.base.ScanStats;
import org.apache.drill.exec.planner.logical.DrillAggregateRel;
import org.apache.drill.exec.planner.logical.DrillProjectRel;
import org.apache.drill.exec.planner.logical.DrillScanRel;
import org.apache.drill.exec.planner.logical.RelOptHelper;
import org.apache.drill.exec.store.direct.DirectGroupScan;
import org.apache.drill.exec.store.pojo.PojoRecordReader;
import org.apache.drill.exec.store.ColumnExplorer;

import com.google.common.collect.Lists;
import org.apache.drill.exec.store.direct.MetadataDirectGroupScan;
import org.apache.drill.exec.store.pojo.DynamicPojoRecordReader;

/**
* This rule will convert
* " select count(*) as mycount from table "
* or " select count( not-nullable-expr) as mycount from table "
* into
*
* <p>
* This rule will convert <b>" select count(*) as mycount from table "</b>
* or <b>" select count(not-nullable-expr) as mycount from table "</b> into
* <pre>
* Project(mycount)
* \
* DirectGroupScan ( PojoRecordReader ( rowCount ))
*
* or
* " select count(column) as mycount from table "
* into
*</pre>
* or <b>" select count(column) as mycount from table "</b> into
* <pre>
* Project(mycount)
* \
* DirectGroupScan (PojoRecordReader (columnValueCount))
*</pre>
* Rule can be applied if query contains multiple count expressions.
* <b>" select count(column1), count(column2), count(*) from table "</b>
* </p>
*
* <p>
* Currently, only parquet group scan has the exact row count and column value count,
* obtained from parquet row group info. This will save the cost to
* scan the whole parquet files.
* </p>
*/

public class ConvertCountToDirectScan extends Prule {

public static final RelOptRule AGG_ON_PROJ_ON_SCAN = new ConvertCountToDirectScan(
Expand All @@ -77,6 +87,8 @@ public class ConvertCountToDirectScan extends Prule {
RelOptHelper.some(DrillAggregateRel.class,
RelOptHelper.any(DrillScanRel.class)), "Agg_on_scan");

private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(ConvertCountToDirectScan.class);

/** Creates a SplunkPushDownRule. */
protected ConvertCountToDirectScan(RelOptRuleOperand rule, String id) {
super(rule, "ConvertCountToDirectScan:" + id);
Expand All @@ -85,40 +97,85 @@ protected ConvertCountToDirectScan(RelOptRuleOperand rule, String id) {
@Override
public void onMatch(RelOptRuleCall call) {
final DrillAggregateRel agg = (DrillAggregateRel) call.rel(0);
final DrillScanRel scan = (DrillScanRel) call.rel(call.rels.length -1);
final DrillProjectRel proj = call.rels.length == 3 ? (DrillProjectRel) call.rel(1) : null;
final DrillScanRel scan = (DrillScanRel) call.rel(call.rels.length - 1);
final DrillProjectRel project = call.rels.length == 3 ? (DrillProjectRel) call.rel(1) : null;

final GroupScan oldGrpScan = scan.getGroupScan();
final PlannerSettings settings = PrelUtil.getPlannerSettings(call.getPlanner());

// Only apply the rule when :
// Only apply the rule when:
// 1) scan knows the exact row count in getSize() call,
// 2) No GroupBY key,
// 3) only one agg function (Check if it's count(*) below).
// 4) No distinct agg call.
// 3) No distinct agg call.
if (!(oldGrpScan.getScanStats(settings).getGroupScanProperty().hasExactRowCount()
&& agg.getGroupCount() == 0
&& agg.getAggCallList().size() == 1
&& !agg.containsDistinctCall())) {
return;
}

AggregateCall aggCall = agg.getAggCallList().get(0);
Map<String, Long> result = collectCounts(settings, agg, scan, project);
logger.trace("Calculated the following aggregate counts: ", result);
// if could not determine the counts, rule won't be applied
if (result.isEmpty()) {
return;
}

final RelDataType scanRowType = constructDataType(agg, result.keySet());

final DynamicPojoRecordReader<Long> reader = new DynamicPojoRecordReader<>(
buildSchema(scanRowType.getFieldNames()),
Collections.singletonList((List<Long>) new ArrayList<>(result.values())));

if (aggCall.getAggregation().getName().equals("COUNT") ) {
final ScanStats scanStats = new ScanStats(ScanStats.GroupScanProperty.EXACT_ROW_COUNT, 1, 1, scanRowType.getFieldCount());
final GroupScan directScan = new MetadataDirectGroupScan(reader, oldGrpScan.getFiles(), scanStats);

final ScanPrel newScan = ScanPrel.create(scan,
scan.getTraitSet().plus(Prel.DRILL_PHYSICAL).plus(DrillDistributionTrait.SINGLETON), directScan,
scanRowType);

final ProjectPrel newProject = new ProjectPrel(agg.getCluster(), agg.getTraitSet().plus(Prel.DRILL_PHYSICAL)
.plus(DrillDistributionTrait.SINGLETON), newScan, prepareFieldExpressions(scanRowType), agg.getRowType());

call.transformTo(newProject);
}

/**
* Collects counts for each aggregation call.
* Will return empty result map if was not able to determine count for at least one aggregation call,
*
* For each aggregate call will determine if count can be calculated. Collects counts only for COUNT function.
* For star, not null expressions and implicit columns sets count to total record number.
* For other cases obtains counts from group scan operator. Also count can not be calculated for parition columns.
*
* @param agg aggregate relational expression
* @param scan scan relational expression
* @param project project relational expression
* @return result map where key is count column name, value is count value
*/
private Map<String, Long> collectCounts(PlannerSettings settings, DrillAggregateRel agg, DrillScanRel scan, DrillProjectRel project) {
final Set<String> implicitColumnsNames = ColumnExplorer.initImplicitFileColumns(settings.getOptions()).keySet();
final GroupScan oldGrpScan = scan.getGroupScan();
final long totalRecordCount = oldGrpScan.getScanStats(settings).getRecordCount();
final LinkedHashMap<String, Long> result = new LinkedHashMap<>();

for (int i = 0; i < agg.getAggCallList().size(); i++) {
AggregateCall aggCall = agg.getAggCallList().get(i);
//for (AggregateCall aggCall : agg.getAggCallList()) {
long cnt;

// rule can be applied only for count function, return empty counts
if (!"count".equalsIgnoreCase(aggCall.getAggregation().getName()) ) {
return ImmutableMap.of();
}

if (containsStarOrNotNullInput(aggCall, agg)) {
cnt = totalRecordCount;

long cnt = 0;
// count(*) == > empty arg ==> rowCount
// count(Not-null-input) ==> rowCount
if (aggCall.getArgList().isEmpty() ||
(aggCall.getArgList().size() == 1 &&
! agg.getInput().getRowType().getFieldList().get(aggCall.getArgList().get(0).intValue()).getType().isNullable())) {
cnt = (long) oldGrpScan.getScanStats(settings).getRecordCount();
} else if (aggCall.getArgList().size() == 1) {
// count(columnName) ==> Agg ( Scan )) ==> columnValueCount
// count(columnName) ==> Agg ( Scan )) ==> columnValueCount
int index = aggCall.getArgList().get(0);

if (proj != null) {
if (project != null) {
// project in the middle of Agg and Scan : Only when input of AggCall is a RexInputRef in Project, we find the index of Scan's field.
// For instance,
// Agg - count($0)
Expand All @@ -127,67 +184,108 @@ public void onMatch(RelOptRuleCall call) {
// \
// Scan (col1, col2).
// return count of "col2" in Scan's metadata, if found.

if (proj.getProjects().get(index) instanceof RexInputRef) {
index = ((RexInputRef) proj.getProjects().get(index)).getIndex();
} else {
return; // do not apply for all other cases.
if (!(project.getProjects().get(index) instanceof RexInputRef)) {
return ImmutableMap.of(); // do not apply for all other cases.
}

index = ((RexInputRef) project.getProjects().get(index)).getIndex();
}

String columnName = scan.getRowType().getFieldNames().get(index).toLowerCase();

cnt = oldGrpScan.getColumnValueCount(SchemaPath.getSimplePath(columnName));
if (cnt == GroupScan.NO_COLUMN_STATS) {
// if column stats are not available don't apply this rule
return;
// for implicit column count will the same as total record count
if (implicitColumnsNames.contains(columnName)) {
cnt = totalRecordCount;
} else {
SchemaPath simplePath = SchemaPath.getSimplePath(columnName);

if (ColumnExplorer.isPartitionColumn(settings.getOptions(), simplePath)) {
return ImmutableMap.of();
}

cnt = oldGrpScan.getColumnValueCount(simplePath);
if (cnt == GroupScan.NO_COLUMN_STATS) {
// if column stats is not available don't apply this rule, return empty counts
return ImmutableMap.of();
}
}
} else {
return; // do nothing.
return ImmutableMap.of();
}

RelDataType scanRowType = getCountDirectScanRowType(agg.getCluster().getTypeFactory());

final ScanPrel newScan = ScanPrel.create(scan,
scan.getTraitSet().plus(Prel.DRILL_PHYSICAL).plus(DrillDistributionTrait.SINGLETON), getCountDirectScan(cnt),
scanRowType);

List<RexNode> exprs = Lists.newArrayList();
exprs.add(RexInputRef.of(0, scanRowType));

final ProjectPrel newProj = new ProjectPrel(agg.getCluster(), agg.getTraitSet().plus(Prel.DRILL_PHYSICAL)
.plus(DrillDistributionTrait.SINGLETON), newScan, exprs, agg.getRowType());

call.transformTo(newProj);
String name = "count" + i + "$" + (aggCall.getName() == null ? aggCall.toString() : aggCall.getName());
result.put(name, cnt);
}

return ImmutableMap.copyOf(result);
}

/**
* Class to represent the count aggregate result.
* Checks if aggregate call contains star or non-null expression:
* <pre>
* count(*) == > empty arg ==> rowCount
* count(Not-null-input) ==> rowCount
* </pre>
*
* @param aggregateCall aggregate call
* @param aggregate aggregate relation expression
* @return true of aggregate call contains star or non-null expression
*/
public static class CountQueryResult {
public long count;

public CountQueryResult(long cnt) {
this.count = cnt;
}
private boolean containsStarOrNotNullInput(AggregateCall aggregateCall, DrillAggregateRel aggregate) {
return aggregateCall.getArgList().isEmpty() ||
(aggregateCall.getArgList().size() == 1 &&
!aggregate.getInput().getRowType().getFieldList().get(aggregateCall.getArgList().get(0)).getType().isNullable());
}

private RelDataType getCountDirectScanRowType(RelDataTypeFactory typeFactory) {
List<RelDataTypeField> fields = Lists.newArrayList();
fields.add(new RelDataTypeFieldImpl("count", 0, typeFactory.createSqlType(SqlTypeName.BIGINT)));

/**
* For each aggregate call creates field based on its name with bigint type.
* Constructs record type for created fields.
*
* @param aggregateRel aggregate relation expression
* @param fieldNames field names
* @return record type
*/
private RelDataType constructDataType(DrillAggregateRel aggregateRel, Collection<String> fieldNames) {
List<RelDataTypeField> fields = new ArrayList<>();
Iterator<String> filedNamesIterator = fieldNames.iterator();
int fieldIndex = 0;
while (filedNamesIterator.hasNext()) {
RelDataTypeField field = new RelDataTypeFieldImpl(
filedNamesIterator.next(),
fieldIndex++,
aggregateRel.getCluster().getTypeFactory().createSqlType(SqlTypeName.BIGINT));
fields.add(field);
}
return new RelRecordType(fields);
}

private GroupScan getCountDirectScan(long cnt) {
CountQueryResult res = new CountQueryResult(cnt);

PojoRecordReader<CountQueryResult> reader = new PojoRecordReader<CountQueryResult>(CountQueryResult.class,
Collections.singleton(res).iterator());
/**
* Builds schema based on given field names.
* Type for each schema is set to long.class.
*
* @param fieldNames field names
* @return schema
*/
private LinkedHashMap<String, Class<?>> buildSchema(List<String> fieldNames) {
LinkedHashMap<String, Class<?>> schema = new LinkedHashMap<>();
for (String fieldName: fieldNames) {
schema.put(fieldName, long.class);
}
return schema;
}

return new DirectGroupScan(reader);
/**
* For each field creates row expression.
*
* @param rowType row type
* @return list of row expressions
*/
private List<RexNode> prepareFieldExpressions(RelDataType rowType) {
List<RexNode> expressions = new ArrayList<>();
for (int i = 0; i < rowType.getFieldCount(); i++) {
expressions.add(RexInputRef.of(i, rowType));
}
return expressions;
}

}

0 comments on commit 8b56423

Please sign in to comment.