### constraints.py demo
This runs through several short demos of how to slap
constraints onto the lo-dimensional embedding of UMAP.

Dataset-independent constraints can be supplied to the constructor;
dataset-dependent ones, to the 'fit' or 'fit_transform' function.
The latter have a first argument that is always the index of the point.

In [None]:
from sklearn.datasets import load_iris

import numpy as np
import umap

iris = load_iris()
umapper0 = umap.UMAP(
    n_neighbors=50, learning_rate=0.5, random_state=12345, min_dist=0.001,
    init="random", n_epochs=1,
)
print("Generating an initial embedding...")
emb0 = umapper0.fit_transform(iris.data)

### demo 1- and 2-d pin mask, data_constrain=array

print("Pinning embeddings of pts 13 and 14 to [-5,0] and [5,0]")
# pin embeddings of two data (13 and 14) to left and right of origin
pin_mask = np.ones_like(emb0)
pin_mask[13] = 0.0
pin_mask[14] = 0.0
emb0[13] = [-5.0, 0]
emb0[14] = [+5.0, 0]
print("Specify 'init' embedding for umapper2")
umapper1 = umap.UMAP(
    n_neighbors=50, learning_rate=0.5, random_state=12346, min_dist=0.001,
    init=emb0, n_epochs=2,
)
print("Embed with pin_mask[13] and pin_mask[14] zero-vectors")
emb1 = umapper1.fit_transform(iris.data, data_constrain=pin_mask)
print("emb0[11:15]\n",emb0[11:15])
print("emb1[11:15]\n",emb1[11:15])

# Now demo pinning point 11 with a one-dimensional pin-mask
pin_mask1d = np.ones( (emb0.shape[0]), dtype=np.int32)
pin_mask1d[13] = 0.0
pin_mask1d[13] = 0.0
emb0[13] = [-4.0, 0]
emb0[14] = [+4.0, 0]
print("Embed with pin_mask[13] and pin_mask[14] zero-values")
emb1 = umapper1.fit_transform(iris.data, data_constrain=pin_mask)
print("emb0[11:15]\n",emb0[11:15])
print("emb1[11:15]\n",emb1[11:15])
print("\nGoodbye")

### demo output_constrain and data_constrain
A UMAP constraints function `y_bounder` restains *y*-values of any point to -5..+5
This is independent of data set, so it is an `output_constrain` argument
to the UMAP constructor.

A `data_constrain` function is used to pin points 13 and 14 to specific positions.

#### Tech note:
- numba jit does **not** imbue @jitclass with sufficient support for callable objects.
- So instead define `mk_FOO` functions with state in a local variable, that jit-compile
  and return a dynamically jitted function.

- This is not the nicest, but heh, it works.  see`mk_bound_y_values` for an example

In [None]:
print("Pinning embeddings of pts 13 and 14 to [-2,0] and [2,0]")
# pin embeddings of two data (13 and 14) to left and right of origin
# via a custom constraint.  inf get no-op, other values get fixed
import umap.constraints as con
import numba
infs0 = np.full_like(emb0, np.float32(np.inf), dtype=np.float32)
infs0[13,:] = [-2.0,0]
infs0[14,:] = [+2.0,0]
@numba.njit()
def constraint_idx_pt0(idx,pt):
    con.freeinf_pt(idx,pt, infs0)
# this function DOES depent on idx of pt
    
constraints = {
    'idx_pt': constraint_idx_pt0,
}
# optional: set up the values to agree
#emb0[13] = [-2.0, 0]
#emb0[14] = [+2.0, 0]
# Here is the "move all points" version of con.freeinf
con.freeinf_pts(emb0, infs0)

# Also demo a non-indexed (UMAP constructor) constraint,
# that is independent of the iris.data.
# Here, let's constrain 'y' to be within -5.0, +5.0
# without, we got:
# emb2[11:15]
#  [[ 4.804111   3.430179 ]
#  [ 4.2598176  7.352637 ]     # <-- 'y' is big here
#  [-2.         0.       ]
#  [ 2.         0.       ]]
# With 'output_constrain':
# emb2[11:15]
# [[ 7.0536     2.275853 ]
# [ 4.8699265  1.9713866]      # y constrained
# [-2.         0.       ]
# [ 2.         0.       ]]
# So we pass illegal range for x, and legal range for y
def mk_bound_y_values(lo, hi):
    bound_los = np.array([+999.,lo], dtype=np.float32)
    bound_his = np.array([-999.,hi], dtype=np.float32)
    @numba.njit()
    def bound_y_values(pt):
        return con.dimlohi_pt(pt, bound_los, bound_his)
    # this function does NOT depend on 'idx' arg
    return bound_y_values
y_bounder = mk_bound_y_values(-5.0,+5.0)

# CHECK: x is unaffected, y range is bounded ...
pt = np.array([1.,2.], dtype=np.float32)
print("pt0",pt); y_bounder(pt); print("pt0",pt); assert pt[0] == 1.; assert pt[1] == 2.
pt[0] = 10.; pt[1] = 10.
print("pt1",pt); y_bounder(pt); print("pt1",pt); assert pt[0] == 10.; assert pt[1] == 5.
pt[0] = -10.; pt[1] = -10.
print("pt2",pt); y_bounder(pt); print("pt2",pt); assert pt[0] == -10.; assert pt[1] == -5.

assert np.all(emb0[13] == [-2.0,0])
assert np.all(emb0[14] == [+2.0,0])
print("Specify 'init' embedding for umapper2")
# output_constraint allowed keys: pt grad epoch_pt final_pt
umapper2 = umap.UMAP(
    n_neighbors=50, learning_rate=0.5, random_state=12346, min_dist=0.001,
    output_constrain = { 'pt': y_bounder }, # any pt, ind't of dataset
    init=emb0, n_epochs=4,
)
print("Embed with pin_mask[13] and pin_mask[14] zero-vectors")
# ... whereas data_constrain depends on the dataset (point number is important)
# data_constraing allowed keys: idx_pt, idx_ipts, idx_grad
emb2 = umapper2.fit_transform(iris.data, data_constrain=constraints)
assert np.all(emb2[:,1] >= -5.0) and np.all(emb2[:,1] <= 5.0) # output_constrain
print("emb0[11:15]\n",emb0[11:15])
print("emb2[11:15]\n",emb2[11:15])
print("\nGoodbye")

### user-defined constraint
You can invent your own constraints.  Here we initialize and keep points
13 and 14 on the x=y line.  We chose to do it here with a gradient-style constraint.

In [None]:
print("grad constraint 13 and 14 on line y=x")
# this one has little help from umap.constraints.py,
# so define the numba constraint functions here:
@numba.njit()
def y_eq_x_pt(idx, pt):
    avg = np.sum(pt) / pt.shape[0]
    pt.fill(avg)
@numba.njit()
def y_eq_x_grad(idx, pt, grad):
    # if we cannot assume pt satisfies constraints:
    #y_eq_x_pt(idx, pt)  # put pt onto 45-degree line
    # now tangent plane projection
    y_eq_x_pt(idx, grad) # gradient also lies on the 45-degree line

