Skip to content

Commit

Permalink
DX: lint type hints with MyPy (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Apr 25, 2024
1 parent e3fc71f commit b5d3195
Show file tree
Hide file tree
Showing 16 changed files with 141 additions and 69 deletions.
1 change: 1 addition & 0 deletions .gitpod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ vscode:
- github.vscode-github-actions
- github.vscode-pull-request-github
- ms-python.python
- ms-python.mypy-type-checker
- ms-python.vscode-pylance
- ms-toolsai.vscode-jupyter-cell-tags
- ms-vscode.live-server
Expand Down
11 changes: 11 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ ci:
autoupdate_commit_msg: "MAINT: update pip constraints and pre-commit"
autoupdate_schedule: quarterly # already done by requirements-cron.yml
skip:
- mypy
- prettier
- pyright
- taplo
Expand Down Expand Up @@ -108,6 +109,16 @@ repos:
.*\.py
)$
- repo: local
hooks:
- id: mypy
name: mypy
entry: mypy
language: system
require_serial: true
types:
- python

- repo: https://github.com/ComPWA/mirrors-pyright
rev: v1.1.359
hooks:
Expand Down
2 changes: 1 addition & 1 deletion .vscode/extensions.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"github.vscode-github-actions",
"github.vscode-pull-request-github",
"ms-python.python",
"ms-python.mypy-type-checker",
"ms-python.vscode-pylance",
"ms-toolsai.vscode-jupyter-cell-tags",
"ms-vscode.live-server",
Expand All @@ -27,7 +28,6 @@
"ms-python.black-formatter",
"ms-python.flake8",
"ms-python.isort",
"ms-python.mypy-type-checker",
"ms-python.pylint",
"travisillig.vscode-json-stable-stringify",
"tyriar.sort-lines"
Expand Down
7 changes: 7 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"gitlens.telemetry.enabled": false,
"livePreview.defaultPreviewPath": "docs/_build/html",
"multiDiffEditor.experimental.enabled": true,
"mypy-type-checker.args": ["--config-file=${workspaceFolder}/pyproject.toml"],
"mypy-type-checker.importStrategy": "fromEnvironment",
"notebook.codeActionsOnSave": {
"notebook.source.organizeImports": "explicit"
},
Expand All @@ -54,6 +56,11 @@
"ruff.enable": true,
"ruff.importStrategy": "fromEnvironment",
"ruff.organizeImports": true,
"search.exclude": {
"**/tests/**/__init__.py": true,
".constraints/*.txt": true,
"typings/**": true
},
"telemetry.telemetryLevel": "off",
"yaml.schemas": {
"https://raw.githubusercontent.com/ComPWA/qrules/0.10.x/src/qrules/particle-validation.json": [
Expand Down
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
"show-inheritance": True,
}
autodoc_member_order = "bysource"
autodoc_type_aliases = {}
autodoc_typehints_format = "short"
autosectionlabel_prefix_document = True
autosectionlabel_maxdepth = 2
Expand Down
16 changes: 16 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ numba = [
]
sty = [
"ampform-dpd[types]",
"mypy",
"pre-commit >=1.4.0",
"ruff",
]
Expand Down Expand Up @@ -134,6 +135,21 @@ where = ["src"]
[tool.setuptools_scm]
write_to = "src/ampform_dpd/version.py"

[tool.mypy]
exclude = "_build"
show_error_codes = true
warn_unused_configs = true

[[tool.mypy.overrides]]
ignore_missing_imports = true
module = ["cloudpickle.*"]

[[tool.mypy.overrides]]
check_untyped_defs = true
disallow_incomplete_defs = false
disallow_untyped_defs = false
module = ["tests.*"]

[tool.pyright]
reportArgumentType = false
reportAssignmentType = false
Expand Down
58 changes: 30 additions & 28 deletions src/ampform_dpd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ThreeBodyDecay,
ThreeBodyDecayChain,
get_decay_product_ids,
get_particle,
)
from ampform_dpd.io import (
simplify_latex_rendering, # noqa: F401 # pyright:ignore[reportUnusedImport]
Expand All @@ -35,9 +36,9 @@ class AmplitudeModel:
intensity: sp.Expr = sp.S.One
amplitudes: dict[sp.Indexed, sp.Expr] = field(factory=dict)
variables: dict[sp.Symbol, sp.Expr] = field(factory=dict)
parameter_defaults: dict[sp.Symbol, float] = field(factory=dict)
masses: dict[sp.Symbol, float] = field(factory=dict)
invariants: dict[sp.Symbol, float] = field(factory=dict)
parameter_defaults: dict[sp.Symbol, float | complex] = field(factory=dict)
masses: dict[sp.Symbol, float | complex] = field(factory=dict)
invariants: dict[sp.Symbol, sp.Expr] = field(factory=dict)

@property
def full_expression(self) -> sp.Expr:
Expand Down Expand Up @@ -79,17 +80,20 @@ def formulate(
reference_subsystem: Literal[1, 2, 3] = 1,
cleanup_summations: bool = False,
) -> AmplitudeModel:
helicity_symbols = sp.symbols("lambda:4", rational=True)
helicity_symbols: tuple[sp.Symbol, sp.Symbol, sp.Symbol, sp.Symbol] = (
sp.symbols("lambda:4", rational=True)
)
allowed_helicities = {
symbol: create_spin_range(self.decay.states[i].spin)
symbol: create_spin_range(self.decay.states[i].spin) # type:ignore[index]
for i, symbol in enumerate(helicity_symbols)
}
amplitude_definitions = {}
angle_definitions = {}
parameter_defaults = {}
for args in product(*allowed_helicities.values()):
for sub_system in [1, 2, 3]:
chain_model = self.formulate_subsystem_amplitude(*args, sub_system)
args: tuple[sp.Rational, sp.Rational, sp.Rational, sp.Rational]
for args in product(*allowed_helicities.values()): # type:ignore[assignment]
for sub_system in (1, 2, 3):
chain_model = self.formulate_subsystem_amplitude(*args, sub_system) # type:ignore[arg-type]
amplitude_definitions.update(chain_model.amplitudes)
angle_definitions.update(chain_model.variables)
parameter_defaults.update(chain_model.parameter_defaults)
Expand All @@ -100,13 +104,13 @@ def formulate(
masses = create_mass_symbol_mapping(self.decay)
parameter_defaults.update(masses)
if cleanup_summations:
aligned_amp = aligned_amp.cleanup()
aligned_amp = aligned_amp.cleanup() # type:ignore[assignment]
intensity = PoolSum(
sp.Abs(aligned_amp) ** 2,
*allowed_helicities.items(),
)
if cleanup_summations:
intensity = intensity.cleanup()
intensity = intensity.cleanup() # type:ignore[assignment]
return AmplitudeModel(
decay=self.decay,
intensity=PoolSum(
Expand Down Expand Up @@ -165,8 +169,8 @@ def formulate_subsystem_amplitude( # noqa: PLR0914
interaction=chain.outgoing_ls,
typ="decay",
)
parameter_defaults[h_prod] = 1 + 0j
parameter_defaults[h_dec] = 1
parameter_defaults[h_prod] = 1 + 0j # type:ignore[index]
parameter_defaults[h_dec] = 1 # type:ignore[index]
sub_amp_expr = (
sp.KroneckerDelta(λ[0], λR - λ[k])
* (-1) ** (spin[k] - λ[k])
Expand All @@ -189,9 +193,8 @@ def formulate_subsystem_amplitude( # noqa: PLR0914
* (-1) ** (spin[j] - λ[j])
)
if not self.use_decay_helicity_couplings:
resonance_isobar = chain.decay.child1
sub_amp_expr *= _formulate_clebsch_gordan_factors(
resonance_isobar,
chain.decay_node,
helicities={
self.decay.final_state[i]: λ[i],
self.decay.final_state[j]: λ[j],
Expand Down Expand Up @@ -255,20 +258,23 @@ def formulate_aligned_amplitude(
(_λ2, create_spin_range(j2)),
(_λ3, create_spin_range(j3)),
)
return amp_expr, wigner_generator.angle_definitions
return amp_expr, wigner_generator.angle_definitions # type:ignore[return-value]


def _create_coupling_symbol(
helicity_coupling: bool,
resonance: Str,
helicities: tuple[sp.Basic, sp.Basic],
interaction: LSCoupling,
interaction: LSCoupling | None,
typ: Literal["production", "decay"],
) -> sp.Indexed:
H = _get_coupling_base(helicity_coupling, typ)
if helicity_coupling:
λi, λj = helicities
return H[resonance, λi, λj]
if interaction is None:
msg = "Cannot formulate LS-coupling without LS combinations"
raise ValueError(msg)
return H[resonance, interaction.L, interaction.S]


Expand Down Expand Up @@ -314,15 +320,9 @@ def _formulate_clebsch_gordan_factors(
return sqrt_factor * cg_ll * cg_ss


def get_particle(isobar: IsobarNode | Particle) -> Particle:
if isinstance(isobar, IsobarNode):
return isobar.parent
return isobar


@lru_cache(maxsize=None)
def _generate_amplitude_index_bases() -> dict[Literal[1, 2, 3], sp.IndexedBase]:
return dict(enumerate(sp.symbols(R"A^(1:4)", cls=sp.IndexedBase), 1))
return dict(enumerate(sp.symbols(R"A^(1:4)", cls=sp.IndexedBase), 1)) # type:ignore[arg-type]


class _AlignmentWignerGenerator:
Expand All @@ -333,8 +333,8 @@ def __init__(self, reference_subsystem: Literal[1, 2, 3] = 1) -> None:
def __call__(
self,
j: sp.Rational,
m: sp.Rational,
m_prime: sp.Rational,
m: sp.Rational | sp.Symbol,
m_prime: sp.Rational | sp.Symbol,
rotated_state: int,
aligned_subsystem: int,
) -> sp.Rational | WignerD:
Expand Down Expand Up @@ -380,16 +380,18 @@ def decay(self) -> ThreeBodyDecay:
class DynamicsBuilder(Protocol):
def __call__(
self, decay_chain: ThreeBodyDecayChain
) -> tuple[sp.Expr, dict[sp.Symbol, float]]: ...
) -> tuple[sp.Expr, dict[sp.Symbol, float | complex]]: ...


def formulate_non_resonant(
decay_chain: ThreeBodyDecayChain,
) -> tuple[sp.Expr, dict[sp.Symbol, float]]:
) -> tuple[sp.Expr, dict[sp.Symbol, float | complex]]:
return sp.Rational(1), {}


def create_mass_symbol_mapping(decay: ThreeBodyDecay) -> dict[sp.Symbol, float]:
def create_mass_symbol_mapping(
decay: ThreeBodyDecay,
) -> dict[sp.Symbol, float | complex]:
return {
sp.Symbol(f"m{i}", nonnegative=True): decay.states[i].mass
for i in sorted(decay.states) # ensure that dict keys are sorted by state ID
Expand Down
10 changes: 7 additions & 3 deletions src/ampform_dpd/_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from __future__ import annotations

from typing import TYPE_CHECKING, SupportsFloat
from typing import TYPE_CHECKING, Iterable, SupportsFloat

import sympy as sp

if TYPE_CHECKING:
from attrs import Attribute

from ampform_dpd.decay import LSCoupling
from ampform_dpd.decay import LSCoupling, ThreeBodyDecayChain


def assert_spin_value(instance, attribute: Attribute, value: sp.Rational) -> None:
Expand All @@ -18,7 +18,7 @@ def assert_spin_value(instance, attribute: Attribute, value: sp.Rational) -> Non
raise ValueError(msg)


def to_ls(obj: LSCoupling | tuple[int, SupportsFloat] | None) -> LSCoupling:
def to_ls(obj: LSCoupling | tuple[int, SupportsFloat] | None) -> LSCoupling | None:
from ampform_dpd.decay import LSCoupling # noqa: PLC0415

if obj is None:
Expand All @@ -32,5 +32,9 @@ def to_ls(obj: LSCoupling | tuple[int, SupportsFloat] | None) -> LSCoupling:
raise TypeError(msg)


def to_chains(obj: Iterable[ThreeBodyDecayChain]) -> tuple[ThreeBodyDecayChain, ...]:
return tuple(obj)


def to_rational(obj: SupportsFloat) -> sp.Rational:
return sp.Rational(obj)
14 changes: 7 additions & 7 deletions src/ampform_dpd/adapter/qrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def to_three_body_decay(
for i, idx in enumerate(sorted(some_transition.final_states), 1)
}
return ThreeBodyDecay(
states={0: initial_state, **final_states},
states={0: initial_state, **final_states}, # type:ignore[dict-item]
chains=tuple(sorted(to_decay_chain(t) for t in transitions)),
)

Expand Down Expand Up @@ -97,9 +97,9 @@ def _convert_edge(state: Any) -> Particle:
raise NotImplementedError(msg)
return Particle(
name=particle.name,
latex=particle.latex,
latex=particle.name if particle.latex is None else particle.latex,
spin=particle.spin,
parity=particle.parity,
parity=int(particle.parity), # type:ignore[arg-type]
mass=particle.mass,
width=particle.width,
)
Expand Down Expand Up @@ -131,11 +131,11 @@ def filter_min_ls(
min_transitions = []
for group in grouped_transitions.values():
transition, *_ = group
min_transition = FrozenTransition(
min_transition: FrozenTransition[EdgeType, NodeType] = FrozenTransition(
topology=transition.topology,
states=transition.states,
interactions={
i: min(t.interactions[i] for t in group)
i: min(t.interactions[i] for t in group) # type:ignore[type-var]
for i in transition.interactions
},
)
Expand All @@ -146,6 +146,6 @@ def filter_min_ls(
def load_particles() -> qrules.particle.ParticleCollection:
src_dir = Path(__file__).parent.parent
particle_database = qrules.load_default_particles()
additional_definitions = qrules.io.load(src_dir / "particle-definitions.yml")
particle_database.update(additional_definitions)
additional_definitions = qrules.io.load(src_dir / "particle-definitions.yml") # type:ignore[arg-type]
particle_database.update(additional_definitions) # type:ignore[arg-type]
return particle_database
Loading

0 comments on commit b5d3195

Please sign in to comment.