Skip to content

Commit

Permalink
Merge 7185bbb into fe3d590
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Sep 8, 2022
2 parents fe3d590 + 7185bbb commit c7eaf13
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI_Windows.yml
Expand Up @@ -26,7 +26,7 @@ jobs:
shell: bash
strategy:
matrix:
julia-version: ['1.7.1']
julia-version: ['1.6', '1.7.1']
python-version: ['3.9']
os: [windows-2019]

Expand Down
90 changes: 68 additions & 22 deletions pysr/julia_helpers.py
@@ -1,42 +1,84 @@
"""Functions for initializing the Julia environment and installing deps."""
import sys
import subprocess
import warnings
from pathlib import Path
import os

from .version import __version__, __symbolic_regression_jl_version__

juliainfo = None
julia_initialized = False


def _load_juliainfo():
"""Execute julia.core.JuliaInfo.load(), and store as juliainfo."""
global juliainfo

if juliainfo is None:
from julia.core import JuliaInfo

try:
juliainfo = JuliaInfo.load(julia="julia")
except FileNotFoundError:
env_path = os.environ["PATH"]
raise FileNotFoundError(
f"Julia is not installed in your PATH. Please install Julia and add it to your PATH.\n\nCurrent PATH: {env_path}",
)

return juliainfo


def _get_julia_env_dir():
# Have to manually get env dir:
try:
julia_env_dir_str = subprocess.run(
["julia", "-e using Pkg; print(Pkg.envdir())"], capture_output=True
).stdout.decode()
except FileNotFoundError:
env_path = os.environ["PATH"]
raise FileNotFoundError(
f"Julia is not installed in your PATH. Please install Julia and add it to your PATH.\n\nCurrent PATH: {env_path}",
)
return Path(julia_env_dir_str)


def _set_julia_project_env(julia_project, is_shared):
if is_shared:
if is_julia_version_greater_eq(version=(1, 7, 0)):
os.environ["JULIA_PROJECT"] = "@" + str(julia_project)
else:
julia_env_dir = _get_julia_env_dir()
os.environ["JULIA_PROJECT"] = str(julia_env_dir / julia_project)
else:
os.environ["JULIA_PROJECT"] = str(julia_project)


def install(julia_project=None, quiet=False): # pragma: no cover
"""
Install PyCall.jl and all required dependencies for SymbolicRegression.jl.
Also updates the local Julia registry.
"""
# Set JULIA_PROJECT so that we install in the pysr environment
julia_project, is_shared = _get_julia_project(julia_project)
if is_shared:
os.environ["JULIA_PROJECT"] = "@" + str(julia_project)
else:
os.environ["JULIA_PROJECT"] = str(julia_project)

import julia

# Set JULIA_PROJECT so that we install in the pysr environment
julia_project, is_shared = _process_julia_project(julia_project)
_set_julia_project_env(julia_project, is_shared)

julia.install(quiet=quiet)

if is_shared:
# is_shared is only true if the julia_project arg was None
# See _get_julia_project
# See _process_julia_project
Main = init_julia(None)
else:
Main = init_julia(julia_project)

Main.eval("using Pkg")

io = "devnull" if quiet else "stderr"
io_arg = f"io={io}" if is_julia_version_greater_eq(Main, "1.6") else ""
io_arg = f"io={io}" if is_julia_version_greater_eq(version=(1, 6, 0)) else ""

# Can't pass IO to Julia call as it evaluates to PyObject, so just directly
# use Main.eval:
Expand All @@ -56,7 +98,7 @@ def install(julia_project=None, quiet=False): # pragma: no cover
)


