Skip to content

Commit

Permalink
Merge 59d9435 into 6443e24
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Mar 23, 2024
2 parents 6443e24 + 59d9435 commit ae9548a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 15 deletions.
34 changes: 19 additions & 15 deletions pysr/julia_extensions.py
@@ -1,10 +1,16 @@
"""This file installs and loads extensions for SymbolicRegression."""

from .julia_import import jl
from typing import Optional

from .julia_import import Pkg, jl


def load_required_packages(
*, turbo=False, bumper=False, enable_autodiff=False, cluster_manager=None
*,
turbo: bool = False,
bumper: bool = False,
enable_autodiff: bool = False,
cluster_manager: Optional[str] = None,
):
if turbo:
load_package("LoopVectorization", "bdcacae8-1622-11e9-2a5c-532679323890")
Expand All @@ -16,17 +22,15 @@ def load_required_packages(
load_package("ClusterManagers", "34f1f09b-3a8b-5176-ab39-66d58a4d544e")


def load_package(package_name, uuid):
jl.seval(
f"""
try
using {package_name}
catch e
isa(e, ArgumentError) || throw(e)
using Pkg: Pkg
Pkg.add(name="{package_name}", uuid="{uuid}")
using {package_name}
end
"""
)
def isinstalled(uuid_s: str) -> bool:
return jl.haskey(Pkg.dependencies(), jl.Base.UUID(uuid_s))


def load_package(package_name: str, uuid_s: str) -> None:
if not isinstalled(uuid_s):
Pkg.add(name=package_name, uuid=uuid_s)

# TODO: Protect against loading the same symbol from two packages,
# maybe with a @gensym here.
jl.seval(f"using {package_name}")
return None
3 changes: 3 additions & 0 deletions pysr/julia_import.py
Expand Up @@ -63,3 +63,6 @@

jl.seval("using SymbolicRegression")
SymbolicRegression = jl.SymbolicRegression

jl.seval("using Pkg: Pkg")
Pkg = jl.Pkg

0 comments on commit ae9548a

Please sign in to comment.