Skip to content

Commit

Permalink
FEAT: implement perform_cached_doit() (#333)
Browse files Browse the repository at this point in the history
* DOC: show expression unfolding with `perform_cached_doit`
* DX: activate `pyright` strict checking mode
* FIX: implement stable hashing for `EnergyDependentWidth`
* MAINT: fix indent `.flake8` config
* MAINT: sort options in `tox.ini` alphabetically
* MAINT: set `PYTHONHASHSEED=0` in Conda and tox environments
  • Loading branch information
redeboer committed Oct 18, 2022
1 parent 107e249 commit 8eb0cf1
Show file tree
Hide file tree
Showing 14 changed files with 327 additions and 83 deletions.
105 changes: 53 additions & 52 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,63 +1,64 @@
[flake8]
application-import-names =
ampform
ampform
filename =
./docs/*.py
./src/*.py
./tests/*.py
./docs/*.py
./src/*.py
./tests/*.py
exclude =
**/__pycache__
**/_build
*.pyi
/typings/**
**/__pycache__
**/_build
*.pyi
/typings/**
ignore =
# False positive with attribute docstrings
B018
# https://github.com/psf/black#slices
E203
# allowed by black
E231
# https://github.com/psf/black#line-length
E501
# should be possible to use {} in latex strings
FS003
# block quote ends without a blank line (black formatting)
RST201
# missing pygments
RST299
# unexpected indentation (related to google style docstring)
RST301
# false-positive error in math directive
RST307
# enforce type ignore with mypy error codes (combined --extend-select=TI100)
TI1
# https://github.com/psf/black#line-breaks--binary-operators
W503
# False positive with attribute docstrings
B018
# https://github.com/psf/black#slices
E203
# allowed by black
E231
# https://github.com/psf/black#line-length
E501
# should be possible to use {} in latex strings
FS003
# block quote ends without a blank line (black formatting)
RST201
# missing pygments
RST299
# unexpected indentation (related to google style docstring)
RST301
# false-positive error in math directive
RST307
# enforce type ignore with mypy error codes (combined --extend-select=TI100)
TI1
# https://github.com/psf/black#line-breaks--binary-operators
W503
extend-select =
TI100
TI100
per-file-ignores =
# unused imports for backward compatibility
src/ampform/dynamics/__init__.py:F401
# unused imports for backward compatibility
src/ampform/dynamics/__init__.py:F401
tests/sympy/test_caching.py:C408
radon-max-cc = 8
radon-no-assert = True
rst-roles =
attr
cite
class
doc
download
eq
file
func
meth
mod
pdg-review
ref
term
attr
cite
class
doc
download
eq
file
func
meth
mod
pdg-review
ref
term
rst-directives =
autolink-preface
automethod
deprecated
envvar
exception
seealso
autolink-preface
automethod
deprecated
envvar
exception
seealso
2 changes: 2 additions & 0 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
name: pytest
env:
PYTHONHASHSEED: "0"

on:
push:
Expand Down
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
],
"python.analysis.autoImportCompletions": false,
"python.analysis.diagnosticMode": "workspace",
"python.analysis.typeCheckingMode": "strict",
"python.formatting.provider": "black",
"python.linting.banditEnabled": false,
"python.linting.enabled": true,
Expand Down
2 changes: 1 addition & 1 deletion docs/_extend_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def _graphviz_to_image( # pylint: disable=too-many-arguments
options = {}
global _GRAPHVIZ_COUNTER # pylint: disable=global-statement
output_file = f"graphviz_{_GRAPHVIZ_COUNTER}"
_GRAPHVIZ_COUNTER += 1
_GRAPHVIZ_COUNTER += 1 # pyright: reportConstantRedefinition=false
graphviz.Source(dot).render(f"{_IMAGE_DIR}/{output_file}", format=format)
restructuredtext = "\n"
if label:
Expand Down
3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

import requests

# pyright: reportConstantRedefinition=false
# pyright: reportMissingImports=false
# pyright: reportUntypedBaseClass=false
# pyright: reportUntypedFunctionDecorator=false
from pybtex.database import Entry
from pybtex.plugin import register_plugin
from pybtex.richtext import Tag, Text
Expand Down
38 changes: 37 additions & 1 deletion docs/usage/amplitude.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,34 @@
" model = pickle.load(stream)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cached expression 'unfolding'\n",
"\n",
"Amplitude model expressions can be extremely large. AmpForm can formulate such expressions relatively fast, but {mod}`sympy` has to 'unfold' these expressions with {meth}`~sympy.core.basic.Basic.doit`, which can take a long time. AmpForm provides a function that can cache the 'unfolded' expression to disk, so that the expression unfolding runs faster upon the next run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ampform.sympy import perform_cached_doit\n",
"\n",
"full_expression = perform_cached_doit(model.expression)\n",
"sp.count_ops(full_expression)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"See {func}`.perform_cached_doit` for some tips on how to improve performance."
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -1046,8 +1074,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ dependencies:
- |
-c .constraints/py3.8.txt
-e .[dev]
variables:
PYTHONHASHSEED: 0
16 changes: 15 additions & 1 deletion pyrightconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,24 @@
"exclude": [".git", ".tox", "docs/_build", "docs/adr"],
"include": ["docs", "src", "tests"],
"reportGeneralTypeIssues": false,
"reportIncompatibleMethodOverride": false,
"reportMissingParameterType": false,
"reportMissingTypeArgument": false,
"reportMissingTypeStubs": false,
"reportOverlappingOverload": false,
"reportPrivateImportUsage": false,
"reportPrivateUsage": false,
"reportUnboundVariable": false,
"reportUnknownArgumentType": false,
"reportUnknownMemberType": false,
"reportUnknownParameterType": false,
"reportUnknownVariableType": false,
"reportUnnecessaryComparison": false,
"reportUnnecessaryContains": false,
"reportUnnecessaryIsInstance": false,
"reportUnusedClass": true,
"reportUnusedFunction": true,
"reportUnusedImport": true,
"reportUnusedVariable": true
"reportUnusedVariable": true,
"typeCheckingMode": "strict"
}
3 changes: 2 additions & 1 deletion src/ampform/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def __getnewargs_ex__(self) -> tuple[tuple, dict]:

def _hashable_content(self) -> tuple:
# https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165
return (*self.args, self.phsp_factor, self._name)
# phsp_factor is converted to string because of unstable hash for classes
return (*super()._hashable_content(), str(self.phsp_factor))

def evaluate(self) -> sp.Expr:
s, mass0, gamma0, m_a, m_b, angular_momentum, meson_radius = self.args
Expand Down
2 changes: 1 addition & 1 deletion src/ampform/helicity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def unfold_poolsums(expr: sp.Expr) -> sp.Expr:

intensity = self.intensity.evaluate()
intensity = unfold_poolsums(intensity)
return intensity.subs(self.amplitudes)
return intensity.xreplace(self.amplitudes)

def rename_symbols( # noqa: R701
self, renames: Iterable[tuple[str, str]] | Mapping[str, str]
Expand Down
84 changes: 84 additions & 0 deletions src/ampform/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,23 @@
from __future__ import annotations

import functools
import hashlib
import itertools
import logging
import os
import pickle
from abc import abstractmethod
from os.path import abspath, dirname, expanduser
from textwrap import dedent
from typing import Callable, Iterable, Sequence, SupportsFloat, TypeVar

import sympy as sp
from sympy.printing.latex import LatexPrinter
from sympy.printing.numpy import NumPyPrinter
from sympy.printing.precedence import PRECEDENCE

_LOGGER = logging.getLogger(__name__)


class UnevaluatedExpression(sp.Expr):
"""Base class for expression classes with an :meth:`evaluate` method.
Expand Down Expand Up @@ -93,6 +101,11 @@ def __getnewargs_ex__(self) -> tuple[tuple, dict]:
kwargs = {"name": self._name}
return args, kwargs

def _hashable_content(self) -> tuple:
# https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165
# name is converted to string because unstable hash for None
return (*super()._hashable_content(), str(self._name))

@abstractmethod
def evaluate(self) -> sp.Expr:
"""Evaluate and 'unfold' this `UnevaluatedExpression` by one level.
Expand Down Expand Up @@ -456,3 +469,74 @@ def _is_regular_series(values: Sequence[SupportsFloat]) -> bool:
if difference != 1.0:
return False
return True


def perform_cached_doit(
unevaluated_expr: sp.Expr, directory: str | None = None
) -> sp.Expr:
"""Perform :meth:`~sympy.core.basic.Basic.doit` cache the result to disk.
The cached result is fetched from disk if the hash of the original expression is the
same as the hash embedded in the filename.
Args:
unevaluated_expr: A `sympy.Expr <sympy.core.expr.Expr>` on which to call
:meth:`~sympy.core.basic.Basic.doit`.
directory: The directory in which to cache the result. If `None`, the cache
directory will be put under the home directory.
.. tip:: For a faster cache, set `PYTHONHASHSEED
<https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED>`_ to a
fixed value.
"""
if directory is None:
home_directory = expanduser("~")
directory = abspath(f"{home_directory}/.sympy-cache")
h = get_readable_hash(unevaluated_expr)
filename = f"{directory}/{h}.pkl"
os.makedirs(dirname(filename), exist_ok=True)
if os.path.exists(filename):
with open(filename, "rb") as f:
return pickle.load(f)
_LOGGER.warning(
f"Cached expression file {filename} not found, performing doit()..."
)
unfolded_expr = unevaluated_expr.doit()
with open(filename, "wb") as f:
pickle.dump(unfolded_expr, f)
return unfolded_expr


def get_readable_hash(obj) -> str:
python_hash_seed = _get_python_hash_seed()
if python_hash_seed is not None:
return f"pythonhashseed-{python_hash_seed}{hash(obj):+}"
b = _to_bytes(obj)
return hashlib.sha256(b).hexdigest()


def _to_bytes(obj) -> bytes:
if isinstance(obj, sp.Expr):
# Using the str printer is slower and not necessarily unique,
# but pickle.dumps() does not always result in the same bytes stream.
_warn_about_unsafe_hash()
return str(obj).encode()
return pickle.dumps(obj)


def _get_python_hash_seed() -> int | None:
python_hash_seed = os.environ.get("PYTHONHASHSEED", "")
if python_hash_seed is not None and python_hash_seed.isdigit():
return int(python_hash_seed)
return None


@functools.lru_cache(maxsize=None) # warn once
def _warn_about_unsafe_hash():
message = """
PYTHONHASHSEED has not been set. For faster and safer hashing of SymPy expressions,
set the PYTHONHASHSEED environment variable to a fixed value and rerun the program.
See https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED
"""
message = dedent(message).replace("\n", " ").strip()
_LOGGER.warning(message)
1 change: 1 addition & 0 deletions tests/dynamics/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_breit_wigner_with_energy_dependent_width(

builder.form_factor = True
bw_with_ff, parameters = builder(particle, variable_set)
# pyright: reportConstantRedefinition=false
L = variable_set.angular_momentum # noqa: N806
form_factor = formulate_form_factor(
s, m1, m2, angular_momentum=L, meson_radius=d
Expand Down

0 comments on commit 8eb0cf1

Please sign in to comment.