In [1]:
import jax.numpy as jnp

from pipeline import Pipeline
from algorithm.neat import *

from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
from utils import Act, Agg

if __name__ == "__main__":
    pipeline = Pipeline(
        algorithm=NEAT(
            species=DefaultSpecies(
                genome=DefaultGenome(
                    num_inputs=16,
                    num_outputs=4,
                    max_nodes=100,
                    max_conns=1000,
                    node_gene=DefaultNodeGene(
                        activation_default=Act.sigmoid,
                        activation_options=(Act.sigmoid, Act.relu, Act.tanh, Act.identity, Act.inv),
                        aggregation_default=Agg.sum,
                        aggregation_options=(Agg.sum, Agg.mean, Agg.max, Agg.product),
                    ),
                    mutation=DefaultMutation(
                        node_add=0.03,
                        conn_add=0.03,
                    )
                ),
                pop_size=10000,
                species_size=100,
                survival_threshold=0.01,
            ),
        ),
        problem=Jumanji_2048(
            max_step=1000,
        ),
        generation_limit=10000,
        fitness_target=13000,
    )

    # initialize state
    state = pipeline.setup()
    # print(state)
    # run until terminate
    state, best = pipeline.auto_run(state)

initializing
initializing finished
start compile
compile finished, cost time: 18.307454s
Generation: 1.0, Cost time: 4551.03ms
 	node counts: max: 21, min: 21, mean: 21.00
 	conn counts: max: 20, min: 20, mean: 20.00
 	species: 1, [10000]
 	fitness: valid cnt: 10000, max: 10124.0000, min: 44.0000, mean: 1758.1263, std: 1212.6823
Generation: 2.0, Cost time: 4636.33ms
 	node counts: max: 22, min: 21, mean: 21.03
 	conn counts: max: 22, min: 20, mean: 20.05
 	species: 1, [10000]
 	fitness: valid cnt: 10000, max: 11000.0000, min: 48.0000, mean: 1870.1300, std: 1263.3086
Generation: 3.0, Cost time: 6271.12ms
 	node counts: max: 23, min: 21, mean: 21.03
 	conn counts: max: 22, min: 20, mean: 20.05
 	species: 1, [10000]
 	fitness: valid cnt: 10000, max: 14624.0000, min: 28.0000, mean: 1943.9924, std: 1293.7146

Fitness limit reached!


In [3]:
genome = pipeline.algorithm.genome

In [4]:
transformed = genome.transform(state, *best)

In [5]:
def policy(board):
    action_scores = genome.forward(state, transformed, board)
    return action_scores

In [11]:
import jax, jumanji

env = jumanji.make("Game2048-v1")
key = jax.random.PRNGKey(48)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
state, timestep = jax.jit(env.reset)(key)
jit_policy = jax.jit(policy)
total_reward = 0
while True:
    board, action_mask = timestep["observation"]
    action = jit_policy(timestep["observation"][0].reshape(-1))
    score_with_mask = jnp.where(action_mask, action, -jnp.inf)
    action = jnp.argmax(score_with_mask)
    state, timestep = jit_step(state, action)
    done = jnp.all(~timestep["observation"][1])
    print(timestep)
    total_reward += timestep["reward"]
    if done:
        break
print(total_reward)

TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0]], dtype=int32), action_mask=Array([False,  True,  True,  True], dtype=bool)), extras={'highest_tile': Array(2, dtype=int32)})
TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 2],
       [0, 1, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True,  True,  True,  True], dtype=bool)), extras={'highest_tile': Array(4, dtype=int32)})
TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 1, 2],
       [0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0]], dtype=int32), action_mask=Array([False,  True,  True,  True], dtype=bool)), extr