Skip to content

Commit

Permalink
GH-37841: [Java] Dictionary decoding not using the compression factor…
Browse files Browse the repository at this point in the history
…y from the ArrowReader (#38371)

### Rationale for this change

This PR addresses #37841. 

### What changes are included in this PR?

Adding compression-based write and read for Dictionary data. 

### Are these changes tested?

Yes.

### Are there any user-facing changes?

No
* Closes: #37841

Lead-authored-by: Vibhatha Lakmal Abeykoon <vibhatha@gmail.com>
Co-authored-by: vibhatha <vibhatha@gmail.com>
Signed-off-by: David Li <li.davidm96@gmail.com>
  • Loading branch information
vibhatha and vibhatha committed Feb 1, 2024
1 parent 87b515e commit f9b7ac2
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.arrow.compression;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -27,63 +29,223 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.GenerateSampleData;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.compression.CompressionUtil;
import org.apache.arrow.vector.compression.NoCompressionCodec;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowFileReader;
import org.apache.arrow.vector.ipc.ArrowFileWriter;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

public class TestArrowReaderWriterWithCompression {

@Test
public void testArrowFileZstdRoundTrip() throws Exception {
// Prepare sample data
final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
private BufferAllocator allocator;
private ByteArrayOutputStream out;
private VectorSchemaRoot root;

@BeforeEach
public void setup() {
if (allocator == null) {
allocator = new RootAllocator(Integer.MAX_VALUE);
}
out = new ByteArrayOutputStream();
root = null;
}

@After
public void tearDown() {
if (root != null) {
root.close();
}
if (allocator != null) {
allocator.close();
}
if (out != null) {
out.reset();
}

}

private void createAndWriteArrowFile(DictionaryProvider provider,
CompressionUtil.CodecType codecType) throws IOException {
List<Field> fields = new ArrayList<>();
fields.add(new Field("col", FieldType.notNullable(new ArrowType.Utf8()), new ArrayList<>()));
VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(fields), allocator);
root = VectorSchemaRoot.create(new Schema(fields), allocator);

final int rowCount = 10;
GenerateSampleData.generateTestData(root.getVector(0), rowCount);
root.setRowCount(rowCount);

// Write an in-memory compressed arrow file
ByteArrayOutputStream out = new ByteArrayOutputStream();
try (final ArrowFileWriter writer =
new ArrowFileWriter(root, null, Channels.newChannel(out), new HashMap<>(),
IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE, CompressionUtil.CodecType.ZSTD, Optional.of(7))) {
try (final ArrowFileWriter writer = new ArrowFileWriter(root, provider, Channels.newChannel(out),
new HashMap<>(), IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE, codecType, Optional.of(7))) {
writer.start();
writer.writeBatch();
writer.end();
}
}

private void createAndWriteArrowStream(DictionaryProvider provider,
CompressionUtil.CodecType codecType) throws IOException {
List<Field> fields = new ArrayList<>();
fields.add(new Field("col", FieldType.notNullable(new ArrowType.Utf8()), new ArrayList<>()));
root = VectorSchemaRoot.create(new Schema(fields), allocator);

final int rowCount = 10;
GenerateSampleData.generateTestData(root.getVector(0), rowCount);
root.setRowCount(rowCount);

try (final ArrowStreamWriter writer = new ArrowStreamWriter(root, provider, Channels.newChannel(out),
IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE, codecType, Optional.of(7))) {
writer.start();
writer.writeBatch();
writer.end();
}
}

// Read the in-memory compressed arrow file with CommonsCompressionFactory provided
private Dictionary createDictionary(VarCharVector dictionaryVector) {
setVector(dictionaryVector,
"foo".getBytes(StandardCharsets.UTF_8),
"bar".getBytes(StandardCharsets.UTF_8),
"baz".getBytes(StandardCharsets.UTF_8));

return new Dictionary(dictionaryVector,
new DictionaryEncoding(/*id=*/1L, /*ordered=*/false, /*indexType=*/null));
}

@Test
public void testArrowFileZstdRoundTrip() throws Exception {
createAndWriteArrowFile(null, CompressionUtil.CodecType.ZSTD);
// with compression
try (ArrowFileReader reader =
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
CommonsCompressionFactory.INSTANCE)) {
Assertions.assertEquals(1, reader.getRecordBlocks().size());
Assertions.assertTrue(reader.loadNextBatch());
Assertions.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assertions.assertFalse(reader.loadNextBatch());
}
// without compression
try (ArrowFileReader reader =
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
allocator, CommonsCompressionFactory.INSTANCE)) {
Assert.assertEquals(1, reader.getRecordBlocks().size());
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
NoCompressionCodec.Factory.INSTANCE)) {
Assertions.assertEquals(1, reader.getRecordBlocks().size());
Exception exception = Assert.assertThrows(IllegalArgumentException.class,
reader::loadNextBatch);
Assertions.assertEquals("Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage());
}
}

@Test
public void testArrowStreamZstdRoundTrip() throws Exception {
createAndWriteArrowStream(null, CompressionUtil.CodecType.ZSTD);
// with compression
try (ArrowStreamReader reader =
new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
CommonsCompressionFactory.INSTANCE)) {
Assert.assertTrue(reader.loadNextBatch());
Assert.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assert.assertFalse(reader.loadNextBatch());
}
// without compression
try (ArrowStreamReader reader =
new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
NoCompressionCodec.Factory.INSTANCE)) {
Exception exception = Assert.assertThrows(IllegalArgumentException.class,
reader::loadNextBatch);
Assert.assertEquals(
"Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage()
);
}
}

