## Merge Sort

In [33]:
from typing import List

In [34]:
def merge(left: List[int], right: List[int], asc: bool) -> List[int]:
    n, m = len(left), len(right)
    merged = []
    lp = rp = 0
    while lp < n or rp < m:
        if lp < n and rp < m:
            if (left[lp] < right[rp]) if asc else (left[lp] > right[rp]):
                merged.append(left[lp])
                lp += 1
            else:
                merged.append(right[rp])
                rp += 1
        elif lp < n:
            merged.append(left[lp])
            lp += 1
        else:
            merged.append(right[rp])
            rp += 1
    return merged

In [35]:
def merge_sort(arr: List[int], asc: bool = True) -> List[int]:
    def sort(left: int, right: int) -> List[int]:
        if left == right:
            return [arr[left]]
        
        mid = ((right - left) // 2) + left
        return merge(sort(left, mid), sort(mid + 1, right), asc)
        
    return sort(0, len(arr) - 1)

In [36]:
merge_sort([3,1,2,8,4,5], False)

[8, 5, 4, 3, 2, 1]

In [37]:
merge_sort([3,1,2,8,4,5])

[1, 2, 3, 4, 5, 8]

In [38]:
def merge_sort_iterative(arr: List[int], asc: bool = True) -> List[int]:
    stack = [(0, len(arr) - 1)]
    seen = {}
    while stack:
        left, right = stack[-1]
        if left == right:
            stack.pop()
            seen[(left, right)] = [arr[left]]
        else:
            mid = ((right - left) // 2) + left
            
            if (mid + 1, right) not in seen:
                stack.append((mid + 1, right))
                
            if (left, mid) not in seen:
                stack.append((left, mid))

            if (left, mid) in seen and (mid + 1, right) in seen:
                stack.pop()
                seen[(left, right)] = merge(seen[(left, mid)], seen[(mid + 1, right)], asc)
                del seen[(left, mid)]
                del seen[(mid + 1, right)]
    return seen[(0, len(arr) - 1)]

In [39]:
merge_sort_iterative([3,1,2,8,4,5])

[1, 2, 3, 4, 5, 8]

In [40]:
merge_sort_iterative([3,1,2,8,4,5], False)

[8, 5, 4, 3, 2, 1]

In [41]:
class Node:
    def __init__(self, val: int, next: "Node" = None):
        self.val = val
        self.next = next
    
    def __str__(self) -> str:
        next = self
        result = []
        while next:
            result.append(next.val)
            next = next.next
        return str(result)
    
    def __repr__(self) -> str:
        return str(self)

In [42]:
def merge_linked_list(left: Node, right: Node, asc: bool) -> Node:
    if not left:
        return right
    
    if not right:
        return left
    
    if (left.val > right.val) if asc else (left.val < right.val):
        right.next = merge_linked_list(left, right.next, asc)
        return right
    
    left.next = merge_linked_list(left.next, right, asc)
    return left

In [43]:
merge_linked_list(Node(5, Node(3, Node(2))), Node(6, Node(4, Node(1))), False)

[6, 5, 4, 3, 2, 1]

In [44]:
def merge_sort_linked_list(current: Node, n: int, asc: bool = True) -> Node:
    if n < 1:
        return None
    
    if n == 1:
        current.next = None
        return current
    
    mid = n // 2
    right = current
    count = 0
    while count < mid:
        right = right.next
        count += 1
    
    return merge_linked_list(
        merge_sort_linked_list(current, count, asc), 
        merge_sort_linked_list(right, n - count, asc),
        asc
    )

In [45]:
merge_sort_linked_list(Node(3, Node(1, Node(7, Node(4, Node(9, Node(0, Node(2))))))), 7, True)

[0, 1, 2, 3, 4, 7, 9]