Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix/task test supoort check #378

Merged
merged 12 commits into from
May 5, 2023
2 changes: 2 additions & 0 deletions nlptest/augmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,13 @@ def fix(
test_type['robustness']['swap_entities']['parameters']['labels'] = [
self.label[each]]
hash_map[each] = TestFactory.transform(
self.task,
[hash_map[each]], test_type)[0]

else:
sample_data = random.choices(data, k=int(sample_length))
aug_data = TestFactory.transform(
self.task,
sample_data, test_type)
fianl_aug_data.extend(aug_data)
if inplace:
Expand Down
22 changes: 10 additions & 12 deletions nlptest/nlptest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .augmentation import AugmentRobustness
from .datahandler.datasource import DataFactory
from .modelhandler import ModelFactory
from .testrunner import BaseRunner
from .transform import TestFactory


Expand Down Expand Up @@ -66,8 +65,8 @@ def __init__(
self.task = task

if isinstance(model, str) and hub is None:
raise ValueError(f"When passing a string argument to the 'model' parameter, you must provide an argument "
f"for the 'hub' parameter as well.")
raise ValueError("When passing a string argument to the 'model' parameter, you must provide an argument "
"for the 'hub' parameter as well.")

if hub is not None and hub not in self.SUPPORTED_HUBS:
raise ValueError(
Expand All @@ -81,13 +80,12 @@ def __init__(
if model == "textcat_imdb":
model = resource_filename("nlptest", "data/textcat_imdb")
self.is_default = True
logging.info(
f"Default dataset '{(task, model, hub)}' successfully loaded.")
logging.info("Default dataset '%s' successfully loaded.", (task, model, hub))

elif data is None and (task, model, hub) not in self.DEFAULTS_DATASET.keys():
raise ValueError(f"You haven't specified any value for the parameter 'data' and the configuration you "
f"passed is not among the default ones. You need to either specify the parameter 'data' "
f"or use a default configuration.")
raise ValueError("You haven't specified any value for the parameter 'data' and the configuration you "
"passed is not among the default ones. You need to either specify the parameter 'data' "
"or use a default configuration.")
elif isinstance(data, list):
self.data = data
else:
Expand Down Expand Up @@ -136,7 +134,7 @@ def configure(self, config: Union[str, dict]) -> dict:
if type(config) == dict:
self._config = config
else:
with open(config, 'r') as yml:
with open(config, 'r', encoding="utf-8") as yml:
self._config = yaml.safe_load(yml)
self._config_copy = self._config
return self._config
Expand All @@ -159,7 +157,7 @@ def generate(self) -> "Harness":
_ = [setattr(sample, 'expected_results', self.model(sample.original))
for sample in m_data]
self._testcases = TestFactory.transform(
self.data, tests, m_data=m_data)
self.task, self.data, tests, m_data=m_data)
return self

def run(self) -> "Harness":
Expand Down Expand Up @@ -280,7 +278,7 @@ def augment(self, input_path: str, output_path: str, inplace: bool = False) -> "
"""

dtypes = list(map(
lambda x: str(x),
str,
self.df_report[['pass_rate', 'minimum_pass_rate']].dtypes.values.tolist()))
if dtypes not in [['int64'] * 2, ['int32'] * 2]:
self.df_report['pass_rate'] = self.df_report['pass_rate'].str.replace(
Expand Down Expand Up @@ -332,7 +330,7 @@ def save(self, save_dir: str) -> None:
if not os.path.isdir(save_dir):
os.mkdir(save_dir)

with open(os.path.join(save_dir, "config.yaml"), 'w') as yml:
with open(os.path.join(save_dir, "config.yaml"), 'w', encoding="utf-8") as yml:
yml.write(yaml.safe_dump(self._config_copy))

with open(os.path.join(save_dir, "test_cases.pkl"), "wb") as writer:
Expand Down
41 changes: 27 additions & 14 deletions nlptest/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TestFactory:
is_augment = False

@staticmethod
def transform(data: List[Sample], test_types: dict, *args, **kwargs) -> List[Result]:
def transform(task: str, data: List[Sample], test_types: dict, *args, **kwargs) -> List[Result]:
"""
Runs the specified tests on the given data and returns a list of results.

Expand All @@ -45,11 +45,23 @@ def transform(data: List[Sample], test_types: dict, *args, **kwargs) -> List[Res
tests = tqdm(test_types.keys(), desc="Generating testcases...",
disable=TestFactory.is_augment)
m_data = kwargs.get('m_data', None)

# Check test-task supportance
for test_category in tests:
if test_category in all_categories.keys():
sub_test_types = test_types[test_category]
for sub_test in sub_test_types:
supported = all_categories[test_category].available_tests()[sub_test].supported_tasks
if task not in supported:
raise ValueError(f"The test type \"{sub_test}\" is not supported for the task \"{task}\". \"{sub_test}\" only supports {supported}.")
elif test_category != "defaults":
raise ValueError(f"The test category {test_category} does not exist. Available categories are: {all_categories.keys()}.")

# Generate testcases
for each in tests:
tests.set_description(f"Generating testcases... ({each})")
if each in all_categories:
sub_test_types = test_types[each]

all_results.extend(
all_categories[each](m_data, sub_test_types,
raw_data=data).transform()
Expand Down Expand Up @@ -141,7 +153,7 @@ class ITests(ABC):
alias_name = None

@abstractmethod
def transform(cls):
def transform(self):
"""
Runs the test and returns the results.

Expand All @@ -151,8 +163,9 @@ def transform(cls):
"""
return NotImplementedError

@staticmethod
@abstractmethod
def available_tests(cls):
def available_tests():
"""
Returns a list of available test scenarios for the test class.

Expand Down Expand Up @@ -279,8 +292,8 @@ def transform(self) -> List[Sample]:
all_samples.extend(transformed_samples)
return all_samples

@classmethod
def available_tests(cls) -> dict:
@staticmethod
def available_tests() -> dict:
"""
Get a dictionary of all available tests, with their names as keys and their corresponding classes as values.

Expand Down Expand Up @@ -414,8 +427,8 @@ def transform(self) -> List[Sample]:
all_samples.extend(transformed_samples)
return all_samples

@classmethod
def available_tests(cls) -> Dict:
@staticmethod
def available_tests() -> Dict:
"""
Get a dictionary of all available tests, with their names as keys and their corresponding classes as values.

Expand Down Expand Up @@ -481,8 +494,8 @@ def transform(self) -> List[Sample]:

return all_samples

@classmethod
def available_tests(cls) -> Dict:
@staticmethod
def available_tests() -> Dict:
"""
Get a dictionary of all available tests, with their names as keys and their corresponding classes as values.

Expand Down Expand Up @@ -546,8 +559,8 @@ def transform(self) -> List[Sample]:
all_samples.extend(transformed_samples)
return all_samples

@classmethod
def available_tests(cls) -> dict:
@staticmethod
def available_tests() -> dict:
"""
Get a dictionary of all available tests, with their names as keys and their corresponding classes as values.

Expand Down Expand Up @@ -623,8 +636,8 @@ def transform(self) -> List[Sample]:
all_samples.extend(transformed_samples)
return all_samples

@classmethod
def available_tests(cls) -> dict:
@staticmethod
def available_tests() -> dict:
"""
Get a dictionary of all available tests, with their names as keys and their corresponding classes as values.

Expand Down
1 change: 1 addition & 0 deletions nlptest/transform/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class BaseAccuracy(ABC):
transform(data: List[Sample]) -> Any: Transforms the input data into an output based on the implemented accuracy measure.
"""
alias_name = None
supported_tasks = ["ner", "text-classification"]

@staticmethod
@abstractmethod
Expand Down
1 change: 1 addition & 0 deletions nlptest/transform/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class BaseBias(ABC):
transform(data: List[Sample]) -> Any: Transforms the input data into an output based on the implemented bias measure.
"""
alias_name = None
supported_tasks = ["ner", "text-classification"]

@abstractmethod
def transform(self, sample_list: List[Sample], *args, **kwargs) -> List[Sample]:
Expand Down
1 change: 1 addition & 0 deletions nlptest/transform/fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class BaseFairness(ABC):
output based on the implemented accuracy measure.
"""
alias_name = None
supported_tasks = ["ner", "text-classification"]

@staticmethod
@abstractmethod
Expand Down
2 changes: 2 additions & 0 deletions nlptest/transform/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class BaseRepresentation(ABC):
based on the implemented representation measure.
"""
alias_name = None
supported_tasks = ["ner", "text-classification"]

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -453,6 +454,7 @@ class ReligionRepresentation(BaseRepresentation):
"min_religion_name_representation_count",
"min_religion_name_representation_proportion"
]
supported_tasks = ["ner", "text-classification"]

def transform(test, data, params):
"""
Expand Down
8 changes: 2 additions & 6 deletions nlptest/transform/robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class BaseRobustness(ABC):
transform(data: List[Sample]) -> Any: Transforms the input data into an output based on the implemented robustness measure.
"""
alias_name = None
supported_tasks = ["ner", "text-classification", "question-answering"]

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -134,7 +135,6 @@ def transform(sample_list: List[Sample]) -> List[Sample]:
sample.category = "robustness"
return sample_list


class TitleCase(BaseRobustness):
alias_name = 'titlecase'

Expand Down Expand Up @@ -267,7 +267,6 @@ def check_whitelist(text, whitelist):
sample.category = "robustness"
return sample_list


class AddTypo(BaseRobustness):
alias_name = 'add_typo'

Expand Down Expand Up @@ -331,9 +330,9 @@ def keyboard_typo(string):

return sample_list


class SwapEntities(BaseRobustness):
alias_name = 'swap_entities'
supported_tasks = ["ner"]

@staticmethod
def transform(
Expand Down Expand Up @@ -411,7 +410,6 @@ def transform(
]
return sample_list


class ConvertAccent(BaseRobustness):
alias_name = ["american_to_british", "british_to_american"]

Expand Down Expand Up @@ -462,7 +460,6 @@ def convert_accent(string: str, accent_map: Dict[str, str]) -> str:

return sample_list


class AddContext(BaseRobustness):
alias_name = 'add_context'

Expand Down Expand Up @@ -616,7 +613,6 @@ def transform(
sample.category = "robustness"
return sample_list


class AddContraction(BaseRobustness):
alias_name = 'add_contraction'

Expand Down
11 changes: 11 additions & 0 deletions tests/test_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ def test_incompatible_tasks(self):
hub="huggingface"
)

def test_unsupported_test_for_task(self):
""""""
with self.assertRaises(ValueError):
h = Harness(
task="text-classification",
model="textcat_imdb",
hub="spacy",
config={'tests':{'robustness':{'swap_entities':{'min_pass_rate':0.5}}}}
)
h.generate()

def test_save(self):
""""""
save_dir = "/tmp/saved_harness_test"
Expand Down