In [None]:
!pip install numba

In [None]:
import numpy as np
import time
from numba import cuda, njit, int32

@cuda.jit
def counting_sort_parallel(arr, exp, output, count):
    tid = cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x
    if tid < arr.size:
        # Contar a ocorrência de cada dígito
        index = (arr[tid] // exp) % 10
        cuda.atomic.add(count, index, 1)
        
    # Sincronização das threads
    cuda.syncthreads()

    # Cálculo da posição real dos elementos
    if tid == 0:
        for i in range(1, 10):
            count[i] += count[i - 1]

    # Sincronização das threads
    cuda.syncthreads()
    
    # Montar o array de saída baseado na contagem
    if tid < arr.size:
        index = (arr[tid] // exp) % 10
        position = count[index] - 1
        output[position] = arr[tid]
        cuda.atomic.sub(count, index, 1)

def radix_sort_cuda(arr):
    max_val = np.max(arr)
    exp = 1
    while max_val // exp > 0:
        count = cuda.to_device(np.zeros(10, dtype=int32))
        output = cuda.to_device(np.zeros(arr.size, dtype=int32))
        
        # Configurar threads e blocos para o tamanho do array
        threads_per_block = 256
        blocks_per_grid = (arr.size + threads_per_block - 1) // threads_per_block
        
        # Chamamos o counting sort paralelizado em CUDA
        counting_sort_parallel[blocks_per_grid, threads_per_block](arr, exp, output, count)
        
        # Copiar o resultado de volta ao array original
        arr = output.copy_to_host()
        exp *= 10
        
    return arr

# Função para testar o Radix Sort paralelo
def test_radix_sort_parallel():
    sizes = [100, 1000, 10000, 100000, 1000000, 10000000]
    results = []

    for size in sizes:
        arr = np.random.randint(0, size, size).astype(np.int32)
        arr_device = cuda.to_device(arr)

        start_time = time.time()
        sorted_arr = radix_sort_cuda(arr_device)
        end_time = time.time()
        
        results.append((size, end_time - start_time))

    print("Tamanho do Array | Tempo de Execução (s)")
    for size, duration in results:
        print(f"{size:<15} | {duration:.6f}")

test_radix_sort_parallel()
