Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch export #48

Merged
merged 16 commits into from
May 31, 2021
Merged

PyTorch export #48

merged 16 commits into from
May 31, 2021

Conversation

MilesCranmer
Copy link
Owner

This uses (a slightly-modified version of) @patrick-kidger's sympytorch to export discovered equations to PyTorch. Parameters in the module are set to the default as found by PySR, but are trainable.

This essentially allows one to do the deep learning -> symbolic regression technique from https://github.com/MilesCranmer/symbolic_deep_learning, but plug the discovered equation back into the network, without any manual work.

This is the same format as the JAX export - it expects a matrix X as input with columns corresponding to symbols, except it outputs a PyTorch module with embedded parameters instead of a (function, parameters) tuple.

@patrick-kidger, what do you think of this way of doing things? The modifications from your package include: expects matrix X as input rather than different vector for each symbol (to unify with the JAX backend) and only one expression per module (was unclear how to merge multiple expressions which the user may or may not use).

@patrick-kidger
Copy link

Seems reasonable as is. Although it may be worth leaving this as a separate package as per my comments in #35? No strong feelings.

@MilesCranmer
Copy link
Owner Author

Thanks!

Yeah, longterm it will probably be best to have a separate modular package for generic SymPy -> trainable modules for numpy/torch/jax/etc. But I guess in the short-term, PySR exports are a bit implementation-specific (for consistency, want an array X as input, rather than inputting vectors for each symbol like sympy.lambdify does), and as a PySR user, I'd rather just have it do the exporting itself without an extra package. Thoughts?

@patrick-kidger
Copy link

patrick-kidger commented May 31, 2021

I'd note that if you want an array then the approach I'd take would be something like

class PySRSymPyModule(torch.nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self._module = sympytorch.SymPyModule(**kwargs)

    def forward(self, X):
        symbols = {f"x{i}": xi for i, xi in enumerate(X)}
        return self._module(**symbols)

to simply wrap existing functionality rather than duplicating it wholesale.

At the end of the day no strong feelings though.

@MilesCranmer
Copy link
Owner Author

You are right; this is probably a cleaner way to do it. I tried implementing this on sympytorch-export.

However, I realized an issue: there is no good way to deal with the install of the exporter libraries. There's two tricky scenarios:

  • User doesn't want torch(/jax) export. So sympy->torch(/jax) and torch(/jax) shouldn't be installed.
  • User does want torch(/jax) export. But they only install torch(/jax) and PySR, not realizing they need a third library for exporting.

I think unfortunately the only good way to deal with these is to put the sympy->torch code here (as in this pull request) along with the sympy->jax code, then do a lazy initialization only if the exporter code is called...

@patrick-kidger
Copy link

So I think that should be doable as well. Pseudocode:

def export(obj_to_export):
    try:
        import export_library
    except ImportError as e:
        raise ImportError("Please additionally install `export_library` to export to `deep_learning_library`.") from e
    return export_library.export(obj_to_export)

If you want you could additionally declare export_library as an optional dependency.

@MilesCranmer
Copy link
Owner Author

MilesCranmer commented May 31, 2021

That sounds like a great way to do it. So, can one get setup.py to install sympytorch if torch is already installed? (But not otherwise)?

@MilesCranmer
Copy link
Owner Author

Maybe this?

import setuptools
import sys

...

setuptools.setup(
    ....
    install_requires=( [
            "numpy",
            "pandas",
            "sympy"
            ] + (
            ['git+https://github.com/patrick-kidger/sympytorch']
            if 'torch' in sys.modules else []
            )
        ),
    ...
)

@MilesCranmer
Copy link
Owner Author

MilesCranmer commented May 31, 2021

By the way, are you planning on putting sympytorch on PyPI? I can't seem to get it working in setup.py without any wheel in your repo or on PyPI.

@MilesCranmer
Copy link
Owner Author

Argh, found some more problems that are specific to PySR. I'm just going to merge the implementation-specific version and try to set up an integration with a generic package later. Maybe for a jax and numpy exporter too. (Maybe all in one repo! Since sympy.lambdify doesn't let one get the parameters out)

Also, can the 3.8 requirement for Python be reduced in sympytorch or is it needed?

@MilesCranmer MilesCranmer merged commit 45b290b into master May 31, 2021
@MilesCranmer MilesCranmer deleted the torch-export branch May 31, 2021 20:05
@patrick-kidger
Copy link

Worth noting that this:

['git+https://github.com/patrick-kidger/sympytorch'] if 'torch' in sys.modules else []

won't work as that only checks if torch has been imported, not installed.

At your request, I've just put sympytorch on PyPI, and loosened the requirements to Python 3.7. Let me know if it all seems to work for you-- if it does I'll update the install instructions.

@MilesCranmer
Copy link
Owner Author

Thanks!! This seems to have fixed all the issues! I now have sympytorch working on the master branch for PyTorch exports 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants