diff --git a/pysr/sr.py b/pysr/sr.py index 2ce25c10a..ce02e5d18 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -779,21 +779,21 @@ def get_params(self, deep=True): **{key: self.__getattribute__(key) for key in self.surface_parameters}, } - def get_best(self, row=None): + def get_best(self, index=None): """Get best equation using `model_selection`. - :param row: Optional. If you wish to select a particular equation + :param index: Optional. If you wish to select a particular equation from `self.equations`, give the row number here. This overrides the `model_selection` parameter. - :type row: int + :type index: int :returns: Dictionary representing the best expression found. :type: pd.Series """ if self.equations is None: raise ValueError("No equations have been generated yet.") - if row is not None: - return self.equations.iloc[row] + if index is not None: + return self.equations.iloc[index] if self.model_selection == "accuracy": if isinstance(self.equations, list): @@ -838,7 +838,7 @@ def refresh(self): # such as extra_sympy_mappings. self.equations = self.get_hof() - def predict(self, X, row=None): + def predict(self, X, index=None): """Predict y from input X using the equation chosen by `model_selection`. You may see what equation is used by printing this object. X should have the same @@ -846,60 +846,60 @@ def predict(self, X, row=None): :param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces). :type X: np.ndarray/pandas.DataFrame - :param row: Optional. If you want to predict an expression using a particular row of - `self.equations`, you may specify the row here. - :type row: int + :param index: Optional. If you want to predict an expression using a particular row of + `self.equations`, you may specify the index here. + :type index: int :returns: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs). :type: np.ndarray """ self.refresh() - best = self.get_best(row=row) + best = self.get_best(index=index) if self.multioutput: return np.stack([eq["lambda_format"](X) for eq in best], axis=1) return best["lambda_format"](X) - def sympy(self, row=None): + def sympy(self, index=None): """Return sympy representation of the equation(s) chosen by `model_selection`. - :param row: Optional. If you wish to select a particular equation - from `self.equations`, give the row number here. This overrides + :param index: Optional. If you wish to select a particular equation + from `self.equations`, give the index number here. This overrides the `model_selection` parameter. - :type row: int + :type index: int :returns: SymPy representation of the best expression. """ self.refresh() - best = self.get_best(row=row) + best = self.get_best(index=index) if self.multioutput: return [eq["sympy_format"] for eq in best] return best["sympy_format"] - def latex(self, row=None): + def latex(self, index=None): """Return latex representation of the equation(s) chosen by `model_selection`. - :param row: Optional. If you wish to select a particular equation - from `self.equations`, give the row number here. This overrides + :param index: Optional. If you wish to select a particular equation + from `self.equations`, give the index number here. This overrides the `model_selection` parameter. - :type row: int + :type index: int :returns: LaTeX expression as a string :type: str """ self.refresh() - sympy_representation = self.sympy(row=row) + sympy_representation = self.sympy(index=index) if self.multioutput: return [sympy.latex(s) for s in sympy_representation] return sympy.latex(sympy_representation) - def jax(self, row=None): + def jax(self, index=None): """Return jax representation of the equation(s) chosen by `model_selection`. Each equation (multiple given if there are multiple outputs) is a dictionary containing {"callable": func, "parameters": params}. To call `func`, pass func(X, params). This function is differentiable using `jax.grad`. - :param row: Optional. If you wish to select a particular equation - from `self.equations`, give the row number here. This overrides + :param index: Optional. If you wish to select a particular equation + from `self.equations`, give the index number here. This overrides the `model_selection` parameter. - :type row: int + :type index: int :returns: Dictionary of callable jax function in "callable" key, and jax array of parameters as "parameters" key. :type: dict @@ -912,12 +912,12 @@ def jax(self, row=None): ) self.set_params(output_jax_format=True) self.refresh() - best = self.get_best(row=row) + best = self.get_best(index=index) if self.multioutput: return [eq["jax_format"] for eq in best] return best["jax_format"] - def pytorch(self, row=None): + def pytorch(self, index=None): """Return pytorch representation of the equation(s) chosen by `model_selection`. Each equation (multiple given if there are multiple outputs) is a PyTorch module @@ -926,10 +926,10 @@ def pytorch(self, row=None): column ordering as trained with. - :param row: Optional. If you wish to select a particular equation + :param index: Optional. If you wish to select a particular equation from `self.equations`, give the row number here. This overrides the `model_selection` parameter. - :type row: int + :type index: int :returns: PyTorch module representing the expression. :type: torch.nn.Module """ @@ -941,7 +941,7 @@ def pytorch(self, row=None): ) self.set_params(output_torch_format=True) self.refresh() - best = self.get_best(row=row) + best = self.get_best(index=index) if self.multioutput: return [eq["torch_format"] for eq in best] return best["torch_format"]