diff --git a/cgp/cartesian_graph.py b/cgp/cartesian_graph.py index e60d90d4..34b779f6 100644 --- a/cgp/cartesian_graph.py +++ b/cgp/cartesian_graph.py @@ -309,7 +309,7 @@ def _f(x): def to_torch( self, parameter_names_to_values: Optional[Dict[str, float]] = None - ) -> torch.nn.Module: + ) -> "torch.nn.Module": """Compile the function(s) represented by the graph to a Torch class. Generates a definition of the Torch class in Python code and @@ -335,7 +335,7 @@ def to_torch( node.format_parameter_str() all_parameter_str.append(node.parameter_str) forward_str = ", ".join(node.output_str for node in self.output_nodes) - class_str = f"""\ + class_str = """\ class _C(torch.nn.Module): def __init__(self): @@ -370,7 +370,7 @@ def to_sympy( self, simplify: Optional[bool] = True, parameter_names_to_values: Optional[Dict[str, float]] = None, - ) -> List[sympy_expr.Expr]: + ) -> List["sympy_expr.Expr"]: """Compile the function(s) represented by the graph to a SymPy expression. Generates one SymPy expression for each output node. diff --git a/cgp/local_search/gradient_based.py b/cgp/local_search/gradient_based.py index 10e4f4d1..4b3fdbf7 100644 --- a/cgp/local_search/gradient_based.py +++ b/cgp/local_search/gradient_based.py @@ -2,7 +2,7 @@ try: import torch # noqa: F401 - from torch.optim.optimizer import Optimizer + from torch.optim.optimizer import Optimizer # noqa: F401 torch_available = True except ModuleNotFoundError: @@ -16,10 +16,10 @@ def gradient_based( individual: Individual, - objective: Callable[[torch.nn.Module], torch.Tensor], + objective: Callable[["torch.nn.Module"], "torch.Tensor"], lr: float, gradient_steps: int, - optimizer: Optional[Optimizer] = None, + optimizer: Optional["Optimizer"] = None, clip_value: Optional[float] = None, ) -> None: """Perform a local search for numeric leaf values for an individual diff --git a/setup.py b/setup.py index 220eaefa..35c96488 100644 --- a/setup.py +++ b/setup.py @@ -13,8 +13,8 @@ def read_extra_requirements(): extra_requirements = {} extra_requirements["all"] = [] with open("./extra-requirements.txt") as f: - for l in f: - req = l.replace("\n", " ") + for dep in f: + req = dep.replace("\n", " ") extra_requirements[req] = [req] extra_requirements["all"].append(req) diff --git a/test/test_individual.py b/test/test_individual.py index d1cb7596..ef907b3f 100644 --- a/test/test_individual.py +++ b/test/test_individual.py @@ -1,7 +1,19 @@ import math import pickle import pytest -import torch + +try: + import sympy # noqa: F401 + + sympy_available = True +except ModuleNotFoundError: + sympy_available = False +try: + import torch # noqa: F401 + + torch_available = True +except ModuleNotFoundError: + torch_available = False import cgp from cgp.individual import Individual @@ -56,6 +68,7 @@ def test_individual_with_parameter_python(): assert y[0] == pytest.approx(x[0] + c) +@pytest.mark.skipif(not torch_available, reason="torch is not installed.") def test_individual_with_parameter_torch(): primitives = (cgp.Add, cgp.Parameter) @@ -96,6 +109,7 @@ def test_individual_with_parameter_torch(): assert y[1, 0].item() == pytest.approx(x[1, 0].item() + c) +@pytest.mark.skipif(not sympy_available, reason="sympy is not installed.") def test_individual_with_parameter_sympy(): primitives = (cgp.Add, cgp.Parameter) @@ -134,6 +148,7 @@ def test_individual_with_parameter_sympy(): assert y == pytest.approx(x[0] + c) +@pytest.mark.skipif(not torch_available, reason="torch is not installed.") def test_to_and_from_torch_plus_backprop(): primitives = (cgp.Mul, cgp.Parameter) genome = cgp.Genome(1, 1, 2, 2, 1, primitives)