In [18]:
import jax
from jax import numpy as jnp

x = jnp.array([1, 2, 3, 4, 5])
mask = x <= 2
jnp.where(mask, x, 1)

Array([1, 2, 1, 1, 1], dtype=int32)

In [59]:
jnp.nanmax(jnp.array([jnp.nan, jnp.nan]))

Array(nan, dtype=float32)

In [58]:
size = 4
n = 3
weight = jnp.eye(n, n) + 1e-6
fitness = jnp.array([[0, 2, 7],
                     [2, 2, 2],
                     [4, 1, 0],
                     [2, 5, 1]])
ideal = jnp.nanmin(fitness, axis=0)
offset_fitness = fitness-ideal

print(offset_fitness)

asf = jnp.repeat(offset_fitness, n, axis=0).reshape(size, n, n) / weight

print(asf)

asf = jnp.max(asf, axis=2)

ex_idx = jnp.argmin(asf, axis=0)

print(asf)

print(ex_idx)


[[0 1 7]
 [2 1 2]
 [4 0 0]
 [2 4 1]]
[[[0.0000000e+00 1.0000000e+06 7.0000000e+06]
  [0.0000000e+00 9.9999905e-01 7.0000000e+06]
  [0.0000000e+00 1.0000000e+06 6.9999933e+00]]

 [[1.9999981e+00 1.0000000e+06 2.0000000e+06]
  [2.0000000e+06 9.9999905e-01 2.0000000e+06]
  [2.0000000e+06 1.0000000e+06 1.9999981e+00]]

 [[3.9999962e+00 0.0000000e+00 0.0000000e+00]
  [4.0000000e+06 0.0000000e+00 0.0000000e+00]
  [4.0000000e+06 0.0000000e+00 0.0000000e+00]]

 [[1.9999981e+00 4.0000000e+06 1.0000000e+06]
  [2.0000000e+06 3.9999962e+00 1.0000000e+06]
  [2.0000000e+06 4.0000000e+06 9.9999905e-01]]]
[[7.0000000e+06 7.0000000e+06 1.0000000e+06]
 [2.0000000e+06 2.0000000e+06 2.0000000e+06]
 [3.9999962e+00 4.0000000e+06 4.0000000e+06]
 [4.0000000e+06 2.0000000e+06 4.0000000e+06]]
[2 1 0]


In [99]:
import jax
from functools import partial
from jax import numpy as jnp