def import_error_string(julia_project=None):
def _import_error_string(julia_project=None):
s = """
Required dependencies are not installed or built. Run the following code in the Python REPL:
Expand All @@ -71,7 +113,7 @@ def import_error_string(julia_project=None):
return s


def _get_julia_project(julia_project):
def _process_julia_project(julia_project):
if julia_project is None:
is_shared = True
julia_project = f"pysr-{__version__}"
Expand All @@ -81,12 +123,19 @@ def _get_julia_project(julia_project):
return julia_project, is_shared


def is_julia_version_greater_eq(Main, version="1.6"):
def is_julia_version_greater_eq(juliainfo=None, version=(1, 6, 0)):
"""Check if Julia version is greater than specified version."""
return Main.eval(f'VERSION >= v"{version}"')
if juliainfo is None:
juliainfo = _load_juliainfo()
current_version = (
juliainfo.version_major,
juliainfo.version_minor,
juliainfo.version_patch,
)
return current_version >= version


def check_for_conflicting_libraries(): # pragma: no cover
def _check_for_conflicting_libraries(): # pragma: no cover
"""Check whether there are conflicting modules, and display warnings."""
# See https://github.com/pytorch/pytorch/issues/78829: importing
# pytorch before running `pysr.fit` causes a segfault.
Expand All @@ -106,15 +155,12 @@ def init_julia(julia_project=None):
global julia_initialized

if not julia_initialized:
check_for_conflicting_libraries()
_check_for_conflicting_libraries()

from julia.core import JuliaInfo, UnsupportedPythonError

julia_project, is_shared = _get_julia_project(julia_project)
if is_shared:
os.environ["JULIA_PROJECT"] = "@" + str(julia_project)
else:
os.environ["JULIA_PROJECT"] = str(julia_project)
julia_project, is_shared = _process_julia_project(julia_project)
_set_julia_project_env(julia_project, is_shared)

try:
info = JuliaInfo.load(julia="julia")
Expand All @@ -125,7 +171,7 @@ def init_julia(julia_project=None):
)

if not info.is_pycall_built():
raise ImportError(import_error_string())
raise ImportError(_import_error_string())

Main = None
try:
Expand Down Expand Up @@ -160,7 +206,7 @@ def _add_sr_to_julia_project(Main, io_arg):


def _escape_filename(filename):
"""Turns a file into a string representation with correctly escaped backslashes"""
"""Turn a path into a string with correctly escaped backslashes."""
str_repr = str(filename)
str_repr = str_repr.replace("\\", "\\\\")
return str_repr
12 changes: 7 additions & 5 deletions pysr/sr.py
Expand Up @@ -23,11 +23,11 @@

from .julia_helpers import (
init_julia,
_get_julia_project,
_process_julia_project,
is_julia_version_greater_eq,
_escape_filename,
_add_sr_to_julia_project,
import_error_string,
_import_error_string,
)
from .export_numpy import CallableEquation
from .export_latex import generate_single_table, generate_multiple_tables, to_latex
Expand Down Expand Up @@ -1437,10 +1437,12 @@ def _run(self, X, y, mutated_params, weights, seed):
cluster_manager = Main.eval(f"addprocs_{cluster_manager}")

if not already_ran:
julia_project, is_shared = _get_julia_project(self.julia_project)
julia_project, is_shared = _process_julia_project(self.julia_project)
Main.eval("using Pkg")
io = "devnull" if update_verbosity == 0 else "stderr"
io_arg = f"io={io}" if is_julia_version_greater_eq(Main, "1.6") else ""
io_arg = (
f"io={io}" if is_julia_version_greater_eq(version=(1, 6, 0)) else ""
)

Main.eval(
f'Pkg.activate("{_escape_filename(julia_project)}", shared = Bool({int(is_shared)}), {io_arg})'
Expand All @@ -1453,7 +1455,7 @@ def _run(self, X, y, mutated_params, weights, seed):
_add_sr_to_julia_project(Main, io_arg)
Main.eval(f"Pkg.resolve({io_arg})")
except (JuliaError, RuntimeError) as e:
raise ImportError(import_error_string(julia_project)) from e
raise ImportError(_import_error_string(julia_project)) from e
Main.eval("using SymbolicRegression")

Main.plus = Main.eval("(+)")
Expand Down

0 comments on commit c7eaf13

Please sign in to comment.