Skip to content

Commit

Permalink
test fused and split through metric
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <dafnashein@yahoo.com>
  • Loading branch information
dafnapension committed May 5, 2024
1 parent 9c261d2 commit e479174
Showing 1 changed file with 135 additions and 0 deletions.
135 changes: 135 additions & 0 deletions tests/library/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from math import isnan
from typing import Dict, List

from unitxt.fusion import FixedFusion
from unitxt.inference import HFPipelineBasedInferenceEngine
from unitxt.llm_as_judge import LLMAsJudge
from unitxt.logging_utils import get_logger
Expand Down Expand Up @@ -47,6 +49,8 @@
TokenOverlap,
UnsortedListExactMatch,
)
from unitxt.operators import IterableSource, SplitByGroup
from unitxt.stream import MultiStream
from unitxt.test_utils.metrics import apply_metric

from tests.utils import UnitxtTestCase
Expand Down Expand Up @@ -693,6 +697,137 @@ def test_rouge(self):
global_target = 5 / 6
self.assertAlmostEqual(global_target, outputs[0]["score"]["global"]["score"])

def test_rougel_simple_avg_with_fuse_and_split(self):
from statistics import mean

class AvgRougeNoBootstrap(Rouge):
# no bootstrap whatsoever
# hf rouge, for use_aggregator = True:
# if use_aggregator:
# aggregator = scoring.BootstrapAggregator()
# and returns a bootstrapped version of the mean score
def prepare(self):
self.n_resamples = None
self.rouge_types = ["rougeL"]
self.use_aggregator = False
super().prepare()

def compute(self, references, predictions, task_data: List[Dict]):
res_list = super().compute(references, predictions, task_data)["rougeL"]
return {"rougeL": mean(res_list)}

metric = AvgRougeNoBootstrap()
references = [
["hello", "there"],
["general kenobi", "general yoda"],
["I sing", "singing in the rain"],
["As a cloud", "I wonder"],
["Tel Mond", "Aviv Tel"],
["no such zone", "return to sender"],
["my party it is", "I cry if I want to"],
["tell him right now", "I know something"],
]
predictions = [
"hello there",
"general kenobi",
"I am singing",
"I wandered",
"Tel Aviv",
"no such number",
"its my party",
"tell him",
]

self.assertEqual(len(references), len(predictions))
grand_input = [
{"references": reference, "prediction": prediction}
for (reference, prediction) in zip(references, predictions)
]
ms = MultiStream.from_iterables({"test": grand_input})
all_through_rouge = list(metric(ms)["test"])
self.assertDictEqual(
{
"global": {
"rougeL": 0.6214285714285714,
"score": 0.6214285714285714,
"score_name": "rougeL",
},
"instance": {
"rougeL": 0.6666666666666666,
"score": 0.6666666666666666,
"score_name": "rougeL",
},
},
all_through_rouge[0]["score"],
)

# now make the same input instances, a fusion of first half and second half of grand_input:
grand_input = [
{"references": reference, "prediction": prediction}
for (reference, prediction) in zip(references, predictions)
]

origin_h1 = grand_input[:4]
origin_h2 = grand_input[4:]
fixed_fusion = FixedFusion(
origins={
"originH1": IterableSource({"test": origin_h1}),
"originH2": IterableSource({"test": origin_h2}),
},
)
ms_fused_two_halves = fixed_fusion()
splitter = SplitByGroup(number_of_fusion_generations=1)
split_ms_fused_two_halves = splitter(ms_fused_two_halves)
self.assertSetEqual(
{"test_originH1", "test_originH2"}, set(split_ms_fused_two_halves.keys())
)
split_through_rouge = metric(split_ms_fused_two_halves)
self.assertDictEqual(
{
"global": {
"rougeL": 0.6416666666666666,
"score": 0.6416666666666666,
"score_name": "rougeL",
},
"instance": {
"rougeL": 0.6666666666666666,
"score": 0.6666666666666666,
"score_name": "rougeL",
},
},
split_through_rouge["test_originH1"].peek()["score"],
)
self.assertDictEqual(
{
"global": {
"rougeL": 0.6011904761904762,
"score": 0.6011904761904762,
"score_name": "rougeL",
},
"instance": {"rougeL": 0.5, "score": 0.5, "score_name": "rougeL"},
},
split_through_rouge["test_originH2"].peek()["score"],
)
self.assertEqual(
mean(
[
split_through_rouge["test_originH1"].peek()["score"]["global"][
"score"
],
split_through_rouge["test_originH2"].peek()["score"]["global"][
"score"
],
]
),
all_through_rouge[0]["score"]["global"]["score"],
)

# outputs = apply_metric(
# metric=metric, predictions=predictions, references=references
# )
# global_target = 5 / 6
# self.assertAlmostEqual(global_target, outputs[0]["score"]["global"]["score"])

def test_rouge_l(self):
metric = Rouge(
n_resamples=None, # disable confidence interval calculation which fails for this metric configuration
Expand Down

0 comments on commit e479174

Please sign in to comment.