### dev-code for umap/constrain_clust.py


### We were asked for a clusterer.
Every time the cluster membership changes, you create a clustering function.

-API **mk_clusters(clusters, springs, *, mindist=None, mults=None, wm=None)**
  - input as list-of-lists `clusters[c][i]`: (i) cluster(`c`=0,1,...), (ii) points (indices `idx`) in cluster.
    - also internally remember map of index to clusters: `clusts[idx] = list-of-cluster(`c`)`
    - also internally calculate cluster size `cl_n[c]`
    - and cluster average position `cl_avg[c]`
    - opt. store both as numpy sparse matrix + sparse transpose ?
  - input spring force `springs[c]` and opt. learning rate lr=1.0
  - input optional de-neighboring `mults[c]` (default 1.0) and modify umap weight matrix `wm`
  - input optional mindist for short-range cluster repulsive force (Morse function variant?)
  - output jit func (`idx`, `pts`)
    - for `c` in `clusts[idx]`:  # all clusters c of idx
      - `delta` of pt moved toward `cl_avg[c]` with `springs[c]`
      - quick-update `cl_avg[c]` by `(delta / cl_n[c])`
      - opt. every 100'th call, recalculate exact cluster averages instead of `delta`-update
      - opt. short-range repulsion to other neighbors of `idx` in cluster `c`

In [4]:
import numba
import numpy as np

In [5]:
@numba.njit() #["f8(f4,f8,f8)","f8(f8,f8,f8)"])
def smoothstep(x, edge0, edge1):
    """ smoothstep smoothly interpolates between 0 and 1
        as x transitions from edge0 to edge1.

    pre-condition edge0 != edge1
    
    This is useful in cases where a threshold function with a smooth transition is desired.
    smoothstep is equivalent to:
    ```
    genType t;  /* Or genDType t; */
    t = clamp((x - edge0) / (edge1 - edge0), 0.0, 1.0);
    return t * t * (3.0 - 2.0 * t);
    ```
    
    """
    t = min(max( float(x-edge0)/float(edge1-edge0), 0.0), 1.0)
    return t * t * (3.0 - 2.0*t)

for i in range(20):
    x = (i-5) * 0.2
    y1 = smoothstep(x,0.0,1.0)
    print(f"{x=:.2f} y1=smoothstep(x,0,1) {y1:.4f} {type(y1)=}")
for i in range(20):
    x = (i-5) * 0.2
    y2 = smoothstep(x,1.0,0.0)
    print(f"{x=:.2f} y2=smoothstep(x,1,0)={y2:.2f} {type(y2)=}")

x=-1.00 y1=smoothstep(x,0,1) 0.0000 type(y1)=<class 'float'>
x=-0.80 y1=smoothstep(x,0,1) 0.0000 type(y1)=<class 'float'>
x=-0.60 y1=smoothstep(x,0,1) 0.0000 type(y1)=<class 'float'>
x=-0.40 y1=smoothstep(x,0,1) 0.0000 type(y1)=<class 'float'>
x=-0.20 y1=smoothstep(x,0,1) 0.0000 type(y1)=<class 'float'>
x=0.00 y1=smoothstep(x,0,1) 0.0000 type(y1)=<class 'float'>
x=0.20 y1=smoothstep(x,0,1) 0.1040 type(y1)=<class 'float'>
x=0.40 y1=smoothstep(x,0,1) 0.3520 type(y1)=<class 'float'>
x=0.60 y1=smoothstep(x,0,1) 0.6480 type(y1)=<class 'float'>
x=0.80 y1=smoothstep(x,0,1) 0.8960 type(y1)=<class 'float'>
x=1.00 y1=smoothstep(x,0,1) 1.0000 type(y1)=<class 'float'>
x=1.20 y1=smoothstep(x,0,1) 1.0000 type(y1)=<class 'float'>
x=1.40 y1=smoothstep(x,0,1) 1.0000 type(y1)=<class 'float'>
x=1.60 y1=smoothstep(x,0,1) 1.0000 type(y1)=<class 'float'>
x=1.80 y1=smoothstep(x,0,1) 1.0000 type(y1)=<class 'float'>
x=2.00 y1=smoothstep(x,0,1) 1.0000 type(y1)=<class 'float'>
x=2.20 y1=smoothstep(x,0,1) 1.0000 

In [6]:
#
# spring force is 'spring * extension'  (normal physics)
#
# So a spring constant of 1.0 applied for exactly 1.0 seconds
# will move you from wherever exactly to spring equilibrium.
#
# Here 'lr' plays role of time, while 'mindist' applies a radius
# to the target position.
# I.e., spring=lr=1.0 would move towards target, but stop mindist away.
#
def spring_mv_v0(pt, target, spring, lr=1.0, mindist=0.0, *, verbose=0):
    """ spring force for time lr moves vector pt toward target without
        overshoot, stopping mindist away.
        
        return the delta vector (all inputs const)
    """
    # springs -> force gradient -> non-overshooting 'delta' movement
    vec = target - pt      # vector toward cluster center
    vecnorm = np.linalg.norm(vec)
    #grad = np_springs[c] * vecnorm
    ## Let's make the spring force never over-shoot
    ##   v.0: by simply capping movement to "90% of the way to cluster avg"
    #pct_move = min(0.9, (lr * grad)/vecnorm)
    #delta = vec * pct_move
    #
    # --- now add mindist ---
    #      
    #      P-------------------|---A  P=pts[idx], A=cl_avg
    # PA = vecnorm             ^
    #      0          (vecnorm - mindist)
    #
    bar = max(vecnorm - mindist, 0.0) # max move dist (along vec/vecnorm)
    #frac_bar = bar / vecnorm          # max move fraction along vec
    # 'lr' ~ "time the force acts", converting force into a distance
    mv = lr * bar * spring                  # force move dist (possibly overshoots)
    # at mv==bar, we should have vec+delta stop mindist from cl_avg
    #             begin with no 'underhoot'
    frac_mv = min(mv, bar) / vecnorm
    delta = frac_mv * vec
    # pt' = pt + delta   =  pt + frac_mv*(target-pt)
    #                    =  (1-frac_mv)*pt + frac_mv*target
    if verbose>0:
        print(f" mv={100.*frac_mv:.2f}%")
    if verbose>1:
        print(f" {vec=}")
        print(f" {vecnorm=} {mindist=}")
        print(f" {bar=}")
        print(f" {mv=}")
        print(f" {frac_mv=}")
        print(f" {delta=}")
    return delta

# Equivalent terse version
def spring_mv_v0_plain_dev(pt, target, spring, lr=1.0, mindist=0.0):
    vec = target - pt      # vector toward cluster center
    vecnorm = np.linalg.norm(vec)
    if mindist > 0.0:
        bar = max(vecnorm - mindist, 0.0) # max move dist (along vec/vecnorm)
        #mv = lr * bar * spring                  # force move dist (possibly overshoots)
        #frac_mv = min(mv, bar) / vecnorm
        #delta = frac_mv * vec
        # equiv.
        #delta = (min(lr * bar * spring, bar) / vecnorm) * vec
        # or
        delta = (min(lr * spring, 1.0) * bar / vecnorm) * vec
        #
        # or
        #delta = (min(lr * spring, 1.0) * max(vecnorm - mindist, 0.0) / vecnorm) * vec
        # or
        #delta = (min(lr * spring, 1.0) * max(1.0 - mindist/vecnorm), 0.0) * vec
        #  vecnorm does not "cancel out"
    else: # mindist == 0.0
        #bar = vecnorm
        #mv = lr * bar * spring
        #frac_mv = min(mv, bar) / vecnorm
        #delta = frac_mv * vec
        # equiv.
        #delta = (min(vecnorm, lr * vecnorm * spring) / vecnorm) * vec
        # or
        delta = min(1.0, lr*spring) * vec
        # Notice that vecnorm "cancels out" when mindist==0.0
    return delta

