Skip to content

Commit

Permalink
Allow user to pass extra torch operators
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed May 31, 2021
1 parent 4697288 commit 9269ce9
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pysr/export_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def forward(self, X):
return self._node(symbols)


def sympy2torch(expression, symbols_in):
def sympy2torch(expression, symbols_in, extra_torch_mappings=None):
"""Returns a module for a given sympy expression with trainable parameters;
This function will assume the input to the module is a matrix X, where
Expand All @@ -170,4 +170,4 @@ def sympy2torch(expression, symbols_in):

_initialize_torch()

return SingleSymPyModule(expression, symbols_in)
return SingleSymPyModule(expression, symbols_in, extra_funcs=extra_torch_mappings)

0 comments on commit 9269ce9

Please sign in to comment.