Skip to content

Commit

Permalink
[Backport] Support sorting on complex columns in MSQ (#16437)
Browse files Browse the repository at this point in the history
* Support sorting on complex columns in MSQ (#16322)

MSQ sorts the columns in a highly specialized manner by byte comparisons. As such the values are serialized differently. This works well for the primitive types and primitive arrays, however complex types cannot be serialized specially.

This PR adds the support for sorting the complex columns by deserializing the value from the field and comparing it via the type strategy. This is a lot slower than the byte comparisons, however, it's the only way to support sorting on complex columns that can have arbitrary serialization not optimized for MSQ.

The primitives and the arrays are still compared via the byte comparison, therefore this doesn't affect the performance of the queries supported before the patch. If there's a sorting key with mixed complex and primitive/primitive array types, for example: longCol1 ASC, longCol2 ASC, complexCol1 DESC, complexCol2 DESC, stringCol1 DESC, longCol3 DESC, longCol4 ASC, the comparison will happen like:

    longCol1, longCol2 (ASC) - Compared together via byte-comparison, since both are byte comparable and need to be sorted in ascending order
    complexCol1 (DESC) - Compared via deserialization, cannot be clubbed with any other field
    complexCol2 (DESC) - Compared via deserialization, cannot be clubbed with any other field, even though the prior field was a complex column with the same order
    stringCol1, longCol3 (DESC) - Compared together via byte-comparison, since both are byte comparable and need to be sorted in descending order
    longCol4 (ASC) - Compared via byte-comparison, couldn't be coalesced with the previous fields as the direction was different

This way, we only deserialize the field wherever required

* Fix conflicts

---------

Co-authored-by: Laksh Singla <lakshsingla@gmail.com>
  • Loading branch information
adarshsanjeev and LakshSingla committed May 13, 2024
1 parent f3d207c commit 463fb3e
Show file tree
Hide file tree
Showing 43 changed files with 3,158 additions and 494 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.testutil.FrameSequenceBuilder;
import org.apache.druid.frame.write.FrameWriters;
import org.apache.druid.guice.NestedDataModule;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.java.util.common.StringUtils;
Expand All @@ -47,6 +49,7 @@
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.nested.StructuredData;
import org.apache.druid.timeline.SegmentId;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
Expand Down Expand Up @@ -82,6 +85,7 @@ public class FrameChannelMergerBenchmark
{
static {
NullHandling.initializeForTests();
NestedDataModule.registerHandlersAndSerde();
}

private static final String KEY = "key";
Expand All @@ -99,6 +103,9 @@ public class FrameChannelMergerBenchmark
@Param({"100"})
private int rowLength;

@Param({"string", "nested"})
private String columnType;

/**
* Linked to {@link KeyGenerator}.
*/
Expand All @@ -121,13 +128,20 @@ enum KeyGenerator
*/
RANDOM {
@Override
public String generateKey(int rowNumber, int keyLength)
public Comparable generateKey(int rowNumber, int keyLength, String columnType)
{
final StringBuilder builder = new StringBuilder(keyLength);
for (int i = 0; i < keyLength; i++) {
builder.append((char) ('a' + ThreadLocalRandom.current().nextInt(26)));
}
return builder.toString();
String str = builder.toString();
if ("string".equals(columnType)) {
return str;
} else if ("nested".equals(columnType)) {
return StructuredData.wrap(str);
} else {
throw new IAE("unsupported column type");
}
}
},

Expand All @@ -136,13 +150,20 @@ public String generateKey(int rowNumber, int keyLength)
*/
SEQUENTIAL {
@Override
public String generateKey(int rowNumber, int keyLength)
public Comparable generateKey(int rowNumber, int keyLength, String columnType)
{
return StringUtils.format("%0" + keyLength + "d", rowNumber);
String str = StringUtils.format("%0" + keyLength + "d", rowNumber);
if ("string".equals(columnType)) {
return str;
} else if ("nested".equals(columnType)) {
return StructuredData.wrap(str);
} else {
throw new IAE("unsupported column type");
}
}
};

public abstract String generateKey(int rowNumber, int keyLength);
public abstract Comparable generateKey(int rowNumber, int keyLength, String columnType);
}

/**
Expand Down Expand Up @@ -176,13 +197,9 @@ public int getChannelNumber(int rowNumber, int numRows, int numChannels)
public abstract int getChannelNumber(int rowNumber, int numRows, int numChannels);
}

private final RowSignature signature =
RowSignature.builder()
.add(KEY, ColumnType.STRING)
.add(VALUE, ColumnType.STRING)
.build();
private RowSignature signature;
private FrameReader frameReader;

private final FrameReader frameReader = FrameReader.create(signature);
private final List<KeyColumn> sortKey = ImmutableList.of(new KeyColumn(KEY, KeyOrder.ASCENDING));

private List<List<Frame>> channelFrames;
Expand All @@ -200,6 +217,14 @@ public int getChannelNumber(int rowNumber, int numRows, int numChannels)
@Setup(Level.Trial)
public void setupTrial()
{
signature =
RowSignature.builder()
.add(KEY, createKeyColumnTypeFromTypeString(columnType))
.add(VALUE, ColumnType.STRING)
.build();

frameReader = FrameReader.create(signature);

exec = new FrameProcessorExecutor(
MoreExecutors.listeningDecorator(
Execs.singleThreaded(StringUtils.encodeForFormat(getClass().getSimpleName()))
Expand All @@ -211,14 +236,15 @@ public void setupTrial()
ChannelDistribution.valueOf(StringUtils.toUpperCase(channelDistributionString));

// Create channelRows which holds rows for each channel.
final List<List<NonnullPair<String, String>>> channelRows = new ArrayList<>();
final List<List<NonnullPair<Comparable, String>>> channelRows = new ArrayList<>();
channelFrames = new ArrayList<>();
for (int channelNumber = 0; channelNumber < numChannels; channelNumber++) {
channelRows.add(new ArrayList<>());
channelFrames.add(new ArrayList<>());
}

// Create "valueString", a string full of spaces to pad out the row.
// Create "valueString", a string full of spaces to pad out the row. Nested columns wrap up strings with the
// corresponding keyLength, therefore the padding works out for the nested types as well.
final StringBuilder valueStringBuilder = new StringBuilder();
for (int i = 0; i < rowLength - keyLength; i++) {
valueStringBuilder.append(' ');
Expand All @@ -227,20 +253,20 @@ public void setupTrial()

// Populate "channelRows".
for (int rowNumber = 0; rowNumber < numRows; rowNumber++) {
final String keyString = keyGenerator.generateKey(rowNumber, keyLength);
final NonnullPair<String, String> row = new NonnullPair<>(keyString, valueString);
final Comparable keyObject = keyGenerator.generateKey(rowNumber, keyLength, columnType);
final NonnullPair<Comparable, String> row = new NonnullPair<>(keyObject, valueString);
channelRows.get(channelDistribution.getChannelNumber(rowNumber, numRows, numChannels)).add(row);
}

// Sort each "channelRows".
for (List<NonnullPair<String, String>> rows : channelRows) {
for (List<NonnullPair<Comparable, String>> rows : channelRows) {
rows.sort(Comparator.comparing(row -> row.lhs));
}

// Populate each "channelFrames".
for (int channelNumber = 0; channelNumber < numChannels; channelNumber++) {
final List<NonnullPair<String, String>> rows = channelRows.get(channelNumber);
final RowBasedSegment<NonnullPair<String, String>> segment =
final List<NonnullPair<Comparable, String>> rows = channelRows.get(channelNumber);
final RowBasedSegment<NonnullPair<Comparable, String>> segment =
new RowBasedSegment<>(
SegmentId.dummy("__dummy"),
Sequences.simple(rows),
Expand Down Expand Up @@ -350,4 +376,14 @@ public void mergeChannels(Blackhole blackhole)
throw new ISE("Incorrect numRows[%s], expected[%s]", FutureUtils.getUncheckedImmediately(retVal), numRows);
}
}

private ColumnType createKeyColumnTypeFromTypeString(final String columnTypeString)
{
if ("string".equals(columnTypeString)) {
return ColumnType.STRING;
} else if ("nested".equals(columnTypeString)) {
return ColumnType.NESTED_DATA;
}
throw new IAE("Unsupported type [%s]", columnTypeString);
}
}
1 change: 1 addition & 0 deletions codestyle/druid-forbidden-apis.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ org.apache.calcite.sql.type.OperandTypes#NULLABLE_LITERAL @ Create an instance o
org.apache.commons.io.FileUtils#getTempDirectory() @ Use org.junit.rules.TemporaryFolder for tests instead
org.apache.commons.io.FileUtils#deleteDirectory(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils#deleteDirectory()
org.apache.commons.io.FileUtils#forceMkdir(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils.mkdirp instead
org.apache.datasketches.memory.Memory#wrap(byte[], int, int, java.nio.ByteOrder) @ The implementation isn't correct in datasketches-memory-2.2.0. Please refer to https://github.com/apache/datasketches-memory/issues/178. Use wrap(byte[]) and modify the offset by the callers instead
java.lang.Class#getCanonicalName() @ Class.getCanonicalName can return null for anonymous types, use Class.getName instead.
java.util.concurrent.Executors#newFixedThreadPool(int) @ Executor is non-daemon and can prevent JVM shutdown, use org.apache.druid.java.util.common.concurrent.Execs#multiThreaded(int, java.lang.String) instead.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl

private ClusterByStatisticsCollectorImpl(
final ClusterBy clusterBy,
final RowSignature rowSignature,
final RowKeyReader keyReader,
final KeyCollectorFactory<?, ?> keyCollectorFactory,
final long maxRetainedBytes,
Expand All @@ -78,7 +79,7 @@ private ClusterByStatisticsCollectorImpl(
this.keyReader = keyReader;
this.keyCollectorFactory = keyCollectorFactory;
this.maxRetainedBytes = maxRetainedBytes;
this.buckets = new TreeMap<>(clusterBy.bucketComparator());
this.buckets = new TreeMap<>(clusterBy.bucketComparator(rowSignature));
this.maxBuckets = maxBuckets;
this.checkHasMultipleValues = checkHasMultipleValues;
this.hasMultipleValues = checkHasMultipleValues ? new boolean[clusterBy.getColumns().size()] : null;
Expand All @@ -98,10 +99,12 @@ public static ClusterByStatisticsCollector create(
)
{
final RowKeyReader keyReader = clusterBy.keyReader(signature);
final KeyCollectorFactory<?, ?> keyCollectorFactory = KeyCollectors.makeStandardFactory(clusterBy, aggregate);
final KeyCollectorFactory<?, ?> keyCollectorFactory =
KeyCollectors.makeStandardFactory(clusterBy, aggregate, signature);

return new ClusterByStatisticsCollectorImpl(
clusterBy,
signature,
keyReader,
keyCollectorFactory,
maxRetainedBytes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.segment.column.RowSignature;

import java.util.Comparator;
import java.util.stream.Collectors;
Expand All @@ -36,9 +37,9 @@ private DistinctKeyCollectorFactory(Comparator<RowKey> comparator)
this.comparator = comparator;
}

static DistinctKeyCollectorFactory create(final ClusterBy clusterBy)
static DistinctKeyCollectorFactory create(final ClusterBy clusterBy, RowSignature rowSignature)
{
return new DistinctKeyCollectorFactory(clusterBy.keyComparator());
return new DistinctKeyCollectorFactory(clusterBy.keyComparator(rowSignature));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.druid.msq.statistics;

import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.segment.column.RowSignature;

public class KeyCollectors
{
Expand All @@ -33,19 +34,20 @@ private KeyCollectors()
*/
public static KeyCollectorFactory<?, ?> makeStandardFactory(
final ClusterBy clusterBy,
final boolean aggregate
final boolean aggregate,
final RowSignature rowSignature
)
{
final KeyCollectorFactory<?, ?> baseFactory;

if (aggregate) {
baseFactory = DistinctKeyCollectorFactory.create(clusterBy);
baseFactory = DistinctKeyCollectorFactory.create(clusterBy, rowSignature);
} else {
baseFactory = QuantilesSketchKeyCollectorFactory.create(clusterBy);
baseFactory = QuantilesSketchKeyCollectorFactory.create(clusterBy, rowSignature);
}

// Wrap in DelegateOrMinKeyCollectorFactory, so we are guaranteed to be able to downsample to a single key. This
// is important because it allows us to better handle large numbers of small buckets.
return new DelegateOrMinKeyCollectorFactory<>(clusterBy.keyComparator(), baseFactory);
return new DelegateOrMinKeyCollectorFactory<>(clusterBy.keyComparator(rowSignature), baseFactory);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.datasketches.quantiles.ItemsSketch;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.column.RowSignature;

import java.nio.ByteOrder;
import java.util.Arrays;
Expand All @@ -46,9 +47,9 @@ private QuantilesSketchKeyCollectorFactory(final Comparator<byte[]> comparator)
this.comparator = comparator;
}

static QuantilesSketchKeyCollectorFactory create(final ClusterBy clusterBy)
static QuantilesSketchKeyCollectorFactory create(final ClusterBy clusterBy, final RowSignature rowSignature)
{
return new QuantilesSketchKeyCollectorFactory(clusterBy.byteKeyComparator());
return new QuantilesSketchKeyCollectorFactory(clusterBy.byteKeyComparator(rowSignature));
}

@Override
Expand Down
Loading

0 comments on commit 463fb3e

Please sign in to comment.