Problem Statement. <br/>
Given a n x n matrix where each of the rows and columns are sorted in ascending order, find the kth smallest element in the matrix. <br/>
Note that it is the kth smallest element in the sorted order, not the kth distinct element. <br/>

Example: <br/>
matrix = [ <br/>
   [ 1,  5,  9], <br/>
   [10, 11, 13], <br/>
   [12, 13, 15] <br/>
], <br/>
k = 8, <br/>

return 13.

# Minheap - O((N + K) *log N) runtime, O(N) space

In [3]:
from typing import List
from heapq import heappush, heappop

class Solution:
    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
        
        n = len(matrix)
        
        minheap = []
        
        for i in range(min(n, k)):
            heappush(minheap, (matrix[i][0], i, 0))
            
        for i in range(1, k):
            _, r, c = heappop(minheap)
            if c < n - 1:
                heappush(minheap, (matrix[r][c + 1], r, c + 1))
            
        return minheap[0][0]

# Binary Search - O(N * log(Max−Min)) runtime, O(1) space, where Max and Min are the maximum and minimum numbers in the matrix

In [7]:
from typing import List, Tuple

class Solution:
    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
        
        n = len(matrix)
        start, end = matrix[0][0], matrix[n - 1][n - 1]
        
        def countLessEqual(mid: int) -> Tuple[int, int, int]:
            count, n = 0, len(matrix)
            smaller, larger = matrix[0][0], matrix[n - 1][n - 1]
            row, col = n - 1, 0

            while row >= 0 and col < n:
                if matrix[row][col] > mid:
                    # As matrix[row][col] is bigger than the mid, let's keep track of the
                    # smallest number greater than the mid
                    larger = min(larger, matrix[row][col])
                    row -= 1
                else:
                    # As matrix[row][col] is less than or equal to the mid, let's keep track of 
                    # the biggest number less than or equal to the mid
                    smaller = max(smaller, matrix[row][col])
                    count += row + 1
                    col += 1

            return count, smaller, larger
        
        while start < end:
            mid = start + (end - start) / 2
            count, smaller, larger = countLessEqual(mid)

            if count == k:
                return smaller
            if count < k:
                start = larger  # search higher
            else:
                end = smaller  # search lower

        return start

In [8]:
instance = Solution()
instance.kthSmallest([[1,3,5],[6,7,12],[11,14,14]], 3)

5