Skip to content

Commit

Permalink
Add call_fitted_method to Vset
Browse files Browse the repository at this point in the history
  • Loading branch information
jpdunc23 committed Apr 26, 2023
1 parent 22fcb92 commit c6e85d5
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions vflow/vset.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,19 @@ def evaluate(self, *args):
"""
return self._apply_func(*args)

def call_fitted_method(self, *args, method: str, with_uncertainty: bool=False, group_by: list=None):
if not self._fitted:
raise AttributeError('Please fit the Vset object before calling call_fitted_method.')
pred_dict = {}
for k, v in self.fitted_vfuncs.items():
if k != '__prev__':
assert hasattr(v, method), f'{v} does not have a "{method}" method.'
pred_dict[k] = getattr(v, method)
preds = self._apply_func(*args, out_dict=pred_dict)
if with_uncertainty:
return prediction_uncertainty(preds, group_by)
return preds

def __call__(self, *args, n_out: int = None, keys=None, **kwargs):
"""Call args using `_apply_func`, optionally seperating
output dictionary into `n_out` dictionaries with `keys`
Expand Down

0 comments on commit c6e85d5

Please sign in to comment.