# pin embeddings of two data (13 and 14) to all-coords-equal line
constraints = {
    'idx_grad': y_eq_x_grad,
}
# init 13 and 14 to 45-degree line
emb0[13,:] = [-1.0, -1.0]
emb0[14,:] = [+1.0, +1.0]
umapper3 = umap.UMAP(
    n_neighbors=50, learning_rate=0.5, random_state=12346, min_dist=0.001,
    init=emb0, n_epochs=2,
)
print("Embed with pin_mask[13] and pin_mask[14] zero-vectors")
emb3 = umapper3.fit_transform(iris.data, data_constrain=constraints)
print("emb0[11:15]\n",emb0[11:15])
print("emb3[11:15]\n",emb3[11:15])
np.testing.assert_allclose(emb3[13,0], emb3[13,1])
np.testing.assert_allclose(emb3[14,0], emb3[14,1])
print("\nGoodbye")

### output_constrain to box, and data_constrain with some springs and some pins
(Actually, pinning could also be done with a spring constant of $\infty$.)

In [None]:
import umap.constraints2 as con
import numba
print("spring force constraint 13 and 14 pulled towards (0,3), (0,-3)")
# This shows a "soft" constraint, with no point-projection step,
# and instead of projecting onto tangent space,
# the gradients get modified by a simple user force.

pin_idx = np.array([13,14], dtype=np.int32)
springs = np.array([0.1, 0.01], dtype=np.float32)   # note np==inf **would** have projection constraint
pin_pos = np.array([[0,3], [0,-3]], dtype=np.float32)
print("pin_idx (anchors)         ", numba.typeof(pin_idx), "\n", pin_idx)
print("springs (force constants) ", numba.typeof(springs),  "\n", springs)
print("pin_pos (anchor positions)", numba.typeof(pin_pos), "\n", pin_pos)
# pin point 12, but not via an infinite force spring... for show
emb0[12,:] = [0.5, 0.5]
my_pinned = np.array([12], dtype=np.int32)
#@numba.njit
#def pin12_grad(idx,pt):
#    con.pinindexed_grad(idx,pt,grad,  my_pinned)
@numba.njit
def my_springs_and_pins(idx, pt, grad):
    # pt is unconstrained
    con.springindexed_grad(idx,pt, grad, pin_idx, pin_pos, springs)
    # this is equivalent to an infinite force spring, but just for show...
    con.pinindexed_grad(idx,pt, grad, my_pinned)
# Note: we can only supply one function per dictionary key for constraints
# This is unfortunate.  A list/tuple might be OK in upstream numba versions

# second constraint (every pt inside simple box)
my_los = np.full(2, -5.0, dtype=np.float32) # x and y low bound <- -5.0
my_his = np.full(2, +5.0, dtype=np.float32)
#original:
#@numba.njit
#def my_box(idx, pt):  # 'idx_pt' argument list
#    con.dimlohi_pt(pt, my_los, my_his)
# Note: idx is not needed -- this can now be supplied as an
# 'output_constrain' value, to the UMAP constructor.
@numba.njit
def my_box2(pt):
    con.dimlohi_pt(pt, my_los, my_his)
# pin embeddings of two data (13 and 14) to all-coords-equal line
constraints = {
    #'idx_pt':   my_box, # better: this does not depend on 'idx',
    #  so it's better to give 'output_constrain=my_box2' in UMAP constructor
    'idx_grad': my_springs_and_pins,
}
# init 13 and 14 "anywhere"
emb0[13,:] = [0, +5]
emb0[14,:] = [0, -5]
print("emb0[11:15] before my_box\n",emb0[11:15])
umapper4 = umap.UMAP(
    n_neighbors=50, learning_rate=0.5, random_state=12347, min_dist=0.001,
    output_constrain = { 'pt': my_box2 },
    init=emb0, n_epochs=4,
)
print("Embed with pin_mask[13] and pin_mask[14] zero-vectors")
emb4 = umapper4.fit_transform(iris.data, data_constrain=constraints)
print("emb0[11:15]\n",emb0[11:15])
print("emb4[11:15]\n",emb4[11:15])
print("pin_pos[0]", pin_pos[0], "distance:",np.linalg.norm(emb4[13] - pin_pos[0]))
print("pin_pos[1]", pin_pos[1], "distance:",np.linalg.norm(emb4[14] - pin_pos[1]))
print("\nGoodbye")

### First we'll try out new constraint FUNCTION_ipts(idx, pts, ...)
Uggh. new **_ipts** suffix to help numba disambiguate FUNCTION_pts(idx,pt,...)

(I could not get numba to run-time dispatch this correctly)

Following cell copies an early simple test, but with 'idx_ipts' constraint
when calling umap `fit_transform`(...,data_constrain=dictionary)

In [None]:
print("Pinning embeddings of pts 13 and 14 to [-2.2,0] and [2.2,0]")
# pin embeddings of two data (13 and 14) to left and right of origin
# via a custom constraint.  inf get no-op, other values get fixed
import umap.constraints as con
import numba
infs0 = np.full_like(emb0, np.float32(np.inf), dtype=np.float32)
infs0[13,:] = [-2.2,0]
infs0[14,:] = [+2.2,0]
@numba.njit("f4[:](i8, f4[:,:])")
def constraint_idx_pt1(idx,pts):
    # Here 'pts' MUST be 2D (full point cloud)
    #   (probably equiv. to constrain_idx_pt0(idx, pts[idx,:])
    return con.freeinf_ipts(idx,pts, infs0)
# this function DOES depent on idx of pt
    
constraints = {
    'idx_ipts': constraint_idx_pt1,  # NEW _ipts suffix ==> alt call signature
    # actually this is most flexible since the full "tail_embedding"
    # point cloud is actually always available.
}
emb0[13] = [-3.14, 3.14] # init with far away coords
emb0[14] = [+3.14, 3.14]
# optional: set up the values to agree
#emb0[13] = [-2.0, 0]
#emb0[14] = [+2.0, 0]
#assert np.all(emb0[13] == [-2.0,0])
#assert np.all(emb0[14] == [+2.0,0])
# Here is the "move all points" version of con.freeinf
con.freeinf_pts(emb0, infs0)

# Also demo a non-indexed (UMAP constructor) constraint,
# that is independent of the iris.data.
# Here, let's constrain 'y' to be within -5.0, +5.0
# without, we got:
# emb2[11:15]
#  [[ 4.804111   3.430179 ]
#  [ 4.2598176  7.352637 ]     # <-- 'y' is big here
#  [-2.         0.       ]
#  [ 2.         0.       ]]
# With 'output_constrain':
# emb2[11:15]
# [[ 7.0536     2.275853 ]
# [ 4.8699265  1.9713866]      # y constrained
# [-2.         0.       ]
# [ 2.         0.       ]]
# So we pass illegal range for x, and legal range for y
def mk_bound_y_values(lo, hi):
    bound_los = np.array([+999.,lo], dtype=np.float32)
    bound_his = np.array([-999.,hi], dtype=np.float32)
    @numba.njit()
    def bound_y_values(pt):
        return con.dimlohi_pt(pt, bound_los, bound_his)
    # this function does NOT depend on 'idx' arg
    return bound_y_values
y_bounder = mk_bound_y_values(-5.0,+5.0)

