Skip to content

Commit

Permalink
Simplify definition of nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobj committed Jul 16, 2020
1 parent 36c25d5 commit 6a67d2b
Show file tree
Hide file tree
Showing 8 changed files with 419 additions and 196 deletions.
11 changes: 10 additions & 1 deletion cgp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from .__version__ import __version__
from .genome import Genome
from .cartesian_graph import CartesianGraph
from .node import Add, ConstantFloat, Div, Mul, Parameter, Pow, Sub
from .node import OperatorNode
from .node_impl import (
Add,
ConstantFloat,
Div,
Mul,
Parameter,
Pow,
Sub,
)
from .population import Population

from .hl_api import evolve
Expand Down
26 changes: 20 additions & 6 deletions cgp/cartesian_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import copy
import math # noqa: F401
import numpy as np # noqa: F401
import re

Expand All @@ -20,7 +21,8 @@

from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING

from .node import Node, InputNode, OutputNode, Parameter
from .node import Node, OperatorNode
from .node_input_output import InputNode, OutputNode

if TYPE_CHECKING:
from .genome import Genome
Expand Down Expand Up @@ -225,7 +227,7 @@ def _format_output_str_of_all_nodes(self) -> None:
node.format_output_str(self)

def _fill_parameter_values(self, func_str: str) -> str:
g = re.findall("<p[0-9]+>", func_str)
g = re.findall("<[a-z]+[0-9]+>", func_str)
if len(g) != 0:
for parameter_name in g:
func_str = func_str.replace(
Expand Down Expand Up @@ -322,9 +324,11 @@ def to_torch(self) -> "torch.nn.Module":
for hidden_column_idx in sorted(active_nodes_by_hidden_column_idx):
for node in active_nodes_by_hidden_column_idx[hidden_column_idx]:
node.format_output_str_torch(self)
if isinstance(node, Parameter):
node.format_parameter_str()
all_parameter_str.append(node.parameter_str)
if isinstance(node, OperatorNode):
if len(node._parameter_names) > 0:
node.format_parameter_str()
all_parameter_str.append(node.parameter_str)

forward_str = ", ".join(node.output_str for node in self.output_nodes)
class_str = """\
class _C(torch.nn.Module):
Expand Down Expand Up @@ -357,6 +361,16 @@ def forward(self, x):

return locals()["_c"]

def _format_output_str_sympy_of_all_nodes(self):

for i, node in enumerate(self.input_nodes):
node.format_output_str_sympy(self)

active_nodes = self._determine_active_nodes()
for hidden_column_idx in sorted(active_nodes):
for node in active_nodes[hidden_column_idx]:
node.format_output_str_sympy(self)

def to_sympy(self, simplify: Optional[bool] = True,) -> List["sympy_expr.Expr"]:
"""Compile the function(s) represented by the graph to a SymPy expression.
Expand All @@ -376,7 +390,7 @@ def to_sympy(self, simplify: Optional[bool] = True,) -> List["sympy_expr.Expr"]:
if not sympy_available:
raise ModuleNotFoundError("No module named 'sympy' (extra requirement)")

self._format_output_str_of_all_nodes()
self._format_output_str_sympy_of_all_nodes()

sympy_exprs = []
for output_node in self.output_nodes:
Expand Down
14 changes: 7 additions & 7 deletions cgp/genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Dict, Generator, List, Optional, Tuple, Type

from .cartesian_graph import CartesianGraph
from .node import Node, Parameter
from .node import Node, OperatorNode
from .primitives import Primitives


Expand Down Expand Up @@ -488,9 +488,9 @@ def _initialize_unkown_parameters(self) -> None:
for region_idx, region in self.iter_hidden_regions():
node_id = region[0]
node_type = self._primitives[node_id]
parameter_name = f"<p{region_idx}>"
if (
issubclass(node_type, Parameter)
and parameter_name not in self._parameter_names_to_values
):
self._parameter_names_to_values[parameter_name] = node_type.initial_value()
assert issubclass(node_type, OperatorNode)
for parameter_name in node_type._parameter_names:
parameter_name_with_idx = "<" + parameter_name[1:-1] + str(region_idx) + ">"
self._parameter_names_to_values[parameter_name_with_idx] = node_type.initial_value(
parameter_name_with_idx
)
Loading

0 comments on commit 6a67d2b

Please sign in to comment.