Skip to content

Commit

Permalink
Merge 4db51e8 into a149430
Browse files Browse the repository at this point in the history
  • Loading branch information
mschmidt87 committed May 12, 2020
2 parents a149430 + 4db51e8 commit 9a2d0dd
Show file tree
Hide file tree
Showing 16 changed files with 866 additions and 265 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ install:
- pip install pytest pytest-cov coveralls
- pip install flake8 black
- pip install .[all]
- pip install mypy
script:
- pytest --cov
- black --check .
- flake8 --config=.flake8 .
- mypy gp
after_success:
- coveralls
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ python-gp
[![Python3.8](https://img.shields.io/badge/python-3.8-red.svg)](https://www.python.org/)
[![GPL license](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/old-licenses/gpl-3.0.html)
[![Build Status](https://api.travis-ci.org/Happy-Algorithms-League/python-gp.svg?branch=master)](https://travis-ci.org/Happy-Algorithms-League/python-gp)

[![Checked with mypy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/)

Cartesian Genetic Programming (CGP) in Python.

Expand Down
106 changes: 64 additions & 42 deletions gp/cartesian_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,60 +4,66 @@

try:
import sympy
from sympy.core import expr as sympy_expr

sympy_available = True
except ModuleNotFoundError:
sympy = None
sympy_available = False

try:
import torch # noqa: F401

torch_available = True
except ModuleNotFoundError:
torch = None
torch_available = False


from .node import InputNode, OutputNode, Parameter
from .node import Node, InputNode, OutputNode, Parameter
from .genome import Genome
from typing import Callable, Dict, List, Optional, Set


class CartesianGraph:
"""Class representing a particular Cartesian graph defined by a
Genome.
"""

def __init__(self, genome):
def __init__(self, genome: Genome) -> None:
"""Init function.
Parameters
----------
genome: Genome
Genome defining graph connectivity and node operations.
"""
self._n_outputs = None
self._n_inputs = None
self._n_columns = None
self._n_rows = None
self._nodes = None
self._gnome = None
self._n_outputs: int
self._n_inputs: int
self._n_columns: int
self._n_rows: int
self._nodes: List

self.parse_genome(genome)
self._genome = genome

def __repr__(self):
def __repr__(self) -> str:
return "CartesianGraph(" + str(self._nodes) + ")"

def print_active_nodes(self):
def print_active_nodes(self) -> str:
"""Print a representation of all active nodes in the graph.
"""
return "CartesianGraph(" + str([node for node in self._nodes if node._active]) + ")"

def pretty_str(self):
def pretty_str(self) -> str:
"""Print a pretty representation of the Cartesian graph.
"""
n_characters = 24

def pretty_node_str(node):
def pretty_node_str(node: Node) -> str:
s = node.pretty_str(n_characters)
assert len(s) == n_characters
return s

def empty_node_str():
def empty_node_str() -> str:
return " " * n_characters

s = "\n"
Expand Down Expand Up @@ -89,7 +95,7 @@ def empty_node_str():

return s

def parse_genome(self, genome):
def parse_genome(self, genome: Genome) -> None:
if genome.dna is None:
raise RuntimeError("dna not initialized")

Expand Down Expand Up @@ -117,22 +123,22 @@ def parse_genome(self, genome):

self._determine_active_nodes()

def _hidden_column_idx(self, idx):
def _hidden_column_idx(self, idx: int) -> int:
return (idx - self._n_inputs) // self._n_rows

@property
def input_nodes(self):
def input_nodes(self) -> List[Node]:
return self._nodes[: self._n_inputs]

@property
def hidden_nodes(self):
def hidden_nodes(self) -> List[Node]:
return self._nodes[self._n_inputs : -self._n_outputs]

@property
def output_nodes(self):
def output_nodes(self) -> List[Node]:
return self._nodes[-self._n_outputs :]

def _determine_active_nodes(self):
def _determine_active_nodes(self) -> Dict[int, Set[Node]]:
"""Determine the active nodes in the graph.
Starting from the output nodes, we work backward through the
Expand All @@ -145,7 +151,7 @@ def _determine_active_nodes(self):
Returns
-------
Dict[Set]
Dict[int, Set[Node]]
Dictionary mapping colum indices to sets of active nodes.
"""
Expand All @@ -167,7 +173,7 @@ def _determine_active_nodes(self):

return active_nodes_by_hidden_column_idx

def determine_active_regions(self):
def determine_active_regions(self) -> List[int]:
"""Determine the active regions in the computational graph.
Returns
Expand All @@ -183,7 +189,7 @@ def determine_active_regions(self):

return active_regions

def __call__(self, x):
def __call__(self, x: List[float]) -> List[float]:
# store values of x in input nodes
for i, xi in enumerate(x):
assert isinstance(self._nodes[i], InputNode)
Expand All @@ -197,16 +203,16 @@ def __call__(self, x):

return [node._output for node in self.output_nodes]

def __getitem__(self, key):
def __getitem__(self, key: int) -> Node:
return self._nodes[key]

def to_str(self):
def to_str(self) -> str:

self._format_output_str_of_all_nodes()
out_str = ", ".join(node.output_str for node in self.output_nodes)
return f"[{out_str}]"

def _format_output_str_of_all_nodes(self):
def _format_output_str_of_all_nodes(self) -> None:

for i, node in enumerate(self.input_nodes):
node.format_output_str(self)
Expand All @@ -216,17 +222,25 @@ def _format_output_str_of_all_nodes(self):
for node in active_nodes[hidden_column_idx]:
node.format_output_str(self)

def _fill_parameter_values(self, func_str, parameter_names_to_values):
def _fill_parameter_values(
self, func_str: str, parameter_names_to_values: Optional[Dict[str, float]] = None
) -> str:
g = re.findall("<p[0-9]+>", func_str)
if g and parameter_names_to_values is None:
raise ValueError("parameter node found but no value dict provided")
for parameter_name in g:
func_str = func_str.replace(
parameter_name, str(parameter_names_to_values[parameter_name])
)
if len(g) != 0:
if parameter_names_to_values is None:
raise ValueError("parameter node found but no value dict provided")
elif parameter_names_to_values is not None:
for parameter_name in g:
func_str = func_str.replace(
parameter_name, str(parameter_names_to_values[parameter_name])
)
else:
pass
return func_str

def to_func(self, parameter_names_to_values=None):
def to_func(
self, parameter_names_to_values: Optional[Dict[str, float]] = None
) -> Callable[[List[float]], List[float]]:
"""Compile the function(s) represented by the graph.
Generates a definition of the function in Python code and
Expand Down Expand Up @@ -259,7 +273,9 @@ def _format_output_str_numpy_of_all_nodes(self):
for node in active_nodes[hidden_column_idx]:
node.format_output_str_numpy(self)

def to_numpy(self, parameter_names_to_values=None):
def to_numpy(
self, parameter_names_to_values: Optional[Dict[str, float]] = None
) -> Callable[[np.ndarray], np.ndarray]:
"""Compile the function(s) represented by the graph to NumPy
expression(s).
Expand Down Expand Up @@ -293,7 +309,9 @@ def _f(x):

return locals()["_f"]

def to_torch(self, parameter_names_to_values=None):
def to_torch(
self, parameter_names_to_values: Optional[Dict[str, float]] = None
) -> torch.nn.Module:
"""Compile the function(s) represented by the graph to a Torch class.
Generates a definition of the Torch class in Python code and
Expand All @@ -304,7 +322,7 @@ def to_torch(self, parameter_names_to_values=None):
torch.nn.Module
Instance of the PyTorch class.
"""
if torch is None:
if not torch_available:
raise ModuleNotFoundError("No module named 'torch' (extra requirement)")

for i, node in enumerate(self.input_nodes):
Expand Down Expand Up @@ -350,7 +368,11 @@ def forward(self, x):

return locals()["_c"]

def to_sympy(self, simplify=True, parameter_names_to_values=None):
def to_sympy(
self,
simplify: Optional[bool] = True,
parameter_names_to_values: Optional[Dict[str, float]] = None,
) -> List[sympy_expr.Expr]:
"""Compile the function(s) represented by the graph to a SymPy expression.
Generates one SymPy expression for each output node.
Expand All @@ -363,13 +385,13 @@ def to_sympy(self, simplify=True, parameter_names_to_values=None):
Returns
----------
List[sympy.core.Expr]
List[sympy.core.expr.Expr]
List of SymPy expressions.
"""
if sympy is None:
if not sympy_available:
raise ModuleNotFoundError("No module named 'sympy' (extra requirement)")

def _validate_sympy_expr(expr):
def _validate_sympy_expr(expr: sympy_expr.Expr) -> sympy_expr.Expr:
"""Helper function that raises an exception upon encountering a SymPy
expression that can not be evaluated.
Expand Down
Loading

0 comments on commit 9a2d0dd

Please sign in to comment.