@partial(jax.jit, static_argnames=('pop_size', 'n_objs'))
def run(pop_size, n_objs, ref, population, next_generation, self_fitness, fitness, rank):
    merged_pop = jnp.concatenate([population, next_generation], axis=0)
    merged_fitness = jnp.concatenate([self_fitness, fitness], axis=0)

    order = jnp.argsort(rank)
    rank = rank[order]
    ranked_pop = merged_pop[order]
    ranked_fitness = merged_fitness[order]
    last_rank = rank[pop_size]
    ranked_fitness = jnp.where(jnp.repeat((rank <= last_rank)[:, None], n_objs, axis=1), ranked_fitness, jnp.nan)
    
    # Normalize
    ideal = jnp.nanmin(ranked_fitness, axis=0)
    offset_fitness = ranked_fitness - ideal
    weight = jnp.eye(n_objs, n_objs) + 1e-6
    weighted = jnp.repeat(offset_fitness, n_objs, axis=0).reshape(len(offset_fitness), n_objs, n_objs) / weight
    asf = jnp.nanmax(weighted, axis=2)
    ex_idx =jnp.argmin(asf, axis=0)
    extreme = offset_fitness[ex_idx]
    
    def extreme_point(val):
        extreme = val[0]
        plane = jnp.linalg.solve(extreme, jnp.ones(n_objs))
        intercept = 1/ plane
        return intercept
    
    def worst_point(val):
        return jnp.nanmax(ranked_fitness, axis=0)
    
    nadir_point = jax.lax.cond(jnp.linalg.matrix_rank(extreme) == n_objs,
                                extreme_point, worst_point,
                                (extreme, offset_fitness))
    normalized_fitness = offset_fitness / nadir_point
    
    # Associate
    def perpendicular_distance(x, y):
        y_norm = jnp.linalg.norm(y, axis=1)
        proj_len = x @ y.T / y_norm
        unit_vec = y / y_norm[:, None]
        proj_vec = jnp.reshape(proj_len, (proj_len.size, 1)) * jnp.tile(unit_vec, (len(x), 1))
        prep_vec = jnp.repeat(x, len(y), axis=0) - proj_vec
        dist = jnp.reshape(jnp.linalg.norm(prep_vec, axis=1), (len(x), len(y)))
        return dist
    
    dist = perpendicular_distance(ranked_fitness, ref)
    pi = jnp.nanargmin(dist, axis=1)
    d = dist[jnp.arange(len(normalized_fitness)), pi]
    
    # Niche
    def niche_loop(val):
        def nope(val):
            idx, i, rho, j = val
            rho = rho.at[j].set(pop_size)
            return idx, i, rho, j
        
        def have(val):
            def zero(val):
                idx, i, rho, j = val
                idx = idx.at[i].set(jnp.nanargmin(jnp.where(pi == j, d, jnp.nan)))
                rho = rho.at[j].add(1)
                return idx, i+1, rho, j
            
            def already(val):
                idx, i, rho, j = val
                key = jax.random.PRNGKey(i * j)
                temp = jax.random.randint(key, (1, len(ranked_pop)), 0, pop_size)
                temp = temp + (pi == j) * pop_size
                idx = idx.at[i].set(jnp.argmax(temp))
                rho = rho.at[j].add(1)
                return idx, i+1, rho, j
            
            return jax.lax.cond(rho[val[3]], already, zero, val)
        
        idx, i, rho = val
        j = jnp.argmin(rho)
        idx, i, rho, j = jax.lax.cond(jnp.sum(pi == j), have, nope, (idx, i, rho, j))
        return idx, i, rho
    
    survivor_idx = jnp.arange(pop_size)
    rho = jnp.bincount(jnp.where(rank < last_rank, pi, len(ref)), length=len(ref))
    pi = jnp.where(rank == last_rank, pi, -1)
    d = jnp.where(rank == last_rank, d, jnp.nan)
    survivor_idx, _, _ = jax.lax.while_loop(lambda val: val[1] < pop_size,
                                      niche_loop,
                                      (survivor_idx, jnp.sum(rho), rho))
    # return survivor_idx, survivor_idx
    return ranked_pop[survivor_idx], ranked_fitness[survivor_idx]

pop_size = 4
n_objs = 3
ref = jnp.array([[1, 0, 0],
                 [0, 1, 0],
                 [0, 0, 1],
                 [1, 1, 1]])
population = jnp.array([[0, 0],
                        [1, 1],
                        [2, 2],
                        [3, 3]])
next_generation = jnp.array([[4, 4],
                             [5, 5],
                             [6, 6],
                             [7, 7]])
self_fitness = jnp.array([[1, 2, 3],
                          [7, 4, 2],
                          [5, 4, 8],
                          [2, 9, 6]])
fitness = jnp.array([[5, 6, 7],
                     [3, 7, 1],
                     [4, 3, 7],
                     [8, 8, 6]])
rank = jnp.array([0, 1, 2, 1, 2, 0, 1, 2])

run(pop_size, n_objs, ref, population, next_generation, self_fitness, fitness, rank)

(Array([[0, 0],
        [5, 5],
        [1, 1],
        [6, 6]], dtype=int32),
 Array([[1., 2., 3.],
        [3., 7., 1.],
        [7., 4., 2.],
        [4., 3., 7.]], dtype=float32, weak_type=True))

In [90]:
@jax.jit
def run(x):
    k = jnp.min(x)
    mask = x == k
    x = jnp.where(mask, x, jnp.nan)
    return x
    
x = jnp.array([[3, 3, 3],
               [2, 2, 2],
               [0, 0, 0],
               [1, 1, 1]])
run(x)

Array([[nan, nan, nan],
       [nan, nan, nan],
       [ 0.,  0.,  0.],
       [nan, nan, nan]], dtype=float32, weak_type=True)

In [130]:
mat = jnp.array([[1., 2., 4.],
                 [1.0000001, 2, 4.],
                 [3, 3, 3]])

jnp.linalg.matrix_rank(mat), jnp.linalg.det(mat)

(Array(2, dtype=int32), Array(-9.536743e-07, dtype=float32))

In [137]:
x = jnp.array([[1, 1, 1],
               [1, 0, 0]])
