Skip to content

Commit

Permalink
ARROW-9377: [Java] Support unsigned dictionary indices
Browse files Browse the repository at this point in the history
  • Loading branch information
liyafan82 committed Jul 22, 2020
1 parent c09a82a commit d7dd662
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 2 deletions.
Expand Up @@ -315,7 +315,7 @@ public void setUnsafeWithPossibleTruncate(int index, long value) {

@Override
public long getValueAsLong(int index) {
return this.get(index);
return this.get(index) & 0xffL;
}


Expand Down
Expand Up @@ -286,7 +286,7 @@ public void setUnsafeWithPossibleTruncate(int index, long value) {

@Override
public long getValueAsLong(int index) {
return this.get(index);
return this.get(index) & 0xffffffffL;
}

private class TransferImpl implements TransferPair {
Expand Down
Expand Up @@ -20,6 +20,7 @@
import static org.apache.arrow.vector.TestUtils.newVarBinaryVector;
import static org.apache.arrow.vector.TestUtils.newVarCharVector;
import static org.apache.arrow.vector.testing.ValueVectorDataPopulator.setVector;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
Expand All @@ -29,6 +30,7 @@
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.function.ToIntBiFunction;

import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
Expand Down Expand Up @@ -878,6 +880,103 @@ public void testEncodeStructSubFieldWithCertainColumns() {
}
}

private void testDictionary(Dictionary dictionary, ToIntBiFunction<ValueVector, Integer> valGetter) {
try (VarCharVector vector = new VarCharVector("vector", allocator)) {
setVector(vector, "1", "3", "5", "7", "9");
try (ValueVector encodedVector = DictionaryEncoder.encode(vector, dictionary)) {

// verify encoded result
assertEquals(vector.getValueCount(), encodedVector.getValueCount());
assertEquals(valGetter.applyAsInt(encodedVector, 0), 1);
assertEquals(valGetter.applyAsInt(encodedVector, 1), 3);
assertEquals(valGetter.applyAsInt(encodedVector, 2), 5);
assertEquals(valGetter.applyAsInt(encodedVector, 3), 7);
assertEquals(valGetter.applyAsInt(encodedVector, 4), 9);

try (ValueVector decodedVector = DictionaryEncoder.decode(encodedVector, dictionary)) {
assertTrue(decodedVector instanceof VarCharVector);
assertEquals(vector.getValueCount(), decodedVector.getValueCount());
assertArrayEquals("1".getBytes(), ((VarCharVector) decodedVector).get(0));
assertArrayEquals("3".getBytes(), ((VarCharVector) decodedVector).get(1));
assertArrayEquals("5".getBytes(), ((VarCharVector) decodedVector).get(2));
assertArrayEquals("7".getBytes(), ((VarCharVector) decodedVector).get(3));
assertArrayEquals("9".getBytes(), ((VarCharVector) decodedVector).get(4));
}
}
}
}

@Test
public void testDictionaryUInt1() {
try (VarCharVector dictionaryVector = new VarCharVector("dict vector", allocator)) {
setVector(dictionaryVector, "0", "1", "2", "3", "4", "5", "6", "7", "8", "9");
Dictionary dictionary1 = new Dictionary(dictionaryVector,
new DictionaryEncoding(/*id=*/10L, /*ordered=*/false, /*indexType=*/ new ArrowType.Int(8, false)));
testDictionary(dictionary1, (vector, index) -> ((UInt1Vector) vector).get(index));
}
}

@Test
public void testDictionaryUInt2() {
try (VarCharVector dictionaryVector = new VarCharVector("dict vector", allocator)) {
setVector(dictionaryVector, "0", "1", "2", "3", "4", "5", "6", "7", "8", "9");
Dictionary dictionary2 = new Dictionary(dictionaryVector,
new DictionaryEncoding(/*id=*/20L, /*ordered=*/false, /*indexType=*/ new ArrowType.Int(16, false)));
testDictionary(dictionary2, (vector, index) -> ((UInt2Vector) vector).get(index));
}
}

@Test
public void testDictionaryUInt4() {
try (VarCharVector dictionaryVector = new VarCharVector("dict vector", allocator)) {
setVector(dictionaryVector, "0", "1", "2", "3", "4", "5", "6", "7", "8", "9");
Dictionary dictionary4 = new Dictionary(dictionaryVector,
new DictionaryEncoding(/*id=*/30L, /*ordered=*/false, /*indexType=*/ new ArrowType.Int(32, false)));
testDictionary(dictionary4, (vector, index) -> ((UInt4Vector) vector).get(index));
}
}

@Test
public void testDictionaryUInt8() {
try (VarCharVector dictionaryVector = new VarCharVector("dict vector", allocator)) {
setVector(dictionaryVector, "0", "1", "2", "3", "4", "5", "6", "7", "8", "9");
Dictionary dictionary8 = new Dictionary(dictionaryVector,
new DictionaryEncoding(/*id=*/40L, /*ordered=*/false, /*indexType=*/ new ArrowType.Int(64, false)));
testDictionary(dictionary8, (vector, index) -> (int) ((UInt8Vector) vector).get(index));
}
}

@Test
public void testDictionaryUIntOverflow() {
// the size is within the range of UInt1, but outside the range of TinyInt.
final int vecLength = 256;
try (VarCharVector dictionaryVector = new VarCharVector("dict vector", allocator)) {
dictionaryVector.allocateNew(vecLength * 3, vecLength);
for (int i = 0; i < vecLength; i++) {
dictionaryVector.set(i, String.valueOf(i).getBytes());
}
dictionaryVector.setValueCount(vecLength);

Dictionary dictionary = new Dictionary(dictionaryVector,
new DictionaryEncoding(/*id=*/10L, /*ordered=*/false, /*indexType=*/ new ArrowType.Int(8, false)));

try (VarCharVector vector = new VarCharVector("vector", allocator)) {
setVector(vector, "255");
try (UInt1Vector encodedVector = (UInt1Vector) DictionaryEncoder.encode(vector, dictionary)) {

// verify encoded result
assertEquals(1, encodedVector.getValueCount());
assertEquals(255, encodedVector.getValueAsLong(0));

try (VarCharVector decodedVector = (VarCharVector) DictionaryEncoder.decode(encodedVector, dictionary)) {
assertEquals(1, decodedVector.getValueCount());
assertArrayEquals("255".getBytes(), decodedVector.get(0));
}
}
}
}
}

private int[] convertListToIntArray(JsonStringArrayList list) {
int[] values = new int[list.size()];
for (int i = 0; i < list.size(); i++) {
Expand Down
Expand Up @@ -2977,4 +2977,47 @@ public void testEmptyBufBehavior() {
assertEquals(0, vector.getOffsetBuffer().capacity());
}
}

@Test
public void testSetGetUInt1() {
try (UInt1Vector vector = new UInt1Vector("vector", allocator)) {
vector.allocateNew(2);

vector.setWithPossibleTruncate(0, 0xffL);
vector.setUnsafeWithPossibleTruncate(1, 0xffL);
vector.setValueCount(2);

assertEquals(255, vector.getValueAsLong(0));
assertEquals(255, vector.getValueAsLong(1));
}
}

@Test
public void testSetGetUInt2() {
try (UInt2Vector vector = new UInt2Vector("vector", allocator)) {
vector.allocateNew(2);

vector.setWithPossibleTruncate(0, 0xffffL);
vector.setUnsafeWithPossibleTruncate(1, 0xffffL);
vector.setValueCount(2);

assertEquals(65535, vector.getValueAsLong(0));
assertEquals(65535, vector.getValueAsLong(1));
}
}

@Test
public void testSetGetUInt4() {
try (UInt4Vector vector = new UInt4Vector("vector", allocator)) {
vector.allocateNew(2);

vector.setWithPossibleTruncate(0, 0xffffffffL);
vector.setUnsafeWithPossibleTruncate(1, 0xffffffffL);
vector.setValueCount(2);

long expected = (1L << 32) - 1L;
assertEquals(expected, vector.getValueAsLong(0));
assertEquals(expected, vector.getValueAsLong(1));
}
}
}
@@ -0,0 +1,207 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.vector.ipc;

import static org.apache.arrow.vector.testing.ValueVectorDataPopulator.setVector;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import java.util.function.ToIntBiFunction;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.UInt1Vector;
import org.apache.arrow.vector.UInt2Vector;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.UInt8Vector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

/**
* Test the round-trip of dictionary encoding,
* with unsigned integer as indices.
*/
@RunWith(Parameterized.class)
public class TestUIntDictionaryRoundTrip {

private final boolean streamMode;

public TestUIntDictionaryRoundTrip(boolean streamMode) {
this.streamMode = streamMode;
}

private BufferAllocator allocator;

private VarCharVector dictionaryVector;

private DictionaryProvider.MapDictionaryProvider dictionaryProvider;

@Before
public void init() {
allocator = new RootAllocator(Long.MAX_VALUE);
dictionaryVector = new VarCharVector("dict vector", allocator);
setVector(dictionaryVector, "0", "1", "2", "3", "4", "5", "6", "7", "8", "9");

dictionaryProvider = new DictionaryProvider.MapDictionaryProvider();
}

@After
public void terminate() throws Exception {
dictionaryVector.close();
allocator.close();
}

private byte[] writeData(FieldVector encodedVector) throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
VectorSchemaRoot root =
new VectorSchemaRoot(
Arrays.asList(encodedVector.getField()), Arrays.asList(encodedVector), encodedVector.getValueCount());
try (ArrowWriter writer = streamMode ?
new ArrowStreamWriter(root, dictionaryProvider, out) :
new ArrowFileWriter(root, dictionaryProvider, Channels.newChannel(out))) {
writer.start();
writer.writeBatch();
writer.end();

return out.toByteArray();
}
}

private void readData(
byte[] data,
Field expectedField,
ToIntBiFunction<ValueVector, Integer> valGetter,
long dictionaryID) throws IOException {
try (ArrowReader reader = streamMode ?
new ArrowStreamReader(new ByteArrayInputStream(data), allocator) :
new ArrowFileReader(new SeekableReadChannel(new ByteArrayReadableSeekableByteChannel(data)), allocator)) {

// verify schema
Schema readSchema = reader.getVectorSchemaRoot().getSchema();
assertEquals(1, readSchema.getFields().size());
assertEquals(expectedField, readSchema.getFields().get(0));

// verify vector schema root
assertTrue(reader.loadNextBatch());
VectorSchemaRoot root = reader.getVectorSchemaRoot();

assertEquals(1, root.getFieldVectors().size());
ValueVector encodedVector = root.getVector(0);
assertEquals(5, encodedVector.getValueCount());

assertEquals(1, valGetter.applyAsInt(encodedVector, 0));
assertEquals(3, valGetter.applyAsInt(encodedVector, 1));
assertEquals(5, valGetter.applyAsInt(encodedVector, 2));
assertEquals(7, valGetter.applyAsInt(encodedVector, 3));
assertEquals(9, valGetter.applyAsInt(encodedVector, 4));

// verify dictionary
Map<Long, Dictionary> dictVectors = reader.getDictionaryVectors();
assertEquals(1, dictVectors.size());
Dictionary dictionary = dictVectors.get(dictionaryID);
assertNotNull(dictionary);

assertTrue(dictionary.getVector() instanceof VarCharVector);
VarCharVector dictVector = (VarCharVector) dictionary.getVector();
assertEquals(10, dictVector.getValueCount());
for (int i = 0; i < dictVector.getValueCount(); i++) {
assertArrayEquals(String.valueOf(i).getBytes(), dictVector.get(i));
}
}
}

private ValueVector createEncodedVector(int bitWidth) {
final DictionaryEncoding dictionaryEncoding =
new DictionaryEncoding(bitWidth, false, new ArrowType.Int(bitWidth, false));
Dictionary dictionary = new Dictionary(dictionaryVector, dictionaryEncoding);
dictionaryProvider.put(dictionary);

final FieldType type =
new FieldType(true, dictionaryEncoding.getIndexType(), dictionaryEncoding, null);
final Field field = new Field("encoded", type, null);
return field.createVector(allocator);
}

@Test
public void testUInt1RoundTrip() throws IOException {
try (UInt1Vector encodedVector1 = (UInt1Vector) createEncodedVector(8)) {
setVector(encodedVector1, (byte) 1, (byte) 3, (byte) 5, (byte) 7, (byte) 9);
byte[] data = writeData(encodedVector1);
readData(data, encodedVector1.getField(), (vector, index) -> ((UInt1Vector) vector).get(index), 8L);
}
}

@Test
public void testUInt2RoundTrip() throws IOException {
try (UInt2Vector encodedVector2 = (UInt2Vector) createEncodedVector(16)) {
setVector(encodedVector2, (char) 1, (char) 3, (char) 5, (char) 7, (char) 9);
byte[] data = writeData(encodedVector2);
readData(data, encodedVector2.getField(), (vector, index) -> ((UInt2Vector) vector).get(index), 16L);
}
}

@Test
public void testUInt4RoundTrip() throws IOException {
try (UInt4Vector encodedVector4 = (UInt4Vector) createEncodedVector(32)) {
setVector(encodedVector4, 1, 3, 5, 7, 9);
byte[] data = writeData(encodedVector4);
readData(data, encodedVector4.getField(), (vector, index) -> ((UInt4Vector) vector).get(index), 32L);
}
}

@Test
public void testUInt8RoundTrip() throws IOException {
try (UInt8Vector encodedVector8 = (UInt8Vector) createEncodedVector(64)) {
setVector(encodedVector8, 1L, 3L, 5L, 7L, 9L);
byte[] data = writeData(encodedVector8);
readData(data, encodedVector8.getField(), (vector, index) -> (int) ((UInt8Vector) vector).get(index), 64L);
}
}

@Parameterized.Parameters(name = "stream mode = {0}")
public static Collection<Object[]> getRepeat() {
return Arrays.asList(
new Object[]{true},
new Object[]{false}
);
}
}

0 comments on commit d7dd662

Please sign in to comment.