In [1]:
lists = [[1,4,5],[1,3,4],[2,6]]
# Output: [1,1,2,3,4,4,5,6]

In [2]:
class ListNode:
    def __init__(self, num):
        self.val = num
        self.next = None

def build_linked_list(l):
    if not l:
        return None
    nodes = [ListNode(val) for val in l]
    for i in range(len(l) - 1):
        nodes[i].next = nodes[i + 1]
    return nodes[0]

def linked_list_to_list(head):
    if not head:
        return []
    res = []
    while head:
        res.append(head.val)
        head = head.next
    return res

heads = [None] * len(lists)
for i in range(len(lists)):
    heads[i] = build_linked_list(lists[i])

    

In [3]:
def mergeTwoLists(head1, head2):
    dummy = ListNode(0)
    tail = dummy
    while head1 and head2:
        if head1.val <= head2.val:
            tail.next = head1
            head1 = head1.next
        else:
            tail.next = head2
            head2 = head2.next
        tail = tail.next
    tail.next = head1 or head2
    return dummy.next
    

In [4]:
def merge_k_lists_naive(heads):
    if not heads:
        return None
    ans = None
    for head in heads:
        ans = mergeTwoLists(ans, head)
    return ans

merged_head_naive = merge_k_lists_naive(heads) # time O(n * k), space O(1)
linked_list_to_list(merged_head_naive)


[1, 1, 2, 3, 4, 4, 5, 6]

In [5]:
import heapq

def merge_k_lists_heap(heads):
    if not heads:
        return None
    min_heap = []
    for i, head in enumerate(heads):
        if head:
            heapq.heappush(min_heap, (head.val, i, head))
            
    dummy = ListNode(0)
    tail = dummy
    while min_heap:
        val, i, node = heapq.heappop(min_heap)
        tail.next = node
        tail = tail.next
        if tail.next:
            heapq.heappush(min_heap, (tail.next.val, i, tail.next))
    return dummy.next

merged_head_heap = merge_k_lists_heap(heads) # time O(n * log k), space O(k)
linked_list_to_list(merged_head_heap)
    

[1, 1, 2, 3, 4, 4, 5, 6]

In [4]:
def merge_k_lists_divide_conquer(heads):
    if not heads:
        return None
    n = len(heads)
    if n == 1:
        return heads[0]
    mid = n // 2
    left = merge_k_lists_divide_conquer(heads[:mid])
    right = merge_k_lists_divide_conquer(heads[mid:])
    return mergeTwoLists(left, right)

merged_head_divide_conquer = merge_k_lists_divide_conquer(heads) # time O(n * log k), space O(k log k), slicing is O(k) stack space is O(log k)
linked_list_to_list(merged_head_divide_conquer)


[1, 1, 2, 3, 4, 4, 5, 6]

In [4]:
def merge_k_lists_divide_conquer_efficient(heads):
    def solve(lo, hi):
        if lo == hi:
            return heads[lo]
        mid = (lo + hi) // 2
        left = solve(lo, mid)
        right = solve(mid + 1, hi)
        return mergeTwoLists(left, right)
    
    if not heads:
        return None
    return solve(0, len(heads) - 1)

merged_head_divide_conquer_efficient = merge_k_lists_divide_conquer_efficient(heads) # time O(n * log k), space O(log k)
linked_list_to_list(merged_head_divide_conquer_efficient)


[1, 1, 2, 3, 4, 4, 5, 6]