In [14]:
import numpy as np
import time
from threading import Thread, Lock

In [15]:
SEED = 42
AMOUNT = 5_000_000
np.random.seed(SEED)
vector = np.random.randint(0, AMOUNT, size=AMOUNT)

In [16]:
class ParallelMergeSort:
    def __init__(self, max_depth=None):
        self.max_depth = max_depth
    
    def sort(self, arr):
        return self._parallel_merge_sort(arr, depth=0)
    
    def _parallel_merge_sort(self, arr, depth):
        if len(arr) <= 1:
            return arr
        
        if self.max_depth != None and depth >= self.max_depth:
            return sorted(arr)
        
        mid = len(arr) // 2
        left = arr[:mid]
        right = arr[mid:]

        left_result = None
        right_result = None
        def sort_left():
            nonlocal left_result
            left_result = self._parallel_merge_sort(left, depth + 1)
        def sort_right():
            nonlocal right_result
            right_result = self._parallel_merge_sort(right, depth + 1)
        
        left_thread = Thread(target=sort_left)
        right_thread = Thread(target=sort_right)

        left_thread.start()
        right_thread.start()
        left_thread.join()
        right_thread.join()

        return merge(left_result, right_result)
    

def merge_sort(arr):
    if len(arr) <= 1:
        return arr
    mid = len(arr) // 2
    left_half = merge_sort(arr[:mid])
    right_half = merge_sort(arr[mid:])
    return merge(left_half, right_half)

def merge(left, right):
    merged = []
    i = j = 0
    while i < len(left) and j < len(right):
        if left[i] < right[j]:
            merged.append(left[i])
            i += 1
        else:
            merged.append(right[j])
            j += 1
    merged.extend(left[i:])
    merged.extend(right[j:])
    return merged

In [17]:
merge_parallel = ParallelMergeSort(max_depth=4)

start_time = time.time()
usual_sort = merge_sort(vector)
end_time = time.time()
print(end_time - start_time)

start_time = time.time()
sorted_vector = merge_parallel.sort(vector)
end_time = time.time()
print(end_time - start_time)
print(np.array_equal(sorted_vector, usual_sort))

19.37580132484436
6.163708209991455
True
