Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RAG (response generation part) tasks and datasets #811

Merged
merged 16 commits into from
May 13, 2024
55 changes: 55 additions & 0 deletions prepare/cards/rag/response_generation/clapnq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from copy import deepcopy

from unitxt import add_to_catalog
from unitxt.blocks import (
LoadHF,
SplitRandomMix,
TaskCard,
TemplatesDict,
)
from unitxt.operators import (
AddFields,
CopyFields,
)
from unitxt.test_utils.card import test_card

card = TaskCard(
loader=LoadHF(
path="PrimeQA/clapnq",
),
preprocess_steps=[
SplitRandomMix({"train": "train", "test": "validation"}),
CopyFields(
field_to_field={
"passages/*/text": "contexts",
"input": "question",
"output/*/answer": "reference_answers",
}
),
AddFields(
fields={
"contexts_ids": [],
}
),
],
task="tasks.rag.response_generation",
templates=TemplatesDict({"default": "templates.rag.response_generation.simple"}),
)

# testing the card is too slow with the bert-score metric, so dropping it
card_for_test = deepcopy(card)
card_for_test.task.metrics = [
"metrics.rag.response_generation.correctness.token_overlap",
"metrics.rag.response_generation.faithfullness.token_overlap",
]

test_card(
card_for_test,
strict=True,
demos_taken_from="test",
)
add_to_catalog(
card,
"cards.rag.response_generation.clapnq",
overwrite=True,
)
77 changes: 38 additions & 39 deletions prepare/metrics/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
model_name="OpenAssistant/reward-model-deberta-v3-large-v2"
),
}

predictions = ["apple", "boy", "cat"]
references = [["apple2"], ["boys"], ["dogs"]]
task_data = [{"context": "apple 2e"}, {"context": "boy"}, {"context": "dog"}]
Expand All @@ -42,10 +41,8 @@
{"f1": 1.0, "precision": 1.0, "recall": 1.0, "score": 1.0, "score_name": "f1"},
{"f1": 0, "precision": 0, "recall": 0, "score": 0, "score_name": "f1"},
]

# Currently, CopyFields does not delete the source fields,
# so we get both "precision" and "precision_overlap_with_context" in results

global_target = {
"f1": 0.56,
"f1_ci_high": 0.89,
Expand All @@ -64,8 +61,6 @@
"score_ci_low": 0.0,
"score_name": "f1",
}


metric = MetricPipeline(
main_score="score",
preprocess_steps=[
Expand All @@ -87,8 +82,6 @@
],
)
add_to_catalog(metric, "metrics.token_overlap_with_context", overwrite=True)


outputs = test_metric(
metric=metric,
predictions=predictions,
Expand All @@ -97,7 +90,6 @@
global_target=global_target,
task_data=task_data,
)

metric = metrics["metrics.bert_score.deberta_xlarge_mnli"]
predictions = ["hello there general dude", "foo bar foobar"]
references = [
Expand All @@ -108,7 +100,6 @@
{"f1": 0.8, "precision": 0.86, "recall": 0.84, "score": 0.8, "score_name": "f1"},
{"f1": 1.0, "precision": 1.0, "recall": 1.0, "score": 1.0, "score_name": "f1"},
]

global_target = {
"f1": 0.9,
"f1_ci_high": 1.0,
Expand All @@ -124,15 +115,13 @@
"score_ci_low": 0.8,
"score_name": "f1",
}

# test_metric(
# metric=metric,
# predictions=predictions,
# references=references,
# instance_targets=instance_targets,
# global_target=global_target,
# )

metric = metrics["metrics.bert_score.deberta_large_mnli"]
predictions = ["hello there general dude", "foo bar foobar"]
references = [
Expand All @@ -143,7 +132,6 @@
{"f1": 0.73, "precision": 0.83, "recall": 0.79, "score": 0.73, "score_name": "f1"},
{"f1": 1.0, "precision": 1.0, "recall": 1.0, "score": 1.0, "score_name": "f1"},
]

