In [1]:
import sys
sys.path.append('/home/jovyan/')
import numba
import copy
# from tests.utils.generate_genome import generate_genome_with_hidden_units


In [2]:
from itertools import product
import itertools
from neat.gene import NodeGene, ConnectionGene
from neat.genome import Genome
from config_files.configuration_utils import create_configuration



def generate_genome_with_hidden_units(n_input, n_output, n_hidden=3):
    # nodes
    node_genes = {}
    for i in range(n_output + n_hidden):
        node_genes = add_node(node_genes, key=i)

    # connections
    # input to hidden
    connection_genes = {}
    input_hidden_tuples = list(product(list(range(-1, -n_input-1, -1)),
                                       list(range(n_output, n_output+n_hidden))))
    for tuple_ in input_hidden_tuples:
        connection_genes = add_connection(connection_genes, key=tuple_)

    # hidden to output
    hidden_output_tuples = list(product(list(range(n_output, n_output + n_hidden)),
                                        list(range(0, n_output))))
    for tuple_ in hidden_output_tuples:
        connection_genes = add_connection(connection_genes, key=tuple_)

    # initialize genome
    genome = Genome(key=1)
    genome.node_genes = node_genes
    genome.connection_genes = connection_genes
    return genome



def add_node(node_genes, key):
    node_i = NodeGene(key=key)
    node_i.random_initialization()
    node_genes[key] = node_i
    return node_genes


def add_connection(connection_genes, key):
    connection_i = ConnectionGene(key=key)
    connection_i.random_initialization()
    connection_genes[key] = connection_i
    return connection_genes

In [3]:
def test_implementation():
    connections = [(1, 2), (2, 3), (3, 1)]
    if not exist_cycle_numba(connections):
        print(f'Error 1')
    

#     def test_exists_cycle_positive_b(self):
    connections = [(1, 2), (2, 3), (1, 4), (4, 3), (3, 1)]
    if not exist_cycle_numba(connections):
        print(f'Error 2')
#         self.assertEqual(True, exist_cycle_numba(connections))

#     def test_exists_cycle_when_negative_a(self):
    connections = [(1, 2), (2, 3), (3, 4)]
    if exist_cycle_numba(connections):
        print(f'Error 3')
#         self.assertEqual(False, exist_cycle_numba(connections))

#     def test_exists_cycle_when_negative_b(self):
    connections = [(1, 2), (2, 3), (1, 4), (4, 3), (3, 5)]
    if exist_cycle_numba(connections):
        print(f'Error 4')
#         self.assertEqual(False, exist_cycle_numba(connections))

#     def test_exists_cycle_when_negative_c(self):
    connections = [(-1, 1), (-1, 2), (-1, 3), (-2, 1), (-2, 2), (-2, 3), (1, 0), (2, 0), (3, 0), (1, 2)]
    if exist_cycle_numba(connections):
        print(f'Error 5')
#         self.assertEqual(False, exist_cycle_numba(connections))

#     def test_self_recursive(self):
    connections = [(1, 1)]
    if not exist_cycle_numba(connections):
        print(f'Error 6')
#         self.assertEqual(True, exist_cycle_numba(connections))

#     def test_self_recursive_b(self):
    connections = [(-1, 1), (-1, 2), (-1, 3), (-2, 1),  (-2, 2), (-2, 3), (1, 0), (2, 0), (3, 0), (2, 2)]
    if not exist_cycle_numba(connections):
        print(f'Error 7')
#         self.assertEqual(True, exist_cycle_numba(connections))

In [14]:
# @numba.jit(nopython=True)
# @numba.jit
def _go_through_graph(node_in, graph, past=[]):
    if node_in in graph.keys():
        for node in graph[node_in]:
            if node in past:
                return True
            else:
                past_copy = copy.deepcopy(past)
                past_copy.append(node_in)
                if _go_through_graph(node_in=node, graph=graph, past=past_copy):
                    return True
    else:
        return False
        

def exist_cycle_numba(connections: list) -> bool:
    # change data structure
    con = _get_connections_per_node(connections)
#     print(con)
    for node_in, nodes_out in con.items():
        if _go_through_graph(node_in, graph=con, past=[]):
            return True
    return False


def _get_connections_per_node(connections: list, inverse_order=False):
    '''
    :param connections: eg. ((-1, 1), (1, 2), (2, 3), (2, 4))
    :param inverse_order: whether it follows the input to output direction or the output to input direction
    :return: {-1: [1], 1: [2], 2: [3, 4]
    '''
    con = {}
    for connection in connections:
        input_node_key, output_node_key = connection
        if inverse_order:
            output_node_key, input_node_key = connection
        if input_node_key in con:
            con[input_node_key].append(output_node_key)
        else:
            con[input_node_key] = [output_node_key]
    return con

In [15]:
%timeit test_implementation()

137 µs ± 1.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [16]:
def _remove_connection_that_introduces_cycles(genome: Genome, possible_connection_set: set) -> set:
    connections_to_remove = []
    for connection in possible_connection_set:
        connections = list(genome.connection_genes.keys()) + [connection]

        if exist_cycle_numba(connections=connections):
            connections_to_remove.append(connection)
#     logger.debug(f'connections that introduce cycles: {connections_to_remove}')
    possible_connection_set -= set(connections_to_remove)
    return possible_connection_set

In [17]:
config = create_configuration(filename='/regression-miso.json')
out_node_key = 5
genome = generate_genome_with_hidden_units(n_input=5, n_output=10, n_hidden=100)

# all nodes
possible_input_keys_set = set(genome.node_genes.keys()).union(set(genome.get_input_nodes_keys()))

# no connection between two output nodes
possible_input_keys_set -= set(genome.get_output_nodes_keys())

if config.feed_forward:
    # avoid self-recurrency
    possible_input_keys_set -= {out_node_key}
    # pass

# REMOVE POSSIBLE CONNECTIONS
possible_connection_set = set(itertools.product(list(possible_input_keys_set), [out_node_key]))

# remove already existing connections: don't duplicate connections
possible_connection_set -= set(genome.connection_genes.keys())

In [19]:
%timeit _remove_connection_that_introduces_cycles(genome, possible_connection_set)

172 ms ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
