Skip to content

Commit

Permalink
Merge pull request #446 from JohnSnowLabs/feature/accuracy-fairness-f…
Browse files Browse the repository at this point in the history
…or-summarization

Feature/accuracy fairness for summarization
  • Loading branch information
ArshaanNazir committed May 23, 2023
2 parents 4c0cdf2 + e17c1bf commit a1d56cc
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 24 deletions.
4 changes: 3 additions & 1 deletion nlptest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,12 +483,14 @@ def load_data(self):
)

elif (self.task=='summarization'):
expected_results = item.get("summary",None)
if isinstance(expected_results, str) or isinstance(expected_results, bool): expected_results = [str(expected_results)]
data.append(
SummarizationSample(
original = item['document'],
expected_results=expected_results,
task=self.task,
dataset_name=self._file_path.split('/')[-2]

)
)

Expand Down
6 changes: 5 additions & 1 deletion nlptest/nlptest.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,12 @@ def generated_results(self) -> Optional[pd.DataFrame]:
return
generated_results_df = pd.DataFrame.from_dict(
[x.to_dict() for x in self._generated_results])
if "test_case" in generated_results_df.columns and "original_question" in generated_results_df.columns:
generated_results_df['original_question'].update(generated_results_df.pop('test_case'))

return generated_results_df.fillna('-')
generated_results_df=generated_results_df[generated_results_df.columns.drop("pass").to_list() + ["pass"]]

return generated_results_df.fillna("-")

def augment(self, input_path: str, output_path: str, inplace: bool = False) -> "Harness":
"""
Expand Down
50 changes: 36 additions & 14 deletions nlptest/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def transform(self) -> List[Sample]:
y_true = pd.Series(data_handler_copy).apply(lambda x: [y.entity for y in x.expected_results.predictions])
elif isinstance(data_handler_copy[0], SequenceClassificationSample):
y_true = pd.Series(data_handler_copy).apply(lambda x: [y.label for y in x.expected_results.predictions])
elif isinstance(data_handler_copy[0], QASample):
elif data_handler_copy[0].task in ["question-answering", "summarization"]:
y_true = pd.Series(data_handler_copy).apply(lambda x: x.expected_results)

y_true = y_true.explode().apply(lambda x: x.split("-")
Expand Down Expand Up @@ -628,18 +628,29 @@ def run(cls, sample_list: Dict[str, List[Sample]], model: ModelFactory, raw_data
y_true = y_true.explode()
y_pred = y_pred.explode()

elif isinstance(data[0], QASample):
elif data[0].task == "question-answering":
dataset_name = data[0].dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = """Context: {context}\nQuestion: {question}\n """ + user_prompt

if data[0].expected_results is None:
logging.warning('The dataset %s does not contain labels and fairness tests cannot be run with it. Skipping the fairness tests.', dataset_name)
return []
raise RuntimeError(f'The dataset {dataset_name} does not contain labels and fairness tests cannot be run with it. Skipping the fairness tests.')
y_true = pd.Series(data).apply(lambda x: x.expected_results)
X_test = pd.Series(data)
y_pred = X_test.apply(lambda sample: model(text={'context':sample.original_context, 'question': sample.original_question}, prompt={"template":prompt_template, 'input_variables':["context", "question"]}))
y_pred = y_pred.apply(lambda x: x.strip())

elif data[0].task == "summarization":
dataset_name = data[0].dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = user_prompt + """Context: {context}\n\n Summary: """
if data[0].expected_results is None:
raise RuntimeError(f'The dataset {dataset_name} does not contain labels and fairness tests cannot be run with it. Skipping the fairness tests.')

y_true = pd.Series(data).apply(lambda x: x.expected_results)
X_test = pd.Series(data)
y_pred = X_test.apply(lambda sample: model(text={'context':sample.original}, prompt={"template":prompt_template, 'input_variables':["context"]}))
y_pred = y_pred.apply(lambda x: x.strip())

if kwargs['is_default']:
y_pred = y_pred.apply(lambda x: '1' if x in ['pos', 'LABEL_1', 'POS'] else (
Expand Down Expand Up @@ -717,15 +728,15 @@ def transform(self) -> List[Sample]:
for test_name, params in self.tests.items():
data_handler_copy = [x.copy() for x in self._data_handler]

if isinstance(data_handler_copy[0], NERSample):
if data_handler_copy[0].task=="ner":
y_true = pd.Series(data_handler_copy).apply(lambda x: [y.entity for y in x.expected_results.predictions])
elif isinstance(data_handler_copy[0], SequenceClassificationSample):
y_true = pd.Series(data_handler_copy).apply(lambda x: [y.label for y in x.expected_results.predictions])
elif isinstance(data_handler_copy[0], QASample):
y_true = pd.Series(data_handler_copy).apply(lambda x: x.expected_results)
y_true = y_true.explode().apply(lambda x: x.split("-")
[-1] if isinstance(x, str) else x)
elif data_handler_copy[0].task=="text-classification":
y_true = pd.Series(data_handler_copy).apply(lambda x: [y.label for y in x.expected_results.predictions]).explode()
elif data_handler_copy[0].task=="question-answering" or data_handler_copy[0].task=="summarization":
y_true = pd.Series(data_handler_copy).apply(lambda x: x.expected_results).explode()

y_true = y_true.explode().apply(lambda x: x.split("-")
[-1] if isinstance(x, str) else x)
y_true = y_true.dropna()
params["test_name"] = test_name
transformed_samples = self.supported_tests[test_name].transform(
Expand Down Expand Up @@ -781,19 +792,30 @@ def run(cls, sample_list: Dict[str, List[Sample]], model: ModelFactory, raw_data
y_true = y_true.explode()
y_pred = y_pred.explode()

elif isinstance(raw_data[0], QASample):
elif raw_data[0].task=="question-answering":
dataset_name = raw_data[0].dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = """Context: {context}\nQuestion: {question}\n """ + user_prompt

if raw_data[0].expected_results is None:
logging.warning('The dataset %s does not contain labels and accuracy tests cannot be run with it. Skipping the accuracy tests.', dataset_name)
return []
raise RuntimeError(f'The dataset {dataset_name} does not contain labels and fairness tests cannot be run with it. Skipping the fairness tests.')
y_true = pd.Series(raw_data).apply(lambda x: x.expected_results)
X_test = pd.Series(raw_data)
y_pred = X_test.apply(lambda sample: model(text={'context':sample.original_context, 'question': sample.original_question}, prompt={"template":prompt_template, 'input_variables':["context", "question"]}))
y_pred = y_pred.apply(lambda x: x.strip())

elif raw_data[0].task=="summarization":
dataset_name = raw_data[0].dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = user_prompt + """Context: {context}\n\n Summary: """
if raw_data[0].expected_results is None:
raise RuntimeError(f'The dataset {dataset_name} does not contain labels and fairness tests cannot be run with it. Skipping the fairness tests.')

y_true = pd.Series(raw_data).apply(lambda x: x.expected_results)
X_test = pd.Series(raw_data)
y_pred = X_test.apply(lambda sample: model(text={'context':sample.original}, prompt={"template":prompt_template, 'input_variables':["context"]}))
y_pred = y_pred.apply(lambda x: x.strip())

if kwargs['is_default']:
y_pred = y_pred.apply(lambda x: '1' if x in ['pos', 'LABEL_1', 'POS'] else (
'0' if x in ['neg', 'LABEL_0', 'NEG'] else x))
Expand Down
6 changes: 3 additions & 3 deletions nlptest/transform/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ class MinEMcore(BaseAccuracy):
"""

