-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use aggregators as a uniform incarnation for the across-instances computation for both instance and global metrics #890
Conversation
355a298
to
73d8ce8
Compare
7c48176
to
ee1dde6
Compare
9541e46
to
41d3f97
Compare
e36feee
to
dd2a491
Compare
src/unitxt/metrics.py
Outdated
class Aggregator(Artifact): | ||
@abstractmethod | ||
def aggregate_one_group_score_named( | ||
self, instances: List[Dict[str, Any]], score_names: List[str] | ||
) -> dict: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class Aggregator(Artifact): | |
@abstractmethod | |
def aggregate_one_group_score_named( | |
self, instances: List[Dict[str, Any]], score_names: List[str] | |
) -> dict: | |
pass | |
class Aggregator(Artifact): | |
""" | |
Aggregate list of instances to a dictionary of scores. | |
"""" | |
score_names: List[str] | |
@abstractmethod | |
def aggregate( | |
self, instances: List[Dict[str, Any]] | |
) -> Dict[str, Any]: | |
pass | |
def __call__(self, instances): | |
return self.aggregate(instances) |
src/unitxt/metrics.py
Outdated
class SimpleAggregator(Aggregator): | ||
aggregating_func: Callable[[List[Dict[str, Any]], str], float] | ||
|
||
def aggregate_one_group_score_named( | ||
self, instances: List[Dict[str, Any]], score_names: List[str] | ||
) -> dict: | ||
result = {} | ||
for score_name in score_names: | ||
result[score_name] = self.aggregating_func(instances, score_name) | ||
return result | ||
|
||
|
||
def average_item_scores(instances: List[dict], score_name: str) -> float: | ||
"""Calculate mean of a set of instance scores (given by score_name), omitting NaN values. | ||
|
||
Args: | ||
instances: list of dicts of each instance's instance scores. | ||
score_name: score field names to compute the mean for. | ||
""" | ||
return nan_mean( | ||
[instance["score"]["instance"][score_name] for instance in instances] | ||
) | ||
|
||
|
||
def max_item_scores(instances: List[dict], score_name: str) -> float: | ||
"""Calculate max of a set of instance scores (given by score_name), omitting NaN values. | ||
|
||
Args: | ||
instances: list of dicts of each instance's instance scores. | ||
score_name: score field names to compute the mean for. | ||
""" | ||
return nan_max( | ||
[instance["score"]["instance"][score_name] for instance in instances] | ||
) | ||
|
||
|
||
class AverageItemsAggregator(SimpleAggregator): | ||
aggregating_func = Field(default_factory=lambda: average_item_scores) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class SimpleAggregator(Aggregator): | |
aggregating_func: Callable[[List[Dict[str, Any]], str], float] | |
def aggregate_one_group_score_named( | |
self, instances: List[Dict[str, Any]], score_names: List[str] | |
) -> dict: | |
result = {} | |
for score_name in score_names: | |
result[score_name] = self.aggregating_func(instances, score_name) | |
return result | |
def average_item_scores(instances: List[dict], score_name: str) -> float: | |
"""Calculate mean of a set of instance scores (given by score_name), omitting NaN values. | |
Args: | |
instances: list of dicts of each instance's instance scores. | |
score_name: score field names to compute the mean for. | |
""" | |
return nan_mean( | |
[instance["score"]["instance"][score_name] for instance in instances] | |
) | |
def max_item_scores(instances: List[dict], score_name: str) -> float: | |
"""Calculate max of a set of instance scores (given by score_name), omitting NaN values. | |
Args: | |
instances: list of dicts of each instance's instance scores. | |
score_name: score field names to compute the mean for. | |
""" | |
return nan_max( | |
[instance["score"]["instance"][score_name] for instance in instances] | |
) | |
class AverageItemsAggregator(SimpleAggregator): | |
aggregating_func = Field(default_factory=lambda: average_item_scores) | |
class Mean(Aggregator): | |
def aggregate( | |
self, instances: List[Dict[str, Any]] | |
) -> dict: | |
result = {} | |
for score_name in self.score_names: | |
result[score_name] = nan_mean( | |
[instance["score"]["instance"][score_name] for instance in instances] | |
) | |
class Max(Aggregator): | |
def aggregate( | |
self, instances: List[Dict[str, Any]] | |
) -> dict: | |
result = {} | |
for score_name in self.score_names: | |
result[score_name] = nan_max( | |
[instance["score"]["instance"][score_name] for instance in instances] | |
) | |
class ConfidenceInterval(Aggregator): | |
aggregator: Aggregator = Mean() | |
def aggregate( | |
self, instances: List[Dict[str, Any]] | |
) -> dict: | |
results = [] | |
for sample in sample(instances): | |
result = self.aggregate(sample) | |
results.append(result) | |
result["ci_low"], results["ci_high"] = ci(results) | |
class Filter(Aggregator): | |
aggregator: Aggregator = Mean() | |
filter: Func | |
def aggregate( | |
self, instances: List[Dict[str, Any]] | |
) -> dict: | |
instances = [instance for instance in instances if self.filter(instance)] | |
return self.aggregate(instances) | |
class Group(Aggregator): | |
group_aggregator: Aggregator = Mean() | |
all_groups_aggregator: Aggregator = Mean() | |
results = {} | |
group_results =[] | |
for group, group_name in split_to_groups(instances): | |
result = self.group_aggregator(group) | |
group_results.append(result) | |
update_result(result, group_name) | |
results.update(self. all_groups_aggregator(group_results)) | |
return results | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class Mapper:
def map(instance: Dict[str, Any]): -> Dict[str, Any]
pass
def __call__(instances):
return [self.map(instance) for instance in instances]
class Metric:
aggregate: Aggregator = Group(ConfigenceInterval(Mean(fields=["f1"]), group_by="group_id")
# aggregator: Aggregator = Group(ConfigenceInterval(RougeAggregator()), group_by="group_id")
# map: Mapper = RougeInstanceScore()
def compute(instances):
instances = self.map(instances)
return self.aggregate(instances)
So every metric has one and only aggregator
src/unitxt/metrics.py
Outdated
ci_samples_from_groups_scores: bool = False | ||
|
||
# the basic aggregation along the instances: no split to groups, no filtering | ||
aggregator: Aggregator = Field(default_factory=lambda: AverageItemsAggregator()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aggregator: Aggregator = Field(default_factory=lambda: AverageItemsAggregator()) | |
aggregator: Aggregator = Mean() |
src/unitxt/metrics.py
Outdated
@@ -317,7 +380,7 @@ def score_based_confidence_interval( | |||
# if aggregation_func is None, we simply take the mean of the resampled instance scores | |||
# otherwise, the aggregation_func needs to be applied AFTER resampling the instances; | |||
# that is, re-form the groups, calculate the function, and take the mean of the group scores | |||
aggregation_func = self.average_item_scores | |||
aggregation_func = AverageItemsAggregator().aggregate_one_group_score_named |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aggregation_func = AverageItemsAggregator().aggregate_one_group_score_named | |
aggregation_func = self.aggregator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comments
56c3af3
to
76e76c1
Compare
af4cb71
to
4f279f1
Compare
edc5749
to
89b583f
Compare
…yping of prediction (type rather than string) Signed-off-by: dafnapension <dafnashein@yahoo.com>
89b583f
to
6bf93e6
Compare
Maybe return to this after the war ends |
too complicated to maintain rebase-able. So closing for now. |
Thus simplified and thereby extended grouping and filtering over to global metrics and to bulk-instance metrics