def spring_mv_v0_plain(pt, target, spring, lr=1.0, mindist=0.0):
    vec = target - pt      # vector toward cluster center
    if mindist > 0.0:
        vecnorm = np.linalg.norm(vec)
        bar = max(vecnorm - mindist, 0.0) # max move dist (along vec/vecnorm)
        #delta = (min(lr * spring, 1.0) * bar / vecnorm) * vec
        delta = (min(bar, bar * lr * spring) / vecnorm) * vec
    else: # mindist == 0.0
        delta = min(1.0, lr*spring) * vec
    return delta


nfail = 0
def test_mv(pt, target, spring=1.0, lr=1.0, mindist=0.0, *, gives=None):
    global nfail
    delta = spring_mv_v0(pt,target,spring, lr, mindist)
    if not np.allclose(delta, gives):
        nfail += 1
        delta = spring_mv_v0(pt,target,spring, lr, mindist, verbose=2)
        print(f"Failed: spring_mv_v0({pt},{target}, {spring},{lr}, {mindist}) ~= {gives}")
        print(f"        spring_mv_v0 --> {delta}")
        print(f"               EXPECTED: {gives}")
        print(f"            discrepancy: {gives-delta}")
    delta_plain = spring_mv_v0_plain(pt, target, spring, lr, mindist)
    if not np.allclose(delta_plain, gives):
        nfail += 1
        #delta = spring_mv_v0(pt,target,spring, lr, mindist, verbose=2)
        print(f"Failed: spring_mv_v0_plain({pt},{target}, {spring},{lr}, {mindist}) ~= {gives}")
        print(f"        spring_mv_v0_plain --> {delta_plain}")
        print(f"                  EXPECTED: {gives}")
        print(f"               discrepancy: {gives-delta_plain}")
    return

print(*(1,2,3))
test_mv(0.0,10.0, 1.0,1.0, 0.0, gives=10.0)
test_mv(0.0,10.0, 1.0,1.0, 1.0, gives=9.0)
test_mv(0.0,10.0, 1.0,1.0, 9.0, gives=1.0)
test_mv(0.0,10.0, 1.0,1.0, 10.0, gives=0.0)
test_mv(0.0,10.0, 1.0,1.0, 11.0, gives=0.0)
eps = 1e-3
test_mv(0.0,10.0, 1+eps,1.0, 0.0, gives=10.0)
test_mv(0.0,10.0, 1-eps,1.0, 0.0, gives=10*(1-eps))
test_mv(0.0,10.0, 1.0,1+eps, 0.0, gives=10.0)
test_mv(0.0,10.0, 1.0,1-eps, 0.0, gives=10*(1-eps))
test_mv(0.0,10.0, 1.0,1.0, 10.0-eps, gives=eps)
test_mv(0.0,10.0, 1.0,1.0, 10.0+eps, gives=0.0)
test_mv(0.0,10+eps, 1.0,1.0, 0.0, gives=10+eps)
test_mv(0.0,10+eps, 1.0,1.0, eps, gives=10)
test_mv(0.0,10, 1.0,1.0, eps, gives=10-eps)

test_mv(1.0,10.0, 1.0,1.0, 0.0, gives=9.0)
test_mv(1.0,10.0, 1.0,1.0, 1.0, gives=8.0)
test_mv(0.0,-10.0, 1.0,1.0, 0.0, gives=-10.0)
test_mv(0.0,-10.0, 1.0,1.0, 1.0, gives=-9.0)
test_mv(10.0,-10.0, 1.0,1.0, 0.0, gives=-20.0)
test_mv(10.0,-10.0, 1.0,1.0, 1.0, gives=-19.0)

print("Goodbye")

1 2 3
Goodbye


In [7]:
print("Let's demo python clustering internals (debug version)")
#
# --------------------- Inputs -----------------
#
n_samples = 8
cluster_lists = [[1,2,3], [2,7,6]]
#           avg    2.0      5.0
# each of the 2 clusters has a spring constant
springs = np.array([0.8, 2.0])
lr = 1.0   # learning rate (gradient multiplier) for spring forces

pts = np.ndarray((n_samples,2), dtype=np.float32)
for i in range(n_samples):
    pts[i,:] = (float(i),float(i%3))



