Skip to content

Commit

Permalink
[WIP] experimental way to access expression methods
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Sep 17, 2023
1 parent 6c92e1c commit ea78532
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pysr/__init__.py
Expand Up @@ -4,6 +4,7 @@
from .export_torch import sympy2torch
from .feynman_problems import FeynmanProblem, Problem
from .julia_helpers import install
from .pysr_expression import Expression
from .sr import PySRRegressor
from .version import __version__

Expand All @@ -15,6 +16,7 @@
"Problem",
"install",
"PySRRegressor",
"Expression",
"best",
"best_callable",
"best_row",
Expand Down
75 changes: 75 additions & 0 deletions pysr/pysr_expression.py
@@ -0,0 +1,75 @@
from numbers import Number
from typing import List, Optional

from .sr import PySRRegressor


class Expression:
"""A wrapper around `SymbolicRegression.Node`"""

def __init__(
self,
equation,
*,
model: PySRRegressor = None,
options=None,
variable_names: Optional[List[str]] = None,
):
super().__init__()
# exactly one of model and options is None:
assert (model is None) != (
options is None
), "Pass exactly one of model and options"

self.equation = equation
self.options = model.sr_options_ if options is None else options
self.variable_names = (
variable_names
if variable_names is not None
else (model.feature_names_in_ if model is not None else None)
)

from julia import Main, SymbolicRegression

self.julia_ = Main
self.backend_ = SymbolicRegression

@classmethod
def from_string(
cls,
s: str,
*,
model: PySRRegressor = None,
options=None,
variable_names: Optional[List[str]] = None,
):
self = cls(None, model=model, options=options, variable_names=variable_names)

for i, variable in enumerate(self.variable_names):
self.julia_.eval(f"{variable} = Node(feature={i + 1})")

self.julia_.last_options = self.options
self.julia_.eval("SymbolicRegression.@extend_operators last_options")

equation = self.julia_.eval(s)

if isinstance(equation, Number):
equation = self.julia_.eval(f"Node(val={equation})")

self.equation = equation

return self

def __repr__(self):
variable_names = (
list(self.variable_names) if self.variable_names is not None else None
)
return self.backend_.string_tree(
self.equation, self.options, variable_names=variable_names
)

def __call__(self, X):
return self.equation(X.T, self.options)

def compute_complexity(self) -> int:
return int(self.backend_.compute_complexity(self.equation, self.options))

0 comments on commit ea78532

Please sign in to comment.