Define functions

In [8]:
from dataclasses import dataclass

@dataclass
class Record:
    key: int
    id: int

In [9]:
def merge_sort(arr):
    if len(arr) <= 1:
        return arr

    mid = len(arr) // 2
    left = merge_sort(arr[:mid])
    right = merge_sort(arr[mid:])

    return merge(left, right)


def merge(left, right):
    result = []
    i = j = 0

    while i < len(left) and j < len(right):
        # Stability condition:
        # If keys are equal, take from LEFT first
        if left[i].key <= right[j].key:
            result.append(left[i])
            i += 1
        else:
            result.append(right[j])
            j += 1

    result.extend(left[i:])
    result.extend(right[j:])
    return result

In [10]:
def quick_sort(arr, low=0, high=None):
    if high is None:
        high = len(arr) - 1

    if low < high:
        pivot_index = partition(arr, low, high)
        quick_sort(arr, low, pivot_index - 1)
        quick_sort(arr, pivot_index + 1, high)


def partition(arr, low, high):
    pivot = arr[high].key
    i = low

    for j in range(low, high):
        if arr[j].key <= pivot:
            # Swapping may reorder equal keys
            arr[i], arr[j] = arr[j], arr[i]
            i += 1

    arr[i], arr[high] = arr[high], arr[i]
    return i

In [11]:
data = [
    Record(2, 'A'),
    Record(2, 'B'),
    Record(1, 'C'),
    Record(2, 'D')
]

Test Case 1

In [12]:
sorted_merge = merge_sort(data)
for r in sorted_merge:
    print(r)

Record(key=1, id='C')
Record(key=2, id='A')
Record(key=2, id='B')
Record(key=2, id='D')


In [13]:
data_qs = data.copy()
quick_sort(data_qs)
for r in data_qs:
    print(r)

Record(key=1, id='C')
Record(key=2, id='B')
Record(key=2, id='A')
Record(key=2, id='D')


Test Case 2:

In [14]:
import random

def generate_dataset(n=1_000_000, key_range=1000):
    return [Record(random.randint(0, key_range), i) for i in range(n)]

In [15]:
def is_sorted(arr):
    return all(arr[i].key <= arr[i+1].key for i in range(len(arr) - 1))

In [16]:
def is_stable(arr):
    last_seen = {}
    for r in arr:
        if r.key in last_seen and r.id < last_seen[r.key]:
            return False
        last_seen[r.key] = r.id
    return True

In [None]:
data = generate_dataset()

# Merge Sort test
merge_sorted = merge_sort(data)
print(is_sorted(merge_sorted))   # Expected: True
print(is_stable(merge_sorted))   # Expected: True

# Quick Sort test
quick_sorted = data.copy()
quick_sort(quick_sorted)
print(is_sorted(quick_sorted))   # Expected: True
print(is_stable(quick_sorted))   # Expected: False

True
True
True
False
