In [11]:
from __future__ import print_function
import os
import neat
import rubiks2
import numpy as np

env = None
difficulty = 14
def eval_genomes(genomes, config):
    global difficulty
    flag_increase = False
    for genome_id, genome in genomes:
        total_done = 0.0
        net = neat.nn.FeedForwardNetwork.create(genome, config)
        for i in range(100):
            state = env.reset(difficulty)
            
            local_done = 0.0
            for i in range(difficulty):
                action = np.argmax(net.activate(state))

                next_state, reward, done, info = env.step(int(action))
                state = next_state
                local_done += done

            total_done += local_done > 0.0
        genome.fitness = total_done/100.0
        if genome.fitness > 0.95:
            flag_increase = True
    if flag_increase:
        difficulty += 1
    print('Difficulty: ', difficulty)
    

def run(config_file):
    # Load configuration.
    config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
                         neat.DefaultSpeciesSet, neat.DefaultStagnation,
                         config_file)

    # Create the population, which is the top-level object for a NEAT run.
    p = neat.Population(config)

    # Add a stdout reporter to show progress in the terminal.
    p.add_reporter(neat.StdOutReporter(True))
    stats = neat.StatisticsReporter()
    p.add_reporter(stats)
    p.add_reporter(neat.Checkpointer(5))

    # Run for up to 300 generations.
    winner = p.run(eval_genomes, 1000)

    # Display the winning genome.
    print('\nBest genome:\n{!s}'.format(winner))

    # Show output of the most fit genome against training data.
    print('\nOutput:')
    winner_net = neat.nn.FeedForwardNetwork.create(winner, config)
    
    for i in range(6):
        env = rubiks2.RubiksEnv2(2, unsolved_reward = -1.0)
        env.step(int(i))
        state = env.get_observation()
        action = np.argmax(winner_net.activate(state))
        env.step(int(action))
        env.render()


if __name__ == '__main__':
    difficulty = 1
    env = rubiks2.RubiksEnv2(2, unsolved_reward = -1.0)
    # Determine path to configuration file. This path manipulation is
    # here so that the script will run successfully regardless of the
    # current working directory.
    local_dir = os.path.dirname('config-feedforward')
    config_path = os.path.join(local_dir, 'config-feedforward')
    run(config_path)


 ****** Running generation 0 ****** 

Difficulty:  1
Population's average fitness: 0.16440 stdev: 0.13781
Best fitness: 0.61000 - size: (6, 432) - species 1 - id 127
Average adjusted fitness: 0.164
Mean genetic distance 2.163, standard deviation 0.289
Population of 150 members in 1 species:
   ID   age  size  fitness  adj fit  stag
     1    0   150      0.6    0.164     0
Total extinctions: 0
Generation time: 4.266 sec

 ****** Running generation 1 ****** 

Difficulty:  1
Population's average fitness: 0.24580 stdev: 0.15179
Best fitness: 0.72000 - size: (6, 428) - species 1 - id 247
Average adjusted fitness: 0.246
Mean genetic distance 2.197, standard deviation 0.265
Population of 150 members in 1 species:
   ID   age  size  fitness  adj fit  stag
     1    1   150      0.7    0.246     0
Total extinctions: 0
Generation time: 4.425 sec (4.345 average)

 ****** Running generation 2 ****** 

Difficulty:  1
Population's average fitness: 0.30467 stdev: 0.14953
Best fitness: 0.67000 - siz

Difficulty:  2
Population's average fitness: 0.27207 stdev: 0.12909
Best fitness: 0.76000 - size: (6, 401) - species 1 - id 2405
Average adjusted fitness: 0.272
Mean genetic distance 1.891, standard deviation 0.302
Population of 150 members in 1 species:
   ID   age  size  fitness  adj fit  stag
     1   17   150      0.8    0.272    12
Total extinctions: 0
Generation time: 8.486 sec (8.624 average)

 ****** Running generation 18 ****** 

Difficulty:  2
Population's average fitness: 0.28253 stdev: 0.14339
Best fitness: 0.73000 - size: (6, 401) - species 1 - id 2405
Average adjusted fitness: 0.283
Mean genetic distance 1.896, standard deviation 0.239
Population of 150 members in 1 species:
   ID   age  size  fitness  adj fit  stag
     1   18   150      0.7    0.283    13
Total extinctions: 0
Generation time: 8.607 sec (8.702 average)

 ****** Running generation 19 ****** 

Difficulty:  2
Population's average fitness: 0.28460 stdev: 0.12609
Best fitness: 0.71000 - size: (6, 401) - speci

CompleteExtinctionException: 