Skip to content

Commit

Permalink
Merge pull request #178 from jakobj/enh/simplify-node
Browse files Browse the repository at this point in the history
Enh/simplify node
  • Loading branch information
jakobj committed Jul 16, 2020
2 parents bd7e12d + f2c94b2 commit 29aa640
Show file tree
Hide file tree
Showing 11 changed files with 463 additions and 274 deletions.
12 changes: 10 additions & 2 deletions cgp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from .__version__ import __version__
from .genome import Genome
from .cartesian_graph import CartesianGraph
from .node import Add, ConstantFloat, Div, Mul, Parameter, Pow, Sub
from .node import OperatorNode
from .node_impl import (
Add,
ConstantFloat,
Div,
Mul,
Parameter,
Pow,
Sub,
)
from .population import Population

from .hl_api import evolve

from . import utils
from . import ea
from . import node_factories as node_factories
from . import local_search

from .individual import IndividualSingleGenome, IndividualMultiGenome
26 changes: 20 additions & 6 deletions cgp/cartesian_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import copy
import math # noqa: F401
import numpy as np # noqa: F401
import re

Expand All @@ -20,7 +21,8 @@

from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING

from .node import Node, InputNode, OutputNode, Parameter
from .node import Node, OperatorNode
from .node_input_output import InputNode, OutputNode

if TYPE_CHECKING:
from .genome import Genome
Expand Down Expand Up @@ -225,7 +227,7 @@ def _format_output_str_of_all_nodes(self) -> None:
node.format_output_str(self)

def _fill_parameter_values(self, func_str: str) -> str:
g = re.findall("<p[0-9]+>", func_str)
g = re.findall("<[a-z]+[0-9]+>", func_str)
if len(g) != 0:
for parameter_name in g:
func_str = func_str.replace(
Expand Down Expand Up @@ -322,9 +324,11 @@ def to_torch(self) -> "torch.nn.Module":
for hidden_column_idx in sorted(active_nodes_by_hidden_column_idx):
for node in active_nodes_by_hidden_column_idx[hidden_column_idx]:
node.format_output_str_torch(self)
if isinstance(node, Parameter):
node.format_parameter_str()
all_parameter_str.append(node.parameter_str)
if isinstance(node, OperatorNode):
if len(node._parameter_names) > 0:
node.format_parameter_str()
all_parameter_str.append(node.parameter_str)

forward_str = ", ".join(node.output_str for node in self.output_nodes)
class_str = """\
class _C(torch.nn.Module):
Expand Down Expand Up @@ -357,6 +361,16 @@ def forward(self, x):

return locals()["_c"]

def _format_output_str_sympy_of_all_nodes(self):

for i, node in enumerate(self.input_nodes):
node.format_output_str_sympy(self)

active_nodes = self._determine_active_nodes()
for hidden_column_idx in sorted(active_nodes):
for node in active_nodes[hidden_column_idx]:
node.format_output_str_sympy(self)

def to_sympy(self, simplify: Optional[bool] = True,) -> List["sympy_expr.Expr"]:
"""Compile the function(s) represented by the graph to a SymPy expression.
Expand All @@ -376,7 +390,7 @@ def to_sympy(self, simplify: Optional[bool] = True,) -> List["sympy_expr.Expr"]:
if not sympy_available:
raise ModuleNotFoundError("No module named 'sympy' (extra requirement)")

self._format_output_str_of_all_nodes()
self._format_output_str_sympy_of_all_nodes()

sympy_exprs = []
for output_node in self.output_nodes:
Expand Down
14 changes: 7 additions & 7 deletions cgp/genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Dict, Generator, List, Optional, Tuple, Type

from .cartesian_graph import CartesianGraph
from .node import Node, Parameter
from .node import Node, OperatorNode
from .primitives import Primitives


Expand Down Expand Up @@ -488,9 +488,9 @@ def _initialize_unkown_parameters(self) -> None:
for region_idx, region in self.iter_hidden_regions():
node_id = region[0]
node_type = self._primitives[node_id]
parameter_name = f"<p{region_idx}>"
if (
issubclass(node_type, Parameter)
and parameter_name not in self._parameter_names_to_values
):
self._parameter_names_to_values[parameter_name] = node_type.initial_value()
assert issubclass(node_type, OperatorNode)
for parameter_name in node_type._parameter_names:
parameter_name_with_idx = "<" + parameter_name[1:-1] + str(region_idx) + ">"
self._parameter_names_to_values[parameter_name_with_idx] = node_type.initial_value(
parameter_name_with_idx
)
Loading

0 comments on commit 29aa640

Please sign in to comment.