Skip to content

Commit

Permalink
ARROW-7400: [Java] Avoid the worst case for quick sort
Browse files Browse the repository at this point in the history
This issue is in response of a discussion in: #5540 (comment).

The quick sort algorithm can degenerate to an O(n^2) algorithm, if the pivot is selected poorly. This is an important problem, as the worst case can happen, if the input vector is alrady sorted, which is frequently encountered in practice.

After some investigation, we solve the problem with a simple but effective approach: take 3 samples and choose the median (with at most 3 comparisons) as the pivot. This sorts the vector which is already sorted in O(nlogn) time.

Closes #6039 from liyafan82/fly_1213_sort and squashes the following commits:

0943b06 <liyafan82>  Make tests more readable
7cdf0a6 <liyafan82>  Fix the bug of choosing pivot and add more tests
e6ab2bf <liyafan82>  Apply insertion sort when the range is small
1167176 <liyafan82>  Avoids the worst case for quick sort

Authored-by: liyafan82 <fan_li_ya@foxmail.com>
Signed-off-by: Micah Kornfield <emkornfield@gmail.com>
  • Loading branch information
liyafan82 authored and emkornfield committed Feb 18, 2020
1 parent 389d38b commit 38504e3
Show file tree
Hide file tree
Showing 7 changed files with 513 additions and 15 deletions.
6 changes: 6 additions & 0 deletions java/algorithm/pom.xml
Expand Up @@ -27,6 +27,12 @@
<version>${project.version}</version>
<classifier>${arrow.vector.classifier}</classifier>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-memory</artifactId>
Expand Down
Expand Up @@ -26,18 +26,25 @@
*/
public class FixedWidthInPlaceVectorSorter<V extends BaseFixedWidthVector> implements InPlaceVectorSorter<V> {

private VectorValueComparator<V> comparator;
/**
* If the number of items is smaller than this threshold, we will use another algorithm to sort the data.
*/
public static final int CHANGE_ALGORITHM_THRESHOLD = 15;

static final int STOP_CHOOSING_PIVOT_THRESHOLD = 3;

VectorValueComparator<V> comparator;

/**
* The vector to sort.
*/
private V vec;
V vec;

/**
* The buffer to hold the pivot.
* It always has length 1.
*/
private V pivotBuffer;
V pivotBuffer;

@Override
public void sortInPlace(V vec, VectorValueComparator<V> comparator) {
Expand All @@ -64,6 +71,12 @@ private void quickSort() {
int high = rangeStack.pop();
int low = rangeStack.pop();
if (low < high) {
if (high - low < CHANGE_ALGORITHM_THRESHOLD) {
// switch to insertion sort
InsertionSorter.insertionSort(vec, low, high, comparator, pivotBuffer);
continue;
}

int mid = partition(low, high);

// push the larger part to stack first,
Expand All @@ -86,8 +99,55 @@ private void quickSort() {
}
}

/**
* Select the pivot as the median of 3 samples.
*/
void choosePivot(int low, int high) {
// we need at least 3 items
if (high - low + 1 < STOP_CHOOSING_PIVOT_THRESHOLD) {
pivotBuffer.copyFrom(low, 0, vec);
return;
}

comparator.attachVector(vec);
int mid = low + (high - low) / 2;

// find the median by at most 3 comparisons
int medianIdx;
if (comparator.compare(low, mid) < 0) {
if (comparator.compare(mid, high) < 0) {
medianIdx = mid;
} else {
if (comparator.compare(low, high) < 0) {
medianIdx = high;
} else {
medianIdx = low;
}
}
} else {
if (comparator.compare(mid, high) > 0) {
medianIdx = mid;
} else {
if (comparator.compare(low, high) < 0) {
medianIdx = low;
} else {
medianIdx = high;
}
}
}

// move the pivot to the low position, if necessary
if (medianIdx != low) {
pivotBuffer.copyFrom(medianIdx, 0, vec);
vec.copyFrom(low, medianIdx, vec);
vec.copyFrom(0, low, pivotBuffer);
}

comparator.attachVectors(vec, pivotBuffer);
}

private int partition(int low, int high) {
pivotBuffer.copyFrom(low, 0, vec);
choosePivot(low, high);

while (low < high) {
while (low < high && comparator.compare(high, 0) >= 0) {
Expand Down
Expand Up @@ -28,6 +28,11 @@
*/
public class IndexSorter<V extends ValueVector> {

/**
* If the number of items is smaller than this threshold, we will use another algorithm to sort the data.
*/
public static final int CHANGE_ALGORITHM_THRESHOLD = 15;

/**
* Comparator for vector indices.
*/
Expand Down Expand Up @@ -68,6 +73,11 @@ private void quickSort() {
int low = rangeStack.pop();

if (low < high) {
if (high - low < CHANGE_ALGORITHM_THRESHOLD) {
InsertionSorter.insertionSort(indices, low, high, comparator);
continue;
}

int mid = partition(low, high, indices, comparator);

// push the larger part to stack first,
Expand All @@ -90,6 +100,53 @@ private void quickSort() {
}
}

/**
* Select the pivot as the median of 3 samples.
*/
static <T extends ValueVector> int choosePivot(
int low, int high, IntVector indices, VectorValueComparator<T> comparator) {
// we need at least 3 items
if (high - low + 1 < FixedWidthInPlaceVectorSorter.STOP_CHOOSING_PIVOT_THRESHOLD) {
return indices.get(low);
}

int mid = low + (high - low) / 2;

// find the median by at most 3 comparisons
int medianIdx;
if (comparator.compare(indices.get(low), indices.get(mid)) < 0) {
if (comparator.compare(indices.get(mid), indices.get(high)) < 0) {
medianIdx = mid;
} else {
if (comparator.compare(indices.get(low), indices.get(high)) < 0) {
medianIdx = high;
} else {
medianIdx = low;
}
}
} else {
if (comparator.compare(indices.get(mid), indices.get(high)) > 0) {
medianIdx = mid;
} else {
if (comparator.compare(indices.get(low), indices.get(high)) < 0) {
medianIdx = low;
} else {
medianIdx = high;
}
}
}

// move the pivot to the low position, if necessary
if (medianIdx != low) {
int tmp = indices.get(medianIdx);
indices.set(medianIdx, indices.get(low));
indices.set(low, tmp);
return tmp;
} else {
return indices.get(low);
}
}

/**
* Partition a range of values in a vector into two parts, with elements in one part smaller than
* elements from the other part. The partition is based on the element indices, so it does
Expand All @@ -103,7 +160,7 @@ private void quickSort() {
*/
public static <T extends ValueVector> int partition(
int low, int high, IntVector indices, VectorValueComparator<T> comparator) {
int pivotIndex = indices.get(low);
int pivotIndex = choosePivot(low, high, indices, comparator);

while (low < high) {
while (low < high && comparator.compare(indices.get(high), pivotIndex) >= 0) {
Expand Down
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.algorithm.sort;

import org.apache.arrow.vector.BaseFixedWidthVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.ValueVector;

/**
* Insertion sorter.
*/
class InsertionSorter {

/**
* Sorts the range of a vector by insertion sort.
*
* @param vector the vector to be sorted.
* @param startIdx the start index of the range (inclusive).
* @param endIdx the end index of the range (inclusive).
* @param buffer an extra buffer with capacity 1 to hold the current key.
* @param comparator the criteria for vector element comparison.
* @param <V> the vector type.
*/
static <V extends BaseFixedWidthVector> void insertionSort(
V vector, int startIdx, int endIdx, VectorValueComparator<V> comparator, V buffer) {
comparator.attachVectors(vector, buffer);
for (int i = startIdx; i <= endIdx; i++) {
buffer.copyFrom(i, 0, vector);
int j = i - 1;
while (j >= startIdx && comparator.compare(j, 0) > 0) {
vector.copyFrom(j, j + 1, vector);
j = j - 1;
}
vector.copyFrom(0, j + 1, buffer);
}
}

/**
* Sorts the range of vector indices by insertion sort.
*
* @param indices the vector indices.
* @param startIdx the start index of the range (inclusive).
* @param endIdx the end index of the range (inclusive).
* @param comparator the criteria for vector element comparison.
* @param <V> the vector type.
*/
static <V extends ValueVector> void insertionSort(
IntVector indices, int startIdx, int endIdx, VectorValueComparator<V> comparator) {
for (int i = startIdx; i <= endIdx; i++) {
int key = indices.get(i);
int j = i - 1;
while (j >= startIdx && comparator.compare(indices.get(j), key) > 0) {
indices.set(j + 1, indices.get(j));
j = j - 1;
}
indices.set(j + 1, key);
}
}
}
Expand Up @@ -17,11 +17,13 @@

package org.apache.arrow.algorithm.sort;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.testing.ValueVectorDataPopulator;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
Expand Down Expand Up @@ -114,4 +116,97 @@ public void testSortLargeIncreasingInt() {
}
}
}

@Test
public void testChoosePivot() {
final int vectorLength = 100;
try (IntVector vec = new IntVector("", allocator)) {
vec.allocateNew(vectorLength);

// the vector is sorted, so the pivot should be in the middle
for (int i = 0; i < vectorLength; i++) {
vec.set(i, i * 100);
}
vec.setValueCount(vectorLength);

FixedWidthInPlaceVectorSorter sorter = new FixedWidthInPlaceVectorSorter();
VectorValueComparator<IntVector> comparator = DefaultVectorComparators.createDefaultComparator(vec);

try (IntVector pivotBuffer = (IntVector) vec.getField().createVector(allocator)) {
// setup internal data structures
pivotBuffer.allocateNew(1);
sorter.pivotBuffer = pivotBuffer;
sorter.comparator = comparator;
sorter.vec = vec;
comparator.attachVectors(vec, pivotBuffer);

int low = 5;
int high = 6;
int pivotValue = vec.get(low);
assertTrue(high - low + 1 < FixedWidthInPlaceVectorSorter.STOP_CHOOSING_PIVOT_THRESHOLD);

// the range is small enough, so the pivot is simply selected as the low value
sorter.choosePivot(low, high);
assertEquals(pivotValue, vec.get(low));

low = 30;
high = 80;
pivotValue = vec.get((low + high) / 2);
assertTrue(high - low + 1 >= FixedWidthInPlaceVectorSorter.STOP_CHOOSING_PIVOT_THRESHOLD);

// the range is large enough, so the median is selected as the pivot
sorter.choosePivot(low, high);
assertEquals(pivotValue, vec.get(low));
}
}
}

/**
* Evaluates choosing pivot for all possible permutations of 3 numbers.
*/
@Test
public void testChoosePivotAllPermutes() {
try (IntVector vec = new IntVector("", allocator)) {
vec.allocateNew(3);

FixedWidthInPlaceVectorSorter sorter = new FixedWidthInPlaceVectorSorter();
VectorValueComparator<IntVector> comparator = DefaultVectorComparators.createDefaultComparator(vec);

try (IntVector pivotBuffer = (IntVector) vec.getField().createVector(allocator)) {
// setup internal data structures
pivotBuffer.allocateNew(1);
sorter.pivotBuffer = pivotBuffer;
sorter.comparator = comparator;
sorter.vec = vec;
comparator.attachVectors(vec, pivotBuffer);

int low = 0;
int high = 2;

ValueVectorDataPopulator.setVector(vec, 11, 22, 33);
sorter.choosePivot(low, high);
assertEquals(22, vec.get(0));

ValueVectorDataPopulator.setVector(vec, 11, 33, 22);
sorter.choosePivot(low, high);
assertEquals(22, vec.get(0));

ValueVectorDataPopulator.setVector(vec, 22, 11, 33);
sorter.choosePivot(low, high);
assertEquals(22, vec.get(0));

ValueVectorDataPopulator.setVector(vec, 22, 33, 11);
sorter.choosePivot(low, high);
assertEquals(22, vec.get(0));

ValueVectorDataPopulator.setVector(vec, 33, 11, 22);
sorter.choosePivot(low, high);
assertEquals(22, vec.get(0));

ValueVectorDataPopulator.setVector(vec, 33, 22, 11);
sorter.choosePivot(low, high);
assertEquals(22, vec.get(0));
}
}
}
}

0 comments on commit 38504e3

Please sign in to comment.