Skip to content

Commit

Permalink
GH-18547: [Java] Support re-emitting dictionaries in ArrowStreamWriter (
Browse files Browse the repository at this point in the history
#35920)

### Rationale for this change

This allows writing IPC streams where dictionary values change between record batches.

### What changes are included in this PR?

* Add new abstract `void ensureDictionariesWritten(DictionaryProvider provider, Set<Long> dictionaryIdsUsed)` to the base `ArrowWriter` class
* Move existing logic that only writes dictionaries once into the `ArrowFileWriter` class
* Implement replacement dictionary writing in `ArrowStreamWriter` by keeping copies of previously written dictionaries

### Are these changes tested?

Yes, I've added a new unit test for this

### Are there any user-facing changes?

Yes, `ArrowStreamWriter` will now write replacement dictionaries when dictionary values change between batches.

**This PR includes breaking changes to public APIs.**

`ArrowWriter` has a new abstract `ensureDictionariesWritten` method. This will only affect users directly inheriting from  `ArrowWriter` rather than `ArrowFileWriter` or `ArrowStreamWriter`.

There's also a behaviour change to `ArrowWriter`, where previously dictionaries were read from a `DictionaryProvider` on construction, but this is now delayed until the first batch is written.
* Closes: #18547

Authored-by: Adam Reeve <adreeve@gmail.com>
Signed-off-by: David Li <li.davidm96@gmail.com>
  • Loading branch information
adamreeve committed Jun 9, 2023
1 parent 766e254 commit 8b2ab4d
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ public void run() throws IOException {
Preconditions.checkState(reader.bytesRead() == writer.bytesWritten());
LOGGER.debug(String.format("Echoed %d records", echoed));
reader.close(false);
writer.close();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.compression.CompressionCodec;
import org.apache.arrow.vector.compression.CompressionUtil;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowBlock;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
Expand All @@ -50,6 +52,7 @@ public class ArrowFileWriter extends ArrowWriter {
private final List<ArrowBlock> recordBlocks = new ArrayList<>();

private Map<String, String> metaData;
private boolean dictionariesWritten = false;

public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
super(root, provider, out);
Expand Down Expand Up @@ -123,6 +126,21 @@ protected void endInternal(WriteChannel out) throws IOException {
LOGGER.debug("magic written, now at {}", out.getCurrentPosition());
}

@Override
protected void ensureDictionariesWritten(DictionaryProvider provider, Set<Long> dictionaryIdsUsed)
throws IOException {
if (dictionariesWritten) {
return;
}
dictionariesWritten = true;
// Write out all dictionaries required.
// Replacement dictionaries are not supported in the IPC file format.
for (long id : dictionaryIdsUsed) {
Dictionary dictionary = provider.lookup(id);
writeDictionaryBatch(dictionary);
}
}

@VisibleForTesting
public List<ArrowBlock> getRecordBlocks() {
return recordBlocks;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@
import java.io.OutputStream;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.compare.VectorEqualsVisitor;
import org.apache.arrow.vector.compression.CompressionCodec;
import org.apache.arrow.vector.compression.CompressionUtil;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
Expand All @@ -34,6 +41,7 @@
* Writer for the Arrow stream format to send ArrowRecordBatches over a WriteChannel.
*/
public class ArrowStreamWriter extends ArrowWriter {
private final Map<Long, FieldVector> previousDictionaries = new HashMap<>();

/**
* Construct an ArrowStreamWriter with an optional DictionaryProvider for the OutputStream.
Expand Down Expand Up @@ -121,4 +129,45 @@ public static void writeEndOfStream(WriteChannel out, IpcOption option) throws I
protected void endInternal(WriteChannel out) throws IOException {
writeEndOfStream(out, option);
}

@Override
protected void ensureDictionariesWritten(DictionaryProvider provider, Set<Long> dictionaryIdsUsed)
throws IOException {
// write out any dictionaries that have changes
for (long id : dictionaryIdsUsed) {
Dictionary dictionary = provider.lookup(id);
FieldVector vector = dictionary.getVector();
if (previousDictionaries.containsKey(id) &&
VectorEqualsVisitor.vectorEquals(vector, previousDictionaries.get(id))) {
// Dictionary was previously written and hasn't changed
continue;
}
writeDictionaryBatch(dictionary);
// Store a copy of the vector in case it is later mutated
if (previousDictionaries.containsKey(id)) {
previousDictionaries.get(id).close();
}
previousDictionaries.put(id, copyVector(vector));
}
}

@Override
public void close() {
super.close();
try {
AutoCloseables.close(previousDictionaries.values());
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private static FieldVector copyVector(FieldVector source) {
FieldVector copy = source.getField().createVector(source.getAllocator());
copy.allocateNew();
for (int i = 0; i < source.getValueCount(); i++) {
copy.copyFromSafe(i, i, source);
}
copy.setValueCount(source.getValueCount());
return copy;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import java.util.Optional;
import java.util.Set;

import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
Expand Down Expand Up @@ -59,13 +58,12 @@ public abstract class ArrowWriter implements AutoCloseable {
protected final WriteChannel out;

private final VectorUnloader unloader;
private final List<ArrowDictionaryBatch> dictionaries;
private final DictionaryProvider dictionaryProvider;
private final Set<Long> dictionaryIdsUsed = new HashSet<>();

private boolean started = false;
private boolean ended = false;

private boolean dictWritten = false;

protected IpcOption option;

protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
Expand Down Expand Up @@ -99,31 +97,16 @@ protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, Writab
/*alignBuffers*/ true);
this.out = new WriteChannel(out);
this.option = option;
this.dictionaryProvider = provider;

List<Field> fields = new ArrayList<>(root.getSchema().getFields().size());
Set<Long> dictionaryIdsUsed = new HashSet<>();

MetadataV4UnionChecker.checkForUnion(root.getSchema().getFields().iterator(), option.metadataVersion);
// Convert fields with dictionaries to have dictionary type
for (Field field : root.getSchema().getFields()) {
fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIdsUsed));
}

// Create a record batch for each dictionary
this.dictionaries = new ArrayList<>(dictionaryIdsUsed.size());
for (long id : dictionaryIdsUsed) {
Dictionary dictionary = provider.lookup(id);
FieldVector vector = dictionary.getVector();
int count = vector.getValueCount();
VectorSchemaRoot dictRoot = new VectorSchemaRoot(
Collections.singletonList(vector.getField()),
Collections.singletonList(vector),
count);
VectorUnloader unloader = new VectorUnloader(dictRoot);
ArrowRecordBatch batch = unloader.getRecordBatch();
this.dictionaries.add(new ArrowDictionaryBatch(id, batch));
}

this.schema = new Schema(fields, root.getSchema().getCustomMetadata());
}

Expand All @@ -136,12 +119,34 @@ public void start() throws IOException {
*/
public void writeBatch() throws IOException {
ensureStarted();
ensureDictionariesWritten();
ensureDictionariesWritten(dictionaryProvider, dictionaryIdsUsed);
try (ArrowRecordBatch batch = unloader.getRecordBatch()) {
writeRecordBatch(batch);
}
}

protected void writeDictionaryBatch(Dictionary dictionary) throws IOException {
FieldVector vector = dictionary.getVector();
long id = dictionary.getEncoding().getId();
int count = vector.getValueCount();
VectorSchemaRoot dictRoot = new VectorSchemaRoot(
Collections.singletonList(vector.getField()),
Collections.singletonList(vector),
count);
VectorUnloader unloader = new VectorUnloader(dictRoot);
ArrowRecordBatch batch = unloader.getRecordBatch();
ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, false);
try {
writeDictionaryBatch(dictionaryBatch);
} finally {
try {
dictionaryBatch.close();
} catch (Exception e) {
throw new RuntimeException("Error occurred while closing dictionary.", e);
}
}
}

protected ArrowBlock writeDictionaryBatch(ArrowDictionaryBatch batch) throws IOException {
ArrowBlock block = MessageSerializer.serialize(out, batch, option);
if (LOGGER.isDebugEnabled()) {
Expand Down Expand Up @@ -183,23 +188,8 @@ private void ensureStarted() throws IOException {
* Write dictionaries after schema and before recordBatches, dictionaries won't be
* written if empty stream (only has schema data in IPC).
*/
private void ensureDictionariesWritten() throws IOException {
if (!dictWritten) {
dictWritten = true;
// write out any dictionaries
try {
for (ArrowDictionaryBatch batch : dictionaries) {
writeDictionaryBatch(batch);
}
} finally {
try {
AutoCloseables.close(dictionaries);
} catch (Exception e) {
throw new RuntimeException("Error occurred while closing dictionaries.", e);
}
}
}
}
protected abstract void ensureDictionariesWritten(DictionaryProvider provider, Set<Long> dictionaryIdsUsed)
throws IOException;

private void ensureEnded() throws IOException {
if (!ended) {
Expand All @@ -219,9 +209,6 @@ public void close() {
try {
end();
out.close();
if (!dictWritten) {
AutoCloseables.close(dictionaries);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.apache.arrow.vector.util.DictionaryUtility;
import org.apache.arrow.vector.util.TransferPair;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -612,6 +613,88 @@ public void testDeltaDictionary() throws Exception {

}

// Tests that the ArrowStreamWriter re-emits dictionaries when they change
@Test
public void testWriteReadStreamWithDictionaryReplacement() throws Exception {
DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider();
provider.put(dictionary1);

String[] batch0 = {"foo", "bar", "baz", "bar", "baz"};
String[] batch1 = {"foo", "aa", "bar", "bb", "baz", "cc"};

VarCharVector vector = newVarCharVector("varchar", allocator);
vector.allocateNewSafe();
for (int i = 0; i < batch0.length; ++i) {
vector.set(i, batch0[i].getBytes(StandardCharsets.UTF_8));
}
vector.setValueCount(batch0.length);
FieldVector encodedVector1 = (FieldVector) DictionaryEncoder.encode(vector, dictionary1);

List<Field> fields = Arrays.asList(encodedVector1.getField());
try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
try (VectorSchemaRoot root =
new VectorSchemaRoot(fields, Arrays.asList(encodedVector1), encodedVector1.getValueCount());
ArrowStreamWriter writer = new ArrowStreamWriter(root, provider, newChannel(out))) {
writer.start();

// Write batch with initial data and dictionary
writer.writeBatch();

// Create data for the next batch, using an extended dictionary with the same id
vector.reset();
for (int i = 0; i < batch1.length; ++i) {
vector.set(i, batch1[i].getBytes(StandardCharsets.UTF_8));
}
vector.setValueCount(batch1.length);

// Re-encode and move encoded data into the vector schema root
provider.put(dictionary3);
FieldVector encodedVector2 = (FieldVector) DictionaryEncoder.encode(vector, dictionary3);
TransferPair transferPair = encodedVector2.makeTransferPair(root.getVector(0));
transferPair.transfer();

// Write second batch
root.setRowCount(batch1.length);
writer.writeBatch();

writer.end();
}

try (ArrowStreamReader reader = new ArrowStreamReader(
new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator)) {
VectorSchemaRoot root = reader.getVectorSchemaRoot();

// Read and verify first batch
assertTrue(reader.loadNextBatch());
assertEquals(batch0.length, root.getRowCount());
FieldVector readEncoded1 = root.getVector(0);
long dictionaryId = readEncoded1.getField().getDictionary().getId();
try (VarCharVector decodedValues =
(VarCharVector) DictionaryEncoder.decode(readEncoded1, reader.lookup(dictionaryId))) {
for (int i = 0; i < batch0.length; ++i) {
assertEquals(batch0[i], new String(decodedValues.get(i), StandardCharsets.UTF_8));
}
}

// Read and verify second batch
assertTrue(reader.loadNextBatch());
assertEquals(batch1.length, root.getRowCount());
FieldVector readEncoded2 = root.getVector(0);
dictionaryId = readEncoded2.getField().getDictionary().getId();
try (VarCharVector decodedValues =
(VarCharVector) DictionaryEncoder.decode(readEncoded2, reader.lookup(dictionaryId))) {
for (int i = 0; i < batch1.length; ++i) {
assertEquals(batch1[i], new String(decodedValues.get(i), StandardCharsets.UTF_8));
}
}

assertFalse(reader.loadNextBatch());
}
}

vector.close();
}

private void serializeDictionaryBatch(
WriteChannel out,
Dictionary dictionary,
Expand Down

0 comments on commit 8b2ab4d

Please sign in to comment.