In [1]:
import multiprocessing
import os

import neat

import torch
import numpy as np

import gym
import pybullet_envs
import neat
import pickle
import random
import signal
import matplotlib.pyplot as plt
import time

from pytorch_neat.recurrent_net import RecurrentNet
from pytorch_neat.multi_env_eval import MultiEnvEvaluator
from pytorch_neat.neat_reporter import LogReporter

batch_size = 4
DEBUG = True

In [2]:
def make_env():
    return gym.make('AntBulletEnv-v0')


def make_net(genome, config, bs):
    return RecurrentNet.create(genome, config, bs)


def activate_net(net, states):
    outputs = net.activate(states).numpy()
    return outputs

def run(n_generations, n_processes):
    # Load the config file, which is assumed to live in
    # the same directory as this script.
    config_path = "data/config-humanoid"
    config = neat.Config(
        neat.DefaultGenome,
        neat.DefaultReproduction,
        neat.DefaultSpeciesSet,
        neat.DefaultStagnation,
        config_path,
    )
    
    evaluator = MultiEnvEvaluator(make_net, activate_net, make_env=make_env, max_env_steps=1000)
    
    pop = neat.Population(config)
    stats = neat.StatisticsReporter()
    pop.add_reporter(stats)
    reporter = neat.StdOutReporter(True)
    pop.add_reporter(reporter)
    pop.add_reporter(neat.Checkpointer(1, None, "saves/robotic_pytorch_checkpoint_"))

    if n_processes > 1:
        pool = multiprocessing.Pool(processes=n_processes)

        def eval_genomes(genomes, config):
            fitnesses = pool.starmap(
                evaluator.eval_genome, ((genome, config) for _, genome in genomes)
            )
            for (_, genome), fitness in zip(genomes, fitnesses):
                genome.fitness = fitness

    else:

        def eval_genomes(genomes, config):
            for i, (_, genome) in enumerate(genomes):
                try:
                    genome.fitness = evaluator.eval_genome(
                        genome, config, debug=DEBUG and i % 100 == 0
                    )
                except Exception as e:
                    print(genome)
                    raise e

    

    winner = pop.run(eval_genomes)

    with open('winnerParaTorchAnt.pkl', 'wb') as output:
        pickle.dump(winner, output, 1)
    
    print(winner)
    final_performance = evaluator.eval_genome(winner, config)
    print("Final performance: {}".format(final_performance))
    generations = reporter.generation + 1
    return generations

In [None]:
gens = run(5,8)
print('gens: '+str(gens))

In [3]:
popu = neat.Checkpointer.restore_checkpoint("saves/robotic_pytorch_checkpoint_245")
print(popu)

<neat.population.Population object at 0x7f69b8b7c0b8>
