In [119]:
from jax import random
import jax.numpy as jnp

In [120]:
def z(X, Y):
    return -X * jnp.sin(jnp.sqrt(jnp.abs(X))) - Y * jnp.sin(jnp.sqrt(jnp.abs(Y)))

def r1(X, Y):
     return jnp.power(Y - jnp.power(X, 2), 2) + jnp.power(1 - X, 2)
    
def rd(X, Y):
    return 1 + r1(X, Y)

In [121]:
def w37(X, Y):
    return 2 * z(X, Y) / rd(X, Y) + jnp.exp(jnp.sin(r1(X, Y)))

In [122]:
n_solutions = 100
key = random.PRNGKey(123)

solutions = random.uniform(key, 
                           shape=(n_solutions, 2), 
                           minval=0, 
                           maxval=10000
                          )

In [123]:
function_values = w37(solutions[:, 0], solutions[:, 1])
ranking = jnp.argsort(function_values)
ranked_solutions = solutions[ranking]

In [124]:
def select(solutions, function, select_n=50):
    """Roulette selection for minimization objective."""
    function_values = function(solutions[:, 0], solutions[:, 1])
    ranking = jnp.argsort(function_values) # increasing order
    ranked_solutions = solutions[ranking]

    normalized_vals = 1 - function_values / function_values.max()
    probabilities = normalized_vals / normalized_vals.sum()

    best_solutions = random.choice(key,
                                   ranked_solutions,
                                   shape=(select_n, 1),
                                   replace=False,
                                   p=probabilities,
                                  )
    best_solutions = jnp.squeeze(best_solutions)
    
    return best_solutions

In [125]:
select(solutions, w37)

Array([[7658.21   , 3157.3057 ],
       [3895.564  , 6025.1465 ],
       [ 572.1009 , 5664.7397 ],
       [ 108.31356,  158.4649 ],
       [6918.415  , 2089.776  ],
       [1080.3342 , 1526.9185 ],
       [3775.5789 , 2050.4963 ],
       [4155.166  , 9554.985  ],
       [2019.0657 , 4163.854  ],
       [4952.711  , 2390.3489 ],
       [9506.117  ,  593.81964],
       [4439.764  ,  590.1802 ],
       [8011.371  , 7690.2485 ],
       [4779.3867 , 7992.337  ],
       [5446.888  , 4574.9355 ],
       [7156.154  , 9973.806  ],
       [3705.207  , 2204.379  ],
       [1493.9058 ,  125.29135],
       [8044.618  , 9929.033  ],
       [9290.986  , 3344.8755 ],
       [6191.2896 , 6918.5493 ],
       [ 616.00806, 5295.392  ],
       [9331.724  , 4821.674  ],
       [5311.6963 , 8604.162  ],
       [9714.916  , 4077.934  ],
       [4061.9863 , 9442.0205 ],
       [3634.882  , 6205.1416 ],
       [2520.8604 , 7921.7803 ],
       [5731.71   , 6610.215  ],
       [1334.0557 , 9797.383  ],
       [58

In [126]:
best_solutions = select(solutions, w37)
random.permutation(key,
                   best_solutions,
                   independent=True
                  )

Array([[ 616.00806,  779.8314 ],
       [8618.488  , 8604.162  ],
       [7156.154  , 3212.092  ],
       [3669.1057 , 3256.5762 ],
       [4155.166  , 3798.0188 ],
       [4952.711  , 3344.8755 ],
       [8044.618  , 7690.2485 ],
       [4249.8086 ,  125.29135],
       [9331.724  , 2390.3489 ],
       [3552.196  , 4574.9355 ],
       [5311.6963 , 1253.2473 ],
       [ 108.31356, 9715.362  ],
       [5964.1646 , 1905.8455 ],
       [ 975.2405 , 3157.3057 ],
       [6046.6494 , 3128.83   ],
       [6356.5884 , 6488.6377 ],
       [3705.207  , 1190.0997 ],
       [2152.09   ,  158.4649 ],
       [6918.415  , 1526.9185 ],
       [8011.371  , 2050.4963 ],
       [2019.0657 , 7992.337  ],
       [7847.7026 , 9929.033  ],
       [5893.5786 , 2204.379  ],
       [4439.764  , 6540.786  ],
       [9506.117  , 5349.991  ],
       [1080.3342 , 2701.3076 ],
       [1493.9058 , 2089.776  ],
       [8911.884  , 4821.674  ],
       [7658.21   , 7279.841  ],
       [6191.2896 , 5357.268  ],
       [97

In [127]:
def crossover(solutions, generation_size=100):
    permutations = jnp.array(jnp.meshgrid(solutions[:, 0], solutions[:, 1])).T.reshape(-1, 2)
    new_generation = random.choice(key,
                                   permutations,
                                   shape=(100, 1),
                                   replace=False,
                                  )
    new_generation = jnp.squeeze(new_generation)
    return new_generation

In [128]:
solutions[:, 0].shape

(100,)

In [129]:
permutations = jnp.array(jnp.meshgrid(solutions[:, 0], solutions[:, 1])).T.reshape(-1, 2)
new_generation = random.choice(key,
                               permutations,
                               shape=(100, 1),
                               replace=False,
                              )
new_generation = jnp.squeeze(new_generation)

In [130]:
def mutate(solutions, probability=0.01):
    """Implements mutations to create a new generation."""
    probabilities = random.uniform(key, shape=(solutions.shape[0], 2))
    mutated_solutions = solutions * -1
    new_solutions = jnp.where(probabilities > probability, 
                                  solutions, 
                                  mutated_solutions)
    return new_solutions

In [143]:
def evolution(objective,
              generation_size=1000,
              limit_generations=5, 
              mutation_p=0.01,
              select_n=50,
              move_to_next=2, 
              seed=123
             ):
    # populate
    key = random.PRNGKey(seed)
    solutions = random.uniform(key, 
                               shape=(generation_size, 2), 
                               minval=0, 
                               maxval=10000
                          )
    
    for generation in range(limit_generations):
        selected_solutions = select(solutions, objective)
        
        # best_solution = objective(selected_solutions[0, 0], selected_solutions[0, 1])
        # if best_solution >= fitness_limit:
        #     break
        
        best_parents = selected_solutions[:move_to_next]
        
        offspring = crossover(selected_solutions, generation_size - move_to_next)
        new_generation = mutate(offspring, mutation_p)
        
        solutions = jnp.append(new_generation, best_parents, axis=0)
        
    function_values = objective(solutions[:, 0], solutions[:, 1])
    ranking = jnp.argsort(function_values) # increasing order
    ranked_solutions = solutions[ranking] 
    
    return ranked_solutions[0]

In [142]:
evolution(w37)

(102, 2)
(102, 2)
(102, 2)
(102, 2)
(102, 2)


Array([ 757.7789, 4222.5063], dtype=float32)