Skip to content

Commit

Permalink
Display output variable in table of expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jul 18, 2022
1 parent fab6f87 commit 3ef2b32
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 21 deletions.
14 changes: 12 additions & 2 deletions pysr/export_latex.py
Expand Up @@ -65,6 +65,7 @@ def generate_single_table(
precision: int = 3,
columns=["equation", "complexity", "loss", "score"],
max_equation_length: int = 50,
output_variable_name: str = "y",
):
"""Generate a booktabs-style LaTeX table for a single set of equations."""
assert isinstance(equations, pd.DataFrame)
Expand Down Expand Up @@ -96,7 +97,9 @@ def generate_single_table(
for col in columns:
if col == "equation":
if len(latex_equation) < max_equation_length:
row_pieces.append("$" + latex_equation + "$")
row_pieces.append(
"$" + output_variable_name + " = " + latex_equation + "$"
)
else:
if not raised_long_equation_warning:
warnings.warn(
Expand All @@ -109,7 +112,7 @@ def generate_single_table(
r"\begin{minipage}{0.8\linewidth}",
r"\vspace{-1em}",
r"\begin{dmath*}",
latex_equation,
output_variable_name + " = " + latex_equation,
r"\end{dmath*}",
r"\end{minipage}",
]
Expand Down Expand Up @@ -137,15 +140,22 @@ def generate_multiple_tables(
indices: List[List[int]] = None,
precision: int = 3,
columns=["equation", "complexity", "loss", "score"],
output_variable_names: str = None,
):
"""Generate multiple latex tables for a list of equation sets."""
# TODO: Let user specify custom output variable

latex_tables = [
generate_single_table(
equations[i],
(None if not indices else indices[i]),
precision=precision,
columns=columns,
output_variable_name=(
"y_{" + str(i) + "}"
if output_variable_names is None
else output_variable_names[i]
),
)
for i in range(len(equations))
]
Expand Down
38 changes: 19 additions & 19 deletions test/test.py
Expand Up @@ -553,9 +553,9 @@ def test_simple_table(self):
columns=["equation", "complexity", "loss"]
)
middle_part = r"""
$x_{0}$ & $1$ & $1.05$ \\
$\cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ \\
$x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
$y = x_{0}$ & $1$ & $1.05$ \\
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ \\
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
"""
true_latex_table_str = self.create_true_latex(middle_part)
self.assertEqual(latex_table_str, true_latex_table_str)
Expand All @@ -565,19 +565,19 @@ def test_other_precision(self):
precision=5, columns=["equation", "complexity", "loss"]
)
middle_part = r"""
$x_{0}$ & $1$ & $1.0520$ \\
$\cos{\left(x_{0} \right)}$ & $2$ & $0.023150$ \\
$x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.1235 \cdot 10^{-15}$ \\
$y = x_{0}$ & $1$ & $1.0520$ \\
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.023150$ \\
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.1235 \cdot 10^{-15}$ \\
"""
true_latex_table_str = self.create_true_latex(middle_part)
self.assertEqual(latex_table_str, true_latex_table_str)

def test_include_score(self):
latex_table_str = self.model.latex_table()
middle_part = r"""
$x_{0}$ & $1$ & $1.05$ & $0.0$ \\
$\cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
$x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
$y = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
"""
true_latex_table_str = self.create_true_latex(middle_part, include_score=True)
self.assertEqual(latex_table_str, true_latex_table_str)
Expand All @@ -587,7 +587,7 @@ def test_last_equation(self):
indices=[2], columns=["equation", "complexity", "loss"]
)
middle_part = r"""
$x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
"""
true_latex_table_str = self.create_true_latex(middle_part)
self.assertEqual(latex_table_str, true_latex_table_str)
Expand All @@ -610,14 +610,14 @@ def test_multi_output(self):
equations = [equations1, equations2]
model = manually_create_model(equations)
middle_part_1 = r"""
$x_{0}$ & $1$ & $1.05$ & $0.0$ \\
$\cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
$x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
$y_{0} = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
$y_{0} = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
$y_{0} = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
"""
middle_part_2 = r"""
$x_{1}$ & $1$ & $1.32$ & $0.0$ \\
$\cos{\left(x_{1} \right)}$ & $2$ & $0.0520$ & $3.23$ \\
$x_{0}^{2} x_{1}$ & $5$ & $2.00 \cdot 10^{-15}$ & $10.3$ \\
$y_{1} = x_{1}$ & $1$ & $1.32$ & $0.0$ \\
$y_{1} = \cos{\left(x_{1} \right)}$ & $2$ & $0.0520$ & $3.23$ \\
$y_{1} = x_{0}^{2} x_{1}$ & $5$ & $2.00 \cdot 10^{-15}$ & $10.3$ \\
"""
true_latex_table_str = "\n\n".join(
self.create_true_latex(part, include_score=True)
Expand Down Expand Up @@ -667,9 +667,9 @@ def test_latex_break_long_equation(self):
model = manually_create_model(equations)
latex_table_str = model.latex_table()
middle_part = r"""
$x_{0}$ & $1$ & $1.05$ & $0.0$ \\
$\cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
\begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} x_{0}^{5} + x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} - 5.20 \sin{\left(2.60 x_{0} - 0.326 \sin{\left(x_{2} \right)} \right)} - \cos{\left(x_{0} x_{1} \right)} + \cos{\left(x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} + \cos{\left(x_{0} x_{1} \right)} \right)} \end{dmath*} \end{minipage} & $30$ & $1.12 \cdot 10^{-15}$ & $1.09$ \\
$y = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
\begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} y = x_{0}^{5} + x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} - 5.20 \sin{\left(2.60 x_{0} - 0.326 \sin{\left(x_{2} \right)} \right)} - \cos{\left(x_{0} x_{1} \right)} + \cos{\left(x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} + \cos{\left(x_{0} x_{1} \right)} \right)} \end{dmath*} \end{minipage} & $30$ & $1.12 \cdot 10^{-15}$ & $1.09$ \\
"""
true_latex_table_str = self.create_true_latex(middle_part, include_score=True)
self.assertEqual(latex_table_str, true_latex_table_str)

0 comments on commit 3ef2b32

Please sign in to comment.