Skip to content

Commit

Permalink
Merge pull request #110 from jakobj/maint/ordered-primitives
Browse files Browse the repository at this point in the history
Replace dictionary in Primitives with tuple and turn Primitives into frozen dataclass
  • Loading branch information
mschmidt87 committed May 22, 2020
2 parents e8794be + 9a00cf6 commit f25563e
Show file tree
Hide file tree
Showing 13 changed files with 127 additions and 140 deletions.
15 changes: 8 additions & 7 deletions gp/genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
n_columns: int,
n_rows: int,
levels_back: int,
primitives: List[Type[Node]],
primitives: Tuple[Type[Node], ...],
) -> None:
"""Init function.
Expand All @@ -39,8 +39,8 @@ def __init__(
levels_back : int
Number of previous columns that an entry in the genome can be
connected with.
primitives : List[Type[Node]]
List of primitives that the genome can refer to.
primitives : Tuple[Type[Node], ...]
Tuple of primitives that the genome can refer to.
"""
if n_inputs <= 0:
raise ValueError("n_inputs must be strictly positive")
Expand Down Expand Up @@ -136,7 +136,7 @@ def _create_random_hidden_region(
) -> List[int]:

region = []
node_id = self._primitives.sample(rng)
node_id = self._primitives.sample_allele(rng)
region.append(node_id)
region += list(rng.choice(permissible_inputs, self._primitives.max_arity))

Expand Down Expand Up @@ -232,7 +232,7 @@ def _validate_dna(self, dna: List[int]) -> None:

for region_idx, hidden_region in self.iter_hidden_regions(dna):

if hidden_region[0] not in self._primitives.alleles:
if not self._primitives.is_valid_allele(hidden_region[0]):
raise ValueError("function gene for hidden node has invalid value")

input_genes = hidden_region[1:]
Expand Down Expand Up @@ -404,7 +404,7 @@ def _mutate_hidden_region(
silent_mutation = region_idx not in active_regions

if self._is_function_gene(gene_idx):
self._dna[gene_idx] = self._primitives.sample(rng)
self._dna[gene_idx] = self._primitives.sample_allele(rng)
return silent_mutation

else:
Expand All @@ -425,13 +425,14 @@ def clone(self) -> "Genome":
-------
gp.Genome
"""

new = Genome(
self._n_inputs,
self._n_outputs,
self._n_columns,
self._n_rows,
self._levels_back,
list(self._primitives),
tuple(self._primitives),
)
new.dna = self._dna.copy()
return new
83 changes: 39 additions & 44 deletions gp/primitives.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,76 @@
from dataclasses import dataclass, field
import numpy as np

from typing import Iterator, List, Tuple, Type
from typing import Iterator, Tuple, Type

from .node import Node


@dataclass(frozen=True)
class Primitives:
"""Class collecting primitives of the Cartesian Genetic Programming framework.
"""

_max_arity = 0
_primitives: dict = {}

def __init__(self, primitives: List[Type[Node]]) -> None:
"""Init function.
"""Convenience class to manage primitives, i.e., Node classes.
Parameters
----------
primitives : List[Type[Node]]
List of primitives.
"""
for i in range(len(primitives)):
if not isinstance(primitives[i], type):
raise TypeError(f"expected class but received {type(primitives[i])}")
if not issubclass(primitives[i], Node):
raise TypeError(f"expected subclass of Node but received {primitives[i].__name__}")
"""

self._primitives = {}
for i in range(len(primitives)):
self._primitives[i] = primitives[i]
_primitives: Tuple[Type[Node], ...]
_max_arity: int = field(init=False)

# hide primitives dict behind MappingProxyType to make sure it
# is not changed after construction
# unfortunately not supported by pickle, necessary for
# multiprocessing; another way to implement this?
# self._primitives = types.MappingProxyType(self._primitives)
def __post_init__(self):
self._check_types()
self.__dict__[
"_max_arity"
] = self._determine_max_arity() # avoid using use __setattr_ since dataclass is frozen

self._determine_max_arity()
def _check_types(self):
if not isinstance(self._primitives, tuple):
raise TypeError(f"expected tuple but received {type(self._primitives)}")

def __iter__(self) -> Iterator:
return iter([self[i] for i in range(len(self._primitives))])
for i in range(len(self._primitives)):
if not isinstance(self._primitives[i], type):
raise TypeError(
f"expected class but received instance of {type(self._primitives[i])}"
)
if not issubclass(self._primitives[i], Node):
raise TypeError(
f"expected subclass of Node but received class {self._primitives[i].__name__}"
)

def _determine_max_arity(self) -> None:
def _determine_max_arity(self) -> int:

arity = 1 # minimal possible arity (output nodes need one input)

for idx, p in self._primitives.items():
for p in self._primitives:
if arity < p._arity:
arity = p._arity

self._max_arity = arity
return arity

def __iter__(self) -> Iterator[Type[Node]]:
return iter(self._primitives)

def sample(self, rng: np.random.RandomState) -> int:
"""Sample a random primitive.
def sample_allele(self, rng: np.random.RandomState) -> int:
"""Sample a random primitive index.
Parameters
----------
rng : numpy.RandomState
Random number generator instance to use for crossover.
Random number generator instance.
Returns
-------
int
Index of the sample primitive
Index of the sampled primitive.
"""
return rng.choice(self.alleles)
return rng.randint(len(self._primitives))

def __getitem__(self, key: int) -> Type[Node]:
if key < 0 or key >= len(self._primitives):
raise IndexError("primitive index out of bounds")
return self._primitives[key]

@property
def max_arity(self) -> int:
return self._max_arity

@property
def alleles(self) -> Tuple:
return tuple(self._primitives.keys())

def __len__(self):
return len(self._primitives)
def is_valid_allele(self, allele: int) -> bool:
return (allele >= 0) and (allele < len(self._primitives))
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def read_extra_requirements():
extra_requirements[req] = [req]
extra_requirements["all"].append(req)

extra_requirements[":python_version == '3.6'"] = ["dataclasses"]
return extra_requirements


Expand Down
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def genome_params():
"n_columns": 3,
"n_rows": 3,
"levels_back": 2,
"primitives": [gp.Add, gp.Sub, gp.ConstantFloat],
"primitives": (gp.Add, gp.Sub, gp.ConstantFloat),
}


