Skip to content

Commit

Permalink
Clean up more unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jun 8, 2021
1 parent 925fb38 commit 5af6354
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 11 deletions.
8 changes: 3 additions & 5 deletions pysr/export_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,12 @@ def fn_(*args):

torch_initialized = False
torch = None
_global_func_lookup = None
_Node = None
SingleSymPyModule = None


def _initialize_torch():
global torch_initialized
global torch
global _global_func_lookup
global _Node
global SingleSymPyModule

# Way to lazy load torch, only if this is called,
Expand Down Expand Up @@ -148,7 +144,7 @@ def forward(self, memodict):
args.append(arg_)
return self._torch_func(*args)

class SingleSymPyModule(torch.nn.Module):
class _SingleSymPyModule(torch.nn.Module):
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch"""

def __init__(
Expand Down Expand Up @@ -177,6 +173,8 @@ def forward(self, X):
symbols = {symbol: X[:, i] for i, symbol in enumerate(self.symbols_in)}
return self._node(symbols)

SingleSymPyModule = _SingleSymPyModule


def sympy2torch(expression, symbols_in, selection=None, extra_torch_mappings=None):
"""Returns a module for a given sympy expression with trainable parameters;
Expand Down
3 changes: 1 addition & 2 deletions pysr/feynman_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def mk_problems(first=100, gen=False, dp=500, data_dir=FEYNMAN_DATASET):
ret.append(p)
except Exception as e:
traceback.print_exc()
print(f"FAILED ON ROW {i}")
print(f"FAILED ON ROW {i} with {e}")
ind += 1
return ret

Expand Down Expand Up @@ -168,7 +168,6 @@ def do_feynman_experiments(
problems = FeynmanProblem.mk_problems(
first=first, gen=True, dp=dp, data_dir=data_dir
)
indx = range(len(problems))
ids = []
predictions = []
true_equations = []
Expand Down
1 change: 0 additions & 1 deletion pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ def pysr(
kwargs = {**dict(equation_file=equation_file), **kwargs}

pkg_directory = kwargs["pkg_directory"]
manifest_file = None
if kwargs["julia_project"] is not None:
manifest_filepath = Path(kwargs["julia_project"]) / "Manifest.toml"
else:
Expand Down
6 changes: 3 additions & 3 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def test_multioutput_weighted_with_callable(self):
w[w >= 0.5] = 1.0

# Double equation when weights are 0:
y += (1 - w) * y
y = (2 - w) * y
# Thus, pysr needs to use the weights to find the right equation!

equations = pysr(
pysr(
self.X,
y,
weights=w,
Expand Down Expand Up @@ -140,7 +140,7 @@ def test_feature_selection_handler(self):
X,
select_k_features=2,
use_custom_variable_names=True,
variable_names=[f"x{i}" for i in range(5)],
variable_names=var_names,
y=y,
)
self.assertTrue((2 in selection) and (3 in selection))
Expand Down

0 comments on commit 5af6354

Please sign in to comment.