// Read the in-memory compressed arrow file without CompressionFactory provided
@Test
public void testArrowFileZstdRoundTripWithDictionary() throws Exception {
VarCharVector dictionaryVector = (VarCharVector)
FieldType.nullable(new ArrowType.Utf8()).createNewSingleVector("f1_file", allocator, null);
Dictionary dictionary = createDictionary(dictionaryVector);
DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider();
provider.put(dictionary);

createAndWriteArrowFile(provider, CompressionUtil.CodecType.ZSTD);

// with compression
try (ArrowFileReader reader =
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
CommonsCompressionFactory.INSTANCE)) {
Assertions.assertEquals(1, reader.getRecordBlocks().size());
Assertions.assertTrue(reader.loadNextBatch());
Assertions.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assertions.assertFalse(reader.loadNextBatch());
}
// without compression
try (ArrowFileReader reader =
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
allocator, NoCompressionCodec.Factory.INSTANCE)) {
Assert.assertEquals(1, reader.getRecordBlocks().size());
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
NoCompressionCodec.Factory.INSTANCE)) {
Assertions.assertEquals(1, reader.getRecordBlocks().size());
Exception exception = Assert.assertThrows(IllegalArgumentException.class,
reader::loadNextBatch);
Assertions.assertEquals("Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage());
}
dictionaryVector.close();
}

@Test
public void testArrowStreamZstdRoundTripWithDictionary() throws Exception {
VarCharVector dictionaryVector = (VarCharVector)
FieldType.nullable(new ArrowType.Utf8()).createNewSingleVector("f1_stream", allocator, null);
Dictionary dictionary = createDictionary(dictionaryVector);
DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider();
provider.put(dictionary);

createAndWriteArrowStream(provider, CompressionUtil.CodecType.ZSTD);

// with compression
try (ArrowStreamReader reader =
new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
CommonsCompressionFactory.INSTANCE)) {
Assertions.assertTrue(reader.loadNextBatch());
Assertions.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assertions.assertFalse(reader.loadNextBatch());
}
// without compression
try (ArrowStreamReader reader =
new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
NoCompressionCodec.Factory.INSTANCE)) {
Exception exception = Assert.assertThrows(IllegalArgumentException.class,
reader::loadNextBatch);
Assertions.assertEquals("Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage());
}
dictionaryVector.close();
}

Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> reader.loadNextBatch());
String expectedMessage = "Please add arrow-compression module to use CommonsCompressionFactory for ZSTD";
Assert.assertEquals(expectedMessage, exception.getMessage());
public static void setVector(VarCharVector vector, byte[]... values) {
final int length = values.length;
vector.allocateNewSafe();
for (int i = 0; i < length; i++) {
if (values[i] != null) {
vector.set(i, values[i]);
}
}
vector.setValueCount(length);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ private void load(ArrowDictionaryBatch dictionaryBatch, FieldVector vector) {
VectorSchemaRoot root = new VectorSchemaRoot(
Collections.singletonList(vector.getField()),
Collections.singletonList(vector), 0);
VectorLoader loader = new VectorLoader(root);
VectorLoader loader = new VectorLoader(root, this.compressionFactory);
try {
loader.load(dictionaryBatch.getDictionary());
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,14 @@ public abstract class ArrowWriter implements AutoCloseable {
private final DictionaryProvider dictionaryProvider;
private final Set<Long> dictionaryIdsUsed = new HashSet<>();

private final CompressionCodec.Factory compressionFactory;
private final CompressionUtil.CodecType codecType;
private final Optional<Integer> compressionLevel;
private boolean started = false;
private boolean ended = false;

private final CompressionCodec codec;

protected IpcOption option;

protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
Expand All @@ -89,16 +94,19 @@ protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, Writab
protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out, IpcOption option,
CompressionCodec.Factory compressionFactory, CompressionUtil.CodecType codecType,
Optional<Integer> compressionLevel) {
this.unloader = new VectorUnloader(
root, /*includeNullCount*/ true,
compressionLevel.isPresent() ?
compressionFactory.createCodec(codecType, compressionLevel.get()) :
compressionFactory.createCodec(codecType),
/*alignBuffers*/ true);
this.out = new WriteChannel(out);
this.option = option;
this.dictionaryProvider = provider;

this.compressionFactory = compressionFactory;
this.codecType = codecType;
this.compressionLevel = compressionLevel;
this.codec = this.compressionLevel.isPresent() ?
this.compressionFactory.createCodec(this.codecType, this.compressionLevel.get()) :
this.compressionFactory.createCodec(this.codecType);
this.unloader = new VectorUnloader(root, /*includeNullCount*/ true, codec,
/*alignBuffers*/ true);

List<Field> fields = new ArrayList<>(root.getSchema().getFields().size());

MetadataV4UnionChecker.checkForUnion(root.getSchema().getFields().iterator(), option.metadataVersion);
Expand Down Expand Up @@ -133,7 +141,8 @@ protected void writeDictionaryBatch(Dictionary dictionary) throws IOException {
Collections.singletonList(vector.getField()),
Collections.singletonList(vector),
count);
VectorUnloader unloader = new VectorUnloader(dictRoot);
VectorUnloader unloader = new VectorUnloader(dictRoot, /*includeNullCount*/ true, this.codec,
/*alignBuffers*/ true);
ArrowRecordBatch batch = unloader.getRecordBatch();
ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, false);
try {
Expand Down

0 comments on commit f9b7ac2

Please sign in to comment.