-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Only Modified Base.py #21
Merged
ConnorStoneAstro
merged 2 commits into
Autostronomy:main
from
ranbir7:ranbir7/convert-docstrings-to-google-18
Feb 7, 2023
Merged
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,29 @@ class BaseOptimizer(object): | |
|
||
""" | ||
def __init__(self, model, initial_state = None, relative_tolerance = 1e-3, **kwargs): | ||
""" | ||
Initializes a new instance of the class. | ||
|
||
Args: | ||
model (object): An object representing the model. | ||
initial_state (Union[None, Tensor]): The initial state of the model. If `None`, the model's default | ||
initial state will be used. | ||
relative_tolerance (float): The relative tolerance for the optimization. | ||
**kwargs (dict): Additional keyword arguments. | ||
|
||
Attributes: | ||
model (object): An object representing the model. | ||
verbose (int): The verbosity level. | ||
current_state (Tensor): The current state of the model. | ||
max_iter (int): The maximum number of iterations. | ||
iteration (int): The current iteration number. | ||
save_steps (Union[None, int]): The frequency at which to save intermediate results. | ||
relative_tolerance (float): The relative tolerance for the optimization. | ||
lambda_history (List[float]): A list of the optimization steps. | ||
loss_history (List[float]): A list of the optimization losses. | ||
message (str): An informational message. | ||
""" | ||
|
||
self.model = model | ||
self.verbose = kwargs.get("verbose", 0) | ||
|
||
|
@@ -44,19 +67,54 @@ def __init__(self, model, initial_state = None, relative_tolerance = 1e-3, **kwa | |
self.message = "" | ||
|
||
def fit(self): | ||
""" | ||
Raises: | ||
NotImplementedError: Error is raised if this method is not implemented in a subclass of BaseOptimizer. | ||
""" | ||
raise NotImplementedError("Please use a subclass of BaseOptimizer for optimization") | ||
def step(self, current_state = None): | ||
""" Args: | ||
current_state (torch.Tensor, optional): Current state of the model parameters. Defaults to None. | ||
|
||
Raises: | ||
NotImplementedError: Error is raised if this method is not implemented in a subclass of BaseOptimizer. | ||
""" | ||
raise NotImplementedError("Please use a subclass of BaseOptimizer for optimization") | ||
|
||
def chi2min(self): | ||
""" | ||
Returns the minimum value of chi^2 loss in the loss history. | ||
|
||
Returns: | ||
float: Minimum value of chi^2 loss. | ||
""" | ||
return np.nanmin(self.loss_history) | ||
def res(self): | ||
""" Returns the value of lambda (regularization strength) at which minimum chi^2 loss was achieved. | ||
|
||
Returns: | ||
float: Value of lambda at which minimum chi^2 loss was achieved. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is right except the return type is ndarray, not float |
||
""" | ||
N = np.isfinite(self.loss_history) | ||
return np.array(self.lambda_history)[N][np.argmin(np.array(self.loss_history)[N])] | ||
|
||
def chi2contour(self, n_params, confidence = 0.682689492137): | ||
|
||
""" | ||
Calculates the chi^2 contour for the given number of parameters. | ||
|
||
Args: | ||
n_params (int): The number of parameters. | ||
confidence (float, optional): The confidence interval (default is 0.682689492137). | ||
|
||
Returns: | ||
float: The calculated chi^2 contour value. | ||
|
||
Raises: | ||
RuntimeError: If unable to compute the Chi^2 contour for the given number of parameters. | ||
|
||
""" | ||
def _f(x, nu): | ||
"""Helper function for calculating chi^2 contour.""" | ||
return (gammainc(nu/2, x/2) - confidence)**2 | ||
|
||
for method in ["L-BFGS-B", "Powell", "Nelder-Mead"]: | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type here is List[ndarray], not list[float]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I'll be resolving it right now
Thanks
So I have to do it for all codes in autoprof dir?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for now you can just change float -> ndarray in the line above. Eventually it would be great to do docstrings for every function, but that can be done in a future PR.