Skip to content

Commit

Permalink
Merge pull request #21 from ranbir7/ranbir7/convert-docstrings-to-goo…
Browse files Browse the repository at this point in the history
…gle-18

Modified Base.py with improved google format docstrings
  • Loading branch information
ConnorStoneAstro committed Feb 7, 2023
2 parents ed36960 + e3dc1e8 commit 40901c4
Showing 1 changed file with 59 additions and 1 deletion.
60 changes: 59 additions & 1 deletion autoprof/fit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,29 @@ class BaseOptimizer(object):
"""
def __init__(self, model, initial_state = None, relative_tolerance = 1e-3, **kwargs):
"""
Initializes a new instance of the class.
Args:
model (object): An object representing the model.
initial_state (Union[None, Tensor]): The initial state of the model. If `None`, the model's default
initial state will be used.
relative_tolerance (float): The relative tolerance for the optimization.
**kwargs (dict): Additional keyword arguments.
Attributes:
model (object): An object representing the model.
verbose (int): The verbosity level.
current_state (Tensor): The current state of the model.
max_iter (int): The maximum number of iterations.
iteration (int): The current iteration number.
save_steps (Union[None, int]): The frequency at which to save intermediate results.
relative_tolerance (float): The relative tolerance for the optimization.
lambda_history (List[ndarray]): A list of the optimization steps.
loss_history (List[float]): A list of the optimization losses.
message (str): An informational message.
"""

self.model = model
self.verbose = kwargs.get("verbose", 0)

Expand All @@ -44,19 +67,54 @@ def __init__(self, model, initial_state = None, relative_tolerance = 1e-3, **kwa
self.message = ""

def fit(self):
"""
Raises:
NotImplementedError: Error is raised if this method is not implemented in a subclass of BaseOptimizer.
"""
raise NotImplementedError("Please use a subclass of BaseOptimizer for optimization")
def step(self, current_state = None):
""" Args:
current_state (torch.Tensor, optional): Current state of the model parameters. Defaults to None.
Raises:
NotImplementedError: Error is raised if this method is not implemented in a subclass of BaseOptimizer.
"""
raise NotImplementedError("Please use a subclass of BaseOptimizer for optimization")

def chi2min(self):
"""
Returns the minimum value of chi^2 loss in the loss history.
Returns:
float: Minimum value of chi^2 loss.
"""
return np.nanmin(self.loss_history)
def res(self):
""" Returns the value of lambda (regularization strength) at which minimum chi^2 loss was achieved.
Returns:
ndarray: Value of lambda at which minimum chi^2 loss was achieved.
"""
N = np.isfinite(self.loss_history)
return np.array(self.lambda_history)[N][np.argmin(np.array(self.loss_history)[N])]

def chi2contour(self, n_params, confidence = 0.682689492137):

"""
Calculates the chi^2 contour for the given number of parameters.
Args:
n_params (int): The number of parameters.
confidence (float, optional): The confidence interval (default is 0.682689492137).
Returns:
float: The calculated chi^2 contour value.
Raises:
RuntimeError: If unable to compute the Chi^2 contour for the given number of parameters.
"""
def _f(x, nu):
"""Helper function for calculating chi^2 contour."""
return (gammainc(nu/2, x/2) - confidence)**2

for method in ["L-BFGS-B", "Powell", "Nelder-Mead"]:
Expand Down

0 comments on commit 40901c4

Please sign in to comment.