## Faster Top-k With Numpy
- Using `np.argpartition` is generally faster than `np.argsort` for selecting the top-k values.
- This is because `np.argpartition` partially sorts the array and focuses only on the boundary between the top-k and the rest, rather than fully sorting all elements.
- After partitioning, you can sort just the selected top-k values.
- This approach is conceptually similar to PyTorch’s `torch.topk`.

In [None]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join('..')))

import numpy as np
import torch

from utils import measure_exec_time

In [3]:
def naive_topk(arr: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
    indices = np.argsort(-arr, axis=-1)[..., :k]
    return np.take_along_axis(arr, indices, axis=-1), indices


def faster_topk(arr: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
    partition_indices = np.argpartition(arr, -k, axis=-1)[..., -k:]
    partition_values = np.take_along_axis(arr, partition_indices, axis=-1)
    sort_order = np.argsort(partition_values, axis=-1)[..., ::-1]
    topk_values = np.take_along_axis(partition_values, sort_order, axis=-1)
    topk_indices = np.take_along_axis(partition_indices, sort_order, axis=-1)

    return topk_values, topk_indices

In [4]:
# Compare with naive argsort method.
arr = np.random.randn(100, 1000)
k = 100
print("Naive topk")
measure_exec_time(naive_topk, arr, k)
print("Faster topk")
measure_exec_time(faster_topk, arr, k)

Naive topk
Mean: 4.095 ms, Std: 0.126 ms
Faster topk
Mean: 1.249 ms, Std: 0.059 ms


In [16]:
# Compare with torch.topk result.
np_values, np_indices = faster_topk(arr, k)
t = torch.from_numpy(arr)
torch_values, torch_indices = torch.topk(t, k=k)
np_indices.flatten().tolist() == torch_indices.flatten().tolist()

True