Skip to content

Commit

Permalink
feat: support lambdify to tensorflow (#245)
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer authored Mar 24, 2021
1 parent 10ba92f commit ea97355
Show file tree
Hide file tree
Showing 21 changed files with 51 additions and 30 deletions.
6 changes: 3 additions & 3 deletions reqs/3.6/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ jupyter-client==6.1.12
jupyter-console==6.3.0
jupyter-core==4.7.1
jupyter-packaging==0.7.12
jupyter-server==1.4.1
jupyter-server==1.5.0
jupyter-sphinx==0.3.1
jupyter==1.0.0
jupyterlab-code-formatter==1.4.5
Expand All @@ -94,7 +94,7 @@ jupyterlab==3.0.12
keras-preprocessing==1.1.2
kiwisolver==1.3.1
labels==20.1.0
lazy-object-proxy==1.5.2
lazy-object-proxy==1.6.0
livereload==2.6.3
llvmlite==0.36.0
markdown-it-py==0.6.2
Expand Down Expand Up @@ -126,7 +126,7 @@ packaging==20.9
pandas==1.1.5
pandocfilters==1.4.3
parso==0.8.1
particle==0.14.0
particle==0.14.1
pathspec==0.8.1
pep517==0.10.0
pep8-naming==0.11.1
Expand Down
2 changes: 1 addition & 1 deletion reqs/3.6/requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ packaging==20.9
pandas==1.1.5
pandocfilters==1.4.3
parso==0.8.1
particle==0.14.0
particle==0.14.1
pexpect==4.8.0
phasespace==1.2.0
pickleshare==0.7.5
Expand Down
2 changes: 1 addition & 1 deletion reqs/3.6/requirements-extras.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ numba==0.53.0
numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
particle==0.14.0
particle==0.14.1
phasespace==1.2.0
protobuf==3.15.6
pyasn1-modules==0.2.8
Expand Down
4 changes: 2 additions & 2 deletions reqs/3.6/requirements-sty.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jaxlib==0.1.64
jsonschema==3.2.0
jupyter-core==4.7.1
keras-preprocessing==1.1.2
lazy-object-proxy==1.5.2
lazy-object-proxy==1.6.0
llvmlite==0.36.0
markdown==3.3.4
mccabe==0.6.1
Expand All @@ -69,7 +69,7 @@ numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
packaging==20.9
particle==0.14.0
particle==0.14.1
pathspec==0.8.1
pep8-naming==0.11.1
phasespace==1.2.0
Expand Down
2 changes: 1 addition & 1 deletion reqs/3.6/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
packaging==20.9
particle==0.14.0
particle==0.14.1
phasespace==1.2.0
pluggy==0.13.1
protobuf==3.15.6
Expand Down
2 changes: 1 addition & 1 deletion reqs/3.6/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ mpmath==1.2.1
numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
particle==0.14.0
particle==0.14.1
phasespace==1.2.0
protobuf==3.15.6
pyasn1-modules==0.2.8
Expand Down
6 changes: 3 additions & 3 deletions reqs/3.7/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ jupyter-client==6.1.12
jupyter-console==6.3.0
jupyter-core==4.7.1
jupyter-packaging==0.7.12
jupyter-server==1.4.1
jupyter-server==1.5.0
jupyter-sphinx==0.3.1
jupyter==1.0.0
jupyterlab-code-formatter==1.4.5
Expand All @@ -90,7 +90,7 @@ jupyterlab==3.0.12
keras-preprocessing==1.1.2
kiwisolver==1.3.1
labels==20.1.0
lazy-object-proxy==1.5.2
lazy-object-proxy==1.6.0
livereload==2.6.3
llvmlite==0.36.0
markdown-it-py==0.6.2
Expand Down Expand Up @@ -122,7 +122,7 @@ packaging==20.9
pandas==1.2.3
pandocfilters==1.4.3
parso==0.8.1
particle==0.14.0
particle==0.14.1
pathspec==0.8.1
pep517==0.10.0
pep8-naming==0.11.1
Expand Down
2 changes: 1 addition & 1 deletion reqs/3.7/requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ packaging==20.9
pandas==1.2.3
pandocfilters==1.4.3
parso==0.8.1
particle==0.14.0
particle==0.14.1
pexpect==4.8.0
phasespace==1.2.0
pickleshare==0.7.5
Expand Down
2 changes: 1 addition & 1 deletion reqs/3.7/requirements-extras.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ numba==0.53.0
numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
particle==0.14.0
particle==0.14.1
phasespace==1.2.0
protobuf==3.15.6
pyasn1-modules==0.2.8
Expand Down
4 changes: 2 additions & 2 deletions reqs/3.7/requirements-sty.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jaxlib==0.1.64
jsonschema==3.2.0
jupyter-core==4.7.1
keras-preprocessing==1.1.2
lazy-object-proxy==1.5.2
lazy-object-proxy==1.6.0
llvmlite==0.36.0
markdown==3.3.4
mccabe==0.6.1
Expand All @@ -67,7 +67,7 @@ numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
packaging==20.9
particle==0.14.0
particle==0.14.1
pathspec==0.8.1
pep8-naming==0.11.1
phasespace==1.2.0
Expand Down
2 changes: 1 addition & 1 deletion reqs/3.7/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
packaging==20.9
particle==0.14.0
particle==0.14.1
phasespace==1.2.0
pluggy==0.13.1
protobuf==3.15.6
Expand Down
2 changes: 1 addition & 1 deletion reqs/3.7/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ mpmath==1.2.1
numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
particle==0.14.0
particle==0.14.1
phasespace==1.2.0
protobuf==3.15.6
pyasn1-modules==0.2.8
Expand Down
6 changes: 3 additions & 3 deletions reqs/3.8/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ jupyter-client==6.1.12
jupyter-console==6.3.0
jupyter-core==4.7.1
jupyter-packaging==0.7.12
jupyter-server==1.4.1
jupyter-server==1.5.0
jupyter-sphinx==0.3.1
jupyter==1.0.0
jupyterlab-code-formatter==1.4.5
Expand All @@ -90,7 +90,7 @@ jupyterlab==3.0.12
keras-preprocessing==1.1.2
kiwisolver==1.3.1
labels==20.1.0
lazy-object-proxy==1.5.2
lazy-object-proxy==1.6.0
livereload==2.6.3
llvmlite==0.36.0
markdown-it-py==0.6.2
Expand Down Expand Up @@ -122,7 +122,7 @@ packaging==20.9
pandas==1.2.3
pandocfilters==1.4.3
parso==0.8.1
particle==0.14.0
particle==0.14.1
pathspec==0.8.1
pep517==0.10.0
pep8-naming==0.11.1
Expand Down
2 changes: 1 addition & 1 deletion reqs/3.8/requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ packaging==20.9
pandas==1.2.3
pandocfilters==1.4.3
parso==0.8.1
particle==0.14.0
particle==0.14.1
pexpect==4.8.0
phasespace==1.2.0
pickleshare==0.7.5
Expand Down
2 changes: 1 addition & 1 deletion reqs/3.8/requirements-extras.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ numba==0.53.0
numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
particle==0.14.0
particle==0.14.1
phasespace==1.2.0
protobuf==3.15.6
pyasn1-modules==0.2.8
Expand Down
4 changes: 2 additions & 2 deletions reqs/3.8/requirements-sty.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jaxlib==0.1.64
jsonschema==3.2.0
jupyter-core==4.7.1
keras-preprocessing==1.1.2
lazy-object-proxy==1.5.2
lazy-object-proxy==1.6.0
llvmlite==0.36.0
markdown==3.3.4
mccabe==0.6.1
Expand All @@ -66,7 +66,7 @@ numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
packaging==20.9
particle==0.14.0
particle==0.14.1
pathspec==0.8.1
pep8-naming==0.11.1
phasespace==1.2.0
Expand Down
2 changes: 1 addition & 1 deletion reqs/3.8/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
packaging==20.9
particle==0.14.0
particle==0.14.1
phasespace==1.2.0
pluggy==0.13.1
protobuf==3.15.6
Expand Down
2 changes: 1 addition & 1 deletion reqs/3.8/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ mpmath==1.2.1
numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
particle==0.14.0
particle==0.14.1
phasespace==1.2.0
protobuf==3.15.6
pyasn1-modules==0.2.8
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ install_requires =
phasespace >= 1.2.0
PyYAML
sympy
tensorflow >= 2.0
tensorflow >= 2.4 # tensorflow.experimental.numpy
tqdm >= 4.24.0 # autonotebook
typing-extensions; python_version < "3.8.0"
packages = find:
Expand Down
23 changes: 22 additions & 1 deletion src/tensorwaves/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def get_backend_modules(
return (jnp, jsp.special)
if backend in {"numpy", "numba"}:
return np.__dict__
if backend in {"tensorflow", "tf"}:
# pylint: disable=import-error
import tensorflow.experimental.numpy as tnp # pyright: reportMissingImports=false

return tnp.__dict__

return backend

Expand Down Expand Up @@ -119,7 +124,7 @@ def __init__(

def lambdify(self, backend: Union[str, tuple, dict]) -> Callable:
"""Lambdify the model using `~sympy.utilities.lambdify.lambdify`."""
# pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel,too-many-return-statements
ordered_symbols = self.__argument_order

def jax_lambdify() -> Callable:
Expand Down Expand Up @@ -147,17 +152,33 @@ def numba_lambdify() -> Callable:
parallel=True,
)

def tensorflow_lambdify() -> Callable:
# pylint: disable=import-error
import tensorflow.experimental.numpy as tnp # pyright: reportMissingImports=false

return sp.lambdify(
ordered_symbols,
self.__expression,
modules=tnp,
)

backend_modules = get_backend_modules(backend)
if isinstance(backend, str):
if backend == "jax":
return jax_lambdify()
if backend == "numba":
return numba_lambdify()
if backend in {"tensorflow", "tf"}:
return tensorflow_lambdify()
if isinstance(backend, tuple):
if any("jax" in x.__name__ for x in backend):
return jax_lambdify()
if any("numba" in x.__name__ for x in backend):
return numba_lambdify()
if any("tensorflow" in x.__name__ for x in backend) or any(
"tf" in x.__name__ for x in backend
):
return tensorflow_lambdify()
return sp.lambdify(
ordered_symbols,
self.__expression,
Expand Down
2 changes: 1 addition & 1 deletion src/tensorwaves/optimizer/minuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def wrapped_function(pars: list) -> float:
update_parameters(pars)
parameters = parameter_handler.unflatten(flattened_parameters)
estimator_value = estimator(parameters)
progress_bar.set_postfix({"estimator": estimator_value})
progress_bar.set_postfix({"estimator": float(estimator_value)})
progress_bar.update()
logs = create_log(estimator_value, parameters)
self.__callback.on_function_call_end(n_function_calls, logs)
Expand Down

0 comments on commit ea97355

Please sign in to comment.