Skip to content

Commit

Permalink
dev: fix tournament selection
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Dec 29, 2023
1 parent 3a791e8 commit d2340bf
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/evox/operators/selection/tournament.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

@partial(jit, static_argnums=[3, 4, 5])
def tournament_single_fit(key, pop, fit, n_round, tournament_func, tournament_size):
chosen = random.choice(key, n_round, shape=(n_round, tournament_size))
pop_size = fit.shape[0]
chosen = random.choice(key, pop_size, shape=(n_round, tournament_size))
candidates_fitness = fit[chosen, ...]
winner_indices = vmap(tournament_func)(candidates_fitness)
index = chosen[jnp.arange(n_round), winner_indices]
Expand All @@ -18,7 +19,8 @@ def tournament_single_fit(key, pop, fit, n_round, tournament_func, tournament_si

@partial(jit, static_argnums=[3, 4, 5])
def tournament_multi_fit(key, pop, fit, n_round, tournament_func, tournament_size):
chosen = random.choice(key, n_round, shape=(n_round, tournament_size))
pop_size = fit.shape[0]
chosen = random.choice(key, pop_size, shape=(n_round, tournament_size))
candidates_fitness = fit[chosen, ...]
winner_indices = vmap(jnp.lexsort)(jnp.transpose(candidates_fitness, (0, 2, 1)))
index = chosen[jnp.arange(n_round), winner_indices[:, 0]]
Expand Down

0 comments on commit d2340bf

Please sign in to comment.