Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enh/simplify node #178

Merged
merged 3 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions cgp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from .__version__ import __version__
from .genome import Genome
from .cartesian_graph import CartesianGraph
from .node import Add, ConstantFloat, Div, Mul, Parameter, Pow, Sub
from .node import OperatorNode
from .node_impl import (
Add,
ConstantFloat,
Div,
Mul,
Parameter,
Pow,
Sub,
)
from .population import Population

from .hl_api import evolve

from . import utils
from . import ea
from . import node_factories as node_factories
from . import local_search

from .individual import IndividualSingleGenome, IndividualMultiGenome
26 changes: 20 additions & 6 deletions cgp/cartesian_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import copy
import math # noqa: F401
import numpy as np # noqa: F401
import re

Expand All @@ -20,7 +21,8 @@

from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING

from .node import Node, InputNode, OutputNode, Parameter
from .node import Node, OperatorNode
from .node_input_output import InputNode, OutputNode

if TYPE_CHECKING:
from .genome import Genome
Expand Down Expand Up @@ -225,7 +227,7 @@ def _format_output_str_of_all_nodes(self) -> None:
node.format_output_str(self)

def _fill_parameter_values(self, func_str: str) -> str:
g = re.findall("<p[0-9]+>", func_str)
g = re.findall("<[a-z]+[0-9]+>", func_str)
if len(g) != 0:
for parameter_name in g:
func_str = func_str.replace(
Expand Down Expand Up @@ -322,9 +324,11 @@ def to_torch(self) -> "torch.nn.Module":
for hidden_column_idx in sorted(active_nodes_by_hidden_column_idx):
for node in active_nodes_by_hidden_column_idx[hidden_column_idx]:
node.format_output_str_torch(self)
if isinstance(node, Parameter):
node.format_parameter_str()
all_parameter_str.append(node.parameter_str)
if isinstance(node, OperatorNode):
if len(node._parameter_names) > 0:
node.format_parameter_str()
all_parameter_str.append(node.parameter_str)

forward_str = ", ".join(node.output_str for node in self.output_nodes)
class_str = """\
class _C(torch.nn.Module):
Expand Down Expand Up @@ -357,6 +361,16 @@ def forward(self, x):

return locals()["_c"]

def _format_output_str_sympy_of_all_nodes(self):

for i, node in enumerate(self.input_nodes):
node.format_output_str_sympy(self)

active_nodes = self._determine_active_nodes()
for hidden_column_idx in sorted(active_nodes):
for node in active_nodes[hidden_column_idx]:
node.format_output_str_sympy(self)

def to_sympy(self, simplify: Optional[bool] = True,) -> List["sympy_expr.Expr"]:
"""Compile the function(s) represented by the graph to a SymPy expression.

Expand All @@ -376,7 +390,7 @@ def to_sympy(self, simplify: Optional[bool] = True,) -> List["sympy_expr.Expr"]:
if not sympy_available:
raise ModuleNotFoundError("No module named 'sympy' (extra requirement)")

self._format_output_str_of_all_nodes()
self._format_output_str_sympy_of_all_nodes()

sympy_exprs = []
for output_node in self.output_nodes:
Expand Down
14 changes: 7 additions & 7 deletions cgp/genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Dict, Generator, List, Optional, Tuple, Type

from .cartesian_graph import CartesianGraph
from .node import Node, Parameter
from .node import Node, OperatorNode
from .primitives import Primitives


Expand Down Expand Up @@ -488,9 +488,9 @@ def _initialize_unkown_parameters(self) -> None:
for region_idx, region in self.iter_hidden_regions():
node_id = region[0]
node_type = self._primitives[node_id]
parameter_name = f"<p{region_idx}>"
if (
issubclass(node_type, Parameter)
and parameter_name not in self._parameter_names_to_values
):
self._parameter_names_to_values[parameter_name] = node_type.initial_value()
assert issubclass(node_type, OperatorNode)
for parameter_name in node_type._parameter_names:
parameter_name_with_idx = "<" + parameter_name[1:-1] + str(region_idx) + ">"
self._parameter_names_to_values[parameter_name_with_idx] = node_type.initial_value(
parameter_name_with_idx
)
Loading