In [6]:
import jax.numpy as jnp

from pipeline import Pipeline
from algorithm.neat import *
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse

from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
from tensorneat.utils import Act, Agg

pipeline = Pipeline(
    algorithm=NEAT(
        species=DefaultSpecies(
            genome=DefaultGenome(
                num_inputs=16,
                num_outputs=4,
                max_nodes=100,
                max_conns=1000,
                node_gene=NodeGeneWithoutResponse(
                    activation_default=Act.sigmoid,
                    activation_options=(
                        Act.sigmoid,
                        Act.relu,
                        Act.tanh,
                        Act.identity,
                    ),
                    aggregation_default=Agg.sum,
                    aggregation_options=(Agg.sum,),
                    activation_replace_rate=0.02,
                    aggregation_replace_rate=0.02,
                    bias_mutate_rate=0.03,
                    bias_init_std=0.5,
                    bias_mutate_power=0.2,
                    bias_replace_rate=0.01,
                ),
                conn_gene=DefaultConnGene(
                    weight_mutate_rate=0.015,
                    weight_replace_rate=0.003,
                    weight_mutate_power=0.5,
                ),
                mutation=DefaultMutation(
                    node_add=0.1, conn_add=0.2, conn_delete=0.2
                ),
            ),
            pop_size=1000,
            species_size=5,
            survival_threshold=0.1,
            max_stagnation=7,
            genome_elitism=3,
            compatibility_threshold=1.2,
        ),
    ),
    problem=Jumanji_2048(max_step=10000, repeat_times=5),
    generation_limit=100,
    fitness_target=13000,
    save_path="2048.pkl",
)
state = pipeline.setup()

initializing
initializing finished


In [7]:
import numpy as np

data = np.load('2048.npz')
nodes, conns = data['nodes'], data['conns']

In [8]:
genome = pipeline.algorithm.species.genome
transformed = genome.transform(state, nodes, conns)

In [9]:
def policy(board):
    action_scores = genome.forward(state, transformed, board)
    return action_scores

In [14]:
import jax, jumanji

env = jumanji.make("Game2048-v1")
key = jax.random.PRNGKey(0)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
state, timestep = jax.jit(env.reset)(key)
jit_policy = jax.jit(policy)
total_reward = 0
while True:
    board, action_mask = timestep["observation"]
    action = jit_policy(timestep["observation"][0].reshape(-1))
    score_with_mask = jnp.where(action_mask, action, -jnp.inf)
    action = jnp.argmax(score_with_mask)
    state, timestep = jit_step(state, action)
    done = jnp.all(~timestep["observation"][1])
    print(timestep)
    total_reward += timestep["reward"]
    if done:
        break
print(total_reward)

TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 1, 0],
       [1, 0, 0, 0]], dtype=int32), action_mask=Array([ True,  True,  True,  True], dtype=bool)), extras={'highest_tile': Array(2, dtype=int32)})
TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0],
       [1, 0, 1, 0]], dtype=int32), action_mask=Array([ True,  True,  True,  True], dtype=bool)), extras={'highest_tile': Array(2, dtype=int32)})
TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 1],
       [1, 1, 1, 0]], dtype=int32), action_mask=Array([ True,  True,  True,  True], dtype=bool)), extr

In [17]:
import jax.random
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048


def random_policy(state, params, obs):
    key = jax.random.key(obs.sum())
    actions = jax.random.normal(key, (4,))
    return actions

problem = Jumanji_2048(max_step=10000, repeat_times=10, guarantee_invalid_action=True)
state = problem.setup()
jit_evaluate = jax.jit(lambda state, randkey: problem.evaluate(state, randkey, random_policy, None))

In [24]:

reward = jit_evaluate(state, randkey)
print(reward)

1193.2001


In [34]:
randkey = jax.random.PRNGKey(14)
jit_policy = jax.jit(random_policy)
total_reward = 0
state, timestep = jax.jit(env.reset)(randkey )
while True:
    board, action_mask = timestep["observation"]
    action = jit_policy(None, None, timestep["observation"][0].reshape(-1))
    score_with_mask = jnp.where(action_mask, action, -jnp.inf)
    action = jnp.argmax(score_with_mask)
    state, timestep = jit_step(state, action)
    done = jnp.all(~timestep["observation"][1])
    print(timestep)
    total_reward += timestep["reward"]
    if done:
        break
print(total_reward)

TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 1, 0, 1]], dtype=int32), action_mask=Array([ True,  True, False,  True], dtype=bool)), extras={'highest_tile': Array(2, dtype=int32)})
TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 1],
       [1, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True,  True,  True,  True], dtype=bool)), extras={'highest_tile': Array(2, dtype=int32)})
TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0],
       [1, 1, 1, 1]], dtype=int32), action_mask=Array([ True,  True, False,  True], dtype=bool)), extr