Skip to content

Commit

Permalink
Merge a5eaab9 into df788e5
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Apr 28, 2024
2 parents df788e5 + a5eaab9 commit 3f38cb5
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 51 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ profile = "black"
dev-dependencies = [
"pre-commit>=3.7.0",
"ipython>=8.23.0",
"mypy>=1.10.0",
]
12 changes: 7 additions & 5 deletions pysr/export_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import sympy
from sympy import sympify

from .utils import ArrayLike

sympy_mappings = {
"div": lambda x, y: x / y,
"mult": lambda x, y: x * y,
Expand All @@ -30,8 +32,8 @@
"acosh": lambda x: sympy.acosh(x),
"acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
"asinh": sympy.asinh,
"atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
"atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - sympy.S(1)),
"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - sympy.S(1)),
"abs": abs,
"mod": sympy.Mod,
"erf": sympy.erf,
Expand Down Expand Up @@ -60,21 +62,21 @@


def create_sympy_symbols_map(
feature_names_in: List[str],
feature_names_in: ArrayLike[str],
) -> Dict[str, sympy.Symbol]:
return {variable: sympy.Symbol(variable) for variable in feature_names_in}


def create_sympy_symbols(
feature_names_in: List[str],
feature_names_in: ArrayLike[str],
) -> List[sympy.Symbol]:
return [sympy.Symbol(variable) for variable in feature_names_in]


def pysr2sympy(
equation: str,
*,
feature_names_in: Optional[List[str]] = None,
feature_names_in: Optional[ArrayLike[str]] = None,
extra_sympy_mappings: Optional[Dict[str, Callable]] = None,
):
if feature_names_in is None:
Expand Down
4 changes: 4 additions & 0 deletions pysr/julia_import.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import warnings
from typing import Any

# Check if JuliaCall is already loaded, and if so, warn the user
# about the relevant environment variables. If not loaded,
Expand Down Expand Up @@ -37,6 +38,9 @@

from juliacall import Main as jl # type: ignore

jl: Any = jl # type: ignore


jl_version = (jl.VERSION.major, jl.VERSION.minor, jl.VERSION.patch)

# Next, automatically load the juliacall extension if we're in a Jupyter notebook
Expand Down

0 comments on commit 3f38cb5

Please sign in to comment.