Skip to content

Commit

Permalink
Revert to old torch export
Browse files Browse the repository at this point in the history
- Installing a separate but optional library with dependency on torch introduced
too many difficulties. In the end, the simplest solution is to just
maintain a separate codebase here.
  • Loading branch information
MilesCranmer committed Jun 1, 2021
1 parent 3bea8e3 commit d18011f
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_jax
shell: bash
- name: "Install Torch"
run: pip install torch sympytorch # (optional import)
run: pip install torch # (optional import)
shell: bash
- name: "Run Torch tests"
run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_torch
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/CI_Windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
run: python -m unittest test.test
shell: bash
- name: "Install Torch"
run: pip install torch sympytorch # (optional import)
run: pip install torch # (optional import)
shell: bash
- name: "Run Torch tests"
run: python -m unittest test.test_torch
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/CI_mac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
run: python -m unittest test.test_jax
shell: bash
- name: "Install Torch"
run: pip install torch sympytorch # (optional import)
run: pip install torch # (optional import)
shell: bash
- name: "Run Torch tests"
run: python -m unittest test.test_torch
Expand Down
168 changes: 142 additions & 26 deletions pysr/export_torch.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,164 @@
#####
# From https://github.com/patrick-kidger/sympytorch
# Copied here to allow PySR-specific tweaks
#####

import collections as co
import functools as ft
import sympy

def _reduce(fn):
def fn_(*args):
return ft.reduce(fn, args)
return fn_

torch_initialized = False
torch = None
sympytorch = None
PySRTorchModule = None
_global_func_lookup = None
_Node = None
SingleSymPyModule = None

def _initialize_torch():
global torch_initialized
global torch
global sympytorch
global PySRTorchModule
global _global_func_lookup
global _Node
global SingleSymPyModule

# Way to lazy load torch and sympytorch, only if this is called,
# Way to lazy load torch, only if this is called,
# but still allow this module to be loaded in __init__
if not torch_initialized:
try:
import torch
import sympytorch
except ImportError:
raise ImportError("You need to pip install `torch` and `sympytorch` before exporting to pytorch.")
torch_initialized = True
import torch as _torch
torch = _torch

_global_func_lookup = {
sympy.Mul: _reduce(torch.mul),
sympy.Add: _reduce(torch.add),
sympy.div: torch.div,
sympy.Abs: torch.abs,
sympy.sign: torch.sign,
# Note: May raise error for ints.
sympy.ceiling: torch.ceil,
sympy.floor: torch.floor,
sympy.log: torch.log,
sympy.exp: torch.exp,
sympy.sqrt: torch.sqrt,
sympy.cos: torch.cos,
sympy.acos: torch.acos,
sympy.sin: torch.sin,
sympy.asin: torch.asin,
sympy.tan: torch.tan,
sympy.atan: torch.atan,
sympy.atan2: torch.atan2,
# Note: May give NaN for complex results.
sympy.cosh: torch.cosh,
sympy.acosh: torch.acosh,
sympy.sinh: torch.sinh,
sympy.asinh: torch.asinh,
sympy.tanh: torch.tanh,
sympy.atanh: torch.atanh,
sympy.Pow: torch.pow,
sympy.re: torch.real,
sympy.im: torch.imag,
sympy.arg: torch.angle,
# Note: May raise error for ints and complexes
sympy.erf: torch.erf,
sympy.loggamma: torch.lgamma,
sympy.Eq: torch.eq,
sympy.Ne: torch.ne,
sympy.StrictGreaterThan: torch.gt,
sympy.StrictLessThan: torch.lt,
sympy.LessThan: torch.le,
sympy.GreaterThan: torch.ge,
sympy.And: torch.logical_and,
sympy.Or: torch.logical_or,
sympy.Not: torch.logical_not,
sympy.Max: torch.max,
sympy.Min: torch.min,
# Matrices
sympy.MatAdd: torch.add,
sympy.HadamardProduct: torch.mul,
sympy.Trace: torch.trace,
# Note: May raise error for integer matrices.
sympy.Determinant: torch.det,
}

class _Node(torch.nn.Module):
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
super().__init__(**kwargs)

self._sympy_func = expr.func

if issubclass(expr.func, sympy.Float):
self._value = torch.nn.Parameter(torch.tensor(float(expr)))
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.UnevaluatedExpr):
if len(expr.args) != 1 or not issubclass(expr.args[0].func, sympy.Float):
raise ValueError("UnevaluatedExpr should only be used to wrap floats.")
self.register_buffer('_value', torch.tensor(float(expr.args[0])))
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Integer):
# Can get here if expr is one of the Integer special cases,
# e.g. NegativeOne
self._value = int(expr)
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Symbol):
self._name = expr.name
self._torch_func = lambda value: value
self._args = ((lambda memodict: memodict[expr.name]),)
else:
self._torch_func = _func_lookup[expr.func]
args = []
for arg in expr.args:
try:
arg_ = _memodict[arg]
except KeyError:
arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs)
_memodict[arg] = arg_
args.append(arg_)
self._args = torch.nn.ModuleList(args)

class PySRTorchModule(torch.nn.Module):
def __init__(self, *, expression, symbols_in,
def forward(self, memodict):
args = []
for arg in self._args:
try:
arg_ = memodict[arg]
except KeyError:
arg_ = arg(memodict)
memodict[arg] = arg_
args.append(arg_)
return self._torch_func(*args)


class SingleSymPyModule(torch.nn.Module):
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
def __init__(self, expression, symbols_in,
selection=None, extra_funcs=None, **kwargs):
super().__init__(**kwargs)
self._module = sympytorch.SymPyModule(
expressions=[expression],
extra_funcs=extra_funcs)
self._selection = selection
self._symbols = symbols_in

if extra_funcs is None:
extra_funcs = {}
_func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)

_memodict = {}
self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
self._expression_string = str(expression)
self._selection = selection
self.symbols_in = [str(symbol) for symbol in symbols_in]

def __repr__(self):
return f"{type(self).__name__}(expression={self._expression_string})"

def forward(self, X):
if self._selection is not None:
X = X[:, self._selection]
symbols = {str(symbol): X[:, i]
for i, symbol in enumerate(self._symbols)}
return self._module(**symbols)[..., 0]
symbols = {symbol: X[:, i]
for i, symbol in enumerate(self.symbols_in)}
return self._node(symbols)


def sympy2torch(expression, symbols_in,
Expand All @@ -51,11 +168,10 @@ def sympy2torch(expression, symbols_in,
This function will assume the input to the module is a matrix X, where
each column corresponds to each symbol you pass in `symbols_in`.
"""
global PySRTorchModule
global SingleSymPyModule

_initialize_torch()

return PySRTorchModule(expression=expression,
symbols_in=symbols_in,
selection=selection,
extra_funcs=extra_torch_mappings)
return SingleSymPyModule(expression, symbols_in,
selection=selection,
extra_funcs=extra_torch_mappings)

0 comments on commit d18011f

Please sign in to comment.