In [1]:
import jax
from jax import numpy as jnp

In [143]:
jnp.linalg.norm(jnp.array([1, 2, jnp.nan]))


x = jnp.array([jnp.nan, jnp.nan])
idx = jnp.nanargmax(x)
x[idx]

Array(nan, dtype=float32)

In [90]:
@jax.jit
def run(x):
    k = jnp.min(x)
    mask = x == k
    x = jnp.where(mask, x, jnp.nan)
    return x
    
x = jnp.array([[3, 3, 3],
               [2, 2, 2],
               [0, 0, 0],
               [1, 1, 1]])
run(x)

Array([[nan, nan, nan],
       [nan, nan, nan],
       [ 0.,  0.,  0.],
       [nan, nan, nan]], dtype=float32, weak_type=True)

In [130]:
mat = jnp.array([[1., 2., 4.],
                 [1.0000001, 2, 4.],
                 [3, 3, 3]])

jnp.linalg.matrix_rank(mat), jnp.linalg.det(mat)

(Array(2, dtype=int32), Array(-9.536743e-07, dtype=float32))

In [137]:
x = jnp.array([[1, 1, 1],
               [1, 0, 0]])
y = jnp.array([[0, 0, 2],
                [1, 1, 1]])
y_norm = jnp.linalg.norm(y, axis=1)

proj_len = x @ y.T / y_norm
unit_vec = y / y_norm[:, None]
proj_vec = jnp.reshape(proj_len, (proj_len.size, 1)) * jnp.tile(unit_vec, (len(x), 1))
h = jnp.repeat(x, len(y), axis=0) - proj_vec
d = jnp.reshape(jnp.linalg.norm(h, axis=1), (len(x), len(y)))

proj_len, unit_vec, proj_vec, h, d

(Array([[1.        , 1.7320508 ],
        [0.        , 0.57735026]], dtype=float32),
 Array([[0.        , 0.        , 1.        ],
        [0.57735026, 0.57735026, 0.57735026]], dtype=float32),
 Array([[0.        , 0.        , 1.        ],
        [0.99999994, 0.99999994, 0.99999994],
        [0.        , 0.        , 0.        ],
        [0.3333333 , 0.3333333 , 0.3333333 ]], dtype=float32),
 Array([[ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00],
        [ 5.9604645e-08,  5.9604645e-08,  5.9604645e-08],
        [ 1.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 6.6666669e-01, -3.3333331e-01, -3.3333331e-01]], dtype=float32),
 Array([[1.4142135e+00, 1.0323827e-07],
        [1.0000000e+00, 8.1649655e-01]], dtype=float32))

In [147]:
@jax.jit
def run(size, x):
    size = 4
    def while_body(val):
        ans, i = val
        return ans+x[i], i+1
    
    return jax.lax.while_loop(lambda val: x[val[1]] < size,
                       while_body,
                       (0, 0))

run(10, jnp.arange(5))

(Array(6, dtype=int32), Array(4, dtype=int32, weak_type=True))

In [144]:
@jax.jit
def run(x, y):
    def perpendicular_distance(x, y):
        y_norm = jnp.linalg.norm(y, axis=1)
        proj_len = x @ y.T / y_norm
        unit_vec = y / y_norm[:, None]
        proj_vec = jnp.reshape(proj_len, (proj_len.size, 1)) * jnp.tile(unit_vec, (len(x), 1))
        prep_vec = jnp.repeat(x, len(y), axis=0) - proj_vec
        dist = jnp.reshape(jnp.linalg.norm(prep_vec, axis=1), (len(x), len(y)))
        return dist
    
    dist = perpendicular_distance(x, y)
    pi = jnp.argmin(dist, axis=1)
    d = dist[jnp.arange(len(x)), pi]
    return dist, pi, d


x = jnp.array([[1, 1, 1],
               [1, 0, 0]])
y = jnp.array([[0, 0, 2],
               [1, 1, 1],
               [3, 0, 0],
               [1, 1, 0]])
run(x, y)

(Array([[1.4142135e+00, 1.0323827e-07, 1.4142135e+00, 1.0000000e+00],
        [1.0000000e+00, 8.1649655e-01, 0.0000000e+00, 7.0710677e-01]],      dtype=float32),
 Array([1, 2], dtype=int32),
 Array([1.0323827e-07, 0.0000000e+00], dtype=float32))

In [215]:
@jax.jit
def run(x, mask):
    key = jax.random.PRNGKey(x[0] * x[-1])
    temp = jax.random.randint(key, (1, len(x)), 0, len(x))
    temp = temp + mask * len(x)
    return jnp.argmax(temp)

x = jnp.array([1, 2, 3, 4, 5, 6, 7])
mask = x % 2 != 0

mask, run(x, mask)

(Array([ True, False,  True, False,  True, False,  True], dtype=bool),
 Array(0, dtype=int32))

In [179]:
import jax.random as random
key = random.PRNGKey(0) # generate a random key
random.randint(key, shape=(1, ), minval=0, maxval=10),  random.randint(key, shape=(3,), minval=0, maxval=10),  random.randint(key, shape=(3,), minval=0, maxval=10)


(Array([2], dtype=int32),
 Array([8, 1, 7], dtype=int32),
 Array([8, 1, 7], dtype=int32))

In [None]:
# weight = jnp.full((self.n_objs, self.n_objs), 1e-6) + jnp.eye(self.n_objs)
# asf = weight @ merged_fitness.T
# extreme = asf