In [67]:
from math import log, comb, ceil
from typing import Tuple

Say we select a pivot by taking `p` random pivots, and picking the "median". This requires `p^2` queries. What's the probability this resulting pivot lies in the interval [1/L, (L-1)/L]?

```
Prob[median ranks below 1/L or above (L-1)/L] 
= 2 * Prob[p/2 samples all rank below 1/L]
= 2 * (1/L)^(p/2) * ((L-1)/L)^(p – p//2) * (p choose p/2) 
```

In [1]:
def probability_bad(p, L):
    """
    Upper bound for probability that pivot does not lie in [1/L, (L-1)/L].
    """
    return (
        2
        * (1 / L) ** (p / 2)
        * ((L - 1) / L) ** (p - p / 2)
        * comb(int(p), int(p) // 2)
    )

In [54]:
def cost_terms(p, L, N=500) -> Tuple[float, float]:
    """
    Returns tuple of two terms -- when summed they give the total cost of compute_nearest_neighbors. (This is assuming the )
    """
    return (p**2 * log(N, L / (L - 1)), L * N)

What this is saying is: we do 
N ( 1 + (L-1)/L + ((L-1)/L)^2 + ...) = LN 
compare() queries once we have the pivot, and so there's about log(N, L/(L-1)) steps, and so we need additionally p^2 * log(N, L/(L-1)) steps to compute the pivot each time. 

In [70]:
def cost(p, L, N=500):
    return sum(cost_terms(p, L, N))

Note: you'll notice this doesn't depend on the size of the nearest neighbor set. This is because we're assuming the size of the nearest neighbor set is way smaller than N. 

Below is code which, given a budget `MAX_NUM_COMPARE_QUERIES`, finds the `p, L` which minimize `probability_bad` (this is the probability that a given pivot does not lie in the range [1/L, (L-1)/L].)

In [73]:
import numpy as np
from scipy.optimize import minimize

MAX_NUM_COMPARE_QUERIES = 5000


def objective_function(variables):
    p, L = variables
    return probability_bad(p, L)


def constraint_function(variables):
    p, L = variables
    M = MAX_NUM_COMPARE_QUERIES  # Set your M value here
    return M - cost(p, L, 500)


constraint1 = {"type": "ineq", "fun": constraint_function}

constraint2 = {"type": "ineq", "fun": lambda variables: variables[1] - 2}  # L >= 2

constraint3 = {"type": "ineq", "fun": lambda variables: variables[0] - 3}  # p >= 3

# Initial guess for p and L
initial_guess = np.array([5, 5])

# Run the optimization
result = minimize(
    objective_function,
    initial_guess,
    method="SLSQP",
    constraints=[constraint1, constraint2, constraint3],
)

optimized_p, optimized_L = result.x
round_p = ceil(optimized_p)
round_L = ceil(optimized_L)
print(f"Compare query budget: {MAX_NUM_COMPARE_QUERIES}")
print(f"Optimized values: p = {optimized_p}, L = {optimized_L}")
print(f"Rounded values: p={round_p}, L={round_L}")
print(f"Probability a pivot is bad= {result.fun}")
print(
    f"Resulting number of compare() queries (using round values) = {cost(round_p, round_L)}"
)

Compare query budget: 5000
Optimized values: p = 7.999620841276833, L = 5.797058705852599
Rounded values: p=8, L=6
Probability a pivot is bad= 0.029073469661271032
Resulting number of compare() queries (using round values) = 5181.502425127429
