    Problem Statement

    We have a list of points on the plane.  Find the K closest points to the origin (0, 0).

    (Here, the distance between two points on a plane is the Euclidean distance.)

    You may return the answer in any order.  The answer is guaranteed to be unique (except for the order that it is in.)



    Example 1:

    Input: points = [[1,3],[-2,2]], K = 1
    Output: [[-2,2]]
    Explanation: 
    The distance between (1, 3) and the origin is sqrt(10).
    The distance between (-2, 2) and the origin is sqrt(8).
    Since sqrt(8) < sqrt(10), (-2, 2) is closer to the origin.
    We only want the closest K = 1 points from the origin, so the answer is just [[-2,2]].

    Example 2:

    Input: points = [[3,3],[5,-1],[-2,4]], K = 2
    Output: [[3,3],[-2,4]]
    (The answer [[-2,4],[3,3]] would also be accepted.)



    Note:

        1 <= K <= points.length <= 10000
        -10000 < points[i][0] < 10000
        -10000 < points[i][1] < 10000

# Heap - O(N * logK) runtime, O(K) space

In [1]:
from typing import List
from heapq import heappush, heappop

class Solution:
    def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
        maxheap = []
        
        for i, point in enumerate(points):
            distance = point[0] ** 2 + point[1] ** 2
            heappush(maxheap, (-distance, point))
            if i >= K: heappop(maxheap)
        
        res = []
        while maxheap:
            res.append(heappop(maxheap)[1])
            
        return res

# Sort - O(N * logN) runtime, O(N) space

In [2]:
from typing import List

class Solution:
    def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
        points.sort(key = lambda P: P[0]**2 + P[1]**2)
        return points[:K]

# Divide and Conquer - O(N) average, O(N ^ 2) worst runtime, O(N) space

In [3]:
from typing import List
from random import randint

class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        dist = lambda p: p[0] ** 2 + p[1] ** 2
        
        def sort(i, j, k):
            if i >= j: return
            r = randint(i, j)
            points[i], points[r] = points[r], points[i]
            
            mid = partition(i, j)
            if k < mid - i + 1:
                sort(i, mid - 1, k)
            elif k > mid - i + 1:
                sort(mid+1, j, k - (mid - i + 1))
                
        def partition(i, j):
            oi = i
            pivot = dist(points[i])
            i += 1
            
            while True:
                while i < j and dist(points[i]) < pivot: i += 1
                while i <= j and dist(points[j]) >= pivot: j -= 1
                
                if i >= j: break
                points[i], points[j] = points[j], points[i]
                
            points[oi], points[j] = points[j], points[oi]
            return j
        
        sort(0, len(points) - 1, k)
        return points[:k]

In [4]:
instance = Solution()
instance.kClosest([[3,3],[5,-1],[-2,4]], 2)

[[3, 3], [-2, 4]]