global_target = {
"f1": 0.87,
"f1_ci_high": 1.0,
Expand All @@ -159,15 +147,13 @@
"score_ci_low": 0.73,
"score_name": "f1",
}

# test_metric(
# metric=metric,
# predictions=predictions,
# references=references,
# instance_targets=instance_targets,
# global_target=global_target,
# )

metric = metrics["metrics.bert_score.deberta_base_mnli"]
predictions = ["hello there general dude", "foo bar foobar"]
references = [
Expand All @@ -178,7 +164,6 @@
{"f1": 0.81, "precision": 0.85, "recall": 0.81, "score": 0.81, "score_name": "f1"},
{"f1": 1.0, "precision": 1.0, "recall": 1.0, "score": 1.0, "score_name": "f1"},
]

global_target = {
"f1": 0.9,
"f1_ci_high": 1.0,
Expand All @@ -194,15 +179,13 @@
"score_ci_low": 0.81,
"score_name": "f1",
}

test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)

metric = metrics["metrics.bert_score.distilbert_base_uncased"]
predictions = ["hello there general dude", "foo bar foobar"]
references = [
Expand All @@ -213,7 +196,6 @@
{"f1": 0.85, "precision": 0.91, "recall": 0.86, "score": 0.85, "score_name": "f1"},
{"f1": 1.0, "precision": 1.0, "recall": 1.0, "score": 1.0, "score_name": "f1"},
]

global_target = {
"f1": 0.93,
"f1_ci_high": 1.0,
Expand All @@ -229,15 +211,13 @@
"score_ci_low": 0.85,
"score_name": "f1",
}

test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)

metric = metrics["metrics.bert_score.deberta_v3_base_mnli_xnli_ml"]
predictions = ["hello there general dude", "foo bar foobar"]
references = [
Expand All @@ -248,7 +228,6 @@
{"f1": 0.74, "precision": 0.81, "recall": 0.71, "score": 0.74, "score_name": "f1"},
{"f1": 1.0, "precision": 1.0, "recall": 1.0, "score": 1.0, "score_name": "f1"},
]

global_target = {
"f1": 0.87,
"f1_ci_high": 1.0,
Expand All @@ -264,22 +243,19 @@
"score_ci_low": 0.74,
"score_name": "f1",
}

test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)

metric = metrics["metrics.sentence_bert.mpnet_base_v2"]
predictions = ["hello there general dude", "foo bar foobar"]
references = [
["hello there general kenobi", "hello there!"],
["foo bar foobar", "foo bar"],
]

instance_targets = [
{"score": 0.71, "score_name": "score"},
{"score": 1.0, "score_name": "score"},
Expand All @@ -290,15 +266,13 @@
"score_ci_low": 0.71,
"score_name": "score",
}

test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)

metric = metrics["metrics.reward.deberta_v3_large_v2"]
predictions = ["hello there General Dude", "foo bar foobar"]
references = [["How do you greet General Dude"], ["What is your name?"]]
Expand All @@ -312,20 +286,16 @@
"score_ci_low": 0.03,
"score_name": "score",
}

# test_metric(
# metric=metric,
# predictions=predictions,
# references=references,
# instance_targets=instance_targets,
# global_target=global_target,
# )

for metric_id, metric in metrics.items():
add_to_catalog(metric, metric_id, overwrite=True)

# rag metrics:

# reference-less:
# context-relevance:
# metrics.rag.context_relevance
Expand All @@ -336,7 +306,6 @@
# metrics.rag.bert_k_precision k_precision stands for "knowledge precision"
# answer-relevance:
# metrics.rag.answer_relevance

# reference-based:
# context-correctness:
# metrics.rag.mrr
Expand All @@ -346,7 +315,6 @@
# metrics.rag.correctness
# metrics.rag.recall
# metrics.rag.bert_recall

