Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify construction of foldable expression with PEP 681 #364

Closed
5 of 11 tasks
redeboer opened this issue Nov 23, 2023 · 0 comments · Fixed by #365
Closed
5 of 11 tasks

Simplify construction of foldable expression with PEP 681 #364

redeboer opened this issue Nov 23, 2023 · 0 comments · Fixed by #365
Assignees
Labels
✨ Feature New feature added to the package

Comments

@redeboer
Copy link
Member

redeboer commented Nov 23, 2023

Currently, AmpForm provides the following mechanism for defining 'foldable' expression classes (UnevaluatedExpression) that make it easier to read large expressions that consist of these folded definitions (see also {doc}/adr/002). (example taken from ComPWA/compwa.github.io#204, see also for example compwa.github.io/polarimetry/appendix/dynamics.html):

import sympy as sp
from ampform.sympy import UnevaluatedExpression, create_expression, implement_doit_method


@implement_doit_method
class ChewMandelstamm(UnevaluatedExpression):
    is_commutative = True
    is_real = False

    def __new__(cls, s, m1, m2, **hints) -> ChewMandelstamm:
        return create_expression(cls, s, m1, m2, *hints)

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        q = BreakupMomentum(s, m1, m2)
        return (
            1
            / (16 * (sp.pi) ** 2)
            * (
                (2 * q / sp.sqrt(s))
                * sp.log((m1**2 + m2**2 - s + 2 * sp.sqrt(s) * q) / (2 * m1 * m2))
                - (m1**2 - m2**2) * (1 / s - 1 / (m1 + m2) ** 2) * sp.log(m1 / m2)
            )
        )

    def _latex(self, printer, *args) -> str:
        s = printer._print(self.args[0])
        return Rf"\rho\left({s}\right)"


@implement_doit_method
class BreakupMomentum(UnevaluatedExpression):
    is_commutative = True

    def __new__(cls, s, m_a, m2, **hints) -> BreakupMomentum:
        return create_expression(cls, s, m_a, m2, **hints)

    def evaluate(self) -> sp.Expr:
        s, m_a, m2 = self.args
        return sp.sqrt((s - (m_a + m2) ** 2) * (s - (m_a - m2) ** 2) / (s * 4))

    def _latex(self, printer, *args) -> str:
        s = printer._print(self.args[0])
        return Rf"q\left({s}\right)"
from ampform.io import aslatex
from IPython.display import Math

s, m1, m2 = sp.symbols("s m1 m2")
expressions = [
    ChewMandelstamm(s, m1, m2),
    BreakupMomentum(s, m1, m2),
]
Math(aslatex({expr: expr.doit(deep=False) for expr in expressions}))

image

The main problem is that there is way to much boilerplate code here. Ideally, the above example should be rewritten into something like this, where unevaluated_expression behaves just like attrs.define():

from ampform.sympy import unevaluated_expression


@unevaluated_expression(commutative=True, real=True)
class ChewMandelstamm:
    s: sp.Basic
    m1: sp.Basic
    m2: sp.Basic

    def _implementation_(self) -> sp.Expr:
        s, m1, m2 = self.args
        q = BreakupMomentum(s, m1, m2)
        return (
            1
            / (16 * (sp.pi) ** 2)
            * (
                (2 * q / sp.sqrt(s))
                * sp.log((m1**2 + m2**2 - s + 2 * sp.sqrt(s) * q) / (2 * m1 * m2))
                - (m1**2 - m2**2) * (1 / s - 1 / (m1 + m2) ** 2) * sp.log(m1 / m2)
            )
        )

    def _latex_repr_(self) -> str:
        s = self.latex_args[0]
        return Rf"\rho\left({s}\right)"


@unevaluated_expression(commutative=True)
class BreakupMomentum:
    s: sp.Basic
    m1: sp.Basic
    m2: sp.Basic

    def _implementation_(self) -> sp.Expr:
        s, m_a, m2 = self.args
        return sp.sqrt((s - (m_a + m2) ** 2) * (s - (m_a - m2) ** 2) / (s * 4))

    def _latex_repr_(self) -> str:
        s = self.latex_args[0]
        return Rf"q\left({s}\right)"

Of course, even nicer would be something like this:

@unevaluated_expression(commutative=True)
class BreakupMomentum:
    s: sp.Basic
    m1: sp.Basic
    m2: sp.Basic
    _latex_repr_ = R"q\left({s}\right)"

    def _implementation_(self) -> sp.Expr:
        return sp.sqrt((s - (m_a + m2) ** 2) * (s - (m_a - m2) ** 2) / (s * 4))

but that would be hard, if not impossible, to implement.

We probably have to use typing.dataclass_transform, because we want the unevaluated_expression decorator to give the class 'dataclass'-like behavior (see PEP 681).

Tasks

  1. 🐛 Bug 🖱️ DX
    redeboer
@redeboer redeboer added the ✨ Feature New feature added to the package label Nov 23, 2023
@redeboer redeboer self-assigned this Nov 23, 2023
@redeboer redeboer changed the title Simplify construction of foldable expression with Simplify construction of foldable expression with PEP 681 Nov 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
✨ Feature New feature added to the package
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant