Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sijunhe committed Feb 9, 2023
1 parent c0eed44 commit 41d483a
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 4 deletions.
6 changes: 3 additions & 3 deletions paddlenlp/experimental/autonlp/auto_trainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,17 @@ def __init__(
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.greater_is_better = greater_is_better
if language not in self.supported_language:
if language not in self.supported_languages:
raise ValueError(
f"'{language}' is not supported. Please choose among the following: {self.supported_language}"
f"'{language}' is not supported. Please choose among the following: {self.supported_languages}"
)

self.language = language
self.output_dir = output_dir

@property
@abstractmethod
def supported_language(self) -> List[str]:
def supported_languages(self) -> List[str]:
"""
Override to store the supported languages for each auto trainer class
"""
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/experimental/autonlp/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
)

@property
def supported_language(self) -> List[str]:
def supported_languages(self) -> List[str]:
return ["Chinese", "English"]

@property
Expand Down
31 changes: 31 additions & 0 deletions tests/experimental/autonlp/test_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,37 @@ def test_untrained_auto_trainer(self):
# test export
auto_trainer.export(temp_dir)

def test_unsupported_languages(self):
with TemporaryDirectory() as temp_dir:
train_ds = copy.deepcopy(self.multi_class_train_ds)
dev_ds = copy.deepcopy(self.multi_class_dev_ds)
with self.assertRaises(ValueError):
AutoTrainerForTextClassification(
train_dataset=train_ds,
eval_dataset=dev_ds,
label_column="label_desc",
text_column="sentence",
language="Spanish", # spanish is unsupported for now
output_dir=temp_dir,
)

def test_model_language_filter(self):
with TemporaryDirectory() as temp_dir:
train_ds = copy.deepcopy(self.multi_class_train_ds)
dev_ds = copy.deepcopy(self.multi_class_dev_ds)
auto_trainer = AutoTrainerForTextClassification(
train_dataset=train_ds,
eval_dataset=dev_ds,
label_column="label_desc",
text_column="sentence",
language="Chinese",
output_dir=temp_dir,
)
for language in auto_trainer.supported_languages:
model_candidates = auto_trainer._filter_model_candidates(language=language)
for candidate in model_candidates:
self.assertEqual(candidate["language"], language)


if __name__ == "__main__":
unittest.main()

0 comments on commit 41d483a

Please sign in to comment.