-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch export #48
PyTorch export #48
Conversation
cdf4798
to
134f41a
Compare
Seems reasonable as is. Although it may be worth leaving this as a separate package as per my comments in #35? No strong feelings. |
9269ce9
to
e7ede78
Compare
Thanks! Yeah, longterm it will probably be best to have a separate modular package for generic SymPy -> trainable modules for numpy/torch/jax/etc. But I guess in the short-term, PySR exports are a bit implementation-specific (for consistency, want an array |
I'd note that if you want an array then the approach I'd take would be something like class PySRSymPyModule(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
self._module = sympytorch.SymPyModule(**kwargs)
def forward(self, X):
symbols = {f"x{i}": xi for i, xi in enumerate(X)}
return self._module(**symbols) to simply wrap existing functionality rather than duplicating it wholesale. At the end of the day no strong feelings though. |
You are right; this is probably a cleaner way to do it. I tried implementing this on However, I realized an issue: there is no good way to deal with the install of the exporter libraries. There's two tricky scenarios:
I think unfortunately the only good way to deal with these is to put the sympy->torch code here (as in this pull request) along with the sympy->jax code, then do a lazy initialization only if the exporter code is called... |
So I think that should be doable as well. Pseudocode: def export(obj_to_export):
try:
import export_library
except ImportError as e:
raise ImportError("Please additionally install `export_library` to export to `deep_learning_library`.") from e
return export_library.export(obj_to_export) If you want you could additionally declare |
That sounds like a great way to do it. So, can one get |
Maybe this? import setuptools
import sys
...
setuptools.setup(
....
install_requires=( [
"numpy",
"pandas",
"sympy"
] + (
['git+https://github.com/patrick-kidger/sympytorch']
if 'torch' in sys.modules else []
)
),
...
) |
By the way, are you planning on putting sympytorch on PyPI? I can't seem to get it working in |
Argh, found some more problems that are specific to PySR. I'm just going to merge the implementation-specific version and try to set up an integration with a generic package later. Maybe for a jax and numpy exporter too. (Maybe all in one repo! Since sympy.lambdify doesn't let one get the parameters out) Also, can the 3.8 requirement for Python be reduced in sympytorch or is it needed? |
Worth noting that this:
won't work as that only checks if At your request, I've just put sympytorch on PyPI, and loosened the requirements to Python 3.7. Let me know if it all seems to work for you-- if it does I'll update the install instructions. |
Thanks!! This seems to have fixed all the issues! I now have sympytorch working on the master branch for PyTorch exports 👍 |
This uses (a slightly-modified version of) @patrick-kidger's sympytorch to export discovered equations to PyTorch. Parameters in the module are set to the default as found by PySR, but are trainable.
This essentially allows one to do the deep learning -> symbolic regression technique from https://github.com/MilesCranmer/symbolic_deep_learning, but plug the discovered equation back into the network, without any manual work.
This is the same format as the JAX export - it expects a matrix
X
as input with columns corresponding to symbols, except it outputs a PyTorch module with embedded parameters instead of a (function, parameters) tuple.@patrick-kidger, what do you think of this way of doing things? The modifications from your package include: expects matrix X as input rather than different vector for each symbol (to unify with the JAX backend) and only one expression per module (was unclear how to merge multiple expressions which the user may or may not use).