# CHECK: x is unaffected, y range is bounded ...
pt = np.array([1.,2.], dtype=np.float32)
print("pt0",pt); y_bounder(pt); print("pt0",pt); assert pt[0] == 1.; assert pt[1] == 2.
pt[0] = 10.; pt[1] = 10.
print("pt1",pt); y_bounder(pt); print("pt1",pt); assert pt[0] == 10.; assert pt[1] == 5.
pt[0] = -10.; pt[1] = -10.
print("pt2",pt); y_bounder(pt); print("pt2",pt); assert pt[0] == -10.; assert pt[1] == -5.

print("Specify 'init' embedding for umapper2")
umapper5 = umap.UMAP(
    n_neighbors=50, learning_rate=0.5, random_state=12346, min_dist=0.001,
    output_constrain = { 'pt': y_bounder }, # any pt, ind't of dataset
    init=emb0, n_epochs=4,
)
print("Embed with pin_mask[13] and pin_mask[14] zero-vectors")
# ... whereas data_constrain depends on the dataset (point number is important)
emb5 = umapper5.fit_transform(iris.data, data_constrain=constraints)
assert np.all(emb5[:,1] >= -5.0) and np.all(emb5[:,1] <= 5.0) # output_constrain
print("emb0[11:15]\n",emb0[11:15])
print("emb5[11:15]\n",emb5[11:15])
assert np.allclose(emb5[13], [-2.2,0])
assert np.allclose(emb5[14], [+2.2,0])
print("\nGoodbye")

### We were asked for a clusterer.
Every time the cluster membership changes, you create a clustering function.

-API **mk_clusters(clusters, springs, *, mindist=None, mults=None, wm=None)**
  - input as list-of-lists `clusters[c][i]`: (i) cluster(`c`=0,1,...), (ii) points (indices `idx`) in cluster.
    - also internally remember map of index to clusters: `clusts[idx] = list-of-cluster(`c`)`
    - also internally calculate cluster size `cl_n[c]`
    - and cluster average position `cl_avg[c]`
    - opt. store both as numpy sparse matrix + sparse transpose ?
  - input spring force `springs[c]` and opt. learning rate lr=1.0
  - TODO: input optional de-neighboring `mults[c]` (default 1.0) and modify umap weight matrix `wm`
  - input optional mindist for short-range cluster repulsive force (Morse function variant?)
  - output jit func (`idx`, `pts`)
    - for `c` in `clusts[idx]`:  # all clusters c of idx
      - `delta` of pt moved toward `cl_avg[c]` with `springs[c]`
    - TODO (maybe)
      - quick-update `cl_avg[c]` by `(delta / cl_n[c])`
      - opt. every 100'th call, recalculate exact cluster averages instead of `delta`-update
      - opt. short-range repulsion to other neighbors of `idx` in cluster `c`

- **do** support a pt assigned to multiple clusters
  - calc eqm_pos and total spring force from all clusters
  - move there *without* applying `mindist`

In [None]:
print("Let's demo python clustering internals (no-debug version)")
#
# --------------------- Inputs -----------------
#
n_samples = 8
cluster_lists = [[1,2,3], [2,7,6]]
#           avg    2.0      5.0
# each of the 2 clusters has a spring constant
springs = np.array([0.8, 2.0])
lr = 1.0   # learning rate (gradient multiplier) for spring forces

pts = np.ndarray((n_samples,2), dtype=np.float32)
for i in range(n_samples):
    pts[i,:] = (float(i),float(i%3))

n_samples = 8
cluster_lists = [[1,2,3], [2,7,6]]
#           avg    2.0      5.0
# each of the 2 clusters has a spring constant
springs = np.array([0.8, 2.0])
lr = 1.0   # learning rate (gradient multiplier) for spring forces

pts = np.ndarray((n_samples,2), dtype=np.float32)
for i in range(n_samples):
    pts[i,:] = (float(i),float(i%3))




#
# ---------------------- Functions --------------
#

# begin with 2 very simple helpers
@numba.njit()
def np_mean_axis_0(pts):
    cl_avg = np.zeros((pts.shape[1]), dtype=np.float32)
    for pt in range(pts.shape[0]):
        cl_avg += pts[pt,:]
    assert pts.shape[0] > 0
    cl_avg /= pts.shape[0]
    return cl_avg

@numba.njit() # eventually should just inline this wherever
def xnp_cluster_list(c, clusters):
    return np.argwhere(clusters[c,:]).flatten()

@numba.njit()
def spring_mv_v0(pt, target, spring, lr=1.0, mindist=0.0):
    vec = target - pt      # vector toward cluster center
    if mindist > 0.0:
        vecnorm = np.linalg.norm(vec)
        bar = max(vecnorm - mindist, 0.0) # max move dist (along vec/vecnorm)
        #delta = (min(lr * spring, 1.0) * bar / vecnorm) * vec
        delta = (min(bar, bar * lr * spring) / vecnorm) * vec
    else: # mindist == 0.0
        delta = min(1.0, lr*spring) * vec
    return delta

