In [1]:
%load_ext profila

> **Note:** loading the `profila` extension can impact Numba's performance. Make sure to disable it once you're done profiling!

In [2]:
# %load_ext snakeviz

In [3]:
import zarr
from scipy.spatial.distance import squareform
import numpy as np
import anjl

## Large

In [17]:
large = zarr.load("../data/large/dist.zarr.zip")
large_D = squareform(large)
shuffle = np.random.choice(large_D.shape[0], size=500, replace=False)
large_D_shuffled = large_D.take(shuffle, axis=0).take(shuffle, axis=1)
large_D_shuffled

array([[ 0., 27., 16., ..., 25., 21., 19.],
       [27.,  0., 23., ..., 22., 30., 22.],
       [16., 23.,  0., ..., 25., 19., 17.],
       ...,
       [25., 22., 25., ...,  0., 30., 24.],
       [21., 30., 19., ..., 30.,  0., 20.],
       [19., 22., 17., ..., 24., 20.,  0.]], dtype=float32)

In [18]:
%%time
large_Z = anjl.canonical_nj(large_D_shuffled)

CPU times: user 175 ms, sys: 0 ns, total: 175 ms
Wall time: 174 ms


In [19]:
%%time
large_Z = anjl.rapid_nj(large_D_shuffled, gc=100)

CPU times: user 534 ms, sys: 0 ns, total: 534 ms
Wall time: 533 ms


In [20]:
# %%snakeviz
# anjl.canonical_nj(large_D_shuffled)
# anjl.rapid_nj(large_D_shuffled, gc=100)

In [21]:
# %%snakeviz
# anjl.rapid_nj(large_D_shuffled, gc=100)

In [22]:
%%profila
anjl.canonical_nj(large_D_shuffled)

array([[1.8000000e+02, 4.5800000e+02, 3.3815260e+00, 6.6184740e+00,
        2.0000000e+00],
       [1.9900000e+02, 2.2400000e+02, 8.8531184e-01, 1.1146882e+00,
        2.0000000e+00],
       [3.4800000e+02, 4.4900000e+02, 9.1733873e-01, 1.0826613e+00,
        2.0000000e+00],
       ...,
       [9.9300000e+02, 9.9500000e+02, 1.0816996e-01, 8.3526894e-02,
        1.9900000e+02],
       [9.9400000e+02, 9.9600000e+02, 8.7903276e-02, 3.3321455e-02,
        4.0500000e+02],
       [9.9100000e+02, 9.9700000e+02, 6.6971444e-02, 6.6971444e-02,
        5.0000000e+02]], dtype=float32)

**Elapsed:** 9.595 seconds

**Total samples:** 675 (80.9% non-Numba samples, 0.0% bad samples)

/home/aliman/github/alimanfoo/anjl/anjl/_canonical.py (lines 188 to 196):

```
  1.5% |         for j in range(i):
       |             visited += 1
  5.3% |             if obsolete[j]:
       |                 continue
       |             u_j = U[j]
       |             d = D[i, j]
  7.4% |             q = coefficient * d - u_i - u_j
  0.9% |             searched += 1
  4.0% |             if q < q_min:
```


In [24]:
%%profila
anjl.rapid_nj(large_D_shuffled, gc=100)

array([[1.8000000e+02, 4.5800000e+02, 3.3815260e+00, 6.6184740e+00,
        2.0000000e+00],
       [1.9900000e+02, 2.2400000e+02, 8.8531184e-01, 1.1146882e+00,
        2.0000000e+00],
       [3.4800000e+02, 4.4900000e+02, 9.1733873e-01, 1.0826613e+00,
        2.0000000e+00],
       ...,
       [9.9300000e+02, 9.9500000e+02, 1.0816996e-01, 8.3526894e-02,
        1.9900000e+02],
       [9.9400000e+02, 9.9600000e+02, 8.7903276e-02, 3.3321455e-02,
        4.0500000e+02],
       [9.9100000e+02, 9.9700000e+02, 6.6971444e-02, 6.6971444e-02,
        5.0000000e+02]], dtype=float32)

**Elapsed:** 32.263 seconds

**Total samples:** 2538 (90.9% non-Numba samples, 0.0% bad samples)

/home/aliman/github/alimanfoo/anjl/anjl/_rapid.py (lines 272 to 352):

```
       |     # Initialize working variables.
       |     q_min = numba.float32(np.inf)
       |     threshold = numba.float32(np.inf)
       |     i_min = -1
       |     j_min = -1
       |     searched = 0
       |     visited = 0
       |     coefficient = numba.float32(n_remaining - 2)
       |     m = nodes_sorted.shape[0]
       |     # n = nodes_sorted.shape[1]
       | 
       |     # # First pass, scan down first values.
       |     # for i in range(m):
       |     #     if obsolete[i]:
       |     #         continue
       |     #     u_i = U[i]
       |     #     id_j = nodes_sorted[i, 0]
       |     #     if clustered[id_j]:
       |     #         continue
       |     #     j = id_to_index[id_j]
       |     #     u_j = U[j]
       |     #     d = D[i, j]
       |     #     q = coefficient * d - u_i - u_j
       |     #     if q < q_min:
       |     #         q_min = q
       |     #         threshold = q_min + u_max
       |     #         i_min = i
       |     #         j_min = j
       | 
       |     # indices_available = np.nonzero(~obsolete)[0]
       |     # # np.random.shuffle(indices_available)
       |     # for i in indices_available:
       | 
       |     # Search all values up to threshold.
       |     for i in range(m):
       |         # Skip if row is no longer in use.
  0.1% |         if obsolete[i]:
       |             continue
       | 
       |         # # Obtain identifier for this row.
       |         # id_i = index_to_id[i]
       | 
       |         # Obtain divergence for node corresponding to this row.
       |         u_i = U[i]
       | 
       |         # Search the row up to threshold.
  1.1% |         for node in nodes_sorted[i]:
       |             visited += 1
       | 
       |             # # Obtain node identifier for the current item.
       |             # id_j = nodes_sorted[i, s]
       | 
       |             # Skip if this node is already clustered.
  1.6% |             if clustered[node]:
       |                 continue
       | 
       |             # Break at end of nodes.
       |             if node < 0:
       |                 break
       | 
       |             # Obtain column index in the distance matrix.
       |             j = id_to_index[node]
       | 
       |             # Partially calculate q.
  0.6% |             d = D[i, j]
  3.5% |             q_partial = coefficient * d - u_i
       | 
       |             # Limit search. Because the row is sorted, if we are already above this
       |             # threshold then we know there is no need to search remaining nodes in the
       |             # row.
  0.6% |             if q_partial > threshold:
       |                 break
       | 
       |             # Fully calculate q.
  0.3% |             u_j = U[j]
  0.1% |             q = q_partial - u_j
  0.4% |             searched += 1
       | 
  0.4% |             if q < q_min:
       |                 q_min = q
  0.3% |                 threshold = q_min + u_max
```
