!pip install numpy

A randomized divide-and-conquer algorithm for selection

For any number v, imagine splitting list S
into three categories: elements smaller than v, those equal to v (there might be duplicates),
and those greater than v. Call these SL, Sv, and SR respectively. For instance, if the array
S: 2 36 5 21 8 13 11 20 5 4 1
is split on v = 5, the three subarrays generated are

SL: 2 4 1
Sv: 5 5
SR: 36 21 8 13 11 20

The search can instantly be narrowed down to one of these sublists. If we want, say, the
eighth-smallest element of S, we know it must be the third-smallest element of SR since
|SL| + |Sv| = 5. That is, selection(S; 8) = selection(SR; 3). More generally, by checking k
against the sizes of the subarrays, we can quickly determine which of them holds the desired
element:

selection(S; k) =
    selection(SL; k)        if k ≤ |SL|
    v                      if |SL| < k ≤ |SL| + |Sv|
    selection(SR; k - |SL| - |Sv|)  if k > |SL| + |Sv|


In [3]:
import numpy as np

In [12]:
def partition(arr, p):
    """
    Partition array `arr` into 3 subarrays:
    - elements < p
    - elements == p
    - elements > p
    """
    L, V, R = [], [], []
    for x in arr:
        if x < p:
            L.append(x)
        elif x == p:
            V.append(x)
        else:
            R.append(x)
    return L, V, R



In [13]:
def partition(arr, p):
    """
    Partition array `arr` into 3 subarrays:
    - elements < p
    - elements == p
    - elements > p
    """
    L, V, R = [], [], []
    for x in arr:
        if x < p:
            L.append(x)
        elif x == p:
            V.append(x)
        else:
            R.append(x)
    return L, V, R




In [14]:
def select(arr, k):
    """
    Return the k-th smallest element of array `arr`.
    """
    n = len(arr)
    if n == 1:
        return arr[0]

    # Pick a random pivot element
    p = np.random.choice(arr)

    # Partition the array into 3 subarrays
    L, V, R = partition(arr, p)

    if k <= len(L):
        # The k-th element is in L
        return select(L, k)
    elif k <= len(L) + len(V):
        # The k-th element is equal to the pivot
        return p
    else:
        # The k-th element is in R
        return select(R, k - len(L) - len(V))


Print out the k-th smallest element of the array S for k = 3, S = [5,8,9,5,0,0,1,7,6,9]
Output:
select 1
array [5, 8, 9, 5, 0, 0, 1, 7, 6, 9]
sorted array [0, 0, 1, 5, 5, 6, 7, 8, 9, 9]
The third smallest element is 1, which is the third element in the sorted array

In [18]:
S = [5,8,9,5,0,0,1,7,6,9]
k = 3
print("select", k)
print("array", S)
print("sorted array", sorted(S))
print("The third smallest element is", select(S, k))
assert select(S, k) == sorted(S)[k-1]

select 3
array [5, 8, 9, 5, 0, 0, 1, 7, 6, 9]
sorted array [0, 0, 1, 5, 5, 6, 7, 8, 9, 9]
The third smallest element is 1


Test the code on large n and k values. For example, n = 1000 and k = 500

In [36]:
# test the selection function for various n and k values
for n in [10, 100, 1000]:
    for k in [1, n//2, n]:
        # generate a random array of size n
        arr = np.random.randint(0, n, size=n)
        # get the k-th smallest element using the selection function
        k_smallest = select(arr, k)
        # sort the array
        sorted_arr = np.sort(arr)
        # print the k-th smallest element, the original array, and the sorted array
        print(f"k={k} n={n}: {k_smallest} \n\t{arr} \n\t{sorted_arr}")

k=1 n=10: 0 
	[5 8 9 0 2 5 8 5 3 9] 
	[0 2 3 5 5 5 8 8 9 9]
k=5 n=10: 4 
	[4 6 9 6 2 1 8 4 0 6] 
	[0 1 2 4 4 6 6 6 8 9]
k=10 n=10: 9 
	[5 5 2 0 7 2 7 3 6 9] 
	[0 2 2 3 5 5 6 7 7 9]
k=1 n=100: 1 
	[96 80 26 39 27 41 67 63 64  9 10 87 34 42 49 54 67 49 22 38 64 10 13 43
 21 22 39 68 30 97 83 49 30 67 52 34 46 32 36 44 55 87 46 61 68  5 48 27
 92 87 47 18 98 57 98 73 96 42 19 99 31 82 66 23 91 77 47 47 87 44 36 68
 28 81 95 69 25 70 11  9  1 61 50 39 64 49 59 99 60 49 37 98 70 29 17 54
 44 67 77 18] 
	[ 1  5  9  9 10 10 11 13 17 18 18 19 21 22 22 23 25 26 27 27 28 29 30 30
 31 32 34 34 36 36 37 38 39 39 39 41 42 42 43 44 44 44 46 46 47 47 47 48
 49 49 49 49 49 50 52 54 54 55 57 59 60 61 61 63 64 64 64 66 67 67 67 67
 68 68 68 69 70 70 73 77 77 80 81 82 83 87 87 87 87 91 92 95 96 96 97 98
 98 98 99 99]
k=50 n=100: 51 
	[12 42 23 71 65 16 93 80 63 21 34 93 53 95 10 93 41 31 67 96 95 47 93 83
 61 51 41 50 28 38 68 14 25  5 10 65 48 26 10  1 93 54 12 86 11 44 35 91
  2 95 82 74 25 90  8 24 12