#@numba.njit()
def do_clustering_ipts(idx, pts, cluster_lists, springs, *, lr=1.0, mindist=0.01):
    assert len(pts.shape) == 2
    n_samples = pts.shape[0]
    n_clust = len(cluster_lists)
    assert n_clust > 0
    assert len(springs) == len(cluster_lists)
    np_cluster_lists = [np.array(clist,np.int32) for clist in cluster_lists]
    np_springs = np.array(springs)
    #
    # cluster membership as bool array[cluster index][point index]
    #
    clusters = np.full((n_clust, n_samples), False, dtype=bool)
    for (c,members) in enumerate(cluster_lists):
        for m in members:
            clusters[c][m] = True
    
    idx_in = clusters[:,idx]
    idx_cl = np.argwhere(idx_in).flatten()
    #
    if len(idx_cl) < 1:
        return

    elif len(idx_cl) == 1:
        # Separate out an easy case (idx in single cluster)
        c = idx_cl[0]
        cluster_pts  = np_cluster_lists[c] # perhaps faster/vectorizable
        if len(cluster_pts) <= 1:
            print("noop: cluster empty or of size 1")
            return

        # cl_avg ~ cluster center
        cl_avg = np.mean(pts[cluster_pts,:], axis=0)
        cl_avg2 = np_mean_axis_0( pts[cluster_pts,:] )  # cluster centroid
        #print(f"{idx=} {cl_avg=}\n{cl_avg2=}")

        delta = spring_mv_v0(pts[idx], cl_avg, springs[c], lr, mindist)
        pts[idx,:] += delta

        # A quick approx update of cl_avg feasible, but don't use it for now
        #cl_avg_new = cl_avg + delta/len(cluster_pts)
        # It's trickier for len(idx_cl) > 1 (more updates needed!)

        return

    assert( len(idx_cl) > 1 )
    # Easy: if point is in two clusters, pt moves first toward one,
    #       then toward next, in SAME pattern.
    #
    # Actual: 1st determine NET gradient direction from
    #         weighted sum of individual gradient forces,
    #         with movement not to overshoot weighted "equilibrium"
    #         position derived from weighted avg of cluster centers
    #
    tot_grad = np.zeros(pts.shape[1],dtype=np.float32)
    eqm_pos  = np.zeros(pts.shape[1],dtype=np.float32)
    sum_springs = 0.0 #np.sum(np_springs)
    n_springs = 0
    if False: # first way, 2 loops
        cl_avgs = []
        for c in idx_cl:
            cluster_pts  = np_cluster_lists[c] # perhaps faster/vectorizable
            if len(cluster_pts) < 1:
                cl_avgs.append(None)
                print("noop: cluster empty or of size 1")
                continue
            # cl_avg ~ cluster center
            cl_avg = np.mean(pts[cluster_pts,:], axis=0)
            cl_avgs.append(cl_avg)
            n_springs += 1

        tot_grad = np.zeros(pts.shape[1],dtype=np.float32)
        eqm_pos  = np.zeros(pts.shape[1],dtype=np.float32)
        sum_springs = np.sum(np_springs) # correct if all clusters non-empty

        # calculate equilibrium spring-weight target position
        #           and total spring-weighted gradient
        # every cluster exerts a force ind't of cluster size
        for (i,c) in enumerate(idx_cl):
            if cl_avgs[i] is not None: # i.e. the cluster is non-empty, has a centroid
                # springs -> force gradient -> non-overshooting 'delta' movement
                eqm_pos += np_springs[c] * cl_avgs[i]
                vec = cl_avgs[i] - pts[idx,:]      # vector toward cluster center
                tot_grad += np_springs[c] * vec

        eqm_pos /= sum_springs
    
    else: # shorter way: combine loops

        # calculate equilibrium spring-weight target position
        #           and total spring-weighted gradient
        # every cluster exerts a force ind't of cluster size
        for c in idx_cl:
            cluster_pts  = xnp_cluster_list(c, clusters) # perhaps faster/vectorizable
            # Note: 1. Each cluster centroid gets its spring regardless
            #          of how populated the cluster is.
            #       2. Size 1 cluster get included
            #       3. Spring constant doubles as weighting factor -- this
            #          might not hold for other spring force models!
            if len(cluster_pts) > 0:
                cl_avg = np_mean_axis_0( pts[cluster_pts,:] )
                #print(f"{c=} {cl_avg=}")
                # Now update sums for equilibrium posn and total gradient
                # springs -> summed force gradient vectors
                spring = np_springs[c]
                sum_springs += spring
                n_springs += 1
                eqm_pos += spring * cl_avg
                vec = cl_avg - pts[idx,:]      # vector toward cluster center
                tot_grad += spring * vec       # only for kr^2 spring physics

        eqm_pos /= sum_springs

    print(f"{eqm_pos=}")

    if True: # assert
        # tot_grad should automatically be in direction of eqm_pos
        # if math is correct:
        grad_dirn = tot_grad / np.linalg.norm(tot_grad)
        eqm_dirn  = eqm_pos - pts[idx,:]
        eqm_dirn /= np.linalg.norm(eqm_dirn)
        assert( np.allclose(grad_dirn, eqm_dirn) )

    # I don't think mindist should apply to "virtual" avg-of-cluster-centroids
    use_mindist_for_eqm = False
    
    if False and not use_mindist_for_eqm: # orig (no mindist)
        #vec = eqm_pos - pts[idx]      # vector toward spring-weighted equilibrium
        #vecnorm = np.linalg.norm(vec)
        eqm_spring = np.linalg.norm(tot_grad) / np.linalg.norm(eqm_pos-pts[idx])
        vec = eqm_pos - pts[idx]      # vector toward spring-weighted equilibrium
        vecnorm = np.linalg.norm(vec)
        print(f"{eqm_spring=} {eqm_spring*vecnorm=}")

        # at this point we have net gradient and terminal eqm_pos
        # we proceed as before, so that the force is toward eqm_pos
        # but the point never overshoots.
        #
        # Can we use an "effective spring constant" for eqm_pos?
        # or do we need a 'gradnorm'-based updater?
        #
        vec = eqm_pos - pts[idx]      # vector toward spring-weighted equilibrium
        vecnorm = np.linalg.norm(vec)
        gradnorm = np.linalg.norm(tot_grad)
        # spring-physics model:
        #    gradnorm = spring_effx * (displacement=vecnorm)
        spring_effx = gradnorm / vecnorm
        print(f"{vecnorm=} {gradnorm=} {spring_effx=}")
        #   v.0: by simply capping movement to "100% of the way to cluster avg"
        #        for movement to eqm posn, can move all the way?
        pct_move = min(1.0, (lr * gradnorm)/vecnorm)
        #   v.1: even for a move to eqm_pos, use a mindist2 guaranteed "small"
        #mindist2 = min( mindist, 0.1 )
        #pct_move = smoothstep( lr * gradnorm, vecnorm - 0.5*mindist2, vecnorm - mindist2 )
        # above is WRONG
        delta = vec * pct_move
        
        #def spring_mv_v0_plain(pt, target, spring, lr=1.0, mindist=0.0):
        #   vec = target - pt      # vector toward cluster center
        #    if mindist > 0.0:
        #        vecnorm = np.linalg.norm(vec)
        #        bar = max(vecnorm - mindist, 0.0) # max move dist (along vec/vecnorm)
        #        delta = (min(lr * spring, 1.0) * bar / vecnorm) * vec
        #    else: # mindist == 0.0
        #        delta = min(1.0, lr*spring) * vec
        #    return delta
        # with above cancellation, for mindist==0.0, we have just:
        
    elif True and not use_mindist_for_eqm: # orig (no mindist) -- extremely simple!
        # almost too simple to be true
        vec = eqm_pos - pts[idx]      # vector toward spring-weighted equilibrium
        gradnorm = np.linalg.norm(tot_grad)
        delta = min(1.0, lr * gradnorm) * vec
        #                     ^^^^^^^^ replaces 'spring'
        
    else: # with min_dist : now the "target" is not eqm_pos, but a point mindist away from it
        # In principle:
        #eqm_spring = np.linalg.norm(tot_grad) / max(1.e-6, (np.linalg.norm(eqm_pos-pts[idx]) - mindist))
        # But for targets that are "midway" between several clusters, maybe it is best to NOT
        # use mindist
        #delta = spring_mv_v0(pts[idx], eqm_pos, eqm_spring, lr, mindist)
        #
        # or long-hand:
        vec = eqm_pos - pts[idx]            # vector toward spring-weighted equilibrium
        vecnorm = np.linalg.norm(vec)
        bar = max(vecnorm - mindist, 0.0)   # max move dist (along vec/vecnorm)
        gradnorm = np.linalg.norm(tot_grad) # this replaces 'spring' in spring_mv_v0_plain
        delta = (min(lr * gradnorm, 1.0) * bar / vecnorm) * vec


    if True:
        pts[idx,:] = pts[idx,:] + delta
    else: # verbose
        newpos = pts[idx,:] + delta
        print(f"pt {idx} {int(pct_move*100.)}% {pts[idx,:]} --> newpos {newpos}")
        pts[idx,:] = newpos
        print(f"final {pts[idx,:]=}")


    return

