Skip to content

Commit

Permalink
fix: CometReader.loadVector should not overwrite dictionary ids (#476)
Browse files Browse the repository at this point in the history
* fix: CometReader.loadVector should not overwrite dictionary ids

* For review
  • Loading branch information
viirya committed May 28, 2024
1 parent 7ba5693 commit 479a97a
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 16 deletions.
73 changes: 73 additions & 0 deletions common/src/main/java/org/apache/arrow/c/CometSchemaImporter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.arrow.c;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.types.pojo.Field;

/** This is a simple wrapper around SchemaImporter to make it accessible from Java Arrow. */
public class CometSchemaImporter {
private final BufferAllocator allocator;
private final SchemaImporter importer;
private final CDataDictionaryProvider provider = new CDataDictionaryProvider();

public CometSchemaImporter(BufferAllocator allocator) {
this.allocator = allocator;
this.importer = new SchemaImporter(allocator);
}

public BufferAllocator getAllocator() {
return allocator;
}

public CDataDictionaryProvider getProvider() {
return provider;
}

public Field importField(ArrowSchema schema) {
try {
return importer.importField(schema, provider);
} finally {
schema.release();
schema.close();
}
}

/**
* Imports data from ArrowArray/ArrowSchema into a FieldVector. This is basically the same as Java
* Arrow `Data.importVector`. `Data.importVector` initiates `SchemaImporter` internally which is
* used to fill dictionary ids for dictionary encoded vectors. Every call to `importVector` will
* begin with dictionary ids starting from 0. So, separate calls to `importVector` will overwrite
* dictionary ids. To avoid this, we need to use the same `SchemaImporter` instance for all calls
* to `importVector`.
*/
public FieldVector importVector(ArrowArray array, ArrowSchema schema) {
Field field = importField(schema);
FieldVector vector = field.createVector(allocator);
Data.importIntoVector(allocator, array, vector, provider);

return vector;
}

public void close() {
provider.close();
}
}
13 changes: 13 additions & 0 deletions common/src/main/java/org/apache/comet/parquet/BatchReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.arrow.c.CometSchemaImporter;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
Expand Down Expand Up @@ -88,6 +91,7 @@
*/
public class BatchReader extends RecordReader<Void, ColumnarBatch> implements Closeable {
private static final Logger LOG = LoggerFactory.getLogger(FileReader.class);
protected static final BufferAllocator ALLOCATOR = new RootAllocator();

private Configuration conf;
private int capacity;
Expand All @@ -104,6 +108,7 @@ public class BatchReader extends RecordReader<Void, ColumnarBatch> implements Cl
private MessageType requestedSchema;
private CometVector[] vectors;
private AbstractColumnReader[] columnReaders;
private CometSchemaImporter importer;
private ColumnarBatch currentBatch;
private Future<Option<Throwable>> prefetchTask;
private LinkedBlockingQueue<Pair<PageReadStore, Long>> prefetchQueue;
Expand Down Expand Up @@ -515,6 +520,10 @@ public void close() throws IOException {
fileReader.close();
fileReader = null;
}
if (importer != null) {
importer.close();
importer = null;
}
}

@SuppressWarnings("deprecation")
Expand Down Expand Up @@ -552,6 +561,9 @@ private boolean loadNextRowGroupIfNecessary() throws Throwable {
numRowGroupsMetric.add(1);
}

if (importer != null) importer.close();
importer = new CometSchemaImporter(ALLOCATOR);

List<ColumnDescriptor> columns = requestedSchema.getColumns();
for (int i = 0; i < columns.size(); i++) {
if (missingColumns[i]) continue;
Expand All @@ -564,6 +576,7 @@ private boolean loadNextRowGroupIfNecessary() throws Throwable {
Utils.getColumnReader(
dataType,
columns.get(i),
importer,
capacity,
useDecimal128,
useLazyMaterialization,
Expand Down
23 changes: 11 additions & 12 deletions common/src/main/java/org/apache/comet/parquet/ColumnReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@

import org.apache.arrow.c.ArrowArray;
import org.apache.arrow.c.ArrowSchema;
import org.apache.arrow.c.CDataDictionaryProvider;
import org.apache.arrow.c.Data;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.c.CometSchemaImporter;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
Expand All @@ -53,7 +50,6 @@

public class ColumnReader extends AbstractColumnReader {
protected static final Logger LOG = LoggerFactory.getLogger(ColumnReader.class);
protected static final BufferAllocator ALLOCATOR = new RootAllocator();

/**
* The current Comet vector holding all the values read by this column reader. Owned by this
Expand Down Expand Up @@ -89,18 +85,19 @@ public class ColumnReader extends AbstractColumnReader {
*/
boolean hadNull;

/** Dictionary provider for this column. */
private final CDataDictionaryProvider dictionaryProvider = new CDataDictionaryProvider();
private final CometSchemaImporter importer;

public ColumnReader(
DataType type,
ColumnDescriptor descriptor,
CometSchemaImporter importer,
int batchSize,
boolean useDecimal128,
boolean useLegacyDateTimestamp) {
super(type, descriptor, useDecimal128, useLegacyDateTimestamp);
assert batchSize > 0 : "Batch size must be positive, found " + batchSize;
this.batchSize = batchSize;
this.importer = importer;
initNative();
}

Expand Down Expand Up @@ -164,7 +161,6 @@ public void close() {
currentVector.close();
currentVector = null;
}
dictionaryProvider.close();
super.close();
}

Expand Down Expand Up @@ -209,10 +205,11 @@ public CometDecodedVector loadVector() {

try (ArrowArray array = ArrowArray.wrap(addresses[0]);
ArrowSchema schema = ArrowSchema.wrap(addresses[1])) {
FieldVector vector = Data.importVector(ALLOCATOR, array, schema, dictionaryProvider);
FieldVector vector = importer.importVector(array, schema);

DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary();

CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128, isUuid);
CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128);

// Update whether the current vector contains any null values. This is used in the following
// batch(s) to determine whether we can skip loading the native vector.
Expand All @@ -234,15 +231,17 @@ public CometDecodedVector loadVector() {

// We should already re-initiate `CometDictionary` here because `Data.importVector` API will
// release the previous dictionary vector and create a new one.
Dictionary arrowDictionary = dictionaryProvider.lookup(dictionaryEncoding.getId());
Dictionary arrowDictionary = importer.getProvider().lookup(dictionaryEncoding.getId());
CometPlainVector dictionaryVector =
new CometPlainVector(arrowDictionary.getVector(), useDecimal128, isUuid);
dictionary = new CometDictionary(dictionaryVector);

currentVector =
new CometDictionaryVector(
cometVector, dictionary, dictionaryProvider, useDecimal128, false, isUuid);
cometVector, dictionary, importer.getProvider(), useDecimal128, false, isUuid);

currentVector =
new CometDictionaryVector(cometVector, dictionary, importer.getProvider(), useDecimal128);
return currentVector;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.io.IOException;

import org.apache.arrow.c.CometSchemaImporter;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.page.PageReader;
import org.apache.spark.sql.types.DataType;
Expand All @@ -45,10 +46,11 @@ public class LazyColumnReader extends ColumnReader {
public LazyColumnReader(
DataType sparkReadType,
ColumnDescriptor descriptor,
CometSchemaImporter importer,
int batchSize,
boolean useDecimal128,
boolean useLegacyDateTimestamp) {
super(sparkReadType, descriptor, batchSize, useDecimal128, useLegacyDateTimestamp);
super(sparkReadType, descriptor, importer, batchSize, useDecimal128, useLegacyDateTimestamp);
this.batchSize = 0; // the batch size is set later in `readBatch`
this.vector = new CometLazyVector(sparkReadType, this, useDecimal128);
}
Expand Down
10 changes: 7 additions & 3 deletions common/src/main/java/org/apache/comet/parquet/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.comet.parquet;

import org.apache.arrow.c.CometSchemaImporter;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.PrimitiveType;
Expand All @@ -28,26 +29,29 @@ public class Utils {
public static ColumnReader getColumnReader(
DataType type,
ColumnDescriptor descriptor,
CometSchemaImporter importer,
int batchSize,
boolean useDecimal128,
boolean useLazyMaterialization) {
// TODO: support `useLegacyDateTimestamp` for Iceberg
return getColumnReader(
type, descriptor, batchSize, useDecimal128, useLazyMaterialization, true);
type, descriptor, importer, batchSize, useDecimal128, useLazyMaterialization, true);
}

public static ColumnReader getColumnReader(
DataType type,
ColumnDescriptor descriptor,
CometSchemaImporter importer,
int batchSize,
boolean useDecimal128,
boolean useLazyMaterialization,
boolean useLegacyDateTimestamp) {
if (useLazyMaterialization && supportLazyMaterialization(type)) {
return new LazyColumnReader(
type, descriptor, batchSize, useDecimal128, useLegacyDateTimestamp);
type, descriptor, importer, batchSize, useDecimal128, useLegacyDateTimestamp);
} else {
return new ColumnReader(type, descriptor, batchSize, useDecimal128, useLegacyDateTimestamp);
return new ColumnReader(
type, descriptor, importer, batchSize, useDecimal128, useLegacyDateTimestamp);
}
}

Expand Down

0 comments on commit 479a97a

Please sign in to comment.