Skip to content

Commit

Permalink
Fix issues with non-installed extra requirements
Browse files Browse the repository at this point in the history
Stringify annotations using optional dependencies
Mark test relying on optional dependencies as skip
Demote f-string to normal string
Fix flake8 issue in setup.py
  • Loading branch information
mschmidt87 committed May 26, 2020
1 parent 3bee89f commit c4519ec
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 9 deletions.
6 changes: 3 additions & 3 deletions cgp/cartesian_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def _f(x):

def to_torch(
self, parameter_names_to_values: Optional[Dict[str, float]] = None
) -> torch.nn.Module:
) -> "torch.nn.Module":
"""Compile the function(s) represented by the graph to a Torch class.
Generates a definition of the Torch class in Python code and
Expand All @@ -335,7 +335,7 @@ def to_torch(
node.format_parameter_str()
all_parameter_str.append(node.parameter_str)
forward_str = ", ".join(node.output_str for node in self.output_nodes)
class_str = f"""\
class_str = """\
class _C(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -370,7 +370,7 @@ def to_sympy(
self,
simplify: Optional[bool] = True,
parameter_names_to_values: Optional[Dict[str, float]] = None,
) -> List[sympy_expr.Expr]:
) -> List["sympy_expr.Expr"]:
"""Compile the function(s) represented by the graph to a SymPy expression.
Generates one SymPy expression for each output node.
Expand Down
6 changes: 3 additions & 3 deletions cgp/local_search/gradient_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

try:
import torch # noqa: F401
from torch.optim.optimizer import Optimizer
from torch.optim.optimizer import Optimizer # noqa: F401

torch_available = True
except ModuleNotFoundError:
Expand All @@ -16,10 +16,10 @@

def gradient_based(
individual: Individual,
objective: Callable[[torch.nn.Module], torch.Tensor],
objective: Callable[["torch.nn.Module"], "torch.Tensor"],
lr: float,
gradient_steps: int,
optimizer: Optional[Optimizer] = None,
optimizer: Optional["Optimizer"] = None,
clip_value: Optional[float] = None,
) -> None:
"""Perform a local search for numeric leaf values for an individual
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def read_extra_requirements():
extra_requirements = {}
extra_requirements["all"] = []
with open("./extra-requirements.txt") as f:
for l in f:
req = l.replace("\n", " ")
for dep in f:
req = dep.replace("\n", " ")
extra_requirements[req] = [req]
extra_requirements["all"].append(req)

Expand Down
17 changes: 16 additions & 1 deletion test/test_individual.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
import math
import pickle
import pytest
import torch

try:
import sympy # noqa: F401

sympy_available = True
except ModuleNotFoundError:
sympy_available = False
try:
import torch # noqa: F401

torch_available = True
except ModuleNotFoundError:
torch_available = False

import cgp
from cgp.individual import Individual
Expand Down Expand Up @@ -56,6 +68,7 @@ def test_individual_with_parameter_python():
assert y[0] == pytest.approx(x[0] + c)


@pytest.mark.skipif(not torch_available, reason="torch is not installed.")
def test_individual_with_parameter_torch():

primitives = (cgp.Add, cgp.Parameter)
Expand Down Expand Up @@ -96,6 +109,7 @@ def test_individual_with_parameter_torch():
assert y[1, 0].item() == pytest.approx(x[1, 0].item() + c)


@pytest.mark.skipif(not sympy_available, reason="sympy is not installed.")
def test_individual_with_parameter_sympy():

primitives = (cgp.Add, cgp.Parameter)
Expand Down Expand Up @@ -134,6 +148,7 @@ def test_individual_with_parameter_sympy():
assert y == pytest.approx(x[0] + c)


@pytest.mark.skipif(not torch_available, reason="torch is not installed.")
def test_to_and_from_torch_plus_backprop():
primitives = (cgp.Mul, cgp.Parameter)
genome = cgp.Genome(1, 1, 2, 2, 1, primitives)
Expand Down

0 comments on commit c4519ec

Please sign in to comment.