alias_name = "min_exact_match_score"
supported_tasks = ["question-answering"]
supported_tasks = ["question-answering", "summarization"]

@staticmethod
def transform(y_true, params):
Expand Down Expand Up @@ -563,7 +563,7 @@ class MinBLEUcore(BaseAccuracy):
"""

alias_name = "min_bleu_score"
supported_tasks = ["question-answering"]
supported_tasks = ["question-answering", "summarization"]

@staticmethod
def transform(y_true, params):
Expand Down Expand Up @@ -626,7 +626,7 @@ class MinROUGEcore(BaseAccuracy):
"""

alias_name = ["min_rouge1_score","min_rouge2_score","min_rougeL_score","min_rougeLsum_score"]
supported_tasks = ["question-answering"]
supported_tasks = ["question-answering", "summarization"]

@staticmethod
def transform(y_true, params):
Expand Down
6 changes: 3 additions & 3 deletions nlptest/transform/fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BaseFairness(ABC):
output based on the implemented accuracy measure.
"""
alias_name = None
supported_tasks = ["ner", "text-classification", "question-answering"]
supported_tasks = ["ner", "text-classification", "question-answering", "summarization"]

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -100,7 +100,7 @@ def transform(data: List[Sample], params):
samples = []
for key, val in min_scores.items():
sample = MinScoreSample(
original="-",
original=None,
category="fairness",
test_type="min_gender_f1_score",
test_case=key,
Expand Down Expand Up @@ -183,7 +183,7 @@ def transform(data: List[Sample], params):
samples = []
for key, val in max_scores.items():
sample = MaxScoreSample(
original="-",
original=None,
category="fairness",
test_type="max_gender_f1_score",
test_case=key,
Expand Down
4 changes: 2 additions & 2 deletions nlptest/utils/custom_types/sample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple, TypeVar
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
from copy import deepcopy
from pydantic import BaseModel, PrivateAttr, validator
from .helpers import Transformation, Span
Expand Down Expand Up @@ -417,7 +417,7 @@ def is_pass(self) -> bool:
class SummarizationSample(BaseModel):
original: str = None
test_case: str = None
expected_results: str = None
expected_results: Union[str, List] = None
actual_results: str = None
state: str = None
dataset_name: str = None
Expand Down

0 comments on commit a1d56cc

Please sign in to comment.