#
# ---------------------- Functions --------------
#
def do_clustering_ipts(idx, pts, cluster_lists, springs, *, lr=1.0, mindist=0.01):
    assert len(pts.shape) == 2
    n_samples = pts.shape[0]
    n_clust = len(cluster_lists)
    assert n_clust > 0
    assert len(springs) == len(cluster_lists)
    np_cluster_lists = [np.array(clist,np.int32) for clist in cluster_lists]
    np_springs = np.array(springs)
    #
    # cluster membership as bool array[cluster index][point index]
    #
    clusters = np.full((n_clust, n_samples), False, dtype=bool)
    for (c,members) in enumerate(cluster_lists):
        for m in members:
            clusters[c][m] = True
    
    #print(f"{clusters=}")
    verbose=0
    if verbose>0: print(f"{clusters[:,idx]=}")
    idx_in = clusters[:,idx]
    #idx_in = np.equal( clusters[:,idx], True )
    #print(f"{idx=} in clusters {idx_in=}")
    #idx_cl = np.argwhere(idx_in) # size n_nonzero x idx_in_size ?
    idx_cl = np.argwhere(idx_in).flatten()
    if verbose>0: print(f"{idx_cl=}")
    #
    verbose=1
    if len(idx_cl) < 1:
        return

    elif len(idx_cl) == 1:
        # Separate out an easy case (idx in single cluster)
        c = idx_cl[0]
        if verbose>0: print("\npush pt",idx,"towards single cluster center",c)
        #cluster_pts0 = np.argwhere(clusters[c,:]).flatten()
        #print(f"{cluster_pts0=}")
        #cluster_pts  = cluster_lists[c]
        cluster_pts  = np_cluster_lists[c] # perhaps faster/vectorizable
        if len(cluster_pts) <= 1:
            print("noop: cluster empty or of size 1")
            return
        if verbose>1: print(f"{cluster_pts=}")

        # cl_avg ~ cluster center
        # cleverer version:
        cl_avg = np.mean(pts[cluster_pts,:], axis=0)
        if verbose>0: print(f"{cl_avg=}")
        if True: # long-handed
            cl_avg0 = np.zeros((pts.shape[1]), dtype=np.float32)
            for pt in cluster_pts:
                cl_avg0 += pts[pt]
            assert len(cluster_pts) > 1
            cl_avg0 /= len(cluster_pts)
            #if verbose>0: print(f"{cl_avg0=}")
            assert np.allclose(cl_avg, cl_avg0)


        # springs -> force gradient -> non-overshooting 'delta' movement
        if False: # v.0
            # springs -> force gradient -> non-overshooting 'delta' movement
            vec = cl_avg - pts[idx]      # vector toward cluster center
            vecnorm = np.linalg.norm(vec)
            #grad = np_springs[c] * vecnorm
            ## Let's make the spring force never over-shoot
            ##   v.0: by simply capping movement to "90% of the way to cluster avg"
            #pct_move = min(0.9, (lr * grad)/vecnorm)
            #delta = vec * pct_move
            #
            # --- now add mindist ---
            #      
            #      P-------------------|---A  P=pts[idx], A=cl_avg
            # PA = vecnorm             ^
            #      0          (vecnorm - mindist)
            #
            bar = max(vecnorm - mindist, 0.0)
            # 'lr' ~ "time the force acts", converting force into a distance
            mv = lr * bar * np.springs[c]     # force move dist (possibly overshoots)
            # at mv==bar, we should have vec+delta stop mindist from cl_avg
            #             begin with no 'underhoot'
            #frac_bar = bar / vecnorm # max move (dist along vec/vecnorm)
            frac_mv = min(mv, frac_bar) / vecnorm
            delta = vecbar * frac_mv
        elif True:
            delta = spring_mv_v0(pts[idx], cl_avg, springs[c], lr, mindist)
        else:
            # EXPERIMENTAL (can we "smooth" over the discontinuities)
            vec = cl_avg - pts[idx]      # vector toward cluster center
            vecnorm = np.linalg.norm(vec)
            if vecnorm <= mindist: # no-op - pt is already within mindist of cl_avg
                pct_move = 0.0
                delta = vec * 0.0
            else:
                grad = np_springs[c] * vecnorm
                gradnorm = lr * grad
                # Let's make the spring force never over-shoot
                #   v.0: by simply capping movement to "90% of the way to cluster avg"
                pct_move0 = min(0.9, lr * grad)
                #delta = vec * pct_move
                #   v.1: mindist  lr*grad *= fn going from 0.0 to 1.0 as
                #                         lr*grad goes from vecnorm-0.5*mindist to vecnorm-mindist
                edge0 = max(vecnorm - 2*mindist, 0.0)
                edge1 = vecnorm - mindist
                if gradnorm <= edge0:
                    pct_move = 1.0
                else:
                    mvlen = edge0 + (edge1-edge0) * smoothstep( gradnorm, edge0, edge1 )
                    print(f"{mvlen=}")
                    pct_move = (mvlen / vecnorm)
                print(f"{lr*grad=:.4f} {edge0=:.4f} {edge1=:.4f}")
                print(f"{mindist=} pct_move v.0 {int(pct_move0*100.):.2f}% v.1 {int(pct_move*100.):.2f}") 

                delta = vec * pct_move

        #if verbose>0: print(f"{vec=} {vecnorm=} {grad=} {lr*grad=} {delta=}")
        # perhaps modify the force from pure-parabolic to something
        # with both long-range attraction and some mindist-repulsion.
        #   (like Morse curve?)
        # or just "if vecnorm < mindist: NOOP"
        # (or some smooth-ish 90% that fades to zero as vecnorm < ~mindist?)

        newpos = pts[idx,:] + delta
        if verbose>0: print(f"pt {idx} {pts[idx,:]} --> newpos {newpos}")
        
        pts[idx,:] = newpos

        #
        # Is a quick approx update of cl_avg feasible?
        cl_avg_new = cl_avg + delta/len(cluster_pts)
        if verbose>0: print(f"cluster {c} : {cl_avg_new=}")
        #
        # NO! but YES if idx is only in a single cluster
        #  Above simplicity does not work, because ALL cluster centers in which
        #  pts[idx] participates must be updated,
        #  (not just current cluster 'c')
        # BUT OK for disjoint cluster memberships (pt never in 2 clusters)

        return

    assert( len(idx_cl) > 1 )
    # Note: if point is in two clusters, pt moves first toward one,
    #       then toward next, in SAME pattern.
    # Better might be to 1st determine NET gradient direction from
    # weighted sum of individual gradient forces,
    # with movement "capped" at some weighted avg of cluster centers?
    #
    # Skip for now (assume multi-cluster membership is rare-ish)
    #
    cl_avgs = []
    if verbose>0: print("\npush pt",idx,"towards cluster centers",idx_cl)
    for c in idx_cl:
        cluster_pts  = np_cluster_lists[c] # perhaps faster/vectorizable
        if len(cluster_pts) <= 1:
            cl_avgs.append(None)
            print("noop: cluster empty or of size 1")
            continue
        if verbose>2: print(f"{cluster_pts=}")

        # cl_avg ~ cluster center
        #cl_avg = np.zeros((pts.shape[1]), dtype=np.float32)
        #for pt in cluster_pts:
        #    #if verbose: print("sum += ",pts[pt])
        #    cl_avg += pts[pt]
        #if len(cluster_pts) > 1:
        #    cl_avg /= len(cluster_pts)
        cl_avg = np.mean(pts[cluster_pts,:], axis=0)
        if verbose>1: print(f"{cl_avg=}")
        cl_avgs.append(cl_avg)

    tot_grad = np.zeros(pts.shape[1],dtype=np.float32)
    eqm_pos  = np.zeros(pts.shape[1],dtype=np.float32)
    sum_springs = np.sum(np_springs) #0.0
    if verbose>1: print(f"{sum_springs=}")
    #grad = None # we don't use this variable here
    for (i,c) in enumerate(idx_cl):
        # equilibrium target is spring-weighted center-of-mass
        eqm_pos += np_springs[c] * cl_avgs[i]
        # springs -> force gradient -> non-overshooting 'delta' movement
        vec = cl_avgs[i] - pts[idx]      # vector toward cluster center
        #vecnorm = np.linalg.norm(vec)
        #grad = springs[c] * vecnorm     # gradient, in dirn vec/vecnorm
        #tot_grad += (spring[c] * vecnorm) * (vec/vecnorm)
        #  so we vecnorm cancels, leaving simply:
        tot_grad += np_springs[c] * vec

    eqm_pos /= sum_springs
    if verbose>0: print(f"{eqm_pos=}")
    if verbose>1: print(f"{tot_grad=}")

    if True: # assert
        # tot_grad should automatically be in direction of eqm_pos
        # if math is correct:
        grad_dirn = tot_grad / np.linalg.norm(tot_grad)
        eqm_dirn  = eqm_pos - pts[idx]
        eqm_dirn /= np.linalg.norm(eqm_dirn)
        assert( np.allclose(grad_dirn, eqm_dirn) )

    # at this point we have net gradient and terminal eqm_pos
    # we proceed as before, so that the force is toward eqm_pos
    # but the point never overshoots.
    if True: # original.
        #
        #  Here we work directly with eqm_pos and totgrad
        #  and ignore mindist.
        #
        vec = eqm_pos - pts[idx]      # vector toward spring-weighted equilibrium
        vecnorm = np.linalg.norm(vec)
        gradnorm = np.linalg.norm(tot_grad)
        if verbose>2: print(f"{gradnorm=}")
        #   v.0: by simply capping movement to "100% of the way to cluster avg"
        #       N.B. for cluster avg 100%, without mindist
        frac_mv = min(1.0, (lr * gradnorm) / vecnorm)
        delta = vec * frac_mv
        if verbose>1: print(f"{vecnorm=} {gradnorm=}")
        elif verbose>2: print(f"{vec=} {vecnorm=} {gradnorm=} {frac_mv=} {delta=}")
        #
        # cf
        # without minidst (i.e. bar = vecnorm)
        #vec = eqm_pos - pts[idx]      # vector toward spring-weighted equilibrium
        #vecnorm = np.linalg.norm(vec)
        #if vecnorm > 1e-5:
        #    #mv = lr * np.linalg.norm(tot_grad)  # lr : force --> distance
        #    #mv = min(mv, vecnorm)
        #    #pts[idx,:] += (mv/vecnorm) * vec
        #    #...aka
        #    #mv = lr * np.linalg.norm(tot_grad)  # lr : force --> distance
        #    #mv = min(1.0, mv/vecnorm)
        #    #pts[idx,:] += mv * vec
        #    #...aka
        #    pts[idx,:] += min(1.0, mv/vecnorm) * vec

    #else: # use same function to calc delta
    #    eqm_spring  (develop in next cell)
    #    delta = spring_mv_v0(pts[idx], cl_avg, , lr, mindist)

    # perhaps modify the force from pure-parabolic to something
    # with both long-range attraction and some mindist-repulsion.
    #   (like Morse curve?)
    # or just "if vecnorm < mindist: NOOP"
    # (or some smooth-ish 90% that fades to zero as vecnorm < ~mindist?)

    newpos = pts[idx,:] + delta
    if verbose>0: print(f"pt {idx} {pts[idx,:]} --> newpos {newpos}")
    
    pts[idx,:] = newpos
    if verbose>0: print(f"final {pts[idx,:]=}")

    #
    # Is a quick approx update of cl_avg feasible?
    #cl_avg_new = cl_avg + delta/len(cluster_pts)
    #if verbose>0: print(f"cluster {c} : {cl_avg_new=}")
    #
    # NO!
    #  Above simplicity does not work, because ALL cluster centers in which
    #  pts[idx] participates must be updated,
    #  (not just current cluster 'c')
    # BUT OK for disjoint cluster memberships (pt never in 2 clusters)
    return
    
