In [1]:
%load_ext profila

> **Note:** loading the `profila` extension can impact Numba's performance. Make sure to disable it once you're done profiling!

In [2]:
# %load_ext snakeviz

In [3]:
import zarr
from scipy.spatial.distance import squareform
import numpy as np
import anjl

## Large

In [4]:
large = zarr.load("../data/large/dist.zarr.zip")
large_D = squareform(large)
shuffle = np.random.choice(large_D.shape[0], size=1500, replace=False)
large_D_shuffled = large_D.take(shuffle, axis=0).take(shuffle, axis=1)
large_D_shuffled

array([[ 0., 19., 19., ..., 22., 28., 12.],
       [19.,  0., 20., ..., 21., 25., 15.],
       [19., 20.,  0., ..., 21., 23., 17.],
       ...,
       [22., 21., 21., ...,  0., 22., 18.],
       [28., 25., 23., ..., 22.,  0., 22.],
       [12., 15., 17., ..., 18., 22.,  0.]], dtype=float32)

In [5]:
%%time
large_Z = anjl.canonical_nj(large_D_shuffled)



CPU times: user 2.56 s, sys: 33.5 ms, total: 2.6 s
Wall time: 2.59 s


In [6]:
%%time
large_Z = anjl.rapid_nj(large_D_shuffled, gc=100)

  u_max = _rapid_update(
  u_max = _rapid_update(


CPU times: user 11.5 s, sys: 39.3 ms, total: 11.5 s
Wall time: 11.5 s


In [7]:
# %%snakeviz
# anjl.canonical_nj(large_D_shuffled)
# anjl.rapid_nj(large_D_shuffled, gc=100)

In [8]:
# %%snakeviz
# anjl.rapid_nj(large_D_shuffled, gc=100)

In [11]:
%%profila
anjl.rapid_nj(large_D_shuffled, gc=100)

array([[5.7700000e+02, 8.4600000e+02, 0.0000000e+00, 0.0000000e+00,
        2.0000000e+00],
       [1.2670000e+03, 1.5000000e+03, 0.0000000e+00, 0.0000000e+00,
        3.0000000e+00],
       [6.2000000e+01, 1.5010000e+03, 2.6737968e-03, 9.9732620e-01,
        4.0000000e+00],
       ...,
       [2.9800000e+03, 2.9910000e+03, 2.2690046e-01, 1.1441791e-01,
        1.3700000e+02],
       [2.9950000e+03, 2.9960000e+03, 8.2687199e-02, 4.4224620e-02,
        1.4600000e+03],
       [2.9490000e+03, 2.9970000e+03, 3.0133229e-01, 3.0133229e-01,
        1.5000000e+03]], dtype=float32)

**Elapsed:** 23.558 seconds

**Total samples:** 1985 (14.7% non-Numba samples, 0.0% bad samples)

/home/aliman/github/alimanfoo/anjl/anjl/_rapid.py (lines 2 to 481):

```
  0.1% | from collections.abc import Mapping
       | import numpy as np
       | import numba
       | import time
       | 
       | 
       | def rapid_nj(
       |     D: np.ndarray,
       |     disallow_negative_distances: bool = True,
       |     progress: Callable | None = None,
       |     progress_options: Mapping = {},
       |     diagnostics=False,
       |     gc=100,
       |     # contiguate=False,
       | ) -> np.ndarray:
       |     """TODO"""
       | 
       |     # Make a copy of distance matrix D because we will overwrite it during the
       |     # algorithm.
       |     D = np.array(D, copy=True, order="C", dtype=np.float32)
       | 
       |     # Initialize the "divergence" array, containing sum of distances to other nodes.
       |     U = np.sum(D, axis=1, dtype=np.float32)
       |     u_max = U.max()
       | 
       |     # Set diagonal to inf to avoid self comparison sorting first.
       |     for i in range(D.shape[0]):
       |         D[i, i] = np.inf
       | 
       |     # Obtain node identifiers to sort the distance matrix row-wise.
       |     nodes_sorted = np.argsort(D, axis=1)
       |     assert D.shape == nodes_sorted.shape
       | 
       |     # Make another copy of the distance matrix sorted.
       |     D_sorted = np.take_along_axis(D, nodes_sorted, axis=1)
       | 
       |     # Number of original observations.
       |     n_original = D.shape[0]
       | 
       |     # Expected number of new (internal) nodes that will be created.
       |     n_internal = n_original - 1
       | 
       |     # Total number of nodes in the tree, including internal nodes.
       |     n_nodes = n_original + n_internal
       | 
       |     # Map row indices to node IDs.
       |     index_to_id = np.arange(n_original)
       | 
       |     # Map node IDs to row indices.
       |     id_to_index = np.full(shape=n_nodes, fill_value=-1)
       |     id_to_index[:n_original] = np.arange(n_original)
       | 
       |     # Initialise output. This is similar to the output that scipy hierarchical
       |     # clustering functions return, where each row contains data for one internal node
       |     # in the tree, except that each row here contains:
       |     # - left child node ID
       |     # - right child node ID
       |     # - distance to left child node
       |     # - distance to right child node
       |     # - total number of leaves
       |     Z = np.zeros(shape=(n_internal, 5), dtype=np.float32)
       | 
       |     # Keep track of which nodes have been clustered and are now "obsolete". N.B., this
       |     # is different from canonical implementation because we index here by node ID.
       |     clustered = np.zeros(shape=n_nodes - 1, dtype=bool)
       | 
       |     # Convenience to also keep track of which rows are no longer in use.
       |     obsolete = np.zeros(shape=n_original, dtype=bool)
       | 
       |     # Support wrapping the iterator in a progress bar.
       |     iterator = range(n_internal)
       |     if progress:
       |         iterator = progress(iterator, **progress_options)
       | 
       |     # Record iteration timings.
       |     timings = []
       |     searches = []
       |     visits = []
       | 
       |     # Begin iterating.
       |     for iteration in iterator:
       |         # print("")
       |         # print("iteration", iteration)
       |         # print("D\n", D)
       |         # print("D_sorted\n", D_sorted)
       |         # print("nodes_sorted\n", nodes_sorted)
       |         # print("U", U)
       |         # print("index_to_id", index_to_id)
       |         # print("id_to_index", id_to_index)
       | 
       |         # Number of nodes remaining in this iteration.
       |         n_remaining = n_original - iteration
       | 
       |         # Garbage collection.
       |         if gc and iteration > 0 and iteration % gc == 0:
       |             nodes_sorted, D_sorted = _rapid_gc(
       |                 nodes_sorted=nodes_sorted,
       |                 D_sorted=D_sorted,
       |                 index_to_id=index_to_id,
       |                 clustered=clustered,
       |                 obsolete=obsolete,
       |                 n_remaining=n_remaining,
       |                 # contiguate=contiguate,
       |             )
       | 
       |         before = time.time()
       | 
       |         # Perform one iteration of the neighbour-joining algorithm.
       |         u_max, searched, visited = _rapid_iteration(
       |             iteration=iteration,
       |             D=D,
       |             D_sorted=D_sorted,
       |             U=U,
       |             nodes_sorted=nodes_sorted,
       |             index_to_id=index_to_id,
       |             id_to_index=id_to_index,
       |             clustered=clustered,
       |             obsolete=obsolete,
       |             Z=Z,
       |             n_original=n_original,
       |             disallow_negative_distances=disallow_negative_distances,
       |             u_max=u_max,
       |         )
       | 
       |         duration = time.time() - before
       |         timings.append(duration)
       |         searches.append(searched)
       |         visits.append(visited)
       | 
       |     if diagnostics:
       |         return Z, np.array(timings), np.array(searches), np.array(visits)
       | 
       |     return Z
       | 
       | 
       | @numba.njit
       | def _rapid_gc(
       |     nodes_sorted: np.ndarray,
       |     D_sorted: np.ndarray,
       |     index_to_id: np.ndarray,
       |     clustered: np.ndarray,
       |     obsolete: np.ndarray,
       |     n_remaining: int,
       |     # contiguate: bool,
       | ):
       |     for i in range(nodes_sorted.shape[0]):
       |         if obsolete[i]:
       |             continue
  0.4% |         id_i = index_to_id[i]
       |         j_new = 0
  1.0% |         for j in range(nodes_sorted.shape[1]):
       |             id_j = nodes_sorted[i, j]
  2.2% |             if clustered[id_j]:
       |                 continue
       |             if id_i == id_j:
       |                 continue
  1.2% |             nodes_sorted[i, j_new] = id_j
  0.8% |             D_sorted[i, j_new] = D_sorted[i, j]
  0.3% |             j_new += 1
       |     nodes_sorted = nodes_sorted[:, :n_remaining]
       |     D_sorted = D_sorted[:, :n_remaining]
       |     # if contiguate:
       |     #     nodes_sorted = nodes_sorted.copy()
       |     #     D_sorted = D_sorted.copy()
       |     return nodes_sorted, D_sorted
       | 
       | 
       | @numba.njit
       | def _rapid_iteration(
       |     iteration: int,
       |     D: np.ndarray,
       |     D_sorted: np.ndarray,
       |     U: np.ndarray,
       |     nodes_sorted: np.ndarray,
       |     index_to_id: np.ndarray,
       |     id_to_index: np.ndarray,
       |     clustered: np.ndarray,
       |     obsolete: np.ndarray,
       |     Z: np.ndarray,
       |     n_original: int,
       |     disallow_negative_distances: bool,
       |     u_max: np.float32,
       | ) -> np.float32:
       |     # This will be the identifier for the new node to be created in this iteration.
       |     node = iteration + n_original
       | 
       |     # Number of nodes remaining in this iteration.
       |     n_remaining = n_original - iteration
       | 
       |     if n_remaining > 2:
       |         # Search for the closest pair of nodes to join.
       |         i_min, j_min, searched, visited = _rapid_search(
       |             D_sorted=D_sorted,
       |             U=U,
       |             nodes_sorted=nodes_sorted,
       |             clustered=clustered,
       |             obsolete=obsolete,
       |             index_to_id=index_to_id,
       |             id_to_index=id_to_index,
       |             n_remaining=n_remaining,
       |             u_max=u_max,
       |         )
       |         assert i_min >= 0
       |         assert j_min >= 0
       |         assert i_min != j_min
       | 
       |         # Get IDs for the nodes to be joined.
       |         child_i = index_to_id[i_min]
       |         child_j = index_to_id[j_min]
       | 
       |         # Calculate distances to the new internal node.
       |         d_ij = D[i_min, j_min]
       |         d_i = 0.5 * (d_ij + (1 / (n_remaining - 2)) * (U[i_min] - U[j_min]))
       |         d_j = 0.5 * (d_ij + (1 / (n_remaining - 2)) * (U[j_min] - U[i_min]))
       | 
       |     else:
       |         # Termination. Join the two remaining nodes, placing the final node at the
       |         # midpoint.
       |         child_i, child_j = np.nonzero(~clustered)[0]
       |         i_min = id_to_index[child_i]
       |         j_min = id_to_index[child_j]
       |         d_ij = D[i_min, j_min]
       |         d_i = d_ij / 2
       |         d_j = d_ij / 2
       |         searched = 0
       |         visited = 0
       | 
       |     # Sanity checks.
       |     assert child_i >= 0
       |     assert child_j >= 0
       |     assert child_i != child_j
       | 
       |     # print("i_min", i_min, "j_min", j_min, "child_i", child_i, "child_j", child_j)
       | 
       |     # Handle possibility of negative distances.
       |     if disallow_negative_distances:
       |         d_i = max(0, d_i)
       |         d_j = max(0, d_j)
       | 
       |     # Stabilise ordering for easier comparisons.
       |     if child_i > child_j:
       |         child_i, child_j = child_j, child_i
       |         i_min, j_min = j_min, i_min
       |         d_i, d_j = d_j, d_i
       | 
       |     # Get number of leaves.
       |     if child_i < n_original:
       |         leaves_i = 1
       |     else:
       |         leaves_i = Z[child_i - n_original, 4]
       |     if child_j < n_original:
       |         leaves_j = 1
       |     else:
       |         leaves_j = Z[child_j - n_original, 4]
       | 
       |     # Store new node data.
       |     Z[iteration, 0] = child_i
       |     Z[iteration, 1] = child_j
       |     Z[iteration, 2] = d_i
       |     Z[iteration, 3] = d_j
       |     Z[iteration, 4] = leaves_i + leaves_j
       | 
       |     if n_remaining > 2:
       |         # Update data structures.
  0.1% |         u_max = _rapid_update(
       |             D=D,
       |             D_sorted=D_sorted,
       |             U=U,
       |             nodes_sorted=nodes_sorted,
       |             index_to_id=index_to_id,
       |             id_to_index=id_to_index,
       |             clustered=clustered,
       |             obsolete=obsolete,
       |             node=node,
       |             child_i=child_i,
       |             child_j=child_j,
       |             i_min=i_min,
       |             j_min=j_min,
       |             d_ij=d_ij,
       |         )
       | 
       |     return u_max, searched, visited
       | 
       | 
       | @numba.njit
       | def _rapid_search(
       |     D_sorted: np.ndarray,
       |     U: np.ndarray,
       |     nodes_sorted: np.ndarray,
       |     clustered: np.ndarray,
       |     obsolete: np.ndarray,
       |     index_to_id: np.ndarray,
       |     id_to_index: np.ndarray,
       |     n_remaining: int,
       |     u_max: np.float32,
       | ) -> tuple[int, int, int]:
       |     # Initialize working variables.
       |     q_min = numba.float32(np.inf)
       |     threshold = numba.float32(np.inf)
       |     i_min = -1
       |     j_min = -1
       |     searched = 0
       |     visited = 0
       |     coefficient = numba.float32(n_remaining - 2)
       |     m = nodes_sorted.shape[0]
       |     n = nodes_sorted.shape[1]
  0.3% |     assert m == D_sorted.shape[0]
       |     assert n == D_sorted.shape[1]
       | 
       |     # indices_available = np.nonzero(~obsolete)[0]
       |     # # np.random.shuffle(indices_available)
       |     # for i in indices_available:
       | 
       |     # Search all values up to threshold.
  0.2% |     for i in range(m):
       |         # Skip if row is no longer in use.
  0.3% |         if obsolete[i]:
       |             continue
       | 
       |         # Obtain identifier for this row.
       |         node_i = index_to_id[i]
       | 
       |         # Obtain divergence for node corresponding to this row.
  0.4% |         u_i = U[i]
       | 
       |         # Search the row up to threshold.
  8.3% |         for s in range(n):
  2.2% |             visited += 1
       | 
       |             # Obtain node identifier for the current item.
  1.1% |             node_j = nodes_sorted[i, s]
       | 
       |             # Break at end of nodes.
  5.4% |             if node_j < 0:
       |                 break
       | 
       |             # Skip if this node is already clustered.
  5.8% |             if clustered[node_j]:
       |                 continue
       | 
       |             # TODO needed?
       |             if node_i == node_j:
       |                 continue
       | 
       |             # Access distance.
       |             d = D_sorted[i, s]
       | 
       |             # Partially calculate q.
  5.2% |             q_partial = coefficient * d - u_i
       | 
       |             # Limit search. Because the row is sorted, if we are already above this
       |             # threshold then we know there is no need to search remaining nodes in the
       |             # row.
 10.7% |             if q_partial > threshold:
       |                 break
       | 
       |             # Fully calculate q.
  2.2% |             j = id_to_index[node_j]
  5.2% |             u_j = U[j]
 11.7% |             q = q_partial - u_j
  0.4% |             searched += 1
       | 
  5.2% |             if q < q_min:
       |                 q_min = q
  0.3% |                 threshold = q_min + u_max
       |                 i_min = i
       |                 j_min = j
       | 
       |     return i_min, j_min, searched, visited
       | 
       | 
       | @numba.njit
       | def _rapid_update(
       |     D: np.ndarray,
       |     D_sorted: np.ndarray,
       |     U: np.ndarray,
       |     nodes_sorted: np.ndarray,
       |     index_to_id: np.ndarray,
       |     id_to_index: np.ndarray,
       |     clustered: np.ndarray,
       |     obsolete: np.ndarray,
       |     node: int,
       |     child_i: int,
       |     child_j: int,
       |     i_min: int,
       |     j_min: int,
       |     d_ij: float,
       | ) -> np.float32:
       |     # Update data structures. Here we obsolete the row corresponding to the node at
       |     # j_min, and we reuse the row at i_min for the new node.
       |     clustered[child_i] = True
       |     clustered[child_j] = True
       |     # index_to_id[j_min] = -1
       |     # id_to_index[child_i] = -1
       |     # id_to_index[child_j] = -1
       | 
       |     # Assign the new node to row at i_min.
       |     index_to_id[i_min] = node
       |     id_to_index[node] = i_min
       | 
       |     # Obsolete the data corresponding to the node at j_min.
       |     obsolete[j_min] = True
       | 
       |     # Subtract out the distances for the nodes that have just been joined.
       |     for i in range(U.shape[0]):
  0.1% |         if i != i_min and i != j_min and not obsolete[i]:
  1.0% |             U[i] -= D[i, i_min]
  0.4% |             U[i] -= D[i, j_min]
       | 
       |     # Initialize divergence for the new node.
       |     u_new = np.float32(0)
       | 
       |     # Find new max.
       |     u_max = np.float32(0)
       | 
       |     # Update distances and divergence.
  0.1% |     for k in range(D.shape[0]):
  0.1% |         if obsolete[k]:
       |             continue
       | 
       |         if k == i_min or k == j_min:
       |             continue
       | 
       |         # Distance from k to the new node.
       |         d_ik = D[k, i_min]
       |         d_jk = D[k, j_min]
  0.4% |         d_k = 0.5 * (d_ik + d_jk - d_ij)
  0.1% |         D[i_min, k] = d_k
       |         D[k, i_min] = d_k
       |         u_k = U[k] + d_k
  0.1% |         U[k] = u_k
       | 
       |         # Record new max.
       |         if u_k > u_max:
       |             u_max = u_k
       | 
       |         # Accumulate divergence for the new node.
       |         u_new += d_k
       | 
       |         # Distance from k to the obsolete node.
       |         D[k, j_min] = np.inf
       | 
       |     # Store divergence for the new node.
       |     U[i_min] = u_new
       | 
       |     # Record new max.
       |     if u_new > u_max:
       |         u_max = u_new
       | 
       |     # Finish up obsoleting data for j_min.
       |     # U[j_min] = np.nan
       |     # D[j_min, :] = np.inf
       |     # D[:, j_min] = np.inf
       |     # D_sorted[j_min] = np.inf
       |     # nodes_sorted[j_min] = -1
       | 
       |     # First cut down to just the active nodes.
  0.5% |     active = ~obsolete
  0.5% |     distances_new = D[i_min, active]
  0.7% |     nodes_active = index_to_id[active]
       | 
       |     # Now sort the new distances.
       |     indices_sorted = np.argsort(distances_new)
  0.2% |     nodes_sorted_new = nodes_active[indices_sorted]
  0.2% |     distances_sorted_new = distances_new[indices_sorted]
       | 
       |     # # Update the sorted distances and indices for the new node.
       |     # distances_new = D[i_min]
       |     # sorted_indices_new = np.argsort(distances_new)
       |     # sorted_ids_new = np.take(index_to_id, sorted_indices_new)
       | 
       |     # # Remove any clustered nodes.
       |     # clustered_new = np.take(clustered, sorted_ids_new)
       |     # sorted_ids_new = sorted_ids_new[~clustered_new]
       | 
       |     p = nodes_sorted_new.shape[0]
       |     assert p == distances_new.shape[0]
       |     nodes_sorted[i_min, :p] = nodes_sorted_new
  0.1% |     nodes_sorted[i_min, p:] = -1
  0.1% |     D_sorted[i_min, :p] = distances_sorted_new
```

/home/aliman/.cache/pypoetry/virtualenvs/anjl-Dyqcv450-py3.10/lib/python3.10/site-packages/numba/misc/quicksort.py (lines 43 to 197):

```
  0.1% |                 return np.arange(A.size)
       | 
       |         @wrap
  0.4% |         def GET(A, idx_or_val):
  1.9% |             return A[idx_or_val]
       | 
       |     else:
       |         @wrap
       |         def make_res(A):
       |             return A
       | 
       |         @wrap
       |         def GET(A, idx_or_val):
       |             return idx_or_val
       | 
       |     def default_lt(a, b):
       |         """
       |         Trivial comparison function between two keys.
       |         """
       |         return a < b
       | 
       |     LT = wrap(lt if lt is not None else default_lt)
       | 
       |     @wrap
       |     def insertion_sort(A, R, low, high):
       |         """
       |         Insertion sort A[low:high + 1]. Note the inclusive bounds.
  0.1% |         """
       |         assert low >= 0
       |         if high <= low:
       |             return
       | 
       |         for i in range(low + 1, high + 1):
       |             k = R[i]
  0.2% |             v = GET(A, k)
       |             # Insert v into A[low:i]
       |             j = i
  0.4% |             while j > low and LT(v, GET(A, R[j - 1])):
       |                 # Make place for moving A[i] downwards
  0.1% |                 R[j] = R[j - 1]
       |                 j -= 1
  0.1% |             R[j] = k
       | 
       |     @wrap
       |     def partition(A, R, low, high):
       |         """
       |         Partition A[low:high + 1] around a chosen pivot.  The pivot's index
       |         is returned.
  0.2% |         """
       |         assert low >= 0
       |         assert high > low
       | 
       |         mid = (low + high) >> 1
       |         # NOTE: the pattern of swaps below for the pivot choice and the
       |         # partitioning gives good results (i.e. regular O(n log n))
       |         # on sorted, reverse-sorted, and uniform arrays.  Subtle changes
       |         # risk breaking this property.
       | 
       |         # median of three {low, middle, high}
  0.1% |         if LT(GET(A, R[mid]), GET(A, R[low])):
       |             R[low], R[mid] = R[mid], R[low]
  0.1% |         if LT(GET(A, R[high]), GET(A, R[mid])):
       |             R[high], R[mid] = R[mid], R[high]
       |         if LT(GET(A, R[mid]), GET(A, R[low])):
       |             R[low], R[mid] = R[mid], R[low]
  0.1% |         pivot = GET(A, R[mid])
       | 
       |         # Temporarily stash the pivot at the end
       |         R[high], R[mid] = R[mid], R[high]
       |         i = low
       |         j = high - 1
       |         while True:
  1.4% |             while i < high and LT(GET(A, R[i]), pivot):
       |                 i += 1
  1.5% |             while j >= low and LT(pivot, GET(A, R[j])):
       |                 j -= 1
       |             if i >= j:
       |                 break
  0.3% |             R[i], R[j] = R[j], R[i]
  0.1% |             i += 1
       |             j -= 1
       |         # Put the pivot back in its final place (all items before `i`
       |         # are smaller than the pivot, all items at/after `i` are larger)
  0.1% |         R[i], R[high] = R[high], R[i]
       |         return i
       | 
       |     @wrap
       |     def partition3(A, low, high):
       |         """
       |         Three-way partition [low, high) around a chosen pivot.
       |         A tuple (lt, gt) is returned such that:
       |             - all elements in [low, lt) are < pivot
       |             - all elements in [lt, gt] are == pivot
       |             - all elements in (gt, high] are > pivot
       |         """
       |         mid = (low + high) >> 1
       |         # median of three {low, middle, high}
       |         if LT(A[mid], A[low]):
       |             A[low], A[mid] = A[mid], A[low]
       |         if LT(A[high], A[mid]):
       |             A[high], A[mid] = A[mid], A[high]
       |         if LT(A[mid], A[low]):
       |             A[low], A[mid] = A[mid], A[low]
       |         pivot = A[mid]
       | 
       |         A[low], A[mid] = A[mid], A[low]
       |         lt = low
       |         gt = high
       |         i = low + 1
       |         while i <= gt:
       |             if LT(A[i], pivot):
       |                 A[lt], A[i] = A[i], A[lt]
       |                 lt += 1
       |                 i += 1
       |             elif LT(pivot, A[i]):
       |                 A[gt], A[i] = A[i], A[gt]
       |                 gt -= 1
       |             else:
       |                 i += 1
       |         return lt, gt
       | 
       |     @wrap
       |     def run_quicksort1(A):
       |         R = make_res(A)
       | 
       |         if len(A) < 2:
  0.1% |             return R
       | 
       |         stack = [Partition(zero, zero)] * MAX_STACK
  0.1% |         stack[0] = Partition(zero, len(A) - 1)
       |         n = 1
       | 
       |         while n > 0:
       |             n -= 1
       |             low, high = stack[n]
       |             # Partition until it becomes more efficient to do an insertion sort
       |             while high - low >= SMALL_QUICKSORT:
       |                 assert n < MAX_STACK
  0.1% |                 i = partition(A, R, low, high)
       |                 # Push largest partition on the stack
       |                 if high - i > i - low:
       |                     # Right is larger
       |                     if high > i:
  0.1% |                         stack[n] = Partition(i + 1, high)
       |                         n += 1
       |                     high = i - 1
       |                 else:
       |                     if i > low:
       |                         stack[n] = Partition(low, i - 1)
       |                         n += 1
       |                     low = i + 1
       | 
  0.1% |             insertion_sort(A, R, low, high)
       | 
  0.1% |         return R
```

/home/aliman/.cache/pypoetry/virtualenvs/anjl-Dyqcv450-py3.10/lib/python3.10/site-packages/numba/np/numpy_support.py (lines 739 to 739):

```
  2.8% |     return a < b or (np.isnan(b) and not np.isnan(a))
```

/home/aliman/.cache/pypoetry/virtualenvs/anjl-Dyqcv450-py3.10/lib/python3.10/site-packages/numba/np/arrayobj.py (lines 4260 to 4822):

```
  0.2% |         return intrin_alloc(allocsize, align)
       |     return impl
       | 
       | 
       | def _call_allocator(arrtype, size, align):
       |     """Trampoline to call the intrinsic used for allocation
       |     """
       |     return arrtype._allocate(size, align)
       | 
       | 
       | @intrinsic
       | def intrin_alloc(typingctx, allocsize, align):
       |     """Intrinsic to call into the allocator for Array
       |     """
       |     def codegen(context, builder, signature, args):
       |         [allocsize, align] = args
       |         meminfo = context.nrt.meminfo_alloc_aligned(builder, allocsize, align)
       |         return meminfo
       | 
       |     mip = types.MemInfoPointer(types.voidptr)    # return untyped pointer
       |     sig = signature(mip, allocsize, align)
       |     return sig, codegen
       | 
       | 
       | def _parse_shape(context, builder, ty, val):
       |     """
       |     Parse the shape argument to an array constructor.
       |     """
       |     def safecast_intp(context, builder, src_t, src):
       |         """Cast src to intp only if value can be maintained"""
       |         intp_t = context.get_value_type(types.intp)
       |         intp_width = intp_t.width
       |         intp_ir = ir.IntType(intp_width)
       |         maxval = Constant(intp_ir, ((1 << intp_width - 1) - 1))
       |         if src_t.width < intp_width:
       |             res = builder.sext(src, intp_ir)
       |         elif src_t.width >= intp_width:
       |             is_larger = builder.icmp_signed(">", src, maxval)
       |             with builder.if_then(is_larger, likely=False):
       |                 context.call_conv.return_user_exc(
       |                     builder, ValueError,
       |                     ("Cannot safely convert value to intp",)
       |                 )
       |             if src_t.width > intp_width:
       |                 res = builder.trunc(src, intp_ir)
       |             else:
       |                 res = src
       |         return res
       | 
       |     if isinstance(ty, types.Integer):
       |         ndim = 1
       |         passed_shapes = [context.cast(builder, val, ty, types.intp)]
       |     else:
       |         assert isinstance(ty, types.BaseTuple)
       |         ndim = ty.count
       |         passed_shapes = cgutils.unpack_tuple(builder, val, count=ndim)
       | 
       |     shapes = []
       |     for s in passed_shapes:
       |         shapes.append(safecast_intp(context, builder, s.type, s))
       | 
       |     zero = context.get_constant_generic(builder, types.intp, 0)
       |     for dim in range(ndim):
       |         is_neg = builder.icmp_signed('<', shapes[dim], zero)
       |         with cgutils.if_unlikely(builder, is_neg):
       |             context.call_conv.return_user_exc(
       |                 builder, ValueError, ("negative dimensions not allowed",)
       |             )
       | 
       |     return shapes
       | 
       | 
       | def _parse_empty_args(context, builder, sig, args):
       |     """
       |     Parse the arguments of a np.empty(), np.zeros() or np.ones() call.
       |     """
       |     arrshapetype = sig.args[0]
       |     arrshape = args[0]
       |     arrtype = sig.return_type
       |     return arrtype, _parse_shape(context, builder, arrshapetype, arrshape)
       | 
       | 
       | def _parse_empty_like_args(context, builder, sig, args):
       |     """
       |     Parse the arguments of a np.empty_like(), np.zeros_like() or
       |     np.ones_like() call.
       |     """
       |     arytype = sig.args[0]
       |     if isinstance(arytype, types.Array):
       |         ary = make_array(arytype)(context, builder, value=args[0])
       |         shapes = cgutils.unpack_tuple(builder, ary.shape, count=arytype.ndim)
       |         return sig.return_type, shapes
       |     else:
       |         return sig.return_type, ()
       | 
       | 
       | def _check_const_str_dtype(fname, dtype):
       |     if isinstance(dtype, types.UnicodeType):
       |         msg = f"If np.{fname} dtype is a string it must be a string constant."
       |         raise errors.TypingError(msg)
       | 
       | 
       | @intrinsic
       | def numpy_empty_nd(tyctx, ty_shape, ty_dtype, ty_retty_ref):
       |     ty_retty = ty_retty_ref.instance_type
       |     sig = ty_retty(ty_shape, ty_dtype, ty_retty_ref)
       | 
       |     def codegen(cgctx, builder, sig, llargs):
       |         arrtype, shapes = _parse_empty_args(cgctx, builder, sig, llargs)
       |         ary = _empty_nd_impl(cgctx, builder, arrtype, shapes)
       |         return ary._getvalue()
       |     return sig, codegen
       | 
       | 
       | @overload(np.empty)
       | def ol_np_empty(shape, dtype=float):
       |     _check_const_str_dtype("empty", dtype)
       |     if (dtype is float or
       |         (isinstance(dtype, types.Function) and dtype.typing_key is float) or
       |             is_nonelike(dtype)): #default
       |         nb_dtype = types.double
       |     else:
       |         nb_dtype = ty_parse_dtype(dtype)
       | 
       |     ndim = ty_parse_shape(shape)
       |     if nb_dtype is not None and ndim is not None:
       |         retty = types.Array(dtype=nb_dtype, ndim=ndim, layout='C')
       | 
  0.1% |         def impl(shape, dtype=float):
       |             return numpy_empty_nd(shape, dtype, retty)
       |         return impl
       |     else:
       |         msg = f"Cannot parse input types to function np.empty({shape}, {dtype})"
       |         raise errors.TypingError(msg)
       | 
       | 
       | @intrinsic
       | def numpy_empty_like_nd(tyctx, ty_prototype, ty_dtype, ty_retty_ref):
       |     ty_retty = ty_retty_ref.instance_type
       |     sig = ty_retty(ty_prototype, ty_dtype, ty_retty_ref)
       | 
       |     def codegen(cgctx, builder, sig, llargs):
       |         arrtype, shapes = _parse_empty_like_args(cgctx, builder, sig, llargs)
       |         ary = _empty_nd_impl(cgctx, builder, arrtype, shapes)
       |         return ary._getvalue()
       |     return sig, codegen
       | 
       | 
       | @overload(np.empty_like)
       | def ol_np_empty_like(arr, dtype=None):
       |     _check_const_str_dtype("empty_like", dtype)
       |     if not is_nonelike(dtype):
       |         nb_dtype = ty_parse_dtype(dtype)
       |     elif isinstance(arr, types.Array):
       |         nb_dtype = arr.dtype
       |     else:
       |         nb_dtype = arr
       |     if nb_dtype is not None:
       |         if isinstance(arr, types.Array):
       |             layout = arr.layout if arr.layout != 'A' else 'C'
       |             retty = arr.copy(dtype=nb_dtype, layout=layout, readonly=False)
       |         else:
       |             retty = types.Array(nb_dtype, 0, 'C')
       |     else:
       |         msg = ("Cannot parse input types to function "
       |                f"np.empty_like({arr}, {dtype})")
       |         raise errors.TypingError(msg)
       | 
       |     def impl(arr, dtype=None):
       |         return numpy_empty_like_nd(arr, dtype, retty)
       |     return impl
       | 
       | 
       | @intrinsic
       | def _zero_fill_array_method(tyctx, self):
       |     sig = types.none(self)
       | 
       |     def codegen(cgctx, builder, sig, llargs):
       |         ary = make_array(sig.args[0])(cgctx, builder, llargs[0])
       |         cgutils.memset(builder, ary.data, builder.mul(ary.itemsize, ary.nitems),
       |                        0)
       |     return sig, codegen
       | 
       | 
       | @overload_method(types.Array, '_zero_fill')
       | def ol_array_zero_fill(self):
       |     """Adds a `._zero_fill` method to zero fill an array using memset."""
       |     def impl(self):
       |         _zero_fill_array_method(self)
       |     return impl
       | 
       | 
       | @overload(np.zeros)
       | def ol_np_zeros(shape, dtype=float):
       |     _check_const_str_dtype("zeros", dtype)
       | 
       |     def impl(shape, dtype=float):
       |         arr = np.empty(shape, dtype=dtype)
       |         arr._zero_fill()
       |         return arr
       |     return impl
       | 
       | 
       | @overload(np.zeros_like)
       | def ol_np_zeros_like(a, dtype=None):
       |     _check_const_str_dtype("zeros_like", dtype)
       | 
       |     # NumPy uses 'a' as the arg name for the array-like
       |     def impl(a, dtype=None):
       |         arr = np.empty_like(a, dtype=dtype)
       |         arr._zero_fill()
       |         return arr
       |     return impl
       | 
       | 
       | @overload(np.ones_like)
       | def ol_np_ones_like(a, dtype=None):
       |     _check_const_str_dtype("ones_like", dtype)
       | 
       |     # NumPy uses 'a' as the arg name for the array-like
       |     def impl(a, dtype=None):
       |         arr = np.empty_like(a, dtype=dtype)
       |         arr_flat = arr.flat
       |         for idx in range(len(arr_flat)):
       |             arr_flat[idx] = 1
       |         return arr
       |     return impl
       | 
       | 
       | @overload(np.full)
       | def impl_np_full(shape, fill_value, dtype=None):
       |     _check_const_str_dtype("full", dtype)
       |     if not is_nonelike(dtype):
       |         nb_dtype = ty_parse_dtype(dtype)
       |     else:
       |         nb_dtype = fill_value
       | 
       |     def full(shape, fill_value, dtype=None):
       |         arr = np.empty(shape, nb_dtype)
       |         arr_flat = arr.flat
       |         for idx in range(len(arr_flat)):
       |             arr_flat[idx] = fill_value
       |         return arr
       |     return full
       | 
       | 
       | @overload(np.full_like)
       | def impl_np_full_like(a, fill_value, dtype=None):
       |     _check_const_str_dtype("full_like", dtype)
       | 
       |     def full_like(a, fill_value, dtype=None):
       |         arr = np.empty_like(a, dtype)
       |         arr_flat = arr.flat
       |         for idx in range(len(arr_flat)):
       |             arr_flat[idx] = fill_value
       |         return arr
       | 
       |     return full_like
       | 
       | 
       | @overload(np.ones)
       | def ol_np_ones(shape, dtype=None):
       |     # for some reason the NumPy default for dtype is None in the source but
       |     # ends up as np.float64 by definition.
       |     _check_const_str_dtype("ones", dtype)
       | 
       |     def impl(shape, dtype=None):
       |         arr = np.empty(shape, dtype=dtype)
       |         arr_flat = arr.flat
       |         for idx in range(len(arr_flat)):
       |             arr_flat[idx] = 1
       |         return arr
       |     return impl
       | 
       | 
       | @overload(np.identity)
       | def impl_np_identity(n, dtype=None):
       |     _check_const_str_dtype("identity", dtype)
       |     if not is_nonelike(dtype):
       |         nb_dtype = ty_parse_dtype(dtype)
       |     else:
       |         nb_dtype = types.double
       | 
       |     def identity(n, dtype=None):
       |         arr = np.zeros((n, n), nb_dtype)
       |         for i in range(n):
       |             arr[i, i] = 1
       |         return arr
       |     return identity
       | 
       | 
       | def _eye_none_handler(N, M):
       |     pass
       | 
       | 
       | @extending.overload(_eye_none_handler)
       | def _eye_none_handler_impl(N, M):
       |     if isinstance(M, types.NoneType):
       |         def impl(N, M):
       |             return N
       |     else:
       |         def impl(N, M):
       |             return M
       |     return impl
       | 
       | 
       | @extending.overload(np.eye)
       | def numpy_eye(N, M=None, k=0, dtype=float):
       | 
       |     if dtype is None or isinstance(dtype, types.NoneType):
       |         dt = np.dtype(float)
       |     elif isinstance(dtype, (types.DTypeSpec, types.Number)):
       |         # dtype or instance of dtype
       |         dt = as_dtype(getattr(dtype, 'dtype', dtype))
       |     else:
       |         dt = np.dtype(dtype)
       | 
       |     def impl(N, M=None, k=0, dtype=float):
       |         _M = _eye_none_handler(N, M)
       |         arr = np.zeros((N, _M), dt)
       |         if k >= 0:
       |             d = min(N, _M - k)
       |             for i in range(d):
       |                 arr[i, i + k] = 1
       |         else:
       |             d = min(N + k, _M)
       |             for i in range(d):
       |                 arr[i - k, i] = 1
       |         return arr
       |     return impl
       | 
       | 
       | @overload(np.diag)
       | def impl_np_diag(v, k=0):
       |     if not type_can_asarray(v):
       |         raise errors.TypingError('The argument "v" must be array-like')
       | 
       |     if isinstance(v, types.Array):
       |         if v.ndim not in (1, 2):
       |             raise errors.NumbaTypeError("Input must be 1- or 2-d.")
       | 
       |         def diag_impl(v, k=0):
       |             if v.ndim == 1:
       |                 s = v.shape
       |                 n = s[0] + abs(k)
       |                 ret = np.zeros((n, n), v.dtype)
       |                 if k >= 0:
       |                     for i in range(n - k):
       |                         ret[i, k + i] = v[i]
       |                 else:
       |                     for i in range(n + k):
       |                         ret[i - k, i] = v[i]
       |                 return ret
       |             else:  # 2-d
       |                 rows, cols = v.shape
       |                 if k < 0:
       |                     rows = rows + k
       |                 if k > 0:
       |                     cols = cols - k
       |                 n = max(min(rows, cols), 0)
       |                 ret = np.empty(n, v.dtype)
       |                 if k >= 0:
       |                     for i in range(n):
       |                         ret[i] = v[i, k + i]
       |                 else:
       |                     for i in range(n):
       |                         ret[i] = v[i - k, i]
       |                 return ret
       |         return diag_impl
       | 
       | 
       | @overload(np.indices)
       | def numpy_indices(dimensions):
       |     if not isinstance(dimensions, types.UniTuple):
       |         msg = 'The argument "dimensions" must be a tuple of integers'
       |         raise errors.TypingError(msg)
       | 
       |     if not isinstance(dimensions.dtype, types.Integer):
       |         msg = 'The argument "dimensions" must be a tuple of integers'
       |         raise errors.TypingError(msg)
       | 
       |     N = len(dimensions)
       |     shape = (1,) * N
       | 
       |     def impl(dimensions):
       |         res = np.empty((N,) + dimensions, dtype=np.int64)
       |         i = 0
       |         for dim in dimensions:
       |             idx = np.arange(dim, dtype=np.int64).reshape(
       |                 tuple_setitem(shape, i, dim)
       |             )
       |             res[i] = idx
       |             i += 1
       | 
       |         return res
       | 
       |     return impl
       | 
       | 
       | @overload(np.diagflat)
       | def numpy_diagflat(v, k=0):
       |     if not type_can_asarray(v):
       |         msg = 'The argument "v" must be array-like'
       |         raise errors.TypingError(msg)
       | 
       |     if not isinstance(k, (int, types.Integer)):
       |         msg = 'The argument "k" must be an integer'
       |         raise errors.TypingError(msg)
       | 
       |     def impl(v, k=0):
       |         v = np.asarray(v)
       |         v = v.ravel()
       |         s = len(v)
       |         abs_k = abs(k)
       |         n = s + abs_k
       |         res = np.zeros((n, n), v.dtype)
       |         i = np.maximum(0, -k)
       |         j = np.maximum(0, k)
       |         for t in range(s):
       |             res[i + t, j + t] = v[t]
       | 
       |         return res
       | 
       |     return impl
       | 
       | 
       | @overload(np.take)
       | @overload_method(types.Array, 'take')
       | def numpy_take(a, indices):
       | 
       |     if isinstance(a, types.Array) and isinstance(indices, types.Integer):
       |         def take_impl(a, indices):
       |             if indices > (a.size - 1) or indices < -a.size:
       |                 raise IndexError("Index out of bounds")
       |             return a.ravel()[indices]
       |         return take_impl
       | 
       |     if all(isinstance(arg, types.Array) for arg in [a, indices]):
       |         F_order = indices.layout == 'F'
       | 
       |         def take_impl(a, indices):
       |             ret = np.empty(indices.size, dtype=a.dtype)
       |             if F_order:
       |                 walker = indices.copy()  # get C order
       |             else:
       |                 walker = indices
       |             it = np.nditer(walker)
       |             i = 0
       |             flat = a.ravel()
       |             for x in it:
       |                 if x > (a.size - 1) or x < -a.size:
       |                     raise IndexError("Index out of bounds")
       |                 ret[i] = flat[x]
       |                 i = i + 1
       |             return ret.reshape(indices.shape)
       |         return take_impl
       | 
       |     if isinstance(a, types.Array) and \
       |             isinstance(indices, (types.List, types.BaseTuple)):
       |         def take_impl(a, indices):
       |             convert = np.array(indices)
       |             ret = np.empty(convert.size, dtype=a.dtype)
       |             it = np.nditer(convert)
       |             i = 0
       |             flat = a.ravel()
       |             for x in it:
       |                 if x > (a.size - 1) or x < -a.size:
       |                     raise IndexError("Index out of bounds")
       |                 ret[i] = flat[x]
       |                 i = i + 1
       |             return ret.reshape(convert.shape)
       |         return take_impl
       | 
       | 
       | def _arange_dtype(*args):
       |     bounds = [a for a in args if not isinstance(a, types.NoneType)]
       | 
       |     if any(isinstance(a, types.Complex) for a in bounds):
       |         dtype = types.complex128
       |     elif any(isinstance(a, types.Float) for a in bounds):
       |         dtype = types.float64
       |     else:
       |         # `np.arange(10).dtype` is always `np.dtype(int)`, aka `np.int_`, which
       |         # in all released versions of numpy corresponds to the C `long` type.
       |         # Windows 64 is broken by default here because Numba (as of 0.47) does
       |         # not differentiate between Python and NumPy integers, so a `typeof(1)`
       |         # on w64 is `int64`, i.e. `intp`. This means an arange(<some int>) will
       |         # be typed as arange(int64) and the following will yield int64 opposed
       |         # to int32. Example: without a load of analysis to work out of the args
       |         # were wrapped in NumPy int*() calls it's not possible to detect the
       |         # difference between `np.arange(10)` and `np.arange(np.int64(10)`.
       |         NPY_TY = getattr(types, "int%s" % (8 * np.dtype(int).itemsize))
       | 
       |         # unliteral these types such that `max` works.
       |         unliteral_bounds = [types.unliteral(x) for x in bounds]
       |         dtype = max(unliteral_bounds + [NPY_TY,])
       | 
       |     return dtype
       | 
       | 
       | @overload(np.arange)
       | def np_arange(start, / ,stop=None, step=None, dtype=None):
       |     if isinstance(stop, types.Optional):
       |         stop = stop.type
       |     if isinstance(step, types.Optional):
       |         step = step.type
       |     if isinstance(dtype, types.Optional):
       |         dtype = dtype.type
       | 
       |     if stop is None:
       |         stop = types.none
       |     if step is None:
       |         step = types.none
       |     if dtype is None:
       |         dtype = types.none
       | 
       |     if (not isinstance(start, types.Number) or
       |         not isinstance(stop, (types.NoneType, types.Number)) or
       |         not isinstance(step, (types.NoneType, types.Number)) or
       |             not isinstance(dtype, (types.NoneType, types.DTypeSpec))):
       | 
       |         return
       | 
       |     if isinstance(dtype, types.NoneType):
       |         true_dtype = _arange_dtype(start, stop, step)
       |     else:
       |         true_dtype = dtype.dtype
       | 
       |     use_complex = any([isinstance(x, types.Complex)
       |                        for x in (start, stop, step)])
       | 
       |     start_value = getattr(start, "literal_value", None)
       |     stop_value = getattr(stop, "literal_value", None)
       |     step_value = getattr(step, "literal_value", None)
       | 
       |     def impl(start, /, stop=None, step=None, dtype=None):
       |         # Allow for improved performance if given literal arguments.
       |         lit_start = start_value if start_value is not None else start
       |         lit_stop = stop_value if stop_value is not None else stop
       |         lit_step = step_value if step_value is not None else step
       | 
       |         _step = lit_step if lit_step is not None else 1
       |         if lit_stop is None:
       |             _start, _stop = 0, lit_start
       |         else:
       |             _start, _stop = lit_start, lit_stop
       | 
       |         if _step == 0:
       |             raise ValueError("Maximum allowed size exceeded")
       | 
       |         nitems_c = (_stop - _start) / _step
       |         nitems_r = int(math.ceil(nitems_c.real))
       | 
       |         # Binary operator needed for compiler branch pruning.
       |         if use_complex is True:
       |             nitems_i = int(math.ceil(nitems_c.imag))
       |             nitems = max(min(nitems_i, nitems_r), 0)
       |         else:
       |             nitems = max(nitems_r, 0)
       |         arr = np.empty(nitems, true_dtype)
       |         val = _start
  0.1% |         for i in range(nitems):
  0.1% |             arr[i] = val + (i * _step)
```


In [None]:
%%profila
anjl.canonical_nj(large_D_shuffled)