Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport] Fix incorrect header names for certain export queries (#16096) #16167

Merged
merged 1 commit into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1913,7 +1913,8 @@ private static QueryDefinition makeQueryDefinition(
.processorFactory(new ExportResultsFrameProcessorFactory(
queryId,
exportStorageProvider,
resultFormat
resultFormat,
columnMappings
))
);
return builder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import org.apache.druid.error.DruidException;
import org.apache.druid.frame.Frame;
import org.apache.druid.frame.channel.ReadableFrameChannel;
Expand All @@ -35,13 +37,14 @@
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.msq.counters.ChannelCounters;
import org.apache.druid.msq.querykit.QueryKitUtils;
import org.apache.druid.msq.util.SequenceUtils;
import org.apache.druid.segment.BaseObjectColumnValueSelector;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.Cursor;
import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.planner.ColumnMapping;
import org.apache.druid.sql.calcite.planner.ColumnMappings;
import org.apache.druid.sql.http.ResultFormat;
import org.apache.druid.storage.StorageConnector;

Expand All @@ -60,6 +63,8 @@ public class ExportResultsFrameProcessor implements FrameProcessor<Object>
private final ObjectMapper jsonMapper;
private final ChannelCounters channelCounter;
final String exportFilePath;
private final Object2IntMap<String> outputColumnNameToFrameColumnNumberMap;
private final RowSignature exportRowSignature;

public ExportResultsFrameProcessor(
final ReadableFrameChannel inputChannel,
Expand All @@ -68,7 +73,8 @@ public ExportResultsFrameProcessor(
final StorageConnector storageConnector,
final ObjectMapper jsonMapper,
final ChannelCounters channelCounter,
final String exportFilePath
final String exportFilePath,
final ColumnMappings columnMappings
)
{
this.inputChannel = inputChannel;
Expand All @@ -78,6 +84,30 @@ public ExportResultsFrameProcessor(
this.jsonMapper = jsonMapper;
this.channelCounter = channelCounter;
this.exportFilePath = exportFilePath;
this.outputColumnNameToFrameColumnNumberMap = new Object2IntOpenHashMap<>();
final RowSignature inputRowSignature = frameReader.signature();

if (columnMappings == null) {
// If the column mappings wasn't sent, fail the query to avoid inconsistency in file format.
throw DruidException.forPersona(DruidException.Persona.OPERATOR)
.ofCategory(DruidException.Category.RUNTIME_FAILURE)
.build("Received null columnMappings from controller. This might be due to an upgrade.");
}
for (final ColumnMapping columnMapping : columnMappings.getMappings()) {
this.outputColumnNameToFrameColumnNumberMap.put(
columnMapping.getOutputColumn(),
frameReader.signature().indexOf(columnMapping.getQueryColumn())
);
}
final RowSignature.Builder exportRowSignatureBuilder = RowSignature.builder();

for (String outputColumn : columnMappings.getOutputColumnNames()) {
exportRowSignatureBuilder.add(
outputColumn,
inputRowSignature.getColumnType(outputColumnNameToFrameColumnNumberMap.getInt(outputColumn)).orElse(null)
);
}
this.exportRowSignature = exportRowSignatureBuilder.build();
}

@Override
Expand Down Expand Up @@ -109,8 +139,6 @@ public ReturnOrAwait<Object> runIncrementally(IntSet readableInputs) throws IOEx

private void exportFrame(final Frame frame) throws IOException
{
final RowSignature exportRowSignature = createRowSignatureForExport(frameReader.signature());

final Sequence<Cursor> cursorSequence =
new FrameStorageAdapter(frame, frameReader, Intervals.ETERNITY)
.makeCursors(null, Intervals.ETERNITY, VirtualColumns.EMPTY, Granularities.ALL, false, null);
Expand All @@ -135,7 +163,7 @@ private void exportFrame(final Frame frame) throws IOException
//noinspection rawtypes
@SuppressWarnings("rawtypes")
final List<BaseObjectColumnValueSelector> selectors =
exportRowSignature
frameReader.signature()
.getColumnNames()
.stream()
.map(columnSelectorFactory::makeColumnValueSelector)
Expand All @@ -144,7 +172,9 @@ private void exportFrame(final Frame frame) throws IOException
while (!cursor.isDone()) {
formatter.writeRowStart();
for (int j = 0; j < exportRowSignature.size(); j++) {
formatter.writeRowField(exportRowSignature.getColumnName(j), selectors.get(j).getObject());
String columnName = exportRowSignature.getColumnName(j);
BaseObjectColumnValueSelector<?> selector = selectors.get(outputColumnNameToFrameColumnNumberMap.getInt(columnName));
formatter.writeRowField(columnName, selector.getObject());
}
channelCounter.incrementRowCount();
formatter.writeRowEnd();
Expand All @@ -162,16 +192,6 @@ private void exportFrame(final Frame frame) throws IOException
}
}

private static RowSignature createRowSignatureForExport(RowSignature inputRowSignature)
{
RowSignature.Builder exportRowSignatureBuilder = RowSignature.builder();
inputRowSignature.getColumnNames()
.stream()
.filter(name -> !QueryKitUtils.PARTITION_BOOST_COLUMN.equals(name))
.forEach(name -> exportRowSignatureBuilder.add(name, inputRowSignature.getColumnType(name).orElse(null)));
return exportRowSignatureBuilder.build();
}

@Override
public void cleanup() throws IOException
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.druid.msq.querykit.results;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import org.apache.druid.error.DruidException;
Expand All @@ -41,6 +42,7 @@
import org.apache.druid.msq.kernel.ProcessorsAndChannels;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.querykit.BaseFrameProcessorFactory;
import org.apache.druid.sql.calcite.planner.ColumnMappings;
import org.apache.druid.sql.http.ResultFormat;
import org.apache.druid.storage.ExportStorageProvider;
import org.apache.druid.utils.CollectionUtils;
Expand All @@ -55,17 +57,20 @@ public class ExportResultsFrameProcessorFactory extends BaseFrameProcessorFactor
private final String queryId;
private final ExportStorageProvider exportStorageProvider;
private final ResultFormat exportFormat;
private final ColumnMappings columnMappings;

@JsonCreator
public ExportResultsFrameProcessorFactory(
@JsonProperty("queryId") String queryId,
@JsonProperty("exportStorageProvider") ExportStorageProvider exportStorageProvider,
@JsonProperty("exportFormat") ResultFormat exportFormat
@JsonProperty("exportFormat") ResultFormat exportFormat,
@JsonProperty("columnMappings") @Nullable ColumnMappings columnMappings
)
{
this.queryId = queryId;
this.exportStorageProvider = exportStorageProvider;
this.exportFormat = exportFormat;
this.columnMappings = columnMappings;
}

@JsonProperty("queryId")
Expand All @@ -87,6 +92,14 @@ public ExportStorageProvider getExportStorageProvider()
return exportStorageProvider;
}

@JsonProperty("columnMappings")
@JsonInclude(JsonInclude.Include.NON_NULL)
@Nullable
public ColumnMappings getColumnMappings()
{
return columnMappings;
}

@Override
public ProcessorsAndChannels<Object, Long> makeProcessors(
StageDefinition stageDefinition,
Expand Down Expand Up @@ -122,7 +135,8 @@ public ProcessorsAndChannels<Object, Long> makeProcessors(
exportStorageProvider.get(),
frameContext.jsonMapper(),
channelCounter,
getExportFilePath(queryId, workerNumber, readableInput.getStagePartition().getPartitionNumber(), exportFormat)
getExportFilePath(queryId, workerNumber, readableInput.getStagePartition().getPartitionNumber(), exportFormat),
columnMappings
)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.export.TestExportStorageConnector;
import org.apache.druid.sql.http.ResultFormat;
import org.junit.Assert;
import org.junit.Test;

import java.io.ByteArrayOutputStream;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.io.InputStreamReader;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -46,14 +45,13 @@ public class MSQExportTest extends MSQTestBase
@Test
public void testExport() throws IOException
{
TestExportStorageConnector storageConnector = (TestExportStorageConnector) exportStorageConnectorProvider.get();

RowSignature rowSignature = RowSignature.builder()
.add("__time", ColumnType.LONG)
.add("dim1", ColumnType.STRING)
.add("cnt", ColumnType.LONG).build();

final String sql = StringUtils.format("insert into extern(%s()) as csv select cnt, dim1 from foo", TestExportStorageConnector.TYPE_NAME);
File exportDir = temporaryFolder.newFolder("export/");
final String sql = StringUtils.format("insert into extern(local(exportPath=>'%s')) as csv select cnt, dim1 as dim from foo", exportDir.getAbsolutePath());

testIngestQuery().setSql(sql)
.setExpectedDataSource("foo1")
Expand All @@ -63,11 +61,47 @@ public void testExport() throws IOException
.setExpectedResultRows(ImmutableList.of())
.verifyResults();

List<Object[]> objects = expectedFooFileContents();
Assert.assertEquals(
1,
Objects.requireNonNull(new File(exportDir.getAbsolutePath()).listFiles()).length
);

File resultFile = new File(exportDir, "query-test-query-worker0-partition0.csv");
List<String> results = readResultsFromFile(resultFile);
Assert.assertEquals(
convertResultsToString(objects),
new String(storageConnector.getByteArrayOutputStream().toByteArray(), Charset.defaultCharset())
expectedFooFileContents(true),
results
);
}

@Test
public void testExport2() throws IOException
{
RowSignature rowSignature = RowSignature.builder()
.add("dim1", ColumnType.STRING)
.add("cnt", ColumnType.LONG).build();

File exportDir = temporaryFolder.newFolder("export/");
final String sql = StringUtils.format("insert into extern(local(exportPath=>'%s')) as csv select dim1 as table_dim, count(*) as table_count from foo where dim1 = 'abc' group by 1", exportDir.getAbsolutePath());

testIngestQuery().setSql(sql)
.setExpectedDataSource("foo1")
.setQueryContext(DEFAULT_MSQ_CONTEXT)
.setExpectedRowSignature(rowSignature)
.setExpectedSegment(ImmutableSet.of())
.setExpectedResultRows(ImmutableList.of())
.verifyResults();

Assert.assertEquals(
1,
Objects.requireNonNull(new File(exportDir.getAbsolutePath()).listFiles()).length
);

File resultFile = new File(exportDir, "query-test-query-worker0-partition0.csv");
List<String> results = readResultsFromFile(resultFile);
Assert.assertEquals(
expectedFoo2FileContents(true),
results
);
}

Expand Down Expand Up @@ -95,36 +129,48 @@ public void testNumberOfRowsPerFile() throws IOException
.verifyResults();

Assert.assertEquals(
expectedFooFileContents().size(),
expectedFooFileContents(false).size(),
Objects.requireNonNull(new File(exportDir.getAbsolutePath()).listFiles()).length
);
}

private List<Object[]> expectedFooFileContents()
private List<String> expectedFooFileContents(boolean withHeader)
{
ArrayList<String> expectedResults = new ArrayList<>();
if (withHeader) {
expectedResults.add("cnt,dim");
}
expectedResults.addAll(ImmutableList.of(
"1,",
"1,10.1",
"1,2",
"1,1",
"1,def",
"1,abc"
)
);
return expectedResults;
}

private List<String> expectedFoo2FileContents(boolean withHeader)
{
return new ArrayList<>(ImmutableList.of(
new Object[]{"1", null},
new Object[]{"1", 10.1},
new Object[]{"1", 2},
new Object[]{"1", 1},
new Object[]{"1", "def"},
new Object[]{"1", "abc"}
));
ArrayList<String> expectedResults = new ArrayList<>();
if (withHeader) {
expectedResults.add("table_dim,table_count");
}
expectedResults.addAll(ImmutableList.of("abc,1"));
return expectedResults;
}

private String convertResultsToString(List<Object[]> expectedRows) throws IOException
private List<String> readResultsFromFile(File resultFile) throws IOException
{
ByteArrayOutputStream expectedResult = new ByteArrayOutputStream();
ResultFormat.Writer formatter = ResultFormat.CSV.createFormatter(expectedResult, objectMapper);
formatter.writeResponseStart();
for (Object[] row : expectedRows) {
formatter.writeRowStart();
for (Object object : row) {
formatter.writeRowField("", object);
List<String> results = new ArrayList<>();
try (BufferedReader br = new BufferedReader(new InputStreamReader(Files.newInputStream(resultFile.toPath()), StringUtils.UTF8_STRING))) {
String line;
while (!(line = br.readLine()).isEmpty()) {
results.add(line);
}
formatter.writeRowEnd();
return results;
}
formatter.writeResponseEnd();
return new String(expectedResult.toByteArray(), Charset.defaultCharset());
}
}
Loading
Loading