#@numba.njit()
def do_clustering_pts(pts, cluster_lists, springs, *, lr=1.0, mindist=0.01):
    assert len(pts.shape) == 2
    n_samples = pts.shape[0]
    # one round of clustering every point once
    for idx in range(n_samples):
        #pt0 = pts[idx,:].copy()
        do_clustering_ipts(idx, pts, cluster_lists, springs, lr=lr, mindist=mindist)
        #print(f"{idx=} {pt0=} --> {pts[idx,:]=}")
    return

print(f"{type(pts)=} {pts.shape=} {pts.dtype=}")
do_clustering_pts(pts, cluster_lists, springs, lr=lr, mindist=0.2)
print("final positions (one epoch)")
for i in range(pts.shape[0]):
    print(f"{i=} pt = {pts[i,:]}")

if False: #pts_ref is not None:
    assert( np.allclose(pts, pts_ref) )
    print("Good: matched pts_ref")
else:
    pts_ref = pts.copy()
print("Goodbye")

In [None]:
import numpy as np
import numba
print("Let's demo python clustering internals (jit version)")
print("This cell is the precursor a new file umap/constrain_clust.py\n")

#
# ---------------------- Functions --------------
# ouch.  these must be jittable now
# numba cannot handle a python list of numpy arrays.
# Let's break things apart to find numba-ready code blocks
#
@numba.njit()
def np_mean_axis_0(pts):
    cl_avg = np.zeros((pts.shape[1]), dtype=np.float32)
    for pt in range(pts.shape[0]):
        cl_avg += pts[pt,:]
    assert pts.shape[0] > 0
    cl_avg /= pts.shape[0]
    return cl_avg

@numba.njit() # approx. void(i8, f4[:,:], i8[:], f8, f8)
#def xdo_clustering_single(idx, pts, cluster_pts, lr, spring):
def xdo_clustering_ipts_toward0(idx, pts, target, lr, spring, maxfrac=0.9):
    """ Non-overshooting move of pts[idx,:] towards cofm(pts[cluster_pts]).
    
        In this version, maxfrac=0.9 is an under-relaxation "don't go all the way".
    """
    if lr*spring > 1e-5:
        vec = target - pts[idx,:]
        vecsz = np.linalg.norm(vec)         # stepsz is distance ~ spring force * lr
        stepsz = lr * (spring * vecsz)
        fracsz = min(maxfrac, stepsz/vecsz) # maxfrac<1 => undershoot
        pts[idx,:] += fracsz * vec
    # v.1  attempt to smooth the transition
    #vec = cl_avg - pts[idx]      # vector toward cluster center
    #vecnorm = np.linalg.norm(vec)
    #if vecnorm <= mindist: # no-op - pt is already within mindist of cl_avg
    #    delta = vec * 0.0
    #else:
    #    grad = np_springs[c] * vecnorm
    #    gradnorm = lr * grad
    #    edge0 = max(vecnorm - 2*mindist, 0.0)
    #    edge1 = vecnorm - mindist
    #    if gradnorm <= edge0:
    #        delta = vec
    #    else:
    #        mvlen = edge0 + (edge1-edge0) * smoothstep( gradnorm, edge0, edge1 )
    #        delta = vec * (mvlen/vecnorm)
    #delta = vec * pct_move
    return
    #
    # compare with multi-cluster case, where we calc tot_grad up front
    # as a spring-weighted sum of force vectors
    #
    #if True:
    #    # at this point we have net gradient and terminal eqm_pos
    #    # we proceed as before, so that the force is toward eqm_pos
    #    # but the point never overshoots.
    #    #   v.0: by simply capping movement to "100% of the way to cluster avg"
    #    #        for movement to eqm posn, can move all the way?
    #    eqm_pos = target
    #    vec = eqm_pos - pts[idx]
    #    vecnorm = np.linalg.norm(vec)
    #    gradnorm = np.linalg.norm(tot_grad) # gradnorm is NOT spring * vecnorm
    #    pct_move = min(1.0, (lr * gradnorm) / vecnorm)
    #    delta = vec * pct_move
    #
    # this can be reproduced by an "effective spring const"
    #   gradnorm = np.linalg.norm(tot_grad)
    # == ?
    #   spring_eff * np.linalg.norm(target-pts[idx])
    # IFF
    #   spring_eff = np.linalg.norm(target-pts[idx]) / np.linalg.norm(tot_grad)
    #

#if False: # for reference
#    def spring_mv_v0_plain(pt, target, spring, lr=1.0, mindist=0.0):
#        vec = target - pt      # vector toward cluster center
#        vecnorm = np.linalg.norm(vec)
#        if mindist > 0.0:
#            bar = max(vecnorm - mindist, 0.0) # max move dist (along vec/vecnorm)
#            delta = (min(lr * spring, 1.0) * bar / vecnorm) * vec
#        else: # mindist == 0.0
#            delta = min(1.0, lr*spring) * vec
#        return delta

@numba.njit() # approx. void(i8, f4[:,:], i8[:], f8, f8)
def xdo_clustering_ipts_toward(idx, pts, target, lr, spring, mindist=0.0):
    """ Non-overshooting move of pts[idx,:] towards cofm(pts[cluster_pts]).
    
        lr: time step (1.0 will move exactly to equilibrium posn if spring==1)
        
        spring : spring constant "force ~ spring * displacement" (parabolic potential)
        
        mindist: to "not go all the way" towards target, but stop mindist away.
    """
    # Using maxfrac:
    #if lr*spring > 1e-5:
    #    vec = target - pts[idx,:]
    #    vecsz = np.linalg.norm(vec)         # stepsz is distance ~ spring force * lr
    #    stepsz = lr * (spring * vecsz)
    #    fracsz = min(maxfrac, stepsz/vecsz) # maxfrac<1 => undershoot
    #    pts[idx,:] += fracsz * vec
    #return
    
    # mindist python move fn:
    #def spring_mv_v0_plain(pt, target, spring, lr=1.0, mindist=0.0):
    #   vec = target - pt      # vector toward cluster center
    #    if mindist > 0.0:
    #        vecnorm = np.linalg.norm(vec)
    #        bar = max(vecnorm - mindist, 0.0) # max move dist (along vec/vecnorm)
    #        delta = (min(lr * spring, 1.0) * bar / vecnorm) * vec
    #  or just delta = (min(bar, bar * lr * spring) / vecnorm) * vec
    #    else: # mindist == 0.0
    #        delta = min(1.0, lr*spring) * vec
    #    return delta
    
    # movement with mindist (patterned after python helper "spring_mv_v0")
    vec = target - pts[idx,:]      # vector toward cluster center
    vecnorm = np.linalg.norm(vec)
    if vecnorm > 1e-5:
        bar = max(vecnorm - mindist, 0.0) # max move dist (along vec/vecnorm)
        #delta = (min(lr * spring, 1.0) * bar / vecnorm) * vec
        mv = min(bar, bar * lr * spring)
        pts[idx,:] += (mv/vecnorm) * vec
    return



@numba.njit() # eventually should just inline this wherever
def xnp_cluster_list(c, clusters):
    return np.argwhere(clusters[c,:]).flatten()

