/
__init__.py
91 lines (67 loc) · 2.26 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""Input-output functions for `ampform` and `sympy` objects.
.. tip:: This function are registered with :func:`functools.singledispatch` and can be
extended as follows:
>>> from ampform.io import aslatex
>>> @aslatex.register(int)
... def _(obj: int) -> str:
... return "my custom rendering"
>>> aslatex(1)
'my custom rendering'
>>> aslatex(3.4 - 2j)
'3.4-2i'
"""
from __future__ import annotations
from collections import abc
from functools import singledispatch
from typing import Iterable, Mapping
import sympy as sp
@singledispatch
def aslatex(obj) -> str:
"""Render objects as a LaTeX `str`.
The resulting `str` can for instance be given to `IPython.display.Math`.
.. versionadded:: 0.14.1
"""
return str(obj)
@aslatex.register(complex)
def _(obj: complex) -> str:
real = __downcast(obj.real)
imag = __downcast(obj.imag)
plus = "+" if imag >= 0 else ""
return f"{real}{plus}{imag}i"
def __downcast(obj: float) -> float | int:
if obj.is_integer():
return int(obj)
return obj
@aslatex.register(sp.Basic)
def _(obj: sp.Basic) -> str:
return sp.latex(obj)
@aslatex.register(abc.Mapping)
def _(obj: Mapping) -> str:
if len(obj) == 0:
msg = "Need at least one dictionary item"
raise ValueError(msg)
latex = R"\begin{array}{rcl}" + "\n"
for lhs, rhs in obj.items():
latex += Rf" {aslatex(lhs)} &=& {aslatex(rhs)} \\" + "\n"
latex += R"\end{array}"
return latex
@aslatex.register(abc.Iterable)
def _(obj: Iterable) -> str:
obj = list(obj)
if len(obj) == 0:
msg = "Need at least one item to render as LaTeX"
raise ValueError(msg)
latex = R"\begin{array}{c}" + "\n"
for item in map(aslatex, obj):
latex += Rf" {item} \\" + "\n"
latex += R"\end{array}"
return latex
def improve_latex_rendering() -> None:
"""Improve LaTeX rendering of an `~sympy.tensor.indexed.Indexed` object.
.. versionadded:: 0.14.2
"""
def _print_Indexed_latex(self, printer, *args): # noqa: N802
base = printer._print(self.base)
indices = ", ".join(map(printer._print, self.indices))
return f"{base}_{{{indices}}}"
sp.Indexed._latex = _print_Indexed_latex # type: ignore[attr-defined]