Skip to content

Commit

Permalink
GH-15203: [Java] Implement writing compressed files (#15223)
Browse files Browse the repository at this point in the history
* Closes: #15203

Authored-by: David Li <li.davidm96@gmail.com>
Signed-off-by: David Li <li.davidm96@gmail.com>
  • Loading branch information
lidavidm committed Jan 19, 2023
1 parent e4019ad commit 4c698fb
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,21 @@

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.stream.Stream;

import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
Expand All @@ -32,80 +42,80 @@
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.compression.CompressionCodec;
import org.apache.arrow.vector.compression.CompressionUtil;
import org.apache.arrow.vector.compression.NoCompressionCodec;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowFileReader;
import org.apache.arrow.vector.ipc.ArrowFileWriter;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

/**
* Test cases for {@link CompressionCodec}s.
*/
@RunWith(Parameterized.class)
public class TestCompressionCodec {

private final CompressionCodec codec;

class TestCompressionCodec {
private BufferAllocator allocator;

private final int vectorLength;

@Before
public void init() {
@BeforeEach
void init() {
allocator = new RootAllocator(Integer.MAX_VALUE);
}

@After
public void terminate() {
@AfterEach
void terminate() {
allocator.close();
}

public TestCompressionCodec(CompressionUtil.CodecType type, int vectorLength, CompressionCodec codec) {
this.codec = codec;
this.vectorLength = vectorLength;
}

@Parameterized.Parameters(name = "codec = {0}, length = {1}")
public static Collection<Object[]> getCodecs() {
List<Object[]> params = new ArrayList<>();
static Collection<Arguments> codecs() {
List<Arguments> params = new ArrayList<>();

int[] lengths = new int[] {10, 100, 1000};
for (int len : lengths) {
CompressionCodec dumbCodec = NoCompressionCodec.INSTANCE;
params.add(new Object[]{dumbCodec.getCodecType(), len, dumbCodec});
params.add(Arguments.arguments(len, dumbCodec));

CompressionCodec lz4Codec = new Lz4CompressionCodec();
params.add(new Object[]{lz4Codec.getCodecType(), len, lz4Codec});
params.add(Arguments.arguments(len, lz4Codec));

CompressionCodec zstdCodec = new ZstdCompressionCodec();
params.add(new Object[]{zstdCodec.getCodecType(), len, zstdCodec});

params.add(Arguments.arguments(len, zstdCodec));
}
return params;
}

private List<ArrowBuf> compressBuffers(List<ArrowBuf> inputBuffers) {
private List<ArrowBuf> compressBuffers(CompressionCodec codec, List<ArrowBuf> inputBuffers) {
List<ArrowBuf> outputBuffers = new ArrayList<>(inputBuffers.size());
for (ArrowBuf buf : inputBuffers) {
outputBuffers.add(codec.compress(allocator, buf));
}
return outputBuffers;
}

private List<ArrowBuf> deCompressBuffers(List<ArrowBuf> inputBuffers) {
private List<ArrowBuf> deCompressBuffers(CompressionCodec codec, List<ArrowBuf> inputBuffers) {
List<ArrowBuf> outputBuffers = new ArrayList<>(inputBuffers.size());
for (ArrowBuf buf : inputBuffers) {
outputBuffers.add(codec.decompress(allocator, buf));
}
return outputBuffers;
}

@Test
public void testCompressFixedWidthBuffers() throws Exception {
@ParameterizedTest
@MethodSource("codecs")
void testCompressFixedWidthBuffers(int vectorLength, CompressionCodec codec) throws Exception {
// prepare vector to compress
IntVector origVec = new IntVector("vec", allocator);
origVec.allocateNew(vectorLength);
Expand All @@ -121,8 +131,8 @@ public void testCompressFixedWidthBuffers() throws Exception {

// compress & decompress
List<ArrowBuf> origBuffers = origVec.getFieldBuffers();
List<ArrowBuf> compressedBuffers = compressBuffers(origBuffers);
List<ArrowBuf> decompressedBuffers = deCompressBuffers(compressedBuffers);
List<ArrowBuf> compressedBuffers = compressBuffers(codec, origBuffers);
List<ArrowBuf> decompressedBuffers = deCompressBuffers(codec, compressedBuffers);

assertEquals(2, decompressedBuffers.size());

Expand All @@ -144,8 +154,9 @@ public void testCompressFixedWidthBuffers() throws Exception {
AutoCloseables.close(decompressedBuffers);
}

@Test
public void testCompressVariableWidthBuffers() throws Exception {
@ParameterizedTest
@MethodSource("codecs")
void testCompressVariableWidthBuffers(int vectorLength, CompressionCodec codec) throws Exception {
// prepare vector to compress
VarCharVector origVec = new VarCharVector("vec", allocator);
origVec.allocateNew();
Expand All @@ -161,8 +172,8 @@ public void testCompressVariableWidthBuffers() throws Exception {

// compress & decompress
List<ArrowBuf> origBuffers = origVec.getFieldBuffers();
List<ArrowBuf> compressedBuffers = compressBuffers(origBuffers);
List<ArrowBuf> decompressedBuffers = deCompressBuffers(compressedBuffers);
List<ArrowBuf> compressedBuffers = compressBuffers(codec, origBuffers);
List<ArrowBuf> decompressedBuffers = deCompressBuffers(codec, compressedBuffers);

assertEquals(3, decompressedBuffers.size());

Expand All @@ -184,8 +195,9 @@ public void testCompressVariableWidthBuffers() throws Exception {
AutoCloseables.close(decompressedBuffers);
}

@Test
public void testEmptyBuffer() throws Exception {
@ParameterizedTest
@MethodSource("codecs")
void testEmptyBuffer(int vectorLength, CompressionCodec codec) throws Exception {
final VarBinaryVector origVec = new VarBinaryVector("vec", allocator);

origVec.allocateNew(vectorLength);
Expand All @@ -194,8 +206,8 @@ public void testEmptyBuffer() throws Exception {
origVec.setValueCount(vectorLength);

final List<ArrowBuf> origBuffers = origVec.getFieldBuffers();
final List<ArrowBuf> compressedBuffers = compressBuffers(origBuffers);
final List<ArrowBuf> decompressedBuffers = deCompressBuffers(compressedBuffers);
final List<ArrowBuf> compressedBuffers = compressBuffers(codec, origBuffers);
final List<ArrowBuf> decompressedBuffers = deCompressBuffers(codec, compressedBuffers);

// orchestrate new vector
VarBinaryVector newVec = new VarBinaryVector("new vec", allocator);
Expand All @@ -210,4 +222,117 @@ public void testEmptyBuffer() throws Exception {
newVec.close();
AutoCloseables.close(decompressedBuffers);
}

private static Stream<CompressionUtil.CodecType> codecTypes() {
return Arrays.stream(CompressionUtil.CodecType.values());
}

@ParameterizedTest
@MethodSource("codecTypes")
void testReadWriteStream(CompressionUtil.CodecType codec) throws Exception {
withRoot(codec, (factory, root) -> {
ByteArrayOutputStream compressedStream = new ByteArrayOutputStream();
try (final ArrowStreamWriter writer = new ArrowStreamWriter(
root, new DictionaryProvider.MapDictionaryProvider(),
Channels.newChannel(compressedStream),
IpcOption.DEFAULT, factory, codec)) {
writer.start();
writer.writeBatch();
writer.end();
} catch (IOException e) {
throw new RuntimeException(e);
}

try (ArrowStreamReader reader = new ArrowStreamReader(
new ByteArrayReadableSeekableByteChannel(compressedStream.toByteArray()), allocator, factory)) {
assertTrue(reader.loadNextBatch());
assertTrue(root.equals(reader.getVectorSchemaRoot()));
assertFalse(reader.loadNextBatch());
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}

@ParameterizedTest
@MethodSource("codecTypes")
void testReadWriteFile(CompressionUtil.CodecType codec) throws Exception {
withRoot(codec, (factory, root) -> {
ByteArrayOutputStream compressedStream = new ByteArrayOutputStream();
try (final ArrowFileWriter writer = new ArrowFileWriter(
root, new DictionaryProvider.MapDictionaryProvider(),
Channels.newChannel(compressedStream),
new HashMap<>(), IpcOption.DEFAULT, factory, codec)) {
writer.start();
writer.writeBatch();
writer.end();
} catch (IOException e) {
throw new RuntimeException(e);
}

try (ArrowFileReader reader = new ArrowFileReader(
new ByteArrayReadableSeekableByteChannel(compressedStream.toByteArray()), allocator, factory)) {
assertTrue(reader.loadNextBatch());
assertTrue(root.equals(reader.getVectorSchemaRoot()));
assertFalse(reader.loadNextBatch());
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}

/** Unloading a vector should not free source buffers. */
@ParameterizedTest
@MethodSource("codecTypes")
void testUnloadCompressed(CompressionUtil.CodecType codec) {
withRoot(codec, (factory, root) -> {
root.getFieldVectors().forEach((vector) -> {
Arrays.stream(vector.getBuffers(/*clear*/ false)).forEach((buf) -> {
assertNotEquals(0, buf.getReferenceManager().getRefCount());
});
});

final VectorUnloader unloader = new VectorUnloader(
root, /*includeNullCount*/ true, factory.createCodec(codec), /*alignBuffers*/ true);
unloader.getRecordBatch().close();

root.getFieldVectors().forEach((vector) -> {
Arrays.stream(vector.getBuffers(/*clear*/ false)).forEach((buf) -> {
assertNotEquals(0, buf.getReferenceManager().getRefCount());
});
});
});
}

void withRoot(CompressionUtil.CodecType codec, BiConsumer<CompressionCodec.Factory, VectorSchemaRoot> testBody) {
final Schema schema = new Schema(Arrays.asList(
Field.nullable("ints", new ArrowType.Int(32, true)),
Field.nullable("strings", ArrowType.Utf8.INSTANCE)));
CompressionCodec.Factory factory = codec == CompressionUtil.CodecType.NO_COMPRESSION ?
NoCompressionCodec.Factory.INSTANCE : CommonsCompressionFactory.INSTANCE;
try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
final IntVector ints = (IntVector) root.getVector(0);
final VarCharVector strings = (VarCharVector) root.getVector(1);
// Doesn't get compresed
ints.setSafe(0, 0x4a3e);
ints.setSafe(1, 0x8aba);
ints.setSafe(2, 0x4362);
ints.setSafe(3, 0x383f);
// Gets compressed
String compressibleString = " "; // 16 bytes
compressibleString = compressibleString + compressibleString;
compressibleString = compressibleString + compressibleString;
compressibleString = compressibleString + compressibleString;
compressibleString = compressibleString + compressibleString;
compressibleString = compressibleString + compressibleString; // 512 bytes
byte[] compressibleData = compressibleString.getBytes(StandardCharsets.UTF_8);
strings.setSafe(0, compressibleData);
strings.setSafe(1, compressibleData);
strings.setSafe(2, compressibleData);
strings.setSafe(3, compressibleData);
root.setRowCount(4);

testBody.accept(factory, root);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public VectorUnloader(
VectorSchemaRoot root, boolean includeNullCount, CompressionCodec codec, boolean alignBuffers) {
this.root = root;
this.includeNullCount = includeNullCount;
this.codec = codec;
this.codec = codec == null ? NoCompressionCodec.INSTANCE : codec;
this.alignBuffers = alignBuffers;
}

Expand All @@ -83,8 +83,10 @@ public ArrowRecordBatch getRecordBatch() {
for (FieldVector vector : root.getFieldVectors()) {
appendNodes(vector, nodes, buffers);
}
// Do NOT retain buffers in ArrowRecordBatch constructor since we have already retained them.
return new ArrowRecordBatch(
root.getRowCount(), nodes, buffers, CompressionUtil.createBodyCompression(codec), alignBuffers);
root.getRowCount(), nodes, buffers, CompressionUtil.createBodyCompression(codec), alignBuffers,
/*retainBuffers*/ false);
}

private void appendNodes(FieldVector vector, List<ArrowFieldNode> nodes, List<ArrowBuf> buffers) {
Expand All @@ -97,6 +99,11 @@ private void appendNodes(FieldVector vector, List<ArrowFieldNode> nodes, List<Ar
vector.getField(), vector.getClass().getSimpleName(), fieldBuffers));
}
for (ArrowBuf buf : fieldBuffers) {
// If the codec is NoCompressionCodec, then it will return the input buffer unchanged. In that case,
// we need to retain it for ArrowRecordBatch. Otherwise, it will return a new buffer, and also close
// the input buffer. In that case, we need to retain the input buffer still to avoid modifying
// the source VectorSchemaRoot.
buf.getReferenceManager().retain();
buffers.add(codec.compress(vector.getAllocator(), buf));
}
for (FieldVector child : vector.getChildrenFromFields()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public ArrowBuf compress(BufferAllocator allocator, ArrowBuf uncompressedBuffer)
if (compressedLength > uncompressedLength) {
// compressed buffer is larger, send the raw buffer
compressedBuffer.close();
// XXX: this makes a copy of uncompressedBuffer
compressedBuffer = CompressionUtil.packageRawBuffer(allocator, uncompressedBuffer);
} else {
writeUncompressedLength(compressedBuffer, uncompressedLength);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.compression.CompressionCodec;
import org.apache.arrow.vector.compression.CompressionUtil;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowBlock;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
Expand Down Expand Up @@ -69,6 +71,13 @@ public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, Writa
this.metaData = metaData;
}

public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out,
Map<String, String> metaData, IpcOption option, CompressionCodec.Factory compressionFactory,
CompressionUtil.CodecType codecType) {
super(root, provider, out, option, compressionFactory, codecType);
this.metaData = metaData;
}

@Override
protected void startInternal(WriteChannel out) throws IOException {
ArrowMagic.writeMagic(out, true);
Expand Down
Loading

0 comments on commit 4c698fb

Please sign in to comment.