Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/accuracy fairness for summarization #446

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion nlptest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,14 @@ def load_data(self):
)

elif (self.task=='summarization'):
expected_results = item.get("summary",None)
if isinstance(expected_results, str) or isinstance(expected_results, bool): expected_results = [str(expected_results)]
data.append(
SummarizationSample(
original = item['document'],
expected_results=expected_results,
task=self.task,
dataset_name=self._file_path.split('/')[-2]

)
)

Expand Down
6 changes: 5 additions & 1 deletion nlptest/nlptest.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,12 @@ def generated_results(self) -> Optional[pd.DataFrame]:
return
generated_results_df = pd.DataFrame.from_dict(
[x.to_dict() for x in self._generated_results])
if "test_case" in generated_results_df.columns and "original_question" in generated_results_df.columns:
generated_results_df['original_question'].update(generated_results_df.pop('test_case'))

return generated_results_df.fillna('-')
generated_results_df=generated_results_df[generated_results_df.columns.drop("pass").to_list() + ["pass"]]

return generated_results_df.fillna("-")

def augment(self, input_path: str, output_path: str, inplace: bool = False) -> "Harness":
"""
Expand Down
50 changes: 36 additions & 14 deletions nlptest/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def transform(self) -> List[Sample]:
y_true = pd.Series(data_handler_copy).apply(lambda x: [y.entity for y in x.expected_results.predictions])
elif isinstance(data_handler_copy[0], SequenceClassificationSample):
y_true = pd.Series(data_handler_copy).apply(lambda x: [y.label for y in x.expected_results.predictions])
elif isinstance(data_handler_copy[0], QASample):
elif data_handler_copy[0].task in ["question-answering", "summarization"]:
y_true = pd.Series(data_handler_copy).apply(lambda x: x.expected_results)

y_true = y_true.explode().apply(lambda x: x.split("-")
Expand Down Expand Up @@ -628,18 +628,29 @@ def run(cls, sample_list: Dict[str, List[Sample]], model: ModelFactory, raw_data
y_true = y_true.explode()
y_pred = y_pred.explode()

elif isinstance(data[0], QASample):
elif data[0].task == "question-answering":
dataset_name = data[0].dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = """Context: {context}\nQuestion: {question}\n """ + user_prompt

if data[0].expected_results is None:
logging.warning('The dataset %s does not contain labels and fairness tests cannot be run with it. Skipping the fairness tests.', dataset_name)
return []
raise RuntimeError(f'The dataset {dataset_name} does not contain labels and fairness tests cannot be run with it. Skipping the fairness tests.')
y_true = pd.Series(data).apply(lambda x: x.expected_results)
X_test = pd.Series(data)
y_pred = X_test.apply(lambda sample: model(text={'context':sample.original_context, 'question': sample.original_question}, prompt={"template":prompt_template, 'input_variables':["context", "question"]}))
y_pred = y_pred.apply(lambda x: x.strip())

elif data[0].task == "summarization":
dataset_name = data[0].dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = user_prompt + """Context: {context}\n\n Summary: """
if data[0].expected_results is None:
raise RuntimeError(f'The dataset {dataset_name} does not contain labels and fairness tests cannot be run with it. Skipping the fairness tests.')

y_true = pd.Series(data).apply(lambda x: x.expected_results)
X_test = pd.Series(data)
y_pred = X_test.apply(lambda sample: model(text={'context':sample.original}, prompt={"template":prompt_template, 'input_variables':["context"]}))
y_pred = y_pred.apply(lambda x: x.strip())

if kwargs['is_default']:
y_pred = y_pred.apply(lambda x: '1' if x in ['pos', 'LABEL_1', 'POS'] else (
Expand Down Expand Up @@ -717,15 +728,15 @@ def transform(self) -> List[Sample]:
for test_name, params in self.tests.items():
data_handler_copy = [x.copy() for x in self._data_handler]

if isinstance(data_handler_copy[0], NERSample):
if data_handler_copy[0].task=="ner":
y_true = pd.Series(data_handler_copy).apply(lambda x: [y.entity for y in x.expected_results.predictions])
elif isinstance(data_handler_copy[0], SequenceClassificationSample):
y_true = pd.Series(data_handler_copy).apply(lambda x: [y.label for y in x.expected_results.predictions])
elif isinstance(data_handler_copy[0], QASample):
y_true = pd.Series(data_handler_copy).apply(lambda x: x.expected_results)
y_true = y_true.explode().apply(lambda x: x.split("-")
[-1] if isinstance(x, str) else x)
elif data_handler_copy[0].task=="text-classification":
y_true = pd.Series(data_handler_copy).apply(lambda x: [y.label for y in x.expected_results.predictions]).explode()
elif data_handler_copy[0].task=="question-answering" or data_handler_copy[0].task=="summarization":
y_true = pd.Series(data_handler_copy).apply(lambda x: x.expected_results).explode()

y_true = y_true.explode().apply(lambda x: x.split("-")
[-1] if isinstance(x, str) else x)
y_true = y_true.dropna()
params["test_name"] = test_name
transformed_samples = self.supported_tests[test_name].transform(
Expand Down Expand Up @@ -781,19 +792,30 @@ def run(cls, sample_list: Dict[str, List[Sample]], model: ModelFactory, raw_data
y_true = y_true.explode()
y_pred = y_pred.explode()

elif isinstance(raw_data[0], QASample):
elif raw_data[0].task=="question-answering":
dataset_name = raw_data[0].dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = """Context: {context}\nQuestion: {question}\n """ + user_prompt

if raw_data[0].expected_results is None:
logging.warning('The dataset %s does not contain labels and accuracy tests cannot be run with it. Skipping the accuracy tests.', dataset_name)
return []
raise RuntimeError(f'The dataset {dataset_name} does not contain labels and fairness tests cannot be run with it. Skipping the fairness tests.')
y_true = pd.Series(raw_data).apply(lambda x: x.expected_results)
X_test = pd.Series(raw_data)
y_pred = X_test.apply(lambda sample: model(text={'context':sample.original_context, 'question': sample.original_question}, prompt={"template":prompt_template, 'input_variables':["context", "question"]}))
y_pred = y_pred.apply(lambda x: x.strip())

elif raw_data[0].task=="summarization":
dataset_name = raw_data[0].dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = user_prompt + """Context: {context}\n\n Summary: """
if raw_data[0].expected_results is None:
raise RuntimeError(f'The dataset {dataset_name} does not contain labels and fairness tests cannot be run with it. Skipping the fairness tests.')

y_true = pd.Series(raw_data).apply(lambda x: x.expected_results)
X_test = pd.Series(raw_data)
y_pred = X_test.apply(lambda sample: model(text={'context':sample.original}, prompt={"template":prompt_template, 'input_variables':["context"]}))
y_pred = y_pred.apply(lambda x: x.strip())

if kwargs['is_default']:
y_pred = y_pred.apply(lambda x: '1' if x in ['pos', 'LABEL_1', 'POS'] else (
'0' if x in ['neg', 'LABEL_0', 'NEG'] else x))
Expand Down
6 changes: 3 additions & 3 deletions nlptest/transform/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ class MinEMcore(BaseAccuracy):
"""

