Skip to content

Commit

Permalink
docs: recombine component-wise lambdified function
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed May 28, 2021
1 parent db93637 commit a62c544
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 8 deletions.
5 changes: 5 additions & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"lambdifies",
"lambdify",
"lambdifying",
"matplotlib",
"numpy",
"pylint",
"qrules",
Expand All @@ -79,10 +80,12 @@
"cmath",
"codemirror",
"commitlint",
"componentwise",
"csqrt",
"dotprint",
"evaluatable",
"expertsystem",
"filterwarnings",
"getsource",
"graphviz",
"hasattr",
Expand All @@ -100,8 +103,10 @@
"nbformat",
"numpycode",
"pandoc",
"phsp",
"py's",
"pygments",
"pyplot",
"pythoncode",
"redeboer",
"rtfd",
Expand Down
265 changes: 259 additions & 6 deletions docs/reports/sympy/lambdify-speedup.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,22 @@
},
"outputs": [],
"source": [
"import inspect\n",
"import logging\n",
"import warnings\n",
"\n",
"import ampform\n",
"import graphviz\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import qrules as q\n",
"import sympy as sp\n",
"from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff"
"from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff\n",
"from tensorwaves.data import generate_phsp\n",
"from tensorwaves.data.transform import HelicityTransformer\n",
"from tensorwaves.model import LambdifiedFunction, SympyModel\n",
"\n",
"logger = logging.getLogger()"
]
},
{
Expand All @@ -57,7 +66,6 @@
},
"outputs": [],
"source": [
"logger = logging.getLogger()\n",
"logger.setLevel(logging.ERROR)"
]
},
Expand Down Expand Up @@ -269,7 +277,6 @@
},
"outputs": [],
"source": [
"logger = logging.getLogger()\n",
"logger.setLevel(logging.INFO)"
]
},
Expand Down Expand Up @@ -301,9 +308,255 @@
"outputs": [],
"source": [
"%%time\n",
"for name, expr in amplitudes.items():\n",
" logging.info(f\"Lambdifying {name}\")\n",
" sp.lambdify(free_symbols, expr.doit(), \"numpy\")"
"np_amplitudes = {}\n",
"for expr, symbol in amplitude_to_symbol.items():\n",
" logging.info(f\"Lambdifying {symbol.name}\")\n",
" np_expr = sp.lambdify(free_symbols, expr.doit(), \"numpy\")\n",
" np_amplitudes[symbol] = np_expr"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Recombining lambdified components"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"{ref}`Recall <reports/sympy/lambdify-speedup:Structure of helicity model components>` what amplitude module expressed in its amplitude components looks like:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"amplitude_expr"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have to lambdify that top expression as well:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sorted_amplitude_symbols = sorted(np_amplitudes, key=lambda s: s.name)\n",
"np_amplitude_expr = sp.lambdify(\n",
" sorted_amplitude_symbols, amplitude_expr, \"numpy\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"source = inspect.getsource(np_amplitude_expr)\n",
"print(source)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now have a lambdified expression for the complete amplitude model, as well as lambdified expressions that are to be plugged in to its arguments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def componentwise_lambdified(*args):\n",
" \"\"\"Lambdified amplitude model, recombined from its amplitude components.\n",
"\n",
" .. warning:: Order of the ``args`` has to be the same as that\n",
" of the ``args`` of the lambdified amplitude components.\n",
" \"\"\"\n",
" amplitude_values = []\n",
" for amp_symbol in sorted_amplitude_symbols:\n",
" np_amplitude = np_amplitudes[amp_symbol]\n",
" values = np_amplitude(*args)\n",
" amplitude_values.append(values)\n",
" return np_amplitude_expr(*amplitude_values)"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"### Test with data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Okay, so does all this work? Let's first generate a phase space sample with good-old {mod}`tensorwaves`. We can then use this sample as input to the component-wise lambdified function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"logger.setLevel(logging.ERROR)\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"sympy_model = SympyModel(\n",
" expression=model.expression,\n",
" parameters=model.parameter_defaults,\n",
")\n",
"intensity = LambdifiedFunction(sympy_model, backend=\"jax\")\n",
"data_converter = HelicityTransformer(model.adapter)\n",
"phsp_sample = generate_phsp(10_000, model.adapter.reaction_info)\n",
"phsp_set = data_converter.transform(phsp_sample)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"plt.hist(phsp_set[\"m_12\"], bins=50, alpha=0.5, density=True)\n",
"plt.hist(\n",
" phsp_set[\"m_12\"],\n",
" bins=50,\n",
" alpha=0.5,\n",
" density=True,\n",
" weights=intensity(phsp_set),\n",
")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The arguments of the component-wise lambdified amplitude model should be covered by the entries in the phase space set and the provided parameter defaults:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"kinematic_variable_names = {key for key in phsp_set}\n",
"parameter_names = {symbol.name for symbol in model.parameter_defaults}\n",
"free_symbol_names = {symbol.name for symbol in free_symbols}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert free_symbol_names <= kinematic_variable_names ^ parameter_names"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That allows us to sort the input arrays and parameter defaults so that they can be used as positional argument input to the component-wise lambdified amplitude model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"merged_par_var_values = {\n",
" symbol.name: value for symbol, value in model.parameter_defaults.items()\n",
"}\n",
"merged_par_var_values.update(phsp_set)\n",
"args_values = [merged_par_var_values[symbol.name] for symbol in free_symbols]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, here's the result of plugging that back into the component-wise lambdified expression:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"componentwise_result = componentwise_lambdified(*args_values)\n",
"componentwise_result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And it's indeed the same as that the intensity computed by {mod}`tensorwaves` (direct lambdify):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"tensorwaves_result = np.array(intensity(phsp_set))\n",
"mean_difference = (componentwise_result - tensorwaves_result).mean()\n",
"mean_difference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert mean_difference < 1e-9"
]
}
],
Expand Down
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ setup_requires =
install_requires =
ampform[viz]
expertsystem
jax
jaxlib
matplotlib
pandas
sympy
tensorwaves[jax]

[options.extras_require]
doc =
Expand Down

0 comments on commit a62c544

Please sign in to comment.