Skip to content

Commit

Permalink
bulkinstance too
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <dafnashein@yahoo.com>
  • Loading branch information
dafnapension committed May 6, 2024
1 parent 93eb629 commit 1267565
Showing 1 changed file with 61 additions and 66 deletions.
127 changes: 61 additions & 66 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,14 +528,14 @@ def score_groups_globally(

groups_global_scores = {}
for group_name, group in grouped_instances.items():
if isinstance(self, InstanceMetric):
if isinstance(self, (InstanceMetric, BulkInstanceMetric)):
groups_global_scores[group_name] = {}
for score_name in score_names:
if isinstance(group, list): # not split to control and comparison
groups_global_scores[group_name][score_name] = self.aggregating[
"aggregating_function"
](instances=group, score_name=score_name)
else:
else: # split to control and comparison
control_scores = [
instance["score"]["instance"][score_name]
for instance in group["control"]
Expand Down Expand Up @@ -565,10 +565,6 @@ def score_groups_globally(
predictions=predictions,
task_data=task_data,
)
elif isinstance(self, BulkInstanceMetric):
raise ValueError(
"What are you doing here? nowhere in BulkInstanceMetric is this method invoked"
)
else:
raise ValueError(
f"Unrecognized extension of MetricWithConfidence: {type(self)}"
Expand Down Expand Up @@ -786,31 +782,31 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
n_resamples: int = OptionalField(
default_factory=lambda: settings.num_resamples_for_instance_metrics
)
main_score: str
reduction_map: Dict[str, List[str]]
aggregating: dict = None
score_names: List[str] = None

def prepare(self):
if self.score_names is None:
self.score_names = [self.main_score]
if self.aggregating is None:
self.aggregating = {
"aggregating_function_name": "mean",
"aggregating_function": MetricWithConfidenceInterval.average_item_scores,
}

implemented_reductions: List[str] = field(default_factory=lambda: ["mean"])
super().prepare()
if self.main_score is None:
self.main_score = "f1"

def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
global_score = {}
instances = []

# consume the stream
references, predictions = map(
list,
zip(
*[
(instance["references"], instance["prediction"])
for instance in stream
]
),
predictions, references, task_data, instances = self.consume_stream(
stream=stream, task_data_field_name="task_data"
)

task_data = [
instance["task_data"] if "task_data" in instance else {}
for instance in stream
]
self._validate_references_and_prediction(references, predictions)

# compute the metric over all refs and preds
instance_scores = self.compute(
references=references,
Expand All @@ -823,45 +819,53 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato
instance_score["score"] = instance_score[self.main_score]
instance_score["score_name"] = self.main_score

for instance, score in zip(stream, instance_scores):
for instance, score in zip(instances, instance_scores):
if "score" not in instance:
instance["score"] = {"global": global_score, "instance": {}}
else:
global_score = instance["score"]["global"]

instance["score"]["instance"].update(score)

instances.append(instance)

for reduction, fields in self.reduction_map.items():
assert (
reduction in self.implemented_reductions
), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"

if reduction == "mean":
for field_name in fields:
global_score[field_name] = mean(
[
instance["score"]["instance"][field_name]
for instance in instances
]
)
if field_name == self.main_score:
global_score["score"] = global_score[field_name]
global_score["score_name"] = self.main_score

ci_fields = (
list(set(self.ci_scores))
if self.ci_scores is not None
else [self.main_score]
# groups also covers for non-grouped, where the whole stream is treated as a single group
groups_global_scores = self.score_groups_globally(
instances=instances, score_names=self.score_names
)
# no playing with field names here as in InstanceMetric, so we simply average over the groups (one or more)
for score_name in self.score_names:
if self.grouping is None:
# there is only one group here
global_score.update(
{score_name: groups_global_scores["all"][score_name]}
)
confidence_interval = self.score_based_confidence_interval(
instances=instances, score_names=ci_fields
else:
global_score.update(
{
score_name: nan_mean(
[
group_global_scores[score_name]
for group_global_scores in groups_global_scores.values()
if isinstance(groups_global_scores, dict)
]
)
}
)
global_score.update(confidence_interval)
if score_name == self.main_score:
global_score["score"] = global_score[self.main_score]
global_score["score_name"] = self.main_score

for instance in instances:
yield instance
ci_fields = (
list(set(self.ci_scores))
if self.ci_scores is not None
else [self.main_score]
)
# working non-grouped, and hence no variation on field names
confidence_interval = self.score_based_confidence_interval(
instances=instances, score_names=ci_fields
)
global_score.update(confidence_interval)

yield from instances

@abstractmethod
def compute(
Expand Down Expand Up @@ -1022,8 +1026,9 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato
{
to_score_name: nan_mean(
[
groups_global_scores[group_name][score_name]
for group_name in groups_global_scores.keys()
group_global_scores[score_name]
for group_global_scores in groups_global_scores.values()
if isinstance(group_global_scores, dict)
]
)
}
Expand Down Expand Up @@ -2076,7 +2081,7 @@ def _compute_single_ref(
class BertScore(HuggingfaceBulkMetric):
hf_metric_name = "bertscore"
main_score = "f1"
reduction_map = {"mean": ["f1", "precision", "recall"]}
score_names = ["f1", "precision", "recall"]
hf_metric_fields = ["f1", "precision", "recall"]
ci_scores = ["f1", "precision", "recall"]
model_name: str
Expand All @@ -2094,7 +2099,6 @@ def prepare(self):


class SentenceBert(BulkInstanceMetric):
reduction_map = {"mean": ["score"]}
main_score = "score"
batch_size: int = 32

Expand Down Expand Up @@ -2145,7 +2149,6 @@ def compute(


class Reward(BulkInstanceMetric):
reduction_map = {"mean": ["score"]}
main_score = "score"
batch_size: int = 32

Expand Down Expand Up @@ -2186,7 +2189,6 @@ def compute(


class Detector(BulkInstanceMetric):
reduction_map = {"mean": ["score"]}
main_score = "score"
batch_size: int = 32

Expand Down Expand Up @@ -2343,7 +2345,6 @@ class Perplexity(BulkInstanceMetric):
"""Computes the likelihood of generating text Y after text X - P(Y|X)."""

main_score = "perplexity"
reduction_map = {"mean": ["perplexity"]}
prediction_type = "str"

source_template: str
Expand Down Expand Up @@ -3525,20 +3526,14 @@ class BinaryAccuracy(InstanceMetric):
"aggregating_function_name": "mean",
"aggregating_function": MetricWithConfidenceInterval.average_item_scores,
}

def _validate_reference(self, reference):
super()._validate_reference(reference)
assert reference[0] in [
0,
1,
], f"all references of {self.main_score} must by 0 or 1"

def compute(
self, references: List[Any], prediction: Any, task_data: List[Dict]
) -> dict:
float_prediction = to_float_or_default(prediction)
prediction = str(int(float_prediction > self.threshold))
references = ["1"] if references[0].lower() in self.pos_classes else ["0"]

def compute(
self, references: List[float], prediction: float, task_data: List[Dict]
) -> dict:
Expand Down

0 comments on commit 1267565

Please sign in to comment.