Skip to content

Commit

Permalink
Merge 34f4e3f into 8da5000
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Aug 5, 2022
2 parents 8da5000 + 34f4e3f commit 8f488fa
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -3,6 +3,7 @@
*.csv
*.csv.out*
*.bkup
*.pkl
performance*txt
*.out
trials*
Expand Down
194 changes: 182 additions & 12 deletions pysr/sr.py
@@ -1,3 +1,4 @@
import copy
import os
import sys
import numpy as np
Expand All @@ -8,6 +9,7 @@
import tempfile
import shutil
from pathlib import Path
import pickle as pkl
from datetime import datetime
import warnings
from multiprocessing import cpu_count
Expand Down Expand Up @@ -562,6 +564,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
equation_file_contents_ : list[pandas.DataFrame]
Contents of the equation file output by the Julia backend.
show_pickle_warnings_ : bool
Whether to show warnings about what attributes can be pickled.
Notes
-----
Most default parameters have been tuned over several example equations,
Expand Down Expand Up @@ -805,6 +810,119 @@ def __init__(
f"{k} is not a valid keyword argument for PySRRegressor."
)

@classmethod
def from_file(
cls,
equation_file,
*,
binary_operators=None,
unary_operators=None,
n_features_in=None,
feature_names_in=None,
selection_mask=None,
nout=1,
**pysr_kwargs,
):
"""
Create a model from a saved model checkpoint or equation file.
Parameters
----------
equation_file : str
Path to a pickle file containing a saved model, or a csv file
containing equations.
binary_operators : list[str]
The same binary operators used when creating the model.
Not needed if loading from a pickle file.
unary_operators : list[str]
The same unary operators used when creating the model.
Not needed if loading from a pickle file.
n_features_in : int
Number of features passed to the model.
Not needed if loading from a pickle file.
feature_names_in : list[str]
Names of the features passed to the model.
Not needed if loading from a pickle file.
selection_mask : list[bool]
If using select_k_features, you must pass `model.selection_mask_` here.
Not needed if loading from a pickle file.
nout : int, default=1
Number of outputs of the model.
Not needed if loading from a pickle file.
pysr_kwargs : dict
Any other keyword arguments to initialize the PySRRegressor object.
These will overwrite those stored in the pickle file.
Not needed if loading from a pickle file.
Returns
-------
model : PySRRegressor
The model with fitted equations.
"""
if os.path.splitext(equation_file)[1] != ".pkl":
pkl_filename = _csv_filename_to_pkl_filename(equation_file)
else:
pkl_filename = equation_file

# Try to load model from <equation_file>.pkl
print(f"Checking if {pkl_filename} exists...")
if os.path.exists(pkl_filename):
print(f"Loading model from {pkl_filename}")
assert binary_operators is None
assert unary_operators is None
assert n_features_in is None
with open(pkl_filename, "rb") as f:
model = pkl.load(f)
# Update any parameters if necessary, such as
# extra_sympy_mappings:
model.set_params(**pysr_kwargs)
if "equations_" not in model.__dict__ or model.equations_ is None:
model.refresh()

return model

# Else, we re-create it.
print(
f"{equation_file} does not exist, "
"so we must create the model from scratch."
)
assert binary_operators is not None
assert unary_operators is not None
assert n_features_in is not None

# TODO: copy .bkup file if exists.
model = cls(
equation_file=equation_file,
binary_operators=binary_operators,
unary_operators=unary_operators,
**pysr_kwargs,
)

model.nout_ = nout
model.n_features_in_ = n_features_in

if feature_names_in is None:
model.feature_names_in_ = [f"x{i}" for i in range(n_features_in)]
else:
assert len(feature_names_in) == n_features_in
model.feature_names_in_ = feature_names_in

if selection_mask is None:
model.selection_mask_ = np.ones(n_features_in, dtype=bool)
else:
model.selection_mask_ = selection_mask

model.refresh(checkpoint_file=equation_file)

return model

