-
Notifications
You must be signed in to change notification settings - Fork 195
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
350 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .sr import pysr, get_hof, best, best_tex, best_callable, best_row | ||
from .feynman_problems import Problem, FeynmanProblem | ||
from .export import sympy2jax | ||
from .export_jax import sympy2jax | ||
from .export_torch import sympy2torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
##### | ||
# From https://github.com/patrick-kidger/sympytorch | ||
# Copied here to allow PySR-specific tweaks | ||
##### | ||
|
||
import collections as co | ||
import functools as ft | ||
import sympy | ||
|
||
def _reduce(fn): | ||
def fn_(*args): | ||
return ft.reduce(fn, args) | ||
return fn_ | ||
|
||
torch_initialized = False | ||
torch = None | ||
_global_func_lookup = None | ||
_Node = None | ||
SingleSymPyModule = None | ||
|
||
def _initialize_torch(): | ||
global torch_initialized | ||
global torch | ||
global _global_func_lookup | ||
global _Node | ||
global SingleSymPyModule | ||
|
||
# Way to lazy load torch, only if this is called, | ||
# but still allow this module to be loaded in __init__ | ||
if not torch_initialized: | ||
import torch as _torch | ||
torch = _torch | ||
|
||
_global_func_lookup = { | ||
sympy.Mul: _reduce(torch.mul), | ||
sympy.Add: _reduce(torch.add), | ||
sympy.div: torch.div, | ||
sympy.Abs: torch.abs, | ||
sympy.sign: torch.sign, | ||
# Note: May raise error for ints. | ||
sympy.ceiling: torch.ceil, | ||
sympy.floor: torch.floor, | ||
sympy.log: torch.log, | ||
sympy.exp: torch.exp, | ||
sympy.sqrt: torch.sqrt, | ||
sympy.cos: torch.cos, | ||
sympy.acos: torch.acos, | ||
sympy.sin: torch.sin, | ||
sympy.asin: torch.asin, | ||
sympy.tan: torch.tan, | ||
sympy.atan: torch.atan, | ||
sympy.atan2: torch.atan2, | ||
# Note: May give NaN for complex results. | ||
sympy.cosh: torch.cosh, | ||
sympy.acosh: torch.acosh, | ||
sympy.sinh: torch.sinh, | ||
sympy.asinh: torch.asinh, | ||
sympy.tanh: torch.tanh, | ||
sympy.atanh: torch.atanh, | ||
sympy.Pow: torch.pow, | ||
sympy.re: torch.real, | ||
sympy.im: torch.imag, | ||
sympy.arg: torch.angle, | ||
# Note: May raise error for ints and complexes | ||
sympy.erf: torch.erf, | ||
sympy.loggamma: torch.lgamma, | ||
sympy.Eq: torch.eq, | ||
sympy.Ne: torch.ne, | ||
sympy.StrictGreaterThan: torch.gt, | ||
sympy.StrictLessThan: torch.lt, | ||
sympy.LessThan: torch.le, | ||
sympy.GreaterThan: torch.ge, | ||
sympy.And: torch.logical_and, | ||
sympy.Or: torch.logical_or, | ||
sympy.Not: torch.logical_not, | ||
sympy.Max: torch.max, | ||
sympy.Min: torch.min, | ||
# Matrices | ||
sympy.MatAdd: torch.add, | ||
sympy.HadamardProduct: torch.mul, | ||
sympy.Trace: torch.trace, | ||
# Note: May raise error for integer matrices. | ||
sympy.Determinant: torch.det, | ||
} | ||
|
||
class _Node(torch.nn.Module): | ||
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch""" | ||
def __init__(self, *, expr, _memodict, _func_lookup, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
self._sympy_func = expr.func | ||
|
||
if issubclass(expr.func, sympy.Float): | ||
self._value = torch.nn.Parameter(torch.tensor(float(expr))) | ||
self._torch_func = lambda: self._value | ||
self._args = () | ||
elif issubclass(expr.func, sympy.UnevaluatedExpr): | ||
if len(expr.args) != 1 or not issubclass(expr.args[0].func, sympy.Float): | ||
raise ValueError("UnevaluatedExpr should only be used to wrap floats.") | ||
self.register_buffer('_value', torch.tensor(float(expr.args[0]))) | ||
self._torch_func = lambda: self._value | ||
self._args = () | ||
elif issubclass(expr.func, sympy.Integer): | ||
# Can get here if expr is one of the Integer special cases, | ||
# e.g. NegativeOne | ||
self._value = int(expr) | ||
self._torch_func = lambda: self._value | ||
self._args = () | ||
elif issubclass(expr.func, sympy.Symbol): | ||
self._name = expr.name | ||
self._torch_func = lambda value: value | ||
self._args = ((lambda memodict: memodict[expr.name]),) | ||
else: | ||
self._torch_func = _func_lookup[expr.func] | ||
args = [] | ||
for arg in expr.args: | ||
try: | ||
arg_ = _memodict[arg] | ||
except KeyError: | ||
arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs) | ||
_memodict[arg] = arg_ | ||
args.append(arg_) | ||
self._args = torch.nn.ModuleList(args) | ||
|
||
def forward(self, memodict): | ||
args = [] | ||
for arg in self._args: | ||
try: | ||
arg_ = memodict[arg] | ||
except KeyError: | ||
arg_ = arg(memodict) | ||
memodict[arg] = arg_ | ||
args.append(arg_) | ||
return self._torch_func(*args) | ||
|
||
|
||
class SingleSymPyModule(torch.nn.Module): | ||
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch""" | ||
def __init__(self, expression, symbols_in, | ||
extra_funcs=None, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
if extra_funcs is None: | ||
extra_funcs = {} | ||
_func_lookup = co.ChainMap(_global_func_lookup, extra_funcs) | ||
|
||
_memodict = {} | ||
self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup) | ||
self._expression_string = str(expression) | ||
self.symbols_in = [str(symbol) for symbol in symbols_in] | ||
|
||
def __repr__(self): | ||
return f"{type(self).__name__}(expression={self._expression_string})" | ||
|
||
def forward(self, X): | ||
symbols = {symbol: X[:, i] | ||
for i, symbol in enumerate(self.symbols_in)} | ||
return self._node(symbols) | ||
|
||
|
||
def sympy2torch(expression, symbols_in, extra_torch_mappings=None): | ||
"""Returns a module for a given sympy expression with trainable parameters; | ||
This function will assume the input to the module is a matrix X, where | ||
each column corresponds to each symbol you pass in `symbols_in`. | ||
""" | ||
global SingleSymPyModule | ||
|
||
_initialize_torch() | ||
|
||
return SingleSymPyModule(expression, symbols_in, extra_funcs=extra_torch_mappings) |
Oops, something went wrong.