def do_clustering_pts(pts, cluster_lists, springs, *, lr=1.0, mindist=0.01):
    assert len(pts.shape) == 2
    n_samples = pts.shape[0]
    # one round of clustering every point once
    for idx in range(n_samples):
        pt0 = pts[idx,:].copy()
        do_clustering_ipts(idx, pts, cluster_lists, springs, lr=lr, mindist=mindist)
        print(f"{idx=} {pt0=} --> {pts[idx,:]=}")
    return
        
do_clustering_pts(pts, cluster_lists, springs, lr=lr, mindist=0.2)
print("final positions (one epoch)")
for i in range(pts.shape[0]):
    print(f"{i=} pt = {pts[i,:]}")

pts_ref = pts.copy()
#

Let's demo python clustering internals (debug version)
idx=0 pt0=array([0., 0.], dtype=float32) --> pts[idx,:]=array([0., 0.], dtype=float32)

push pt 1 towards single cluster center 0
cl_avg=array([2., 1.], dtype=float32)
pt 1 [1. 1.] --> newpos [1.64 1.  ]
cluster 0 : cl_avg_new=array([2.2133334, 1.       ], dtype=float32)
idx=1 pt0=array([1., 1.], dtype=float32) --> pts[idx,:]=array([1.64, 1.  ], dtype=float32)

push pt 2 towards cluster centers [0 1]
eqm_pos=array([4.2038097, 1.       ], dtype=float32)
pt 2 [2. 2.] --> newpos [4.2038097 1.       ]
final pts[idx,:]=array([4.2038097, 1.       ], dtype=float32)
idx=2 pt0=array([2., 2.], dtype=float32) --> pts[idx,:]=array([4.2038097, 1.       ], dtype=float32)

