# Nearest Example

## Server's Data Setup
The server owns coordinates to points of interest like restaurants and commerces. The coordinates are kept in a LookupTable

In [1]:
from concrete import fhe
import numpy


# Database of Points of Interests
points_array = numpy.array([
    [2, 3], [1, 5], [3, 2], [5, 2], [1, 1],
    [9, 4], [13, 2], [14, 13], [9, 8], [8, 0],
    [2, 10], [3, 8], [8, 12], [4, 10], [7, 7],
])
N_PTS = points_array.shape[0]
points = fhe.LookupTable(points_array.flatten())


def get_point(index):
    return (points[2*index], points[2*index + 1])


def all_distances(x, y):
    xs = numpy.arange(0, 2 * N_PTS, 2)
    ys = numpy.arange(1, 2 * N_PTS, 2)
    a = abs(points[xs] - x)
    b = abs(points[ys] - y)
    return a + b

We use swap sort to find the $K$ nearest points to a given point. However, we are interested in the indices of the elements, not just their distances. We must therefore work on tuples of index and distance, effectively implementing numpy argpartition.

In [2]:
# TLUs
relu = fhe.univariate(lambda x: x if x > 0 else 0)
is_positive = fhe.univariate(lambda x: 1 if x > 0 else 0)
arg_selection = fhe.univariate(lambda x: (x-1)//2 if x % 2 else 0)  # relu packed with a flag (alternating between 0 and relu)

In [3]:
def swap(this_idx, this_dist, that_idx, that_dist):
    """
    Swaps this and that if this > that. 
    We must pass both the index and the distance for both this and that.

    Returns:
      idxmin, min, idxmax, max of this and that based on distance
    """
    diff = this_dist - that_dist
    idx = arg_selection(2 * (this_idx - that_idx) + is_positive(diff))
    dist = relu(diff)

    idx_min = this_idx - idx
    idx_max = that_idx + idx 
    dist_min = this_dist - dist
    dist_max = that_dist + dist
    return fhe.array([idx_min, dist_min, idx_max, dist_max])


@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def knn(x, y):
    dist = all_distances(x, y)
    idx = list(range(N_PTS))
    for k in range(2):
        for i in range(k+1, N_PTS):
             idx[k], dist[k], idx[i], dist[i] = swap(idx[k], dist[k], idx[i], dist[i])
    return fhe.array([get_point(idx[j]) for j in range(2)])


inputset = [(4, 3), (0, 0), (15, 3), (4, 15)]

circuit = knn.compile(inputset)


## Client
The client simply invokes the server's nearest neighbours circuit.

In [4]:
%%timeit -r 1 -n 1
circuit.client.keys.generate()

57.9 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [5]:
def nearest(x, y):
    ex, ey = circuit.encrypt(x, y)
    res = circuit.run(ex, ey)   # Simulate request to the server
    return circuit.decrypt(res)

## Benchmarks

In [6]:
%%timeit -r 1 -n 1
nearest(4, 3)

21.7 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