alias_name = "min_exact_match_score"
supported_tasks = ["question-answering"]
supported_tasks = ["question-answering", "summarization"]

@staticmethod
def transform(y_true, params):
Expand Down Expand Up @@ -563,7 +563,7 @@ class MinBLEUcore(BaseAccuracy):
"""

alias_name = "min_bleu_score"
supported_tasks = ["question-answering"]
supported_tasks = ["question-answering", "summarization"]

@staticmethod
def transform(y_true, params):
Expand Down Expand Up @@ -626,7 +626,7 @@ class MinROUGEcore(BaseAccuracy):
"""

alias_name = ["min_rouge1_score","min_rouge2_score","min_rougeL_score","min_rougeLsum_score"]
supported_tasks = ["question-answering"]
supported_tasks = ["question-answering", "summarization"]

@staticmethod
def transform(y_true, params):
Expand Down
6 changes: 3 additions & 3 deletions nlptest/transform/fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BaseFairness(ABC):
output based on the implemented accuracy measure.
"""
alias_name = None
supported_tasks = ["ner", "text-classification", "question-answering"]
supported_tasks = ["ner", "text-classification", "question-answering", "summarization"]

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -100,7 +100,7 @@ def transform(data: List[Sample], params):
samples = []
for key, val in min_scores.items():
sample = MinScoreSample(
original="-",
original=None,
category="fairness",
test_type="min_gender_f1_score",
test_case=key,
Expand Down Expand Up @@ -183,7 +183,7 @@ def transform(data: List[Sample], params):
samples = []
for key, val in max_scores.items():
sample = MaxScoreSample(
original="-",
original=None,
category="fairness",
test_type="max_gender_f1_score",
test_case=key,
Expand Down
4 changes: 2 additions & 2 deletions nlptest/utils/custom_types/sample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple, TypeVar
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
from copy import deepcopy
from pydantic import BaseModel, PrivateAttr, validator
from .helpers import Transformation, Span
Expand Down Expand Up @@ -417,7 +417,7 @@ def is_pass(self) -> bool:
class SummarizationSample(BaseModel):
original: str = None
test_case: str = None
expected_results: str = None
expected_results: Union[str, List] = None
actual_results: str = None
state: str = None
dataset_name: str = None
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ openai
langchain
evaluate
inflect
rouge_score
rouge_score
typing-extensions < 4.6.0