diff --git a/.cspell.json b/.cspell.json index df8d82ea..0e244bdf 100644 --- a/.cspell.json +++ b/.cspell.json @@ -63,6 +63,7 @@ "lambdifies", "lambdify", "lambdifying", + "matplotlib", "numpy", "pylint", "qrules", @@ -79,10 +80,12 @@ "cmath", "codemirror", "commitlint", + "componentwise", "csqrt", "dotprint", "evaluatable", "expertsystem", + "filterwarnings", "getsource", "graphviz", "hasattr", @@ -100,8 +103,10 @@ "nbformat", "numpycode", "pandoc", + "phsp", "py's", "pygments", + "pyplot", "pythoncode", "redeboer", "rtfd", diff --git a/docs/reports/sympy/lambdify-speedup.ipynb b/docs/reports/sympy/lambdify-speedup.ipynb index 66374935..98f01e5f 100644 --- a/docs/reports/sympy/lambdify-speedup.ipynb +++ b/docs/reports/sympy/lambdify-speedup.ipynb @@ -24,13 +24,22 @@ }, "outputs": [], "source": [ + "import inspect\n", "import logging\n", + "import warnings\n", "\n", "import ampform\n", "import graphviz\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "import qrules as q\n", "import sympy as sp\n", - "from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff" + "from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff\n", + "from tensorwaves.data import generate_phsp\n", + "from tensorwaves.data.transform import HelicityTransformer\n", + "from tensorwaves.model import LambdifiedFunction, SympyModel\n", + "\n", + "logger = logging.getLogger()" ] }, { @@ -57,7 +66,6 @@ }, "outputs": [], "source": [ - "logger = logging.getLogger()\n", "logger.setLevel(logging.ERROR)" ] }, @@ -269,7 +277,6 @@ }, "outputs": [], "source": [ - "logger = logging.getLogger()\n", "logger.setLevel(logging.INFO)" ] }, @@ -301,9 +308,255 @@ "outputs": [], "source": [ "%%time\n", - "for name, expr in amplitudes.items():\n", - " logging.info(f\"Lambdifying {name}\")\n", - " sp.lambdify(free_symbols, expr.doit(), \"numpy\")" + "np_amplitudes = {}\n", + "for expr, symbol in amplitude_to_symbol.items():\n", + " logging.info(f\"Lambdifying {symbol.name}\")\n", + " np_expr = sp.lambdify(free_symbols, expr.doit(), \"numpy\")\n", + " np_amplitudes[symbol] = np_expr" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Recombining lambdified components" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "{ref}`Recall ` what amplitude module expressed in its amplitude components looks like:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "amplitude_expr" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We have to lambdify that top expression as well:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sorted_amplitude_symbols = sorted(np_amplitudes, key=lambda s: s.name)\n", + "np_amplitude_expr = sp.lambdify(\n", + " sorted_amplitude_symbols, amplitude_expr, \"numpy\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "source = inspect.getsource(np_amplitude_expr)\n", + "print(source)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now have a lambdified expression for the complete amplitude model, as well as lambdified expressions that are to be plugged in to its arguments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def componentwise_lambdified(*args):\n", + " \"\"\"Lambdified amplitude model, recombined from its amplitude components.\n", + "\n", + " .. warning:: Order of the ``args`` has to be the same as that\n", + " of the ``args`` of the lambdified amplitude components.\n", + " \"\"\"\n", + " amplitude_values = []\n", + " for amp_symbol in sorted_amplitude_symbols:\n", + " np_amplitude = np_amplitudes[amp_symbol]\n", + " values = np_amplitude(*args)\n", + " amplitude_values.append(values)\n", + " return np_amplitude_expr(*amplitude_values)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "### Test with data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Okay, so does all this work? Let's first generate a phase space sample with good-old {mod}`tensorwaves`. We can then use this sample as input to the component-wise lambdified function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "logger.setLevel(logging.ERROR)\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "sympy_model = SympyModel(\n", + " expression=model.expression,\n", + " parameters=model.parameter_defaults,\n", + ")\n", + "intensity = LambdifiedFunction(sympy_model, backend=\"jax\")\n", + "data_converter = HelicityTransformer(model.adapter)\n", + "phsp_sample = generate_phsp(10_000, model.adapter.reaction_info)\n", + "phsp_set = data_converter.transform(phsp_sample)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "plt.hist(phsp_set[\"m_12\"], bins=50, alpha=0.5, density=True)\n", + "plt.hist(\n", + " phsp_set[\"m_12\"],\n", + " bins=50,\n", + " alpha=0.5,\n", + " density=True,\n", + " weights=intensity(phsp_set),\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The arguments of the component-wise lambdified amplitude model should be covered by the entries in the phase space set and the provided parameter defaults:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "kinematic_variable_names = {key for key in phsp_set}\n", + "parameter_names = {symbol.name for symbol in model.parameter_defaults}\n", + "free_symbol_names = {symbol.name for symbol in free_symbols}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert free_symbol_names <= kinematic_variable_names ^ parameter_names" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "That allows us to sort the input arrays and parameter defaults so that they can be used as positional argument input to the component-wise lambdified amplitude model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "merged_par_var_values = {\n", + " symbol.name: value for symbol, value in model.parameter_defaults.items()\n", + "}\n", + "merged_par_var_values.update(phsp_set)\n", + "args_values = [merged_par_var_values[symbol.name] for symbol in free_symbols]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, here's the result of plugging that back into the component-wise lambdified expression:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "componentwise_result = componentwise_lambdified(*args_values)\n", + "componentwise_result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And it's indeed the same as that the intensity computed by {mod}`tensorwaves` (direct lambdify):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "tensorwaves_result = np.array(intensity(phsp_set))\n", + "mean_difference = (componentwise_result - tensorwaves_result).mean()\n", + "mean_difference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert mean_difference < 1e-9" ] } ], diff --git a/setup.cfg b/setup.cfg index cf7aa958..608f9a8e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,9 +17,10 @@ setup_requires = install_requires = ampform[viz] expertsystem - jax - jaxlib + matplotlib + pandas sympy + tensorwaves[jax] [options.extras_require] doc =