@numba.njit(
    locals={'tot_grad': numba.float32[:],
            'eqm_pos' : numba.float32[:],
            'sum_springs' : numba.float32,
            'n_springs'   : numba.int64,
           }
)
def xdo_clustering_mult(idx, idx_cl, clusters, springs, pts):
    """ return target+grad info for idx w/ springs to  multiple clusters. """
    tot_grad = np.zeros(pts.shape[1],dtype=np.float32)
    eqm_pos  = np.zeros(pts.shape[1],dtype=np.float32)
    sum_springs = 0.0 #np.sum(np_springs)
    n_springs = 0

    # calculate equilibrium spring-weight target position
    #           and total spring-weighted gradient
    # every cluster exerts a force ind't of cluster size
    for cc in idx_cl:
        cluster_pts  = xnp_cluster_list(cc, clusters) # perhaps faster/vectorizable
        # Note: 1. Each cluster centroid gets its spring regardless
        #          of how populated the cluster is.
        #       2. Size 1 cluster get included
        #       3. Spring constant doubles as weighting factor -- this
        #          might not hold for other spring force models!
        if len(cluster_pts) > 0:
            cl_avg = np_mean_axis_0( pts[cluster_pts,:] )
            #print(f"{cc=} {cluster_pts=}{cl_avg=}")
            # Now update sums for equilibrium posn and total gradient
            # springs -> force gradient -> non-overshooting 'delta' movement
            spring = springs[cc]
            sum_springs += spring
            n_springs += 1
            eqm_pos += spring * cl_avg
            vec = cl_avg - pts[idx,:]      # vector toward cluster center
            tot_grad += spring * vec       # only for kr^2 spring physics
    #
    eqm_pos /= sum_springs

    return (n_springs, eqm_pos, tot_grad)

@numba.njit("void(i8, f4[:,:], boolean[:,:], f4[:], f8, f8)")
def xdo_clustering_ipts(idx, pts, clusters, springs, lr=1.0, mindist=0.01):
    #assert len(pts.shape) == 2
    n_samples = pts.shape[0]
    n_clust = springs.shape[0] # it is a vector, one per cluster
    #  otherwise n_clust = np.max(clusters) + 1
    #n_clust = len(np_cluster_lists)
    #assert n_clust > 0
    #assert len(springs) == len(np_cluster_lists)
    #np_cluster_lists = [np.array(clist,np.int32) for clist in cluster_lists]
    #np_springs = np.array(springs)
    idx_in = clusters[:,idx]
    idx_cl = np.argwhere(idx_in).flatten()

    if len(idx_cl) < 1:
        return

    elif len(idx_cl) == 1:
        # Separate out an easy case (idx in single cluster)
        c = idx_cl[0]
        cluster_pts  = xnp_cluster_list(c, clusters)
        # pts[idx] in a cluster of size 1 is already at the cluster centroid
        if len(cluster_pts) > 1:
            centroid = np_mean_axis_0( pts[cluster_pts,:] )  # cluster centroid
            #pt0 = pts[idx,:].copy()
            xdo_clustering_ipts_toward(idx, pts, centroid,
                                       lr, springs[c], mindist=mindist)
            #print(f"{idx=} {c=} {centroid=}\n{pt0=} --> {pts[idx]}")
        return

    #else: idx attracted to multiple clusters... len(idx_cl) > 1 ... rare?
    # jit test:
    (n_springs, eqm_pos, tot_grad) = xdo_clustering_mult(idx, idx_cl, clusters, springs, pts)
    
    if n_springs==0:  # all clusters empty? No-op
        return

    if True:
        # multi-cluster effective spring constant, to re-use "_toward" code...
        #  This way re-uses existing code but redoes some vector calcs
        eqm_spring = np.linalg.norm(tot_grad) / np.linalg.norm(eqm_pos-pts[idx])
        #xdo_clustering_ipts_toward0(idx, pts, eqm_pos, lr, eqm_spring, maxfrac=1.0)
        xdo_clustering_ipts_toward(idx, pts, eqm_pos, lr, eqm_spring, mindist=0.0)
    elif False:
        # Actually, with mindist 0 the calc is very simple, so longhand:
        
        # The calculation with mindist:
        #vec = target - pt      # vector toward cluster center
        #vecnorm = np.linalg.norm(vec)
        #if vecnorm > 1e-5:
        #    bar = max(vecnorm - mindist, 0.0) # max move dist (along vec/vecnorm)
        #    mv = min(bar, lr * bar * spring)
        #    pts[idx,:] += (mv/vecnorm) * vec
        #return
        
        # without minidst (i.e. bar = vecnorm)
        vec = eqm_pos - pts[idx]      # vector toward spring-weighted equilibrium
        vecnorm = np.linalg.norm(vec)
        if vecnorm > 1e-5:
            mv = lr * np.linalg.norm(tot_grad)  # lr : force --> distance
            mv = min(mv, vecnorm)
            pts[idx,:] += (mv/vecnorm) * vec
            #
            #pts[idx,:] += ((min(vecnorm, lr*np.linalg.norm(tot_grad)) / vecnorm) * vec
            #
            #pts[idx,:] += (min(1.0, lr*np.linalg.norm(tot_grad))) * vec
    else: # "longhand" (so simple for mindist=0.0)
        vec = eqm_pos - pts[idx]      # vector toward spring-weighted equilibrium
        mv = lr * np.linalg.norm(tot_grad)
        #pts[idx,:] += min(1.0, mv) * vec
        if mv < 1.0:
            pts[idx,:] += mv * vec
        else:
            pts[idx,:] = eqm_pos
    
    return


#@numba.njit()
def xdo_clustering_pts(pts, cluster2d, springs, lr=1.0, mindist=0.01):
    assert len(pts.shape) == 2
    n_samples = pts.shape[0]
    # one round of clustering every point once
    for idx in range(n_samples):
        xdo_clustering_ipts(idx, pts, cluster2d, springs, lr, mindist)
    return

#
# -------------- python (non-numba) --------------
# These show how to convert python args to expected numba-compliant types
#

# for completeness, since this might also become a 'mk_FOO' jit-fn-generator
def do_clustering_ipts_py(idx, pts, cluster_lists, springs, lr, mindist):
    """ python-ish front-end, with appropriate setup for just internals.

        pts: array[n_samples,dim]  sample ~ "idx"
        clusters: list-of-lists ~ (cluster, idx)
        springs: python list of numbers (np array[:]
        lr, mindist: python numbers
    """
    idx = int(idx)
    lr = float(lr)
    mindist = float(mindist)
    assert len(pts.shape) == 2
    assert len(cluster_lists) == len(springs)
    n_samples = pts.shape[0]
    assert idx < pts.shape[0]
    
    #np_cluster_lists = [np.array(clist,np.int32) for clist in cluster_lists]
    n_clust = len(cluster_lists)

    if n_samples==0 or n_clust==0:
        return
    
    #
    # generalization:  clusters[c,idx] is True IFF idx is in cluster c
    clusters = np.full((n_clust, n_samples), False, dtype=bool)
    for (c,members) in enumerate(cluster_lists):
        for m in members:
            clusters[c][m] = True
            
    springs = np.array(springs, dtype=np.float32)  # don't need python float64 default
    # no negative or inf or nan springs
    assert np.all(springs >= 0.0)  # actually nan is also NOT >= 0 so elision...
    assert np.count_nonzero((springs == np.inf) | (springs == np.nan)) == 0
    
    # pts is to be modified -- do not create a copy!
    
    # invoke jit fn (or create and return it)
    xdo_clustering_ipts( idx, pts, clusters, springs, lr=lr, mindist=mindist )

    return