y = jnp.array([[0, 0, 2],
                [1, 1, 1]])
y_norm = jnp.linalg.norm(y, axis=1)

proj_len = x @ y.T / y_norm
unit_vec = y / y_norm[:, None]
proj_vec = jnp.reshape(proj_len, (proj_len.size, 1)) * jnp.tile(unit_vec, (len(x), 1))
h = jnp.repeat(x, len(y), axis=0) - proj_vec
d = jnp.reshape(jnp.linalg.norm(h, axis=1), (len(x), len(y)))

proj_len, unit_vec, proj_vec, h, d

(Array([[1.        , 1.7320508 ],
        [0.        , 0.57735026]], dtype=float32),
 Array([[0.        , 0.        , 1.        ],
        [0.57735026, 0.57735026, 0.57735026]], dtype=float32),
 Array([[0.        , 0.        , 1.        ],
        [0.99999994, 0.99999994, 0.99999994],
        [0.        , 0.        , 0.        ],
        [0.3333333 , 0.3333333 , 0.3333333 ]], dtype=float32),
 Array([[ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00],
        [ 5.9604645e-08,  5.9604645e-08,  5.9604645e-08],
        [ 1.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 6.6666669e-01, -3.3333331e-01, -3.3333331e-01]], dtype=float32),
 Array([[1.4142135e+00, 1.0323827e-07],
        [1.0000000e+00, 8.1649655e-01]], dtype=float32))

In [147]:
@jax.jit
def run(size, x):
    size = 4
    def while_body(val):
        ans, i = val
        return ans+x[i], i+1
    
    return jax.lax.while_loop(lambda val: x[val[1]] < size,
                       while_body,
                       (0, 0))

run(10, jnp.arange(5))

(Array(6, dtype=int32), Array(4, dtype=int32, weak_type=True))

In [144]:
@jax.jit
def run(x, y):
    def perpendicular_distance(x, y):
        y_norm = jnp.linalg.norm(y, axis=1)
        proj_len = x @ y.T / y_norm
        unit_vec = y / y_norm[:, None]
        proj_vec = jnp.reshape(proj_len, (proj_len.size, 1)) * jnp.tile(unit_vec, (len(x), 1))
        prep_vec = jnp.repeat(x, len(y), axis=0) - proj_vec
        dist = jnp.reshape(jnp.linalg.norm(prep_vec, axis=1), (len(x), len(y)))
        return dist
    
    dist = perpendicular_distance(x, y)
    pi = jnp.argmin(dist, axis=1)
    d = dist[jnp.arange(len(x)), pi]
    return dist, pi, d


x = jnp.array([[1, 1, 1],
               [1, 0, 0]])
y = jnp.array([[0, 0, 2],
               [1, 1, 1],
               [3, 0, 0],
               [1, 1, 0]])
run(x, y)

(Array([[1.4142135e+00, 1.0323827e-07, 1.4142135e+00, 1.0000000e+00],
        [1.0000000e+00, 8.1649655e-01, 0.0000000e+00, 7.0710677e-01]],      dtype=float32),
 Array([1, 2], dtype=int32),
 Array([1.0323827e-07, 0.0000000e+00], dtype=float32))

In [215]:
@jax.jit
def run(x, mask):
    key = jax.random.PRNGKey(x[0] * x[-1])
    temp = jax.random.randint(key, (1, len(x)), 0, len(x))
    temp = temp + mask * len(x)
    return jnp.argmax(temp)

x = jnp.array([1, 2, 3, 4, 5, 6, 7])
mask = x % 2 != 0

mask, run(x, mask)

(Array([ True, False,  True, False,  True, False,  True], dtype=bool),
 Array(0, dtype=int32))

In [179]:
import jax.random as random
key = random.PRNGKey(0) # generate a random key
random.randint(key, shape=(1, ), minval=0, maxval=10),  random.randint(key, shape=(3,), minval=0, maxval=10),  random.randint(key, shape=(3,), minval=0, maxval=10)


(Array([2], dtype=int32),
 Array([8, 1, 7], dtype=int32),
 Array([8, 1, 7], dtype=int32))

In [None]:
# weight = jnp.full((self.n_objs, self.n_objs), 1e-6) + jnp.eye(self.n_objs)
# asf = weight @ merged_fitness.T
# extreme = asf