Skip to content

Commit

Permalink
Merge 4db1c62 into 6a4fa2c
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed May 31, 2021
2 parents 6a4fa2c + 4db1c62 commit fa0233f
Show file tree
Hide file tree
Showing 7 changed files with 350 additions and 66 deletions.
12 changes: 9 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,23 @@ jobs:
python setup.py install
- name: "Install Coverage tool"
run: pip install coverage coveralls
- name: "Run tests"
run: coverage run --source=pysr --omit='*/feynman_problems.py' -m unittest test.test
shell: bash
- name: "Install JAX"
if: matrix.os != 'windows-latest'
run: pip install jax jaxlib # (optional import)
shell: bash
- name: "Run tests"
run: coverage run --source=pysr --omit='*/feynman_problems.py' -m unittest test.test
shell: bash
- name: "Run JAX tests"
if: matrix.os != 'windows-latest'
run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_jax
shell: bash
- name: "Install Torch"
run: pip install torch # (optional import)
shell: bash
- name: "Run Torch tests"
run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_torch
shell: bash
- name: Coveralls
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
3 changes: 2 additions & 1 deletion pysr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
from .feynman_problems import Problem, FeynmanProblem
from .export import sympy2jax
from .export_jax import sympy2jax
from .export_torch import sympy2torch
131 changes: 76 additions & 55 deletions pysr/export.py → pysr/export_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,59 +3,53 @@
import string
import random

try:
import jax
from jax import numpy as jnp
from jax.scipy import special as jsp

# Special since need to reduce arguments.
MUL = 0
ADD = 1

_jnp_func_lookup = {
sympy.Mul: MUL,
sympy.Add: ADD,
sympy.div: "jnp.div",
sympy.Abs: "jnp.abs",
sympy.sign: "jnp.sign",
# Note: May raise error for ints.
sympy.ceiling: "jnp.ceil",
sympy.floor: "jnp.floor",
sympy.log: "jnp.log",
sympy.exp: "jnp.exp",
sympy.sqrt: "jnp.sqrt",
sympy.cos: "jnp.cos",
sympy.acos: "jnp.acos",
sympy.sin: "jnp.sin",
sympy.asin: "jnp.asin",
sympy.tan: "jnp.tan",
sympy.atan: "jnp.atan",
sympy.atan2: "jnp.atan2",
# Note: Also may give NaN for complex results.
sympy.cosh: "jnp.cosh",
sympy.acosh: "jnp.acosh",
sympy.sinh: "jnp.sinh",
sympy.asinh: "jnp.asinh",
sympy.tanh: "jnp.tanh",
sympy.atanh: "jnp.atanh",
sympy.Pow: "jnp.power",
sympy.re: "jnp.real",
sympy.im: "jnp.imag",
sympy.arg: "jnp.angle",
# Note: May raise error for ints and complexes
sympy.erf: "jsp.erf",
sympy.erfc: "jsp.erfc",
sympy.LessThan: "jnp.less",
sympy.GreaterThan: "jnp.greater",
sympy.And: "jnp.logical_and",
sympy.Or: "jnp.logical_or",
sympy.Not: "jnp.logical_not",
sympy.Max: "jnp.max",
sympy.Min: "jnp.min",
sympy.Mod: "jnp.mod",
}
except ImportError:
...
MUL = 0
ADD = 1

_jnp_func_lookup = {
sympy.Mul: MUL,
sympy.Add: ADD,
sympy.div: "jnp.div",
sympy.Abs: "jnp.abs",
sympy.sign: "jnp.sign",
# Note: May raise error for ints.
sympy.ceiling: "jnp.ceil",
sympy.floor: "jnp.floor",
sympy.log: "jnp.log",
sympy.exp: "jnp.exp",
sympy.sqrt: "jnp.sqrt",
sympy.cos: "jnp.cos",
sympy.acos: "jnp.acos",
sympy.sin: "jnp.sin",
sympy.asin: "jnp.asin",
sympy.tan: "jnp.tan",
sympy.atan: "jnp.atan",
sympy.atan2: "jnp.atan2",
# Note: Also may give NaN for complex results.
sympy.cosh: "jnp.cosh",
sympy.acosh: "jnp.acosh",
sympy.sinh: "jnp.sinh",
sympy.asinh: "jnp.asinh",
sympy.tanh: "jnp.tanh",
sympy.atanh: "jnp.atanh",
sympy.Pow: "jnp.power",
sympy.re: "jnp.real",
sympy.im: "jnp.imag",
sympy.arg: "jnp.angle",
# Note: May raise error for ints and complexes
sympy.erf: "jsp.erf",
sympy.erfc: "jsp.erfc",
sympy.LessThan: "jnp.less",
sympy.GreaterThan: "jnp.greater",
sympy.And: "jnp.logical_and",
sympy.Or: "jnp.logical_or",
sympy.Not: "jnp.logical_not",
sympy.Max: "jnp.max",
sympy.Min: "jnp.min",
sympy.Mod: "jnp.mod",
}


def sympy2jaxtext(expr, parameters, symbols_in):
if issubclass(expr.func, sympy.Float):
Expand All @@ -75,7 +69,28 @@ def sympy2jaxtext(expr, parameters, symbols_in):
else:
return f'{_func}({", ".join(args)})'

def sympy2jax(equation, symbols_in):

jax_initialized = False
jax = None
jnp = None
jsp = None

def _initialize_jax():
global jax_initialized
global jax
global jnp
global jsp

if not jax_initialized:
import jax as _jax
from jax import numpy as _jnp
from jax.scipy import special as _jsp
jax = _jax
jnp = _jnp
jsp = _jsp