# python "frontend-to-jit" demo for cell output
def do_clustering_pts_py(pts, cluster_lists, springs, lr, mindist):
    """ python-ish front-end, with appropriate setup for just internals.

        pts: array[n_samples,dim]  sample ~ "idx"
        clusters: list-of-lists ~ (cluster, idx)
        springs: python list of numbers
        lr, mindist: python numbers
    """
    lr = float(lr)
    mindist = float(mindist)
    assert len(pts.shape) == 2
    assert len(cluster_lists) == len(springs)
    n_samples = pts.shape[0]
    
    #np_cluster_lists = [np.array(clist,np.int32) for clist in cluster_lists]
    n_clust = len(cluster_lists)
    
    if n_samples==0 or n_clust==0:
        return
    
    clusters = np.full((n_clust, n_samples), False, dtype=bool)
    for (c,members) in enumerate(cluster_lists):
        for m in members:
            clusters[c][m] = True
            
    springs = np.array(springs, dtype=np.float32)  # don't need python float64 default
    # no negative or inf or nan springs
    assert np.all(springs >= 0.0)  # actually nan is also NOT >= 0 so elision...
    assert np.count_nonzero((springs == np.inf) | (springs == np.nan)) == 0
    
    # pts is to be modified -- do not create a copy!
    
    if False:
        print(f"{numba.typeof(pts)=} {pts.shape=}")
        print(f"{numba.typeof(clusters)=}")
        print(f"{numba.typeof(springs)=}")
        print(f"{numba.typeof(lr)=}")
        print(f"{numba.typeof(mindist)=}")
    
    # invoke jit fn (or create and return it)
    xdo_clustering_pts( pts, clusters, springs, lr=lr, mindist=mindist )
    
    return

#
# --------------------- Inputs -----------------
#

n_samples = 8
cluster_lists = [[1,2,3], [2,7,6]]
#           avg    2.0      5.0
# each of the 2 clusters has a spring constant
springs = np.array([0.8, 2.0], dtype=np.float32)
lr = 1.0   # learning rate (gradient multiplier) for spring forces

pts = np.ndarray((n_samples,2), dtype=np.float32)
for i in range(n_samples):
    pts[i,:] = (float(i),float(i%3))


#np_cluster_lists = [np.array(clist,np.int32) for clist in cluster_lists]
#n_clust = len(np_cluster_lists)
## numba-friendly: single full-sized array
##  TODO: compressed data,indptr,indices versions cluster and its transpose
#
## clusters[ c, idx ], for c in [0,n_clust] and idx in [0,n_samples),
##         is True IFF idx is in cluster c
#clusters = np.full((n_clust, n_samples), False, dtype=bool)
#for (c,members) in enumerate(np_cluster_lists):
#    for m in members:
#        clusters[c][m] = True
#print(f"{numba.typeof(clusters)=}")

#
# ------------------------ python/jit test -------------
#
mindist=0.2 # support TBD (need to rething equations)

#xdo_clustering_pts(pts, clusters, springs, lr=lr, mindist=0.2)
do_clustering_pts_py(pts, cluster_lists, springs, lr=lr, mindist=0.2)

print("final positions (one epoch)")
for i in range(pts.shape[0]):
    print(f"{i=} pt = {pts[i,:]}")

try:
    pts_ref
    assert np.allclose(pts, pts_ref)
    print("Good: matched pts_ref")
except NameError:
    pass

print("Goodbye!")


## FINALLY have some *mk_FOO*
### jitted cluster-funcs

#### What's the point?
*umap-constraints, short-n-sweet* jitted "lambda functions" can be used
as `data_constrain=` or `output_constrain=` args to umap euclidean
embedding calls (`fit`, `fit_transform`)

The umap embedding should *pull together* our **user-specified clusters**,
with whatever effects these have on next-nearest members, etc. .

In [None]:
# Now try above cell as part of umap-constraints...
# This call can run standalone:
#   It invokes a jit-function mk_FOO,
#   and then invokes it
#   (umap.UMAP not involved)
#
#   I found the mk_FOO marks some local array refs 'readonly array'.
#   It was easiest to simply remove the overly constraining specs
#   and let @numba.njit() autogenerate the required signatures
#   in umap/constrain_clust.py
#
import numpy as np
import numba

from umap.constrain_clust import mk_clustering_pts, mk_clustering_ipts

#
# ------- Inputs (python) -----------------------------
n_samples = 8
cluster_lists = [[1,2,3], [2,7,6]]
#           avg    2.0      5.0
# each of the 2 clusters has a spring constant
springs = np.array([0.8, 2.0], dtype=np.float32)
lr = 1.0   # learning rate (gradient multiplier) for spring forces

pts = np.ndarray((n_samples,2), dtype=np.float32)
for i in range(n_samples):
    pts[i,:] = (float(i),float(i%3))

mindist = 0.2

#
# ------------ create & call jit constraint -----------
# Now that we have the python-ish inputs set up,
# create a jit constraint function of simplified signature
#
mkdo_pts = mk_clustering_pts(pts, cluster_lists, springs, lr=lr, mindist=mindist)

# and just invoke it as
mkdo_pts(pts)  # simplified call signature, other args are now mk_FOO locals
#    without re-supplying all the args like
#        do_clustering_pts_py(pts, cluster_lists, springs, lr=lr, mindist=0.2)

#
# ------------- output --------------------------------
# check we got the same "output" (in-place modification of our pts array)
#
print("final positions (one epoch)")
for i in range(pts.shape[0]):
    print(f"{i=} pt = {pts[i,:]}")

try:
    pts_ref
    assert np.allclose(pts, pts_ref)
    print("Good: matched pts_ref")
except NameError:
    pass

print("Goodbye!")

In [None]:
import numpy as np
import numba

from umap.constrain_clust import mk_clustering_pts, mk_clustering_ipts


#
# WIP:  This cell should use mk_FOO clustering for the iris data set,
#       actually running a umap embedding showing user-clustering via
#       the "additional springs" approach.
# Based on plots/ behaviors,
# I expect the mk_FOO calls may have some behavioral modifiers
# like:
#    - actually supporting the mindist "target radius" clustering arg
#    - weakening connections to next-nearest-neighbors (alpha?)
#    - ...
#
cluster_lists = [[0, 50, 100, 149]]  # a single cluster
print("# let's cluster pts", cluster_lists, "with a big force")
for idx in cluster_lists[0]:
    print(f"  cluster {idx=} {emb0[idx]=}")
cl_pts = emb0[cluster_lists[0],:]
cl_avg = np.mean(cl_pts, axis=0)
cl_avg2 = umap.constrain_clust.np_mean_axis_0(cl_pts)
print(f" emb0 cluster centroid  @ {cl_avg}")
print(f" emb0 cluster centroid' @ {cl_avg2}")

# set things so cluster pts move "all the way" to their target
springs = [1.0]
lr = 1.0
mindist = 0.0