def __repr__(self):
"""
Prints all current equations fitted by the model.
Expand Down Expand Up @@ -873,17 +991,31 @@ def __getstate__(self):
from the pickled instance.
"""
state = self.__dict__
if "raw_julia_state_" in state:
show_pickle_warning = not (
"show_pickle_warnings_" in state and not state["show_pickle_warnings_"]
)
if "raw_julia_state_" in state and show_pickle_warning:
warnings.warn(
"raw_julia_state_ cannot be pickled and will be removed from the "
"serialized instance. This will prevent a `warm_start` fit of any "
"model that is deserialized via `pickle.load()`."
)
state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"]
for state_key in state_keys_containing_lambdas:
if state[state_key] is not None and show_pickle_warning:
warnings.warn(
f"`{state_key}` cannot be pickled and will be removed from the "
"serialized instance. When loading the model, please redefine "
f"`{state_key}` at runtime."
)
state_keys_to_clear = ["raw_julia_state_"] + state_keys_containing_lambdas
pickled_state = {
key: None if key == "raw_julia_state_" else value
key: (None if key in state_keys_to_clear else value)
for key, value in state.items()
}
if "equations_" in pickled_state:
if ("equations_" in pickled_state) and (
pickled_state["equations_"] is not None
):
pickled_state["output_torch_format"] = False
pickled_state["output_jax_format"] = False
if self.nout_ == 1:
Expand All @@ -906,6 +1038,16 @@ def __getstate__(self):
]
return pickled_state

def _checkpoint(self):
"""Saves the model's current state to a checkpoint file.
This should only be used internally by PySRRegressor."""
# Save model state:
self.show_pickle_warnings_ = False
with open(_csv_filename_to_pkl_filename(self.equation_file_), "wb") as f:
pkl.dump(self, f)
self.show_pickle_warnings_ = True

@property
def equations(self): # pragma: no cover
warnings.warn(
Expand Down Expand Up @@ -1606,8 +1748,20 @@ def fit(
y,
)

# Fitting procedure
return self._run(X, y, mutated_params, weights=weights, seed=seed)
# Initially, just save model parameters, so that
# it can be loaded from an early exit:
if not self.temp_equation_file:
self._checkpoint()

# Perform the search:
self._run(X, y, mutated_params, weights=weights, seed=seed)

# Then, after fit, we save again, so the pickle file contains
# the equations:
if not self.temp_equation_file:
self._checkpoint()

return self

def refresh(self, checkpoint_file=None):
"""
Expand All @@ -1619,10 +1773,10 @@ def refresh(self, checkpoint_file=None):
checkpoint_file : str, default=None
Path to checkpoint hall of fame file to be loaded.
"""
check_is_fitted(self, attributes=["equation_file_"])
if checkpoint_file:
self.equation_file_ = checkpoint_file
self.equation_file_contents_ = None
check_is_fitted(self, attributes=["equation_file_"])
self.equations_ = self.get_hof()

def predict(self, X, index=None):
Expand Down Expand Up @@ -1812,10 +1966,10 @@ def _read_equation_file(self):
if self.nout_ > 1:
all_outputs = []
for i in range(1, self.nout_ + 1):
df = pd.read_csv(
str(self.equation_file_) + f".out{i}" + ".bkup",
sep="|",
)
cur_filename = str(self.equation_file_) + f".out{i}" + ".bkup"
if not os.path.exists(cur_filename):
cur_filename = str(self.equation_file_) + f".out{i}"
df = pd.read_csv(cur_filename, sep="|")
# Rename Complexity column to complexity:
df.rename(
columns={
Expand All @@ -1828,7 +1982,10 @@ def _read_equation_file(self):

all_outputs.append(df)
else:
all_outputs = [pd.read_csv(str(self.equation_file_) + ".bkup", sep="|")]
filename = str(self.equation_file_) + ".bkup"
if not os.path.exists(filename):
filename = str(self.equation_file_)
all_outputs = [pd.read_csv(filename, sep="|")]
all_outputs[-1].rename(
columns={
"Complexity": "complexity",
Expand Down Expand Up @@ -1886,7 +2043,9 @@ def get_hof(self):

ret_outputs = []

for output in self.equation_file_contents_:
equation_file_contents = copy.deepcopy(self.equation_file_contents_)

for output in equation_file_contents:

scores = []
lastMSE = None
Expand Down Expand Up @@ -2035,3 +2194,14 @@ def run_feature_selection(X, y, select_k_features, random_state=None):
clf, threshold=-np.inf, max_features=select_k_features, prefit=True
)
return selector.get_support(indices=True)


def _csv_filename_to_pkl_filename(csv_filename) -> str:
# Assume that the csv filename is of the form "foo.csv"
dirname = str(os.path.dirname(csv_filename))
basename = str(os.path.basename(csv_filename))
base = str(os.path.splitext(basename)[0])

pkl_basename = base + ".pkl"

return os.path.join(dirname, pkl_basename)

0 comments on commit 8f488fa

Please sign in to comment.