Skip to content

Commit

Permalink
Merge pull request #654 from JohnSnowLabs/feature/age-tests
Browse files Browse the repository at this point in the history
feature/add random age test
  • Loading branch information
JulesBelveze committed Jul 24, 2023
2 parents 8269ede + 6103996 commit b917bc7
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 0 deletions.
83 changes: 83 additions & 0 deletions langtest/transform/robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,3 +1632,86 @@ def check_whitelist(text):
sample.category = "robustness"

return sample_list


class RandomAge(BaseRobustness):
"""A class for adding abbreviations to the input text."""

alias_name = "randomize_age"

@staticmethod
def transform(
sample_list: List[Sample],
prob: Optional[float] = 1.0,
random_amount: int = 5,
count: int = 1,
) -> List[Sample]:
"""Transforms the given sample list by randomizing the ages by a certain amount.
Args:
sample_list (List[Sample]): The list of samples to transform.
prob (Optional[float]): The probability controlling the proportion of words to be perturbed.
Defaults to 1.0, which means all samples will be transformed.
random_amount (Optional[int]): The range to randomize the ages. +-random_amount is added to ages.
Default is 5.
count (Optional[int]): Number of variations to create from one sample.
Returns:
List[Sample]: The transformed list of samples with abbreviations added
"""
age_expressions = [
r"\d+ years old",
r"\d+ months old",
r"\d+ weeks old",
r"\d+ days old",
]

def randomize_ages(text):
perturbed_text = text
transformations = []

for expr in age_expressions:
matches = re.finditer(expr, text)
for match in matches:
start = match.start()
end = match.end()
token = text[start:end]
new_age = random.randint(-random_amount, random_amount) + int(
token.split(" ")[0]
)
new_age = new_age if new_age > 0 else 1
corrected_token = re.sub(r"\d+", str(new_age), token)
if corrected_token != token and (random.random() < prob):
perturbed_text = (
perturbed_text[:start]
+ corrected_token
+ perturbed_text[end:]
)
transformations.append(
Transformation(
original_span=Span(start=start, end=end, word=token),
new_span=Span(
start=start,
end=start + len(corrected_token),
word=corrected_token,
),
ignore=False,
)
)

return perturbed_text, transformations

perturbed_samples = []
for sample in sample_list:
for i in range(count):
if isinstance(sample, str):
s, _ = randomize_ages(sample)
perturbed_samples.append(s)
else:
s = deepcopy(sample)
s.test_case, transformations = randomize_ages(s.original)
if s.task in ("ner", "text-classification"):
s.transformations = transformations
s.category = "robustness"

return perturbed_samples
13 changes: 13 additions & 0 deletions tests/test_robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def setUp(self) -> None:
),
SequenceClassificationSample(original="They have a beautiful house."),
]
self.age_sentences = [
SequenceClassificationSample(original="I am 75 years old."),
SequenceClassificationSample(original="The baby is 40 days old."),
]
self.test_qa = [
"20 euro note -- Until now there has been only one complete series of euro notes; however a new series, similar to the current one, is being released. The European Central Bank will, in due time, announce when banknotes from the first series lose legal tender status.",
"is the first series 20 euro note still legal tender",
Expand Down Expand Up @@ -421,3 +425,12 @@ def test_adj_antonym_swap(self) -> None:
self.assertIsInstance(transformed_samples, list)
for sample in transformed_samples:
self.assertNotEqual(sample.test_case, sample.original)

def test_random_age(self) -> None:
"""
Test the RandomAge transformation.
"""
transformed_samples = RandomAge.transform(self.age_sentences, random_amount=100)
self.assertIsInstance(transformed_samples, list)
for sample in transformed_samples:
self.assertNotEqual(sample.test_case, sample.original)

0 comments on commit b917bc7

Please sign in to comment.