print("\nmove individual pts via python do_clustering_ipts_py")
t = emb0.copy()
for idx in [0, 1]:
    do_clustering_ipts_py(idx, t, cluster_lists, springs, lr, mindist)
    print(f" cluster {idx=} ({emb0[idx,0]:.2f},{emb0[idx,1]:.2f}) to ({t[idx,0]:.2f},{t[idx,1]:.2f})")

print("\nmove all pts in cluster via python do_clustering_pts_py")
t = emb0.copy()
for idx in cluster_lists[0]:
    do_clustering_ipts_py(idx, t, cluster_lists, springs, lr, mindist)
    print(f" cluster {idx=} ({emb0[idx,0]:.2f},{emb0[idx,1]:.2f}) to ({t[idx,0]:.2f},{t[idx,1]:.2f})")

print("\nmove all pts via mk_clustering_pts constraint-generator")
all_points_clusterer = mk_clustering_pts(
    iris.data, # for nsamples=shape[0], we'll actually call using lo-D embedding features
    cluster_lists,
    springs,
    lr=1.0, # timestep: enough for spring 1.0 to get all the way
    mindist=0.0,   # really go all the way
)
t = emb0.copy()
all_points_clusterer(t)
for idx in cluster_lists[0]:
    print(f" cluster {idx=} ({emb0[idx,0]:.2f},{emb0[idx,1]:.2f}) to ({t[idx,0]:.2f},{t[idx,1]:.2f})")
print("  note that after 1st pt moved, cluster center changed")
print("  so pts 2,3,4 in cluster moved to slightly different locations")
print("  trying another 2 'epochs' to converge better (mindist is {mindist})")
for i in range(2):
    all_points_clusterer(t)
for idx in cluster_lists[0]:
    print(f" cluster {idx=} ({emb0[idx,0]:.2f},{emb0[idx,1]:.2f}) to ({t[idx,0]:.2f},{t[idx,1]:.2f})")

print("\nmove individual pts via mk_clustering_ipts constraint-generator")
point_clusterer = mk_clustering_ipts(
    0,   # idx, unused
    iris.data, # for nsamples=shape[0], we'll actually call using lo-D embedding features
    cluster_lists,
    springs,
    lr=1.0, # timestep: enough for spring 1.0 to get all the way
    mindist=0.0,
)
t = emb0.copy()
point_clusterer(0, t)
point_clusterer(1, t)
for idx in [0,1]:
    point_clusterer(idx, t)
    print(f" cluster {idx=} ({emb0[idx,0]:.2f},{emb0[idx,1]:.2f}) to ({t[idx,0]:.2f},{t[idx,1]:.2f})")

print("Goodbye")

In [None]:
if False:
    # If I had only the clustering constraint(idx,pt), ...
    constraints = {
        'idx_ipts': point_clusterer
    }
else:
    #
    # ... but I have not supported LISTS of CONSTRAINTS, but I can package
    #     the point clusterer with another constraint like this:
    #
    #  (you can use print in numba, but only strings and values, no kwargs either)
    #
    @numba.njit("f4[:](i8, f4[:,:])")
    def constraint_idx_pt1(idx,pts):
        # Here 'pts' MUST be 2D (full point cloud)
        #pt0 = pts[idx,:].copy()
        point_clusterer(idx, pts)
        #if idx==0:
        #    print(" idx0 (",pt0[0],",",pt0[1],") to (",pts[idx,0],",",pts[idx,1],")")
        return con.freeinf_ipts(idx,pts, infs0)
    # this function DOES depent on idx of pt

    constraints = {
        'idx_ipts': constraint_idx_pt1,  # NEW _ipts suffix ==> alt call signature
        # actually this is most flexible since the full "tail_embedding"
        # point cloud is actually always available.
    }

print("# let's begin with 11:15 an a square around the origin")
print("# their cluster will still feel pulls of neighbors")
emb0[11] = [+3.14, +3.14]
emb0[12] = [+3.14, -3.14]
emb0[13] = [-3.14, +3.14] # init with far away coords
emb0[14] = [+3.14, -3.14]
# optional: set up the values to agree
#emb0[13] = [-2.0, 0]
#emb0[14] = [+2.0, 0]
#assert np.all(emb0[13] == [-2.0,0])
#assert np.all(emb0[14] == [+2.0,0])
# Here is the "move all points" version of con.freeinf
con.freeinf_pts(emb0, infs0)

# Also demo a non-indexed (UMAP constructor) constraint,
# that is independent of the iris.data.
# Here, let's constrain 'y' to be within -5.0, +5.0
# without, we got:
# emb2[11:15]
#  [[ 4.804111   3.430179 ]
#  [ 4.2598176  7.352637 ]     # <-- 'y' is big here
#  [-2.         0.       ]
#  [ 2.         0.       ]]
# With 'output_constrain':
# emb2[11:15]
# [[ 7.0536     2.275853 ]
# [ 4.8699265  1.9713866]      # y constrained
# [-2.         0.       ]
# [ 2.         0.       ]]
# So we pass illegal range for x, and legal range for y
def mk_bound_y_values(lo, hi):
    bound_los = np.array([+999.,lo], dtype=np.float32)
    bound_his = np.array([-999.,hi], dtype=np.float32)
    @numba.njit()
    def bound_y_values(pt):
        return con.dimlohi_pt(pt, bound_los, bound_his)
    # this function does NOT depend on 'idx' arg
    return bound_y_values
y_bounder = mk_bound_y_values(-5.0,+5.0)

# CHECK: x is unaffected, y range is bounded ...
pt = np.array([1.,2.], dtype=np.float32)
print("pt0",pt); y_bounder(pt); print("pt0",pt); assert pt[0] == 1.; assert pt[1] == 2.
pt[0] = 10.; pt[1] = 10.
print("pt1",pt); y_bounder(pt); print("pt1",pt); assert pt[0] == 10.; assert pt[1] == 5.
pt[0] = -10.; pt[1] = -10.
print("pt2",pt); y_bounder(pt); print("pt2",pt); assert pt[0] == -10.; assert pt[1] == -5.

print("Specify 'init' embedding for umapper2")
umapper6 = umap.UMAP(
    n_neighbors=50, learning_rate=0.5, random_state=12346, min_dist=0.001,
    output_constrain = { 'pt': y_bounder }, # any pt, ind't of dataset
    init=emb0, n_epochs=4,
)
print("Embed with pin_mask[13] and pin_mask[14] zero-vectors")
# ... whereas data_constrain depends on the dataset (point number is important)
emb6 = umapper6.fit_transform(iris.data, data_constrain=constraints)
assert np.all(emb6[:,1] >= -5.0) and np.all(emb6[:,1] <= 5.0) # output_constrain
print("emb0[11:15]\n",emb0[11:15])
print("emb6[11:15]\n",emb6[11:15])
assert np.allclose(emb6[13], [-2.2,0])
assert np.allclose(emb6[14], [+2.2,0])

print("\nWhat happened to cluster-constrained umap pts?")
for idx in cluster_lists[0]:
    print(f" cluster {idx=} ({emb0[idx,0]:.2f},{emb0[idx,1]:.2f}) to ({emb6[idx,0]:.2f},{emb6[idx,1]:.2f})")
print("After printing the right thing, the cluster really does exist!")

print("\nGoodbye")