for metric_name, catalog_name in [
("map", "metrics.rag.map"),
("mrr", "metrics.rag.mrr"),
Expand All @@ -363,8 +331,6 @@
metric=f"metrics.{metric_name}",
)
add_to_catalog(metric, catalog_name, overwrite=True)


context_relevance = MetricPipeline(
main_score="perplexity",
preprocess_steps=[
Expand All @@ -376,7 +342,6 @@
metric="metrics.perplexity_q.flan_t5_small",
)
add_to_catalog(context_relevance, "metrics.rag.context_relevance", overwrite=True)

context_perplexity = MetricPipeline(
main_score="score",
preprocess_steps=[
Expand All @@ -395,7 +360,6 @@
],
)
add_to_catalog(context_perplexity, "metrics.rag.context_perplexity", overwrite=True)

for new_catalog_name, base_catalog_name in [
("metrics.rag.faithfulness", "metrics.token_overlap"),
("metrics.rag.k_precision", "metrics.token_overlap"),
Expand All @@ -416,7 +380,6 @@
metric=base_catalog_name,
)
add_to_catalog(metric, new_catalog_name, overwrite=True)

for new_catalog_name, base_catalog_name in [
("metrics.rag.answer_correctness", "metrics.token_overlap"),
("metrics.rag.recall", "metrics.token_overlap"),
Expand All @@ -434,7 +397,6 @@
metric=base_catalog_name,
)
add_to_catalog(metric, new_catalog_name, overwrite=True)

answer_reward = MetricPipeline(
main_score="score",
preprocess_steps=[
Expand All @@ -450,7 +412,6 @@
metric="metrics.reward.deberta_v3_large_v2",
)
add_to_catalog(answer_reward, "metrics.rag.answer_reward", overwrite=True)

answer_inference = MetricPipeline(
main_score="perplexity",
preprocess_steps=[
Expand All @@ -462,3 +423,41 @@
metric="metrics.perplexity_nli.t5_nli_mixture",
)
add_to_catalog(answer_inference, "metrics.rag.answer_inference", overwrite=True)

for axis, base_metric, main_score in [
("correctness", "token_overlap", "f1"),
("correctness", "bert_score.deberta_large_mnli", "recall"),
("correctness", "bert_score.deberta_v3_base_mnli_xnli_ml", "recall"),
("faithfullness", "token_overlap", "precision"),
]:
preprocess_steps = (
[
CopyFields(field_to_field=[("task_data/contexts", "references")]),
]
if axis == "faithfullness"
else []
)

metric = MetricPipeline(
main_score=main_score,
preprocess_steps=preprocess_steps,
postpreprocess_steps=[
CopyFields(
field_to_field={
"score/instance/f1": f"score/instance/{axis}_f1_{base_metric}",
"score/instance/recall": f"score/instance/{axis}_recall_{base_metric}",
"score/instance/precision": f"score/instance/{axis}_precision_{base_metric}",
"score/global/f1": f"score/global/{axis}_f1_{base_metric}",
"score/global/recall": f"score/global/{axis}_recall_{base_metric}",
"score/global/precision": f"score/global/{axis}_precision_{base_metric}",
},
not_exist_ok=True,
),
],
metric=f"metrics.{base_metric}",
prediction_type="str",
)

add_to_catalog(
metric, f"metrics.rag.response_generation.{axis}.{base_metric}", overwrite=True
)
23 changes: 23 additions & 0 deletions prepare/tasks/rag/response_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from unitxt import add_to_catalog
from unitxt.blocks import (
FormTask,
)

add_to_catalog(
FormTask(
inputs={
"contexts": "List[str]",
"contexts_ids": "List[int]",
"question": "str",
},
outputs={"reference_answers": "List[str]"},
metrics=[
"metrics.rag.response_generation.correctness.token_overlap",
"metrics.rag.response_generation.faithfullness.token_overlap",
"metrics.rag.response_generation.correctness.bert_score.deberta_large_mnli",
],
augmentable_inputs=["contexts", "question"],
),
"tasks.rag.response_generation",
overwrite=True,
)
Loading
Loading