-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ARROW-6013: [Java] Support range searcher
For a sorted vector, the range searcher finds the first/last occurrence of a particular element. The search is based on binary search, which takes O(logn) time. Closes #4925 from liyafan82/fly_0723_range and squashes the following commits: 4690f69 <liyafan82> Support range searcher Authored-by: liyafan82 <fan_li_ya@foxmail.com> Signed-off-by: Pindikura Ravindra <ravindra@dremio.com>
- Loading branch information
Showing
3 changed files
with
324 additions
and
9 deletions.
There are no files selected for viewing
138 changes: 138 additions & 0 deletions
138
java/algorithm/src/main/java/org/apache/arrow/algorithm/search/VectorRangeSearcher.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.arrow.algorithm.search; | ||
|
||
import org.apache.arrow.algorithm.sort.VectorValueComparator; | ||
import org.apache.arrow.vector.ValueVector; | ||
|
||
/** | ||
* Search for the range of a particular element in the target vector. | ||
*/ | ||
public class VectorRangeSearcher { | ||
|
||
/** | ||
* Result returned when a search fails. | ||
*/ | ||
public static final int SEARCH_FAIL_RESULT = -1; | ||
|
||
/** | ||
* Search for the first occurrence of an element. | ||
* The search is based on the binary search algorithm. So the target vector must be sorted. | ||
* @param targetVector the vector from which to perform the search. | ||
* @param comparator the criterion for the comparison. | ||
* @param keyVector the vector containing the element to search. | ||
* @param keyIndex the index of the search key in the key vector. | ||
* @param <V> the vector type. | ||
* @return the index of the first matched element if any, and -1 otherwise. | ||
*/ | ||
public static <V extends ValueVector> int getFirstMatch( | ||
V targetVector, VectorValueComparator<V> comparator, V keyVector, int keyIndex) { | ||
comparator.attachVectors(keyVector, targetVector); | ||
|
||
int low = 0; | ||
int high = targetVector.getValueCount() - 1; | ||
|
||
while (low <= high) { | ||
int mid = low + (high - low) / 2; | ||
int result = comparator.compare(keyIndex, mid); | ||
if (result < 0) { | ||
// the key is smaller | ||
high = mid - 1; | ||
} else if (result > 0) { | ||
// the key is larger | ||
low = mid + 1; | ||
} else { | ||
// the key equals the mid value, find the lower bound by going left-ward. | ||
|
||
// compare with the left neighbour | ||
int left = mid - 1; | ||
if (left == -1) { | ||
// this is the first value in the vector | ||
return mid; | ||
} else { | ||
int leftResult = comparator.compare(keyIndex, left); | ||
if (leftResult > 0) { | ||
// the key is greater than the left neighbour, and equal to the current one | ||
// we find it | ||
return mid; | ||
} else if (leftResult == 0) { | ||
// the left neighbour is also equal, continue to go left | ||
high = mid - 1; | ||
} else { | ||
// the key is larger than the left neighbour, this is not possible | ||
throw new IllegalStateException("The target vector is not sorted "); | ||
} | ||
} | ||
} | ||
} | ||
return SEARCH_FAIL_RESULT; | ||
} | ||
|
||
/** | ||
* Search for the last occurrence of an element. | ||
* The search is based on the binary search algorithm. So the target vector must be sorted. | ||
* @param targetVector the vector from which to perform the search. | ||
* @param comparator the criterion for the comparison. | ||
* @param keyVector the vector containing the element to search. | ||
* @param keyIndex the index of the search key in the key vector. | ||
* @param <V> the vector type. | ||
* @return the index of the last matched element if any, and -1 otherwise. | ||
*/ | ||
public static <V extends ValueVector> int getLastMatch( | ||
V targetVector, VectorValueComparator<V> comparator, V keyVector, int keyIndex) { | ||
comparator.attachVectors(keyVector, targetVector); | ||
|
||
int low = 0; | ||
int high = targetVector.getValueCount() - 1; | ||
|
||
while (low <= high) { | ||
int mid = low + (high - low) / 2; | ||
int result = comparator.compare(keyIndex, mid); | ||
if (result < 0) { | ||
// the key is smaller | ||
high = mid - 1; | ||
} else if (result > 0) { | ||
// the key is larger | ||
low = mid + 1; | ||
} else { | ||
// the key equals the mid value, find the upper bound by going right-ward. | ||
|
||
// compare with the right neighbour | ||
int right = mid + 1; | ||
if (right == targetVector.getValueCount()) { | ||
// this is the last value in the vector | ||
return mid; | ||
} else { | ||
int rightResult = comparator.compare(keyIndex, right); | ||
if (rightResult < 0) { | ||
// the key is smaller than the right neighbour, and equal to the current one | ||
// we find it | ||
return mid; | ||
} else if (rightResult == 0) { | ||
// the right neighbour is also equal, continue to go right | ||
low = mid + 1; | ||
} else { | ||
// the key is smaller than the right neighbour, this is not possible | ||
throw new IllegalStateException("The target vector is not sorted "); | ||
} | ||
} | ||
} | ||
} | ||
return SEARCH_FAIL_RESULT; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
178 changes: 178 additions & 0 deletions
178
java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorRangeSearcher.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.arrow.algorithm.search; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
|
||
import org.apache.arrow.algorithm.sort.DefaultVectorComparators; | ||
import org.apache.arrow.algorithm.sort.VectorValueComparator; | ||
import org.apache.arrow.memory.BufferAllocator; | ||
import org.apache.arrow.memory.RootAllocator; | ||
import org.apache.arrow.vector.IntVector; | ||
|
||
import org.junit.After; | ||
import org.junit.Before; | ||
import org.junit.Test; | ||
|
||
/** | ||
* Test cases for {@link VectorRangeSearcher}. | ||
*/ | ||
public class TestVectorRangeSearcher { | ||
|
||
private BufferAllocator allocator; | ||
|
||
@Before | ||
public void prepare() { | ||
allocator = new RootAllocator(1024 * 1024); | ||
} | ||
|
||
@After | ||
public void shutdown() { | ||
allocator.close(); | ||
} | ||
|
||
@Test | ||
public void testGetLowerBounds() { | ||
final int maxValue = 100; | ||
final int repeat = 5; | ||
try (IntVector intVector = new IntVector("int vec", allocator)) { | ||
// allocate vector | ||
intVector.allocateNew(maxValue * repeat); | ||
intVector.setValueCount(maxValue * repeat); | ||
|
||
// prepare data in sorted order | ||
// each value is repeated some times | ||
for (int i = 0; i < maxValue; i++) { | ||
for (int j = 0; j < repeat; j++) { | ||
if (i == 0) { | ||
intVector.setNull(i * repeat + j); | ||
} else { | ||
intVector.set(i * repeat + j, i); | ||
} | ||
} | ||
} | ||
|
||
// do search | ||
VectorValueComparator<IntVector> comparator = DefaultVectorComparators.createDefaultComparator(intVector); | ||
for (int i = 0; i < maxValue; i++) { | ||
int result = VectorRangeSearcher.getFirstMatch(intVector, comparator, intVector, i * repeat); | ||
assertEquals(i * repeat, result); | ||
} | ||
} | ||
} | ||
|
||
@Test | ||
public void testGetLowerBoundsNegative() { | ||
final int maxValue = 100; | ||
final int repeat = 5; | ||
try (IntVector intVector = new IntVector("int vec", allocator); | ||
IntVector negVector = new IntVector("neg vec", allocator)) { | ||
// allocate vector | ||
intVector.allocateNew(maxValue * repeat); | ||
intVector.setValueCount(maxValue * repeat); | ||
|
||
negVector.allocateNew(maxValue); | ||
negVector.setValueCount(maxValue); | ||
|
||
// prepare data in sorted order | ||
// each value is repeated some times | ||
for (int i = 0; i < maxValue; i++) { | ||
for (int j = 0; j < repeat; j++) { | ||
if (i == 0) { | ||
intVector.setNull(i * repeat + j); | ||
} else { | ||
intVector.set(i * repeat + j, i); | ||
} | ||
} | ||
negVector.set(i, maxValue + i); | ||
} | ||
|
||
// do search | ||
VectorValueComparator<IntVector> comparator = DefaultVectorComparators.createDefaultComparator(intVector); | ||
for (int i = 0; i < maxValue; i++) { | ||
int result = VectorRangeSearcher.getFirstMatch(intVector, comparator, negVector, i); | ||
assertEquals(-1, result); | ||
} | ||
} | ||
} | ||
|
||
@Test | ||
public void testGetUpperBounds() { | ||
final int maxValue = 100; | ||
final int repeat = 5; | ||
try (IntVector intVector = new IntVector("int vec", allocator)) { | ||
// allocate vector | ||
intVector.allocateNew(maxValue * repeat); | ||
intVector.setValueCount(maxValue * repeat); | ||
|
||
// prepare data in sorted order | ||
// each value is repeated some times | ||
for (int i = 0; i < maxValue; i++) { | ||
for (int j = 0; j < repeat; j++) { | ||
if (i == 0) { | ||
intVector.setNull(i * repeat + j); | ||
} else { | ||
intVector.set(i * repeat + j, i); | ||
} | ||
} | ||
} | ||
|
||
// do search | ||
VectorValueComparator<IntVector> comparator = DefaultVectorComparators.createDefaultComparator(intVector); | ||
for (int i = 0; i < maxValue; i++) { | ||
int result = VectorRangeSearcher.getLastMatch(intVector, comparator, intVector, i * repeat); | ||
assertEquals((i + 1) * repeat - 1, result); | ||
} | ||
} | ||
} | ||
|
||
@Test | ||
public void testGetUpperBoundsNegative() { | ||
final int maxValue = 100; | ||
final int repeat = 5; | ||
try (IntVector intVector = new IntVector("int vec", allocator); | ||
IntVector negVector = new IntVector("neg vec", allocator)) { | ||
// allocate vector | ||
intVector.allocateNew(maxValue * repeat); | ||
intVector.setValueCount(maxValue * repeat); | ||
|
||
negVector.allocateNew(maxValue); | ||
negVector.setValueCount(maxValue); | ||
|
||
// prepare data in sorted order | ||
// each value is repeated some times | ||
for (int i = 0; i < maxValue; i++) { | ||
for (int j = 0; j < repeat; j++) { | ||
if (i == 0) { | ||
intVector.setNull(i * repeat + j); | ||
} else { | ||
intVector.set(i * repeat + j, i); | ||
} | ||
} | ||
negVector.set(i, maxValue + i); | ||
} | ||
|
||
// do search | ||
VectorValueComparator<IntVector> comparator = DefaultVectorComparators.createDefaultComparator(intVector); | ||
for (int i = 0; i < maxValue; i++) { | ||
int result = VectorRangeSearcher.getLastMatch(intVector, comparator, negVector, i); | ||
assertEquals(-1, result); | ||
} | ||
} | ||
} | ||
} |