In [1]:
%load_ext profila

> **Note:** loading the `profila` extension can impact Numba's performance. Make sure to disable it once you're done profiling!

In [2]:
# %load_ext snakeviz

In [3]:
import zarr
from scipy.spatial.distance import squareform
import numpy as np
import anjl

## Large

In [12]:
large = zarr.load("../data/large/dist.zarr.zip")
large_D = squareform(large)
shuffle = np.random.choice(large_D.shape[0], size=1000, replace=False)
large_D_shuffled = large_D.take(shuffle, axis=0).take(shuffle, axis=1)
large_D_shuffled

array([[ 0., 38., 28., ..., 29., 26., 24.],
       [38.,  0., 28., ..., 35., 32., 24.],
       [28., 28.,  0., ..., 23., 28., 12.],
       ...,
       [29., 35., 23., ...,  0., 29., 19.],
       [26., 32., 28., ..., 29.,  0., 24.],
       [24., 24., 12., ..., 19., 24.,  0.]], dtype=float32)

In [13]:
%%time
large_Z = anjl.canonical_nj(large_D_shuffled)

CPU times: user 233 ms, sys: 105 μs, total: 233 ms
Wall time: 231 ms


In [14]:
# %%snakeviz
# anjl.canonical_nj(large_D_shuffled)

In [15]:
%%profila
anjl.canonical_nj(large_D_shuffled)

array([[6.0400000e+02, 8.3400000e+02, 9.9699396e-01, 3.0060119e-03,
        2.0000000e+00],
       [9.8000000e+01, 1.1700000e+02, 0.0000000e+00, 0.0000000e+00,
        2.0000000e+00],
       [7.3600000e+02, 1.0010000e+03, 0.0000000e+00, 0.0000000e+00,
        3.0000000e+00],
       ...,
       [1.9780000e+03, 1.9830000e+03, 2.6133865e-01, 2.8120428e-01,
        1.0300000e+02],
       [1.9950000e+03, 1.9960000e+03, 7.0536315e-02, 2.6785880e-02,
        5.1200000e+02],
       [1.9920000e+03, 1.9970000e+03, 5.9355550e-02, 5.9355550e-02,
        1.0000000e+03]], dtype=float32)

**Elapsed:** 6.006 seconds

**Total samples:** 429 (4.2% non-Numba samples, 0.0% bad samples)

/home/aliman/github/alimanfoo/anjl/anjl/_canonical.py (lines 136 to 219):

```
  0.2% |     Z[iteration, 4] = leaves_i + leaves_j
       | 
       |     if n_remaining > 2:
       |         # Update data structures.
       |         _canonical_update(
       |             D=D,
       |             U=U,
       |             index_to_id=index_to_id,
       |             obsolete=obsolete,
       |             node=node,
       |             i_min=i_min,
       |             j_min=j_min,
       |             d_ij=d_ij,
       |         )
       | 
       | 
       | @numba.njit
       | def _canonical_search(
       |     D: np.ndarray, U: np.ndarray, obsolete: np.ndarray, n: int
       | ) -> tuple[int, int]:
       |     # Search for the closest pair of neighbouring nodes to join.
       |     q_min = numba.float32(np.inf)
       |     i_min = -1
       |     j_min = -1
       |     coefficient = numba.float32(n - 2)
       |     m = D.shape[0]
  0.9% |     for i in range(m):
  0.5% |         if obsolete[i]:
       |             continue
       |         u_i = U[i]
 13.8% |         for j in range(i):
 17.7% |             if obsolete[j]:
       |                 continue
       |             u_j = U[j]
       |             d = D[i, j]
 32.9% |             q = coefficient * d - u_i - u_j
 25.9% |             if q < q_min:
       |                 q_min = q
       |                 i_min = i
       |                 j_min = j
       |     return i_min, j_min
       | 
       | 
       | @numba.njit
       | def _canonical_update(
       |     D: np.ndarray,
       |     U: np.ndarray,
       |     index_to_id: np.ndarray,
       |     obsolete: np.ndarray,
       |     node: int,
       |     i_min: int,
       |     j_min: int,
       |     d_ij: float,
       | ) -> None:
       |     # Here we obsolete the row and column corresponding to the node at j_min, and we
       |     # reuse the row and column at i_min for the new node.
       |     obsolete[j_min] = True
       |     index_to_id[i_min] = node
       | 
       |     # Subtract out the distances for the nodes that have just been joined.
  0.2% |     U -= D[i_min]
       |     U -= D[j_min]
       | 
       |     # Initialize divergence for the new node.
       |     u_new = np.float32(0)
       | 
       |     # Update distances and divergence.
  0.5% |     for k in range(D.shape[0]):
  0.5% |         if obsolete[k] or k == i_min or k == j_min:
       |             continue
       | 
       |         # Distance from k to the new node.
       |         d_ik = D[i_min, k]
       |         d_jk = D[j_min, k]
  0.2% |         d_k = 0.5 * (d_ik + d_jk - d_ij)
  0.9% |         D[i_min, k] = d_k
  1.2% |         D[k, i_min] = d_k
  0.2% |         U[k] += d_k
       | 
       |         # Accumulate divergence for the new node.
       |         u_new += d_k
       | 
       |     # Assign divergence for the new node.
  0.2% |     U[i_min] = u_new
```