def sympy2jax(expression, symbols_in):
"""Returns a function f and its parameters;
the function takes an input matrix, and a list of arguments:
f(X, parameters)
Expand Down Expand Up @@ -146,9 +161,15 @@ def sympy2jax(equation, symbols_in):
# 3.5427954 , -2.7479894 ], dtype=float32)
```
"""
_initialize_jax()
global jax_initialized
global jax
global jnp
global jsp

parameters = []
functional_form_text = sympy2jaxtext(equation, parameters, symbols_in)
hash_string = 'A_' + str(abs(hash(str(equation) + str(symbols_in))))
functional_form_text = sympy2jaxtext(expression, parameters, symbols_in)
hash_string = 'A_' + str(abs(hash(str(expression) + str(symbols_in))))
text = f"def {hash_string}(X, parameters):\n"
text += " return "
text += functional_form_text
Expand Down
171 changes: 171 additions & 0 deletions pysr/export_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#####
# From https://github.com/patrick-kidger/sympytorch
# Copied here to allow PySR-specific tweaks
#####

import collections as co
import functools as ft
import sympy

def _reduce(fn):
def fn_(*args):
return ft.reduce(fn, args)
return fn_

torch_initialized = False
torch = None
_global_func_lookup = None
_Node = None
SingleSymPyModule = None

def _initialize_torch():
global torch_initialized
global torch
global _global_func_lookup
global _Node
global SingleSymPyModule

# Way to lazy load torch, only if this is called,
# but still allow this module to be loaded in __init__
if not torch_initialized:
import torch as _torch
torch = _torch

_global_func_lookup = {
sympy.Mul: _reduce(torch.mul),
sympy.Add: _reduce(torch.add),
sympy.div: torch.div,
sympy.Abs: torch.abs,
sympy.sign: torch.sign,
# Note: May raise error for ints.
sympy.ceiling: torch.ceil,
sympy.floor: torch.floor,
sympy.log: torch.log,
sympy.exp: torch.exp,
sympy.sqrt: torch.sqrt,
sympy.cos: torch.cos,
sympy.acos: torch.acos,
sympy.sin: torch.sin,
sympy.asin: torch.asin,
sympy.tan: torch.tan,
sympy.atan: torch.atan,
sympy.atan2: torch.atan2,
# Note: May give NaN for complex results.
sympy.cosh: torch.cosh,
sympy.acosh: torch.acosh,
sympy.sinh: torch.sinh,
sympy.asinh: torch.asinh,
sympy.tanh: torch.tanh,
sympy.atanh: torch.atanh,
sympy.Pow: torch.pow,
sympy.re: torch.real,
sympy.im: torch.imag,
sympy.arg: torch.angle,
# Note: May raise error for ints and complexes
sympy.erf: torch.erf,
sympy.loggamma: torch.lgamma,
sympy.Eq: torch.eq,
sympy.Ne: torch.ne,
sympy.StrictGreaterThan: torch.gt,
sympy.StrictLessThan: torch.lt,
sympy.LessThan: torch.le,
sympy.GreaterThan: torch.ge,
sympy.And: torch.logical_and,
sympy.Or: torch.logical_or,
sympy.Not: torch.logical_not,
sympy.Max: torch.max,
sympy.Min: torch.min,
# Matrices
sympy.MatAdd: torch.add,
sympy.HadamardProduct: torch.mul,
sympy.Trace: torch.trace,
# Note: May raise error for integer matrices.
sympy.Determinant: torch.det,
}

class _Node(torch.nn.Module):
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
super().__init__(**kwargs)

self._sympy_func = expr.func

if issubclass(expr.func, sympy.Float):
self._value = torch.nn.Parameter(torch.tensor(float(expr)))
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.UnevaluatedExpr):
if len(expr.args) != 1 or not issubclass(expr.args[0].func, sympy.Float):
raise ValueError("UnevaluatedExpr should only be used to wrap floats.")
self.register_buffer('_value', torch.tensor(float(expr.args[0])))
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Integer):
# Can get here if expr is one of the Integer special cases,
# e.g. NegativeOne
self._value = int(expr)
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Symbol):
self._name = expr.name
self._torch_func = lambda value: value
self._args = ((lambda memodict: memodict[expr.name]),)
else:
self._torch_func = _func_lookup[expr.func]
args = []
for arg in expr.args:
try:
arg_ = _memodict[arg]
except KeyError:
arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs)
_memodict[arg] = arg_
args.append(arg_)
self._args = torch.nn.ModuleList(args)

def forward(self, memodict):
args = []
for arg in self._args:
try:
arg_ = memodict[arg]
except KeyError:
arg_ = arg(memodict)
memodict[arg] = arg_
args.append(arg_)
return self._torch_func(*args)


class SingleSymPyModule(torch.nn.Module):
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
def __init__(self, expression, symbols_in,
extra_funcs=None, **kwargs):
super().__init__(**kwargs)

if extra_funcs is None:
extra_funcs = {}
_func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)

_memodict = {}
self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
self._expression_string = str(expression)
self.symbols_in = [str(symbol) for symbol in symbols_in]

def __repr__(self):
return f"{type(self).__name__}(expression={self._expression_string})"

def forward(self, X):
symbols = {symbol: X[:, i]
for i, symbol in enumerate(self.symbols_in)}
return self._node(symbols)


def sympy2torch(expression, symbols_in, extra_torch_mappings=None):
"""Returns a module for a given sympy expression with trainable parameters;
This function will assume the input to the module is a matrix X, where
each column corresponds to each symbol you pass in `symbols_in`.
"""
global SingleSymPyModule

_initialize_torch()

return SingleSymPyModule(expression, symbols_in, extra_funcs=extra_torch_mappings)

0 comments on commit fa0233f

Please sign in to comment.