push pt 3 towards single cluster center 0
cl_avg=array([2.9479363, 0.6666667], dtype=float32)
pt 3 [3. 0.] --> newpos [2.9708064  0.37381905]
cluster 0 : cl_avg_new=array([2.938205  , 0.79127306], dtype=float32)
idx=3 pt0=array([3., 0.], dtype=float32) --> pts[idx,:]=array([2.9

In [8]:
print("Let's demo python clustering internals (no-debug version)")
#
# --------------------- Inputs -----------------
#
n_samples = 8
cluster_lists = [[1,2,3], [2,7,6]]
#           avg    2.0      5.0
# each of the 2 clusters has a spring constant
springs = np.array([0.8, 2.0])
lr = 1.0   # learning rate (gradient multiplier) for spring forces

pts = np.ndarray((n_samples,2), dtype=np.float32)
for i in range(n_samples):
    pts[i,:] = (float(i),float(i%3))

n_samples = 8
cluster_lists = [[1,2,3], [2,7,6]]
#           avg    2.0      5.0
# each of the 2 clusters has a spring constant
springs = np.array([0.8, 2.0])
lr = 1.0   # learning rate (gradient multiplier) for spring forces

pts = np.ndarray((n_samples,2), dtype=np.float32)
for i in range(n_samples):
    pts[i,:] = (float(i),float(i%3))




#
# ---------------------- Functions --------------
#

# begin with 2 very simple helpers
@numba.njit()
def np_mean_axis_0(pts):
    cl_avg = np.zeros((pts.shape[1]), dtype=np.float32)
    for pt in range(pts.shape[0]):
        cl_avg += pts[pt,:]
    assert pts.shape[0] > 0
    cl_avg /= pts.shape[0]
    return cl_avg

@numba.njit() # eventually should just inline this wherever
def xnp_cluster_list(c, clusters):
    return np.argwhere(clusters[c,:]).flatten()



#@numba.njit()
def do_clustering_ipts(idx, pts, cluster_lists, springs, *, lr=1.0, mindist=0.01):
    assert len(pts.shape) == 2
    n_samples = pts.shape[0]
    n_clust = len(cluster_lists)
    assert n_clust > 0
    assert len(springs) == len(cluster_lists)
    np_cluster_lists = [np.array(clist,np.int32) for clist in cluster_lists]
    np_springs = np.array(springs)
    #
    # cluster membership as bool array[cluster index][point index]
    #
    clusters = np.full((n_clust, n_samples), False, dtype=bool)
    for (c,members) in enumerate(cluster_lists):
        for m in members:
            clusters[c][m] = True
    
    idx_in = clusters[:,idx]
    idx_cl = np.argwhere(idx_in).flatten()
    #
    if len(idx_cl) < 1:
        return

    elif len(idx_cl) == 1:
        # Separate out an easy case (idx in single cluster)
        c = idx_cl[0]
        cluster_pts  = np_cluster_lists[c] # perhaps faster/vectorizable
        if len(cluster_pts) <= 1:
            print("noop: cluster empty or of size 1")
            return

        # cl_avg ~ cluster center
        cl_avg = np.mean(pts[cluster_pts,:], axis=0)
        cl_avg2 = np_mean_axis_0( pts[cluster_pts,:] )  # cluster centroid
        #print(f"{idx=} {cl_avg=}\n{cl_avg2=}")

        delta = spring_mv_v0(pts[idx], cl_avg, springs[c], lr, mindist)
        pts[idx,:] += delta

        # A quick approx update of cl_avg feasible, but don't use it for now
        #cl_avg_new = cl_avg + delta/len(cluster_pts)
        # It's trickier for len(idx_cl) > 1 (more updates needed!)

        return

    assert( len(idx_cl) > 1 )
    # Easy: if point is in two clusters, pt moves first toward one,
    #       then toward next, in SAME pattern.
    #
    # Actual: 1st determine NET gradient direction from
    #         weighted sum of individual gradient forces,
    #         with movement not to overshoot weighted "equilibrium"
    #         position derived from weighted avg of cluster centers
    #
    tot_grad = np.zeros(pts.shape[1],dtype=np.float32)
    eqm_pos  = np.zeros(pts.shape[1],dtype=np.float32)
    sum_springs = 0.0 #np.sum(np_springs)
    n_springs = 0
    if False: # first way, 2 loops
        cl_avgs = []
        for c in idx_cl:
            cluster_pts  = np_cluster_lists[c] # perhaps faster/vectorizable
            if len(cluster_pts) < 1:
                cl_avgs.append(None)
                print("noop: cluster empty or of size 1")
                continue
            # cl_avg ~ cluster center
            cl_avg = np.mean(pts[cluster_pts,:], axis=0)
            cl_avgs.append(cl_avg)
            n_springs += 1

        tot_grad = np.zeros(pts.shape[1],dtype=np.float32)
        eqm_pos  = np.zeros(pts.shape[1],dtype=np.float32)
        sum_springs = np.sum(np_springs) # correct if all clusters non-empty

        # calculate equilibrium spring-weight target position
        #           and total spring-weighted gradient
        # every cluster exerts a force ind't of cluster size
        for (i,c) in enumerate(idx_cl):
            if cl_avgs[i] is not None: # i.e. the cluster is non-empty, has a centroid
                # springs -> force gradient -> non-overshooting 'delta' movement
                eqm_pos += np_springs[c] * cl_avgs[i]
                vec = cl_avgs[i] - pts[idx,:]      # vector toward cluster center
                tot_grad += np_springs[c] * vec

        eqm_pos /= sum_springs
    
    else: # shorter way: combine loops

        # calculate equilibrium spring-weight target position
        #           and total spring-weighted gradient
        # every cluster exerts a force ind't of cluster size
        for c in idx_cl:
            cluster_pts  = xnp_cluster_list(c, clusters) # perhaps faster/vectorizable
            # Note: 1. Each cluster centroid gets its spring regardless
            #          of how populated the cluster is.
            #       2. Size 1 cluster get included
            #       3. Spring constant doubles as weighting factor -- this
            #          might not hold for other spring force models!
            if len(cluster_pts) > 0:
                cl_avg = np_mean_axis_0( pts[cluster_pts,:] )
                #print(f"{c=} {cl_avg=}")
                # Now update sums for equilibrium posn and total gradient
                # springs -> summed force gradient vectors
                spring = np_springs[c]
                sum_springs += spring
                n_springs += 1
                eqm_pos += spring * cl_avg
                vec = cl_avg - pts[idx,:]      # vector toward cluster center
                tot_grad += spring * vec       # only for kr^2 spring physics

        eqm_pos /= sum_springs

    print(f"{eqm_pos=}")

    if True: # assert
        # tot_grad should automatically be in direction of eqm_pos
        # if math is correct:
        grad_dirn = tot_grad / np.linalg.norm(tot_grad)
        eqm_dirn  = eqm_pos - pts[idx,:]
        eqm_dirn /= np.linalg.norm(eqm_dirn)
        assert( np.allclose(grad_dirn, eqm_dirn) )

    # I don't think mindist should apply to "virtual" avg-of-cluster-centroids
    use_mindist_for_eqm = False
    
    if False and not use_mindist_for_eqm: # orig (no mindist)
        #vec = eqm_pos - pts[idx]      # vector toward spring-weighted equilibrium
        #vecnorm = np.linalg.norm(vec)
        eqm_spring = np.linalg.norm(tot_grad) / np.linalg.norm(eqm_pos-pts[idx])
        vec = eqm_pos - pts[idx]      # vector toward spring-weighted equilibrium
        vecnorm = np.linalg.norm(vec)
        print(f"{eqm_spring=} {eqm_spring*vecnorm=}")

        # at this point we have net gradient and terminal eqm_pos
        # we proceed as before, so that the force is toward eqm_pos
        # but the point never overshoots.
        #
        # Can we use an "effective spring constant" for eqm_pos?
        # or do we need a 'gradnorm'-based updater?
        #
        vec = eqm_pos - pts[idx]      # vector toward spring-weighted equilibrium
        vecnorm = np.linalg.norm(vec)
        gradnorm = np.linalg.norm(tot_grad)
        # spring-physics model:
        #    gradnorm = spring_effx * (displacement=vecnorm)
        spring_effx = gradnorm / vecnorm
        print(f"{vecnorm=} {gradnorm=} {spring_effx=}")
        #   v.0: by simply capping movement to "100% of the way to cluster avg"
        #        for movement to eqm posn, can move all the way?
        pct_move = min(1.0, (lr * gradnorm)/vecnorm)
        #   v.1: even for a move to eqm_pos, use a mindist2 guaranteed "small"
        #mindist2 = min( mindist, 0.1 )
        #pct_move = smoothstep( lr * gradnorm, vecnorm - 0.5*mindist2, vecnorm - mindist2 )
        # above is WRONG
        delta = vec * pct_move
        
        #def spring_mv_v0_plain(pt, target, spring, lr=1.0, mindist=0.0):
        #   vec = target - pt      # vector toward cluster center
        #    if mindist > 0.0:
        #        vecnorm = np.linalg.norm(vec)
        #        bar = max(vecnorm - mindist, 0.0) # max move dist (along vec/vecnorm)
        #        delta = (min(lr * spring, 1.0) * bar / vecnorm) * vec
        #    else: # mindist == 0.0
        #        delta = min(1.0, lr*spring) * vec
        #    return delta
        # with above cancellation, for mindist==0.0, we have just:
        
    elif True and not use_mindist_for_eqm: # orig (no mindist) -- extremely simple!
        # almost too simple to be true
        vec = eqm_pos - pts[idx]      # vector toward spring-weighted equilibrium
        gradnorm = np.linalg.norm(tot_grad)
        delta = min(1.0, lr * gradnorm) * vec
        #                     ^^^^^^^^ replaces 'spring'
        
    else: # with min_dist : now the "target" is not eqm_pos, but a point mindist away from it
        # In principle:
        #eqm_spring = np.linalg.norm(tot_grad) / max(1.e-6, (np.linalg.norm(eqm_pos-pts[idx]) - mindist))
        # But for targets that are "midway" between several clusters, maybe it is best to NOT
        # use mindist
        #delta = spring_mv_v0(pts[idx], eqm_pos, eqm_spring, lr, mindist)
        #
        # or long-hand:
        vec = eqm_pos - pts[idx]            # vector toward spring-weighted equilibrium
        vecnorm = np.linalg.norm(vec)
        bar = max(vecnorm - mindist, 0.0)   # max move dist (along vec/vecnorm)
        gradnorm = np.linalg.norm(tot_grad) # this replaces 'spring' in spring_mv_v0_plain
        delta = (min(lr * gradnorm, 1.0) * bar / vecnorm) * vec


    if True:
        pts[idx,:] = pts[idx,:] + delta
    else: # verbose
        newpos = pts[idx,:] + delta
        print(f"pt {idx} {int(pct_move*100.)}% {pts[idx,:]} --> newpos {newpos}")
        pts[idx,:] = newpos
        print(f"final {pts[idx,:]=}")


    return

#@numba.njit()
def do_clustering_pts(pts, cluster_lists, springs, *, lr=1.0, mindist=0.01):
    assert len(pts.shape) == 2
    n_samples = pts.shape[0]
    # one round of clustering every point once
    for idx in range(n_samples):
        #pt0 = pts[idx,:].copy()
        do_clustering_ipts(idx, pts, cluster_lists, springs, lr=lr, mindist=mindist)
        #print(f"{idx=} {pt0=} --> {pts[idx,:]=}")
    return

print(f"{type(pts)=} {pts.shape=} {pts.dtype=}")
do_clustering_pts(pts, cluster_lists, springs, lr=lr, mindist=0.2)
print("final positions (one epoch)")
for i in range(pts.shape[0]):
    print(f"{i=} pt = {pts[i,:]}")

if pts_ref is not None:
    assert( np.allclose(pts, pts_ref) )
    print("Good: matched pts_ref")
else:
    pts_ref = pts.copy()
print("Goodbye")

Let's demo python clustering internals (no-debug version)
type(pts)=<class 'numpy.ndarray'> pts.shape=(8, 2) pts.dtype=dtype('float32')
eqm_pos=array([4.2038097, 1.       ], dtype=float32)
final positions (one epoch)
i=0 pt = [0. 0.]
i=1 pt = [1.64 1.  ]
i=2 pt = [4.2038097 1.       ]
i=3 pt = [2.9708064  0.37381905]
i=4 pt = [4. 1.]
i=5 pt = [5. 2.]
i=6 pt = [5.808576  0.4808495]
i=7 pt = [5.8691216 0.8527701]
Good: matched pts_ref
Goodbye


In [9]:
import numpy as np
import numba
print("Let's demo python clustering internals (jit version)")
print("This cell is the precursor a new file umap/constrain_clust.py\n")

#
# ---------------------- Functions --------------
# ouch.  these must be jittable now
# numba cannot handle a python list of numpy arrays.
# Let's break things apart to find numba-ready code blocks
#
@numba.njit()
def np_mean_axis_0(pts):
    cl_avg = np.zeros((pts.shape[1]), dtype=np.float32)
    for pt in range(pts.shape[0]):
        cl_avg += pts[pt,:]
    assert pts.shape[0] > 0
    cl_avg /= pts.shape[0]
    return cl_avg

@numba.njit() # approx. void(i8, f4[:,:], i8[:], f8, f8)
#def xdo_clustering_single(idx, pts, cluster_pts, lr, spring):
def xdo_clustering_ipts_toward0(idx, pts, target, lr, spring, maxfrac=0.9):
    """ Non-overshooting move of pts[idx,:] towards cofm(pts[cluster_pts]).
    
        In this version, maxfrac=0.9 is an under-relaxation "don't go all the way".
    """
    if lr*spring > 1e-5:
        vec = target - pts[idx,:]
        vecsz = np.linalg.norm(vec)         # stepsz is distance ~ spring force * lr
        stepsz = lr * (spring * vecsz)
        fracsz = min(maxfrac, stepsz/vecsz) # maxfrac<1 => undershoot
        pts[idx,:] += fracsz * vec
    # v.1  attempt to smooth the transition
    #vec = cl_avg - pts[idx]      # vector toward cluster center
    #vecnorm = np.linalg.norm(vec)
    #if vecnorm <= mindist: # no-op - pt is already within mindist of cl_avg
    #    delta = vec * 0.0
    #else:
    #    grad = np_springs[c] * vecnorm
    #    gradnorm = lr * grad
    #    edge0 = max(vecnorm - 2*mindist, 0.0)
    #    edge1 = vecnorm - mindist
    #    if gradnorm <= edge0:
    #        delta = vec
    #    else:
    #        mvlen = edge0 + (edge1-edge0) * smoothstep( gradnorm, edge0, edge1 )
    #        delta = vec * (mvlen/vecnorm)
    #delta = vec * pct_move
    return
    #
    # compare with multi-cluster case, where we calc tot_grad up front
    # as a spring-weighted sum of force vectors
    #
    #if True:
    #    # at this point we have net gradient and terminal eqm_pos
    #    # we proceed as before, so that the force is toward eqm_pos
    #    # but the point never overshoots.
    #    #   v.0: by simply capping movement to "100% of the way to cluster avg"
    #    #        for movement to eqm posn, can move all the way?
    #    eqm_pos = target
    #    vec = eqm_pos - pts[idx]
    #    vecnorm = np.linalg.norm(vec)
    #    gradnorm = np.linalg.norm(tot_grad) # gradnorm is NOT spring * vecnorm
    #    pct_move = min(1.0, (lr * gradnorm) / vecnorm)
    #    delta = vec * pct_move
    #
    # this can be reproduced by an "effective spring const"
    #   gradnorm = np.linalg.norm(tot_grad)
    # == ?
    #   spring_eff * np.linalg.norm(target-pts[idx])
    # IFF
    #   spring_eff = np.linalg.norm(target-pts[idx]) / np.linalg.norm(tot_grad)
    #

#if False: # for reference
#    def spring_mv_v0_plain(pt, target, spring, lr=1.0, mindist=0.0):
#        vec = target - pt      # vector toward cluster center
#        vecnorm = np.linalg.norm(vec)
#        if mindist > 0.0:
#            bar = max(vecnorm - mindist, 0.0) # max move dist (along vec/vecnorm)
#            delta = (min(lr * spring, 1.0) * bar / vecnorm) * vec
#        else: # mindist == 0.0
#            delta = min(1.0, lr*spring) * vec
#        return delta

@numba.njit() # approx. void(i8, f4[:,:], i8[:], f8, f8)
def xdo_clustering_ipts_toward(idx, pts, target, lr, spring, mindist=0.0):
    """ Non-overshooting move of pts[idx,:] towards cofm(pts[cluster_pts]).
    
        lr: time step (1.0 will move exactly to equilibrium posn if spring==1)
        
        spring : spring constant "force ~ spring * displacement" (parabolic potential)
        
        mindist: to "not go all the way" towards target, but stop mindist away.
    """
    # Using maxfrac:
    #if lr*spring > 1e-5:
    #    vec = target - pts[idx,:]
    #    vecsz = np.linalg.norm(vec)         # stepsz is distance ~ spring force * lr
    #    stepsz = lr * (spring * vecsz)
    #    fracsz = min(maxfrac, stepsz/vecsz) # maxfrac<1 => undershoot
    #    pts[idx,:] += fracsz * vec
    #return
    
    # mindist python move fn:
    #def spring_mv_v0_plain(pt, target, spring, lr=1.0, mindist=0.0):
    #   vec = target - pt      # vector toward cluster center
    #    if mindist > 0.0:
    #        vecnorm = np.linalg.norm(vec)
    #        bar = max(vecnorm - mindist, 0.0) # max move dist (along vec/vecnorm)
    #        delta = (min(lr * spring, 1.0) * bar / vecnorm) * vec
    #  or just delta = (min(bar, bar * lr * spring) / vecnorm) * vec
    #    else: # mindist == 0.0
    #        delta = min(1.0, lr*spring) * vec
    #    return delta
    
    # movement with mindist (patterned after python helper "spring_mv_v0")
    vec = target - pts[idx,:]      # vector toward cluster center
    vecnorm = np.linalg.norm(vec)
    if vecnorm > 1e-5:
        bar = max(vecnorm - mindist, 0.0) # max move dist (along vec/vecnorm)
        #delta = (min(lr * spring, 1.0) * bar / vecnorm) * vec
        mv = min(bar, bar * lr * spring)
        pts[idx,:] += (mv/vecnorm) * vec
    return



@numba.njit() # eventually should just inline this wherever
def xnp_cluster_list(c, clusters):
    return np.argwhere(clusters[c,:]).flatten()

@numba.njit(
    locals={'tot_grad': numba.float32[:],
            'eqm_pos' : numba.float32[:],
            'sum_springs' : numba.float32,
            'n_springs'   : numba.int64,
           }
)
def xdo_clustering_mult(idx, idx_cl, clusters, springs, pts):
    """ return target+grad info for idx w/ springs to  multiple clusters. """
    tot_grad = np.zeros(pts.shape[1],dtype=np.float32)
    eqm_pos  = np.zeros(pts.shape[1],dtype=np.float32)
    sum_springs = 0.0 #np.sum(np_springs)
    n_springs = 0

    # calculate equilibrium spring-weight target position
    #           and total spring-weighted gradient
    # every cluster exerts a force ind't of cluster size
    for cc in idx_cl:
        cluster_pts  = xnp_cluster_list(cc, clusters) # perhaps faster/vectorizable
        # Note: 1. Each cluster centroid gets its spring regardless
        #          of how populated the cluster is.
        #       2. Size 1 cluster get included
        #       3. Spring constant doubles as weighting factor -- this
        #          might not hold for other spring force models!
        if len(cluster_pts) > 0:
            cl_avg = np_mean_axis_0( pts[cluster_pts,:] )
            #print(f"{cc=} {cluster_pts=}{cl_avg=}")
            # Now update sums for equilibrium posn and total gradient
            # springs -> force gradient -> non-overshooting 'delta' movement
            spring = springs[cc]
            sum_springs += spring
            n_springs += 1
            eqm_pos += spring * cl_avg
            vec = cl_avg - pts[idx,:]      # vector toward cluster center
            tot_grad += spring * vec       # only for kr^2 spring physics
    #
    eqm_pos /= sum_springs

    return (n_springs, eqm_pos, tot_grad)

#@numba.njit("void(i8, f4[:,:], boolean[:,:], f4[:], f8, f8)")
def xdo_clustering_ipts(idx, pts, clusters, springs, lr=1.0, mindist=0.01):
    #assert len(pts.shape) == 2
    n_samples = pts.shape[0]
    n_clust = springs.shape[0] # it is a vector, one per cluster
    #  otherwise n_clust = np.max(clusters) + 1
    #n_clust = len(np_cluster_lists)
    #assert n_clust > 0
    #assert len(springs) == len(np_cluster_lists)
    #np_cluster_lists = [np.array(clist,np.int32) for clist in cluster_lists]
    #np_springs = np.array(springs)
    idx_in = clusters[:,idx]
    idx_cl = np.argwhere(idx_in).flatten()

    if len(idx_cl) < 1:
        return

    elif len(idx_cl) == 1:
        # Separate out an easy case (idx in single cluster)
        c = idx_cl[0]
        cluster_pts  = xnp_cluster_list(c, clusters)
        # pts[idx] in a cluster of size 1 is already at the cluster centroid
        if len(cluster_pts) > 1:
            centroid = np_mean_axis_0( pts[cluster_pts,:] )  # cluster centroid
            pt0 = pts[idx,:].copy()
            xdo_clustering_ipts_toward(idx, pts, centroid,
                                       lr, springs[c], mindist=mindist)
            print(f"{idx=} {c=} {centroid=}\n{pt0=} --> {pts[idx]}")
        return

    #else: idx attracted to multiple clusters... len(idx_cl) > 1 ... rare?
    # jit test:
    (n_springs, eqm_pos, tot_grad) = xdo_clustering_mult(idx, idx_cl, clusters, springs, pts)
    
    if n_springs==0:  # all clusters empty? No-op
        return

    if True:
        # multi-cluster effective spring constant, to re-use "_toward" code...
        #  This way re-uses existing code but redoes some vector calcs
        eqm_spring = np.linalg.norm(tot_grad) / np.linalg.norm(eqm_pos-pts[idx])
        #xdo_clustering_ipts_toward0(idx, pts, eqm_pos, lr, eqm_spring, maxfrac=1.0)
        xdo_clustering_ipts_toward(idx, pts, eqm_pos, lr, eqm_spring, mindist=0.0)
    elif False:
        # Actually, with mindist 0 the calc is very simple, so longhand:
        
        # The calculation with mindist:
        #vec = target - pt      # vector toward cluster center
        #vecnorm = np.linalg.norm(vec)
        #if vecnorm > 1e-5:
        #    bar = max(vecnorm - mindist, 0.0) # max move dist (along vec/vecnorm)
        #    mv = min(bar, lr * bar * spring)
        #    pts[idx,:] += (mv/vecnorm) * vec
        #return
        
        # without minidst (i.e. bar = vecnorm)
        vec = eqm_pos - pts[idx]      # vector toward spring-weighted equilibrium
        vecnorm = np.linalg.norm(vec)
        if vecnorm > 1e-5:
            mv = lr * np.linalg.norm(tot_grad)  # lr : force --> distance
            mv = min(mv, vecnorm)
            pts[idx,:] += (mv/vecnorm) * vec
            #
            #pts[idx,:] += ((min(vecnorm, lr*np.linalg.norm(tot_grad)) / vecnorm) * vec
            #
            #pts[idx,:] += (min(1.0, lr*np.linalg.norm(tot_grad))) * vec
    else: # "longhand" (so simple for mindist=0.0)
        vec = eqm_pos - pts[idx]      # vector toward spring-weighted equilibrium
        mv = lr * np.linalg.norm(tot_grad)
        #pts[idx,:] += min(1.0, mv) * vec
        if mv < 1.0:
            pts[idx,:] += mv * vec
        else:
            pts[idx,:] = eqm_pos
    
    return


#@numba.njit()
def xdo_clustering_pts(pts, cluster2d, springs, lr=1.0, mindist=0.01):
    assert len(pts.shape) == 2
    n_samples = pts.shape[0]
    # one round of clustering every point once
    for idx in range(n_samples):
        xdo_clustering_ipts(idx, pts, cluster2d, springs, lr, mindist)
    return

#
# -------------- python (non-numba) --------------
# These show how to convert python args to expected numba-compliant types
#

# for completeness, since this might also become a 'mk_FOO' jit-fn-generator
def do_clustering_ipts_py(idx, pts, cluster_lists, springs, lr, mindist):
    """ python-ish front-end, with appropriate setup for just internals.

        pts: array[n_samples,dim]  sample ~ "idx"
        clusters: list-of-lists ~ (cluster, idx)
        springs: python list of numbers (np array[:]
        lr, mindist: python numbers
    """
    idx = int(idx)
    lr = float(lr)
    mindist = float(mindist)
    assert len(pts.shape) == 2
    assert len(cluster_lists) == len(springs)
    n_samples = pts.shape[0]
    assert idx < pts.shape[0]
    
    #np_cluster_lists = [np.array(clist,np.int32) for clist in cluster_lists]
    n_clust = len(cluster_lists)

    if n_samples==0 or n_clust==0:
        return
    
    #
    # generalization:  clusters[c,idx] is True IFF idx is in cluster c
    clusters = np.full((n_clust, n_samples), False, dtype=bool)
    for (c,members) in enumerate(cluster_lists):
        for m in members:
            clusters[c][m] = True
            
    springs = np.array(springs, dtype=np.float32)  # don't need python float64 default
    # no negative or inf or nan springs
    assert np.all(springs >= 0.0)  # actually nan is also NOT >= 0 so elision...
    assert np.count_nonzero((springs == np.inf) | (springs == np.nan)) == 0
    
    # pts is to be modified -- do not create a copy!
    
    # invoke jit fn (or create and return it)
    xdo_clustering_pts( pts, clusters, springs, lr=lr, mindist=mindist )

    return

# python "frontend-to-jit" demo for cell output
def do_clustering_pts_py(pts, cluster_lists, springs, lr, mindist):
    """ python-ish front-end, with appropriate setup for just internals.

        pts: array[n_samples,dim]  sample ~ "idx"
        clusters: list-of-lists ~ (cluster, idx)
        springs: python list of numbers
        lr, mindist: python numbers
    """
    lr = float(lr)
    mindist = float(mindist)
    assert len(pts.shape) == 2
    assert len(cluster_lists) == len(springs)
    n_samples = pts.shape[0]
    
    #np_cluster_lists = [np.array(clist,np.int32) for clist in cluster_lists]
    n_clust = len(cluster_lists)
    
    if n_samples==0 or n_clust==0:
        return
    
    clusters = np.full((n_clust, n_samples), False, dtype=bool)
    for (c,members) in enumerate(cluster_lists):
        for m in members:
            clusters[c][m] = True
            
    springs = np.array(springs, dtype=np.float32)  # don't need python float64 default
    # no negative or inf or nan springs
    assert np.all(springs >= 0.0)  # actually nan is also NOT >= 0 so elision...
    assert np.count_nonzero((springs == np.inf) | (springs == np.nan)) == 0
    
    # pts is to be modified -- do not create a copy!
    
    if False:
        print(f"{numba.typeof(pts)=} {pts.shape=}")
        print(f"{numba.typeof(clusters)=}")
        print(f"{numba.typeof(springs)=}")
        print(f"{numba.typeof(lr)=}")
        print(f"{numba.typeof(mindist)=}")
    
    # invoke jit fn (or create and return it)
    xdo_clustering_pts( pts, clusters, springs, lr=lr, mindist=mindist )
    
    return

#
# --------------------- Inputs -----------------
#

n_samples = 8
cluster_lists = [[1,2,3], [2,7,6]]
#           avg    2.0      5.0
# each of the 2 clusters has a spring constant
springs = np.array([0.8, 2.0], dtype=np.float32)
lr = 1.0   # learning rate (gradient multiplier) for spring forces

pts = np.ndarray((n_samples,2), dtype=np.float32)
for i in range(n_samples):
    pts[i,:] = (float(i),float(i%3))


#np_cluster_lists = [np.array(clist,np.int32) for clist in cluster_lists]
#n_clust = len(np_cluster_lists)
## numba-friendly: single full-sized array
##  TODO: compressed data,indptr,indices versions cluster and its transpose
#
## clusters[ c, idx ], for c in [0,n_clust] and idx in [0,n_samples),
##         is True IFF idx is in cluster c
#clusters = np.full((n_clust, n_samples), False, dtype=bool)
#for (c,members) in enumerate(np_cluster_lists):
#    for m in members:
#        clusters[c][m] = True
#print(f"{numba.typeof(clusters)=}")

#
# ------------------------ python/jit test -------------
#
mindist=0.2 # support TBD (need to rething equations)

#xdo_clustering_pts(pts, clusters, springs, lr=lr, mindist=0.2)
do_clustering_pts_py(pts, cluster_lists, springs, lr=lr, mindist=0.2)

print("final positions (one epoch)")
for i in range(pts.shape[0]):
    print(f"{i=} pt = {pts[i,:]}")

try:
    pts_ref
    assert np.allclose(pts, pts_ref)
    print("Good: matched pts_ref")
except NameError:
    pass

print("Goodbye!")


Let's demo python clustering internals (jit version)
This cell is the precursor a new file umap/constrain_clust.py

idx=1 c=0 centroid=array([2., 1.], dtype=float32)
pt0=array([1., 1.], dtype=float32) --> [1.64 1.  ]
idx=3 c=0 centroid=array([2.9479363, 0.6666667], dtype=float32)
pt0=array([3., 0.], dtype=float32) --> [2.9708064  0.37381905]
idx=6 c=1 centroid=array([5.7346034, 0.6666667], dtype=float32)
pt0=array([6., 0.], dtype=float32) --> [5.808576   0.48084953]
idx=7 c=1 centroid=array([5.6707954 , 0.82694983], dtype=float32)
pt0=array([7., 1.], dtype=float32) --> [5.8691216 0.8527701]
final positions (one epoch)
i=0 pt = [0. 0.]
i=1 pt = [1.64 1.  ]
i=2 pt = [4.2038097 1.       ]
i=3 pt = [2.9708064  0.37381905]
i=4 pt = [4. 1.]
i=5 pt = [5. 2.]
i=6 pt = [5.808576   0.48084953]
i=7 pt = [5.8691216 0.8527701]
Good: matched pts_ref
Goodbye!


## FINALLY have some *mk_FOO*
### jitted cluster-funcs

#### What's the point?
*umap-constraints, short-n-sweet* jitted "lambda functions" can be used
as `data_constrain=` or `output_constrain=` args to umap euclidean
embedding calls (`fit`, `fit_transform`)

The umap embedding should *pull together* our **user-specified clusters**,
with whatever effects these have on next-nearest members, etc. .

In [10]:
# Now try above cell as part of umap-constraints...
# This call can run standalone:
#   It invokes a jit-function mk_FOO,
#   and then invokes it
#   (umap.UMAP not involved)
#
#   I found the mk_FOO marks some local array refs 'readonly array'.
#   It was easiest to simply remove the overly constraining specs
#   and let @numba.njit() autogenerate the required signatures
#   in umap/constrain_clust.py
#
import numpy as np
import numba

from umap.constrain_clust import mk_clustering_pts, mk_clustering_ipts

#
# ------- Inputs (python) -----------------------------
n_samples = 8
cluster_lists = [[1,2,3], [2,7,6]]
#           avg    2.0      5.0
# each of the 2 clusters has a spring constant
springs = np.array([0.8, 2.0], dtype=np.float32)
lr = 1.0   # learning rate (gradient multiplier) for spring forces

pts = np.ndarray((n_samples,2), dtype=np.float32)
for i in range(n_samples):
    pts[i,:] = (float(i),float(i%3))

mindist = 0.2

#
# ------------ create & call jit constraint -----------
# Now that we have the python-ish inputs set up,
# create a jit constraint function of simplified signature
#
mkdo_pts = mk_clustering_pts(pts, cluster_lists, springs, lr=lr, mindist=mindist)

# and just invoke it as
mkdo_pts(pts)  # simplified call signature, other args are now mk_FOO locals
#    without re-supplying all the args like
#        do_clustering_pts_py(pts, cluster_lists, springs, lr=lr, mindist=0.2)

#
# ------------- output --------------------------------
# check we got the same "output" (in-place modification of our pts array)
#
print("final positions (one epoch)")
for i in range(pts.shape[0]):
    print(f"{i=} pt = {pts[i,:]}")

try:
    pts_ref
    assert np.allclose(pts, pts_ref)
    print("Good: matched pts_ref")
except NameError:
    pass

print("Goodbye!")

final positions (one epoch)
i=0 pt = [0. 0.]
i=1 pt = [1.64 1.  ]
i=2 pt = [4.2038097 1.       ]
i=3 pt = [2.9708064  0.37381905]
i=4 pt = [4. 1.]
i=5 pt = [5. 2.]
i=6 pt = [5.808576   0.48084953]
i=7 pt = [5.8691216 0.8527701]
Good: matched pts_ref
Goodbye!
