## Merge k sorted lists in O(nlgk)

In [17]:
import math

In [18]:
class tree:
    def __init__(self, n:list):
        self.n = [(-100, 0, 0)] + n
        self.size = len(n)

class minheap:
    def _left(self, i:int) -> int:
        return 2*i
    def _right(self, i:int) -> int:
        return 2*i+1
    def _parent(self, i:int) -> int:
        return i//2
    
    def maintenance(self, A:tree, i:int) -> tree:
        # float down a specific element
        smallest = i
        l = self._left(i)
        r = self._right(i)

        if l <= A.size:
            if A.n[l][0] < A.n[smallest][0]: smallest = l
        if r <= A.size:
            if A.n[r][0] < A.n[smallest][0]: smallest = r
        
        if smallest != i:
            A.n[i], A.n[smallest] = A.n[smallest], A.n[i]
            self.maintenance(A, smallest)

        return A
    
    def build(self, A:tree) -> tree:
        # build the whole tree given an array
        for i in reversed(range(1, (A.size+1)//2+1)): 
            self.maintenance(A, i)
        return A
    
    def decrease(self, A:tree, i:int, k:tuple) -> tree:
        # modify the value of a node
        if k[0] > A.n[i][0]: 
            return
        A.n[i] = k
        parent = self._parent(i)
        while A.n[i][0] < A.n[parent][0]:
            A.n[i], A.n[parent] = A.n[parent], A.n[i]
            i = parent
            parent = self._parent(i)

        return A
    
    def insert(self, A:tree, k:tuple) -> tree:
        # insert a node to the tree
        A.size += 1
        A.n.append((math.inf, k[1]))
        A = self.decrease(A, A.size, k)
        return A
    
    def extract_min(self, A:tree) -> tuple:
        # return and remove the smallest node in the tree
        A.n[1], A.n[-1] = A.n[-1], A.n[1]
        smallest = A.n[-1]
        A.size -= 1
        A.n = A.n[:-1]
        self.maintenance(A, 1)
        return A, smallest

In [19]:
def merge(A: list) -> list:
    min_heap = tree([])
    result = []
    heap = minheap()

    for k in range(len(A)):
        heap.insert(min_heap, (A[k][0], k))
        A[k].pop(0)

    while min_heap.size > 0:
        min_heap, smallest = heap.extract_min(min_heap)
        result.append(smallest[0])
        if A[smallest[1]]:
            heap.insert(min_heap, (A[smallest[1]][0], smallest[1]))
            A[smallest[1]].pop(0)
        else:
            for k in range(len(A)):
                if A[k]:
                    heap.insert(min_heap, (A[k][0], k))
                    A[k].pop(0)
                    break            
    return result

In [20]:
A = [[0, 3, 4, 7, 7, 11, 24, 25],
     [3, 4, 6, 8, 9, 14, 18, 26, 69, 89],
     [2, 4, 7, 20]]
merge(A)

[0, 2, 3, 3, 4, 4, 4, 6, 7, 7, 7, 8, 9, 11, 14, 18, 20, 24, 25, 26, 69, 89]