Expand Down
32 changes: 16 additions & 16 deletions test/test_cartesian_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def test_direct_input_output():
params = {"n_inputs": 1, "n_outputs": 1, "n_columns": 3, "n_rows": 3, "levels_back": 2}
primitives = [gp.Add, gp.Sub]
primitives = (gp.Add, gp.Sub)
genome = gp.Genome(
params["n_inputs"],
params["n_outputs"],
Expand All @@ -29,7 +29,7 @@ def test_direct_input_output():


def test_to_func_simple():
primitives = [gp.Add]
primitives = (gp.Add,)
genome = gp.Genome(2, 1, 1, 1, 1, primitives)

genome.dna = [
Expand All @@ -54,7 +54,7 @@ def test_to_func_simple():

assert x[0] + x[1] == pytest.approx(y[0])

primitives = [gp.Sub]
primitives = (gp.Sub,)
genome = gp.Genome(2, 1, 1, 1, 1, primitives)

genome.dna = [
Expand All @@ -81,7 +81,7 @@ def test_to_func_simple():


def test_compile_two_columns():
primitives = [gp.Add, gp.Sub]
primitives = (gp.Add, gp.Sub)
genome = gp.Genome(2, 1, 2, 1, 1, primitives)

genome.dna = [
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_compile_two_columns():


def test_compile_two_columns_two_rows():
primitives = [gp.Add, gp.Sub]
primitives = (gp.Add, gp.Sub)
genome = gp.Genome(2, 2, 2, 2, 1, primitives)

genome.dna = [
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_compile_two_columns_two_rows():
def test_compile_addsubmul():
params = {"n_inputs": 2, "n_outputs": 1, "n_columns": 2, "n_rows": 2, "levels_back": 1}

primitives = [gp.Add, gp.Sub, gp.Mul]
primitives = (gp.Add, gp.Sub, gp.Mul)
genome = gp.Genome(
params["n_inputs"],
params["n_outputs"],
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_compile_addsubmul():


def test_to_numpy():
primitives = [gp.Add, gp.Mul, gp.ConstantFloat]
primitives = (gp.Add, gp.Mul, gp.ConstantFloat)
genome = gp.Genome(1, 1, 2, 2, 1, primitives)
# f(x) = x ** 2 + 1.
genome.dna = [
Expand Down Expand Up @@ -229,7 +229,7 @@ def test_to_numpy():


batch_sizes = [1, 10]
primitives = [gp.Mul, gp.ConstantFloat]
primitives = (gp.Mul, gp.ConstantFloat)
genomes = [gp.Genome(1, 1, 2, 2, 1, primitives) for i in range(2)]
# Function: f(x) = 1*x
genomes[0].dna = [
Expand Down Expand Up @@ -347,7 +347,7 @@ def test_compile_torch_output_shape(genome, batch_size):
def test_to_sympy():
sympy = pytest.importorskip("sympy")

primitives = [gp.Add, gp.ConstantFloat]
primitives = (gp.Add, gp.ConstantFloat)
genome = gp.Genome(1, 1, 2, 2, 1, primitives)

genome.dna = [
Expand Down Expand Up @@ -386,7 +386,7 @@ def test_to_sympy():
def test_catch_invalid_sympy_expr():
pytest.importorskip("sympy")

primitives = [gp.Sub, gp.Div]
primitives = (gp.Sub, gp.Div)
genome = gp.Genome(1, 1, 2, 1, 1, primitives)

# x[0] / (x[0] - x[0])
Expand All @@ -413,7 +413,7 @@ def test_catch_invalid_sympy_expr():
def test_allow_powers_of_x_0():
pytest.importorskip("sympy")

primitives = [gp.Sub, gp.Mul]
primitives = (gp.Sub, gp.Mul)
genome = gp.Genome(1, 1, 2, 1, 1, primitives)

# x[0] ** 2
Expand All @@ -438,7 +438,7 @@ def test_allow_powers_of_x_0():
def test_input_dim_python(rng_seed):
rng = np.random.RandomState(rng_seed)

genome = gp.Genome(2, 1, 1, 1, 1, [gp.ConstantFloat])
genome = gp.Genome(2, 1, 1, 1, 1, (gp.ConstantFloat,))
genome.randomize(rng)
f = gp.CartesianGraph(genome).to_func()

Expand All @@ -457,7 +457,7 @@ def test_input_dim_python(rng_seed):
def test_input_dim_numpy(rng_seed):
rng = np.random.RandomState(rng_seed)

genome = gp.Genome(2, 1, 1, 1, 1, [gp.ConstantFloat])
genome = gp.Genome(2, 1, 1, 1, 1, (gp.ConstantFloat,))
genome.randomize(rng)
f = gp.CartesianGraph(genome).to_numpy()

Expand All @@ -482,7 +482,7 @@ def test_input_dim_torch(rng_seed):

rng = np.random.RandomState(rng_seed)

genome = gp.Genome(2, 1, 1, 1, 1, [gp.ConstantFloat])
genome = gp.Genome(2, 1, 1, 1, 1, (gp.ConstantFloat,))
genome.randomize(rng)
f = gp.CartesianGraph(genome).to_torch()

Expand All @@ -503,7 +503,7 @@ def test_input_dim_torch(rng_seed):


def test_pretty_str():
primitives = [gp.Sub, gp.Mul]
primitives = (gp.Sub, gp.Mul)
genome = gp.Genome(1, 1, 2, 1, 1, primitives)

# x[0] ** 2
Expand Down Expand Up @@ -534,7 +534,7 @@ def test_pretty_str():


def test_pretty_str_with_unequal_inputs_rows_outputs():
primitives = [gp.Add]
primitives = (gp.Add,)

# less rows than inputs/outputs
genome = gp.Genome(1, 1, 1, 2, 1, primitives)
Expand Down
Loading

0 comments on commit f25563e

Please sign in to comment.