# Kth Smallest Element in a Sorted Matrix

**Problem**:
Given an `n x n` matrix where each row and column is sorted in ascending order, return the kth smallest element in the matrix. It's important to note that the requirement is for the kth smallest element in the sorted order, not the kth distinct element. The solution should have a memory complexity `better` than `O(n^2)`.

**Examples**:

1. **Input**:
   matrix = `[[1,5,9],[10,11,13],[12,13,15]]`, `k = 8`
   
   **Output**: `13`
   
   **Explanation**:
   The elements in the matrix are `[1,5,9,10,11,12,13,13,15]`, and the 8th smallest number is `13`.

2. **Input**:
   matrix = `[[−5]]`, `k = 1`
   
   **Output**: `−5`
   
   **Explanation**:
   The only element in the matrix is `−5`, which is also the smallest.

**Constraints**:
- `n == matrix.length == matrix[i].length`
- `1 <= n <= 300`
- `-10^9 <= matrix[i][j] <= 10^9`
- All the rows and columns of the matrix are sorted in non-decreasing order.
- `1 <= k <= n^2`


In [1]:
from typing import List
import heapq

In [11]:
class Solution1:
    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
        heap = []
        for i in range(len(matrix)):
            for j in range(len(matrix[0])):
                heapq.heappush(heap, matrix[i][j])
        for i in range(k):
            ret = heapq.heappop(heap)
        return ret


a = [[1,5,9],[10,11,13],[12,13,15]]
s1 = Solution1()
s1.kthSmallest(a, 8)

13

In [15]:
'''
    Merge sorting (with a min heap)
    Time complexity: O(Nlogk)
'''

class Solution2:
    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
        ret = 0
        heap = [(num[0], x, 0) for x, num in enumerate(matrix)]  # x->k, y->in array
        heapq.heapify(heap)
        for i in range(k):
            ret, x, y = heapq.heappop(heap)
            if y < len(matrix[x])-1:
                heapq.heappush(heap, (matrix[x][y+1], x, y+1))
        return ret

s2 = Solution2()
s2.kthSmallest(a, 8)

13

In [43]:
'''
    Binary Search. Time complexity O(nlog(r-l))  O(n) for find(), and the rest for binary search
    This is really fast. Pay extra attention to how binary search is implemented.
    More details on
    https://leetcode.cn/problems/kth-smallest-element-in-a-sorted-matrix/solutions/311472/you-xu-ju-zhen-zhong-di-kxiao-de-yuan-su-by-leetco/
'''

class Solution3:
    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
        def find(mid):
            ret = 0
            i, j = len(matrix) - 1, 0
            while j <= len(matrix[0]) - 1 and i >= 0:
                if matrix[i][j] <= mid:
                    ret += i+1
                    j += 1
                else:
                    i -= 1

            return ret

        def binary():
            low, high = matrix[0][0], matrix[-1][-1]

            while low<high:
                mid = (low + high) // 2
                num = find(mid)
                if num < k:
                    low = mid
                else:
                    high = mid

            return low

        return binary()

'''
    num表示的是矩阵中<不小于>mid的元素个数。如果num < k, 那么意味着所求的值必定大于mid, 不可能等于mid.
    因此需要将low设置为mid+1,以[确保]下一次循环中mid至少加1.
'''

a = [[1,2],[1,3]]
s3 = Solution3()
s3.kthSmallest(a, 1)

1

3