In [31]:
# Definition for singly-linked list.
class ListNode:
    def __init__(self, val=0, next=None):
        self.val = val
        self.next = next
    
    def __repr__(self):
        return f'{self.val} -> {self.next}'
    
    def __str__(self):
        return self.__repr__()
    
from typing import Optional
import sys
sys.setrecursionlimit(10**6)


class Solution:
    # recursive merge
    # takes in two sorted linked lists and merges them into one sorted linked lists
    def merge(self, left: Optional[ListNode], right: Optional[ListNode]) -> Optional[ListNode]:
        if left is None:
            return right
        if right is None:
            return left
        
        if left.val < right.val:
            left.next = self.merge(left.next, right)
            return left
        else:
            right.next = self.merge(left, right.next)
            return right
    
    # iterative merge instead of recursive
    # faster, avoids deep stacks
    def merge(self, left: Optional[ListNode], right: Optional[ListNode]) -> Optional[ListNode]:
        if left is None:
            return right
        if right is None:
            return left
        
        temp = ListNode()
        node = temp

        while left or right:
            if left is None:
                node.next = right
                break
            if right is None:
                node.next = left
                break
            if left.val < right.val:
                node.next = left
                left = left.next
            else:
                node.next = right
                right = right.next
            node = node.next
        return temp.next
    
    def sort(self, node: Optional[ListNode], n: int) -> Optional[ListNode]:
        if n == 0:
            return None
        if n == 1:
            node.next = None
            return node
        
        mid = n // 2
        midNode = node
        for i in range(mid):
            midNode = midNode.next
        
        left = self.sort(node, mid)
        right = self.sort(midNode, n - mid)
        
        return self.merge(left, right)
        
    def sortList(self, head: Optional[ListNode]) -> Optional[ListNode]:
        node = head
        n = 0
        while node:
            n += 1
            node = node.next
            
        return self.sort(head, n)

In [32]:
# test 1
head = ListNode(4)
head.next = ListNode(2)
head.next.next = ListNode(1)
head.next.next.next = ListNode(3)
print(head)
print(Solution().sortList(head))
print()


# test 2
head = ListNode(-1)
head.next = ListNode(5)
head.next.next = ListNode(3)
head.next.next.next = ListNode(4)
head.next.next.next.next = ListNode(0)
print(head)
print(Solution().sortList(head))
print()

# test 3
head = None
print(head)
print(Solution().sortList(head))
print()

# test 4
left = ListNode(1)
print(left)
print(Solution().merge(left, None))
print()

# test 5
left = ListNode(1)
right = ListNode(2)
print(left, right)
print(Solution().merge(left, right))
print()

# test 6
left = ListNode(2)
right = ListNode(1)
print(left, right)
print(Solution().merge(left, right))
print()

# test 7
left = ListNode(1)
left.next = ListNode(3)
left.next.next = ListNode(5)
left.next.next.next = ListNode(7)
right = ListNode(2)
right.next = ListNode(4)
right.next.next = ListNode(6)
right.next.next.next = ListNode(8)
print(left, right)
print(Solution().merge(left, right))
print()

4 -> 2 -> 1 -> 3 -> None
1 -> 2 -> 3 -> 4 -> None

-1 -> 5 -> 3 -> 4 -> 0 -> None
-1 -> 0 -> 3 -> 4 -> 5 -> None

None
None

1 -> None
1 -> None

1 -> None 2 -> None
1 -> 2 -> None

2 -> None 1 -> None
1 -> 2 -> None

1 -> 3 -> 5 -> 7 -> None 2 -> 4 -> 6 -> 8 -> None
1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7 -> 8 -> None

