# ANSWER ONE - using quick selection and quick sort

In [27]:
def get_distance(point_one, point_two):
    distance = ((point_one[0] - point_two[0])**2 + (point_one[1] - point_two[1])**2)**0.5
    return distance
     
    
def quick_select(nums, start, end, k):
    left = start
    right = end
    pivot = nums[(start + end) // 2]
            
    while left <= right:
        while left <= right and nums[left] < pivot:
            left += 1
        while left <= right and nums[right] > pivot:
            right -= 1
        if left <= right:
            nums[left], nums[right] = nums[right], nums[left]
            left += 1
            right -= 1
                
    if k <= right:
        quick_select(nums, start, right, k)
    if k >= left:
        quick_select(nums, left, end, k)     
        
    return nums[k]
    

def quick_sort(nums, start, end):
    if start >= end:
        return
    
    left = start
    right = end
    pivot = nums[(start + end) // 2]
    
    while left <= right:
        while left <= right and nums[left] < pivot:
            left += 1
        while left <= right and nums[right] > pivot:
            right -= 1
        if left <= right:
            nums[left], nums[right] = nums[right], nums[left]
            left += 1
            right -= 1
            
    quick_sort(nums, start, right)
    quick_sort(nums, left, end)
    
    return nums
    
    
def kClosest(points, origin, k):
    # get distance
    hash_point_distance = {}
    for point in points:
        distance = get_distance(origin, point)
        if distance not in hash_point_distance:
            hash_point_distance[distance] = point
        
    # find kth distance
    distances = list(hash_point_distance.keys())
    kth_num = quick_select(distances, 0, len(distances) - 1, len(distances) - k)
        
    # get kth_points_hash_map
    kth_points_hash_map = {}
    for distance, point in hash_point_distance.items():
        if distance <= kth_num:
            kth_points_hash_map[distance] = point
            
    # sort kth_points_hash_map
    kth_points_list = list(kth_points_hash_map.keys())
    quick_sort(kth_points_list, 0, len(kth_points_hash_map) - 1)
    
    # print out
    result = []
    for point in kth_points_list:
        result.append(kth_points_hash_map[point])
    
    return result

In [26]:
points = [[4,6],[4,7],[4,4],[2,5],[1,1]]
origin = [0, 0]
k = 3 
kClosest(points, origin, k)

[[1, 1], [2, 5], [4, 4]]

In [None]:
# note: the time complexity is O(n) + O(klogk)

# ANSWER TWO - using min-heap

In [37]:
import heapq

def get_distance(point_one, point_two):
    distance = ((point_one[0] - point_two[0])**2 + (point_one[1] - point_two[1])**2)**0.5
    return distance

def kClosest(points, origin, k):
    # create point_distance_hash
    point_distance_hash = {}
    for point in points:
        distance = get_distance(origin, point)
        point_distance_hash[distance] = point
        
    # put distance list into a min-heap
    heap = []
    distance_list = list(point_distance_hash.keys())
    for i in distance_list:
        heapq.heappush(heap, i)
        
    # take k elements from min-head and get final result
    result = []
    while k > 0:
        min_distance = heapq.heappop(heap)
        result.append(point_distance_hash[min_distance])
        k -= 1
        
    return result

In [38]:
points = [[4,6],[4,7],[4,4],[2,5],[1,1]]
origin = [0, 0]
k = 3 
kClosest(points, origin, k)

[[1, 1], [2, 5], [4, 4]]

In [None]:
# note: the time complexity is O(n) + O(nlogk)

# ANSWER THREE - using max-heap (an on-line method)

In [60]:
import heapq

def get_distance(point_one, point_two):
    distance = ((point_one[0] - point_two[0])**2 + (point_one[1] - point_two[1])**2)**0.5
    return distance

def kClosest(points, origin, k):
    heap = []
    heap_size = 0
    point_distance_hash = {}
    
    for point in points:
        distance = get_distance(origin, point)
        point_distance_hash[distance] = point
        
        if heap_size < k:
            heapq.heappush(heap, -distance)
            heap_size += 1
            
        elif heap_size >= k:
            heap_root_value = heap[0]
            
            if -distance > heap_root_value:
                heapq.heappop(heap)
                heapq.heappush(heap, -distance)
    
    result = []
    while heap:
        point = point_distance_hash[-heapq.heappop(heap)]
        result.append(point)
        
    result.reverse()                
    return result

In [61]:
points = [[4,6],[4,7],[4,4],[2,5],[1,1]]
origin = [0, 0]
k = 3 
kClosest(points, origin, k)

[[1, 1], [2, 5], [4, 4]]

In [None]:
# note: the time complexity is O(nlogk)