Skip to content

Commit

Permalink
polished tests
Browse files Browse the repository at this point in the history
  • Loading branch information
KennethEnevoldsen committed Dec 26, 2023
1 parent 49784ea commit f4ffc6e
Show file tree
Hide file tree
Showing 12 changed files with 129 additions and 97 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
]
keywords = [
"nlp",
Expand Down Expand Up @@ -62,7 +64,7 @@ da = ["dacy>=2.3.1"]
all = ["nltk>=3.6.7"]
dev = [
"cruft>=2.0.0",
"pyright>=1.1.338",
"pyright>=1.1.343",
"ruff>=0.0.254",
"pyproject-parser[cli, readme]>=0.9.1",
]
Expand Down Expand Up @@ -157,7 +159,7 @@ exclude = [
]
# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
target-version = "py39"
target-version = "py38"

[tool.ruff.flake8-annotations]
mypy-init-return = true
Expand Down
15 changes: 8 additions & 7 deletions src/augmenty/doc/subset.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import random
from functools import partial
from typing import Callable, Iterator, Union
from typing import Iterator, Union

import numpy as np
import spacy
from spacy.language import Language
from spacy.training import Example

from augmenty.augment_utilities import make_text_from_orth
from augmenty.util import Augmenter


def paragraph_subset_augmenter_v1(
nlp: Language,
example: Example,
*,
min_paragraph: Union[float, int], # type: ignore
max_paragraph: Union[float, int], # type: ignore
min_paragraph: Union[float, int],
max_paragraph: Union[float, int],
respect_sentences: bool,
) -> Iterator[Example]: # type: ignore
) -> Iterator[Example]:
example_dict = example.to_dict()
token_anno = example_dict["token_annotation"]
doc_anno = example_dict["doc_annotation"]
Expand Down Expand Up @@ -69,10 +70,10 @@ def paragraph_subset_augmenter_v1(

@spacy.registry.augmenters("paragraph_subset_augmenter_v1") # type: ignore
def create_paragraph_subset_augmenter_v1(
min_paragraph: Union[float, int] = 1, # type: ignore
max_paragraph: Union[float, int] = 1.00, # type: ignore
min_paragraph: Union[float, int] = 1,
max_paragraph: Union[float, int] = 1.00,
respect_sentences: bool = True,
) -> Callable[[Language, Example], Iterator[Example]]: # type: ignore
) -> Augmenter:
"""Create an augmenter that extracts a subset of a document.
Args:
Expand Down
8 changes: 4 additions & 4 deletions src/augmenty/keyboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ class Keyboard(BaseModel):
Keyboard: a Keyboard object
"""

keyboard_array: Dict[str, List[List[str]]] # type: ignore
keyboard_array: Dict[str, List[List[str]]]
shift_distance: int = 3

def coordinate(self, key: str) -> Tuple[int, int]: # type: ignore
def coordinate(self, key: str) -> Tuple[int, int]:
"""Get coordinate for key.
Args:
Expand All @@ -47,7 +47,7 @@ def is_shifted(self, key: str) -> bool:
Returns:
bool: a boolean indicating whether key is shifted.
"""
for x in self.keyboard_array["shift"]:
for x in self.keyboard_array["shift"]: # noqa
if key in x:
return True
return False
Expand Down Expand Up @@ -83,7 +83,7 @@ def all_keys(self):
for x, _ in enumerate(self.keyboard_array[arr]):
yield from self.keyboard_array[arr][x]

def get_neighbours(self, key: str, distance: int = 1) -> List[int]: # type: ignore
def get_neighbours(self, key: str, distance: int = 1) -> List[int]:
"""Gets the neighbours of a key with a specified distance.
Args:
Expand Down
62 changes: 32 additions & 30 deletions src/augmenty/util.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,35 @@
"""Utility functions for the package."""

from typing import Callable, Dict, Iterable, Iterator, List
from typing import Any, Callable, Dict, Iterable, Iterator, List

import catalogue # type: ignore
import spacy # type: ignore
import thinc # type: ignore
from spacy.language import Language # type: ignore
from spacy.tokens import Doc # type: ignore
from spacy.training import Example # type: ignore
import catalogue
import spacy
import thinc
from spacy.language import Language
from spacy.tokens import Doc
from spacy.training import Example

Augmenter = Callable[[Language, Example], Iterator[Example]]


class registry(thinc.registry):
keyboards = catalogue.create("augmenty", "keyboards", entry_points=True)


def docs(
docs: Iterable[Doc], # type: ignore
augmenter: Callable[[Language, Example], Iterator[Example]], # type: ignore
docs: Iterable[Doc],
augmenter: Augmenter,
nlp: Language,
) -> Iterator[Doc]: # type: ignore
) -> Iterator[Doc]:
"""Augments an iterable of spaCy Doc.
Args:
docs (Iterable[Doc]): A iterable of spaCy Docs
augmenter (Callable[[Language, Example], Iterator[Example]]): An augmenter
nlp (Language): A spaCy language pipeline.
docs: A iterable of spaCy Docs
augmenter: An augmenter
nlp: A spaCy language pipeline.
Return:
Iterator[Doc]: An iterator of the augmented Docs.
An iterator of the augmented Docs.
Yields:
Doc: The augmented Docs.
Expand All @@ -50,22 +52,22 @@ def docs(


def texts(
texts: Iterable[str], # type: ignore
augmenter: Callable[[Language, Example], Iterator[Example]], # type: ignore
texts: Iterable[str],
augmenter: Augmenter,
nlp: Language,
) -> Iterable[str]: # type: ignore
) -> Iterable[str]:
"""Augments an list of texts.
Args:
texts (Iterable[str]): A iterable of strings
augmenter (Callable[[Language, Example], Iterator[Example]]): An augmenter
nlp (Language): A spaCy language pipeline.
texts: A iterable of strings
augmenter: An augmenter
nlp: A spaCy language pipeline.
Return:
Iterator[str]: An iterator of the augmented texts.
An iterator of the augmented texts.
Yields:
str: The augmented text.
The augmented text.
"""
if isinstance(texts, str):
texts = [texts]
Expand All @@ -78,11 +80,11 @@ def __gen() -> Iterable[Doc]: # type: ignore
yield doc.text


def augmenters() -> Dict[str, Callable]: # type: ignore
def augmenters() -> Dict[str, Augmenter]:
"""A utility function to get an overview of all augmenters.
Returns:
Dict[str, Callable]: Dictionary of all augmenters
Dictionary of all augmenters
Example:
>>> augmenters = augmenty.augmenters()
Expand All @@ -92,7 +94,7 @@ def augmenters() -> Dict[str, Callable]: # type: ignore
return spacy.registry.augmenters.get_all() # type: ignore


def load(augmenter: str, **kwargs) -> Callable: # type: ignore
def load(augmenter: str, **kwargs: Any) -> Augmenter:
"""A utility functionload an augmenter.
Returns:
Expand All @@ -114,19 +116,19 @@ def keyboards() -> List[str]: # type: ignore
"""A utility function to get an overview of all keyboards.
Returns:
List[str]]: List of all keyboards
List of all keyboards
Example:
>>> keyboards = augmenty.keyboards()
"""
return list(registry.keyboards.get_all().keys())


def meta() -> Dict[str, dict]: # type: ignore
def meta() -> Dict[str, dict]:
"""Returns a a dictionary containing metadata for each augmenter.
Returns:
Dict[str, dict]: A dictionary of meta data
A dictionary of meta data
Example:
>>> metadata = augmenty.meta()
Expand All @@ -137,7 +139,7 @@ def meta() -> Dict[str, dict]: # type: ignore
import pathlib

p = pathlib.Path(__file__).parent.resolve()
p = os.path.join(p, "meta.json") # type: ignore
with open(p) as f: # type: ignore
p = p / "meta.json"
with p.open() as f:
r = json.load(f)
return r
File renamed without changes.
21 changes: 10 additions & 11 deletions tests/test_all_augmenters.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
"""Pytest script for testing all augmenters in a variety of cases."""

from typing import Callable

from typing import Iterable # noqa

import augmenty
import numpy as np
import pytest
from spacy.language import Language
from spacy.tokens import Token

from .fixtures import (
books_w_annotations,
books_without_annotations,
dane_test,
nlp_da,
nlp_en,
nlp_en_md,
)
from spacy.training import Example

np.seterr(divide="raise", invalid="raise")

Expand Down Expand Up @@ -100,7 +93,13 @@ def is_pronoun(token: Token) -> bool:
),
],
)
def test_augmenters(aug: str, args: dict, examples, nlp: Language, level: float): # noqa # type: ignore
def test_augmenters(
aug: str,
args: dict,
examples: Iterable[Example],
nlp: Language,
level: float,
):
args["level"] = level
augmenter = augmenty.load(aug, **args)
augmented_examples = [e for ex in examples for e in augmenter(nlp=nlp, example=ex)]
Expand Down
11 changes: 5 additions & 6 deletions tests/test_augmentation_utilities.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import augmenty
from spacy.language import Language
from spacy.tokens import Doc, Span

from .fixtures import nlp_en, nlp_en_md


def test_combine(nlp_en_md): # noqa F811
def test_combine(nlp_en_md: Language):
words = ["Augmenty", "is", "a", "wonderful", "tool", "for", "augmentation", "."]
spaces = [True, True, True, True, True, True, False, False]
doc = Doc(nlp_en_md.vocab, words=words, spaces=spaces)
Expand All @@ -25,7 +24,7 @@ def test_combine(nlp_en_md): # noqa F811
assert augmented_docs[0][0].text == "spaCy" # type: ignore


def test_yield_original(nlp_en): # noqa F811
def test_yield_original(nlp_en: Language): # F811
texts = ["Augmenty is a wonderful tool for augmentation."]

aug = augmenty.load("upper_case_v1", level=1)
Expand All @@ -37,7 +36,7 @@ def test_yield_original(nlp_en): # noqa F811
assert len(augmented_docs) == 2


def test_repeat(nlp_en): # noqa F811
def test_repeat(nlp_en: Language): # F811
texts = ["Augmenty is a wonderful tool for augmentation."]

aug = augmenty.load("upper_case_v1", level=1)
Expand All @@ -49,7 +48,7 @@ def test_repeat(nlp_en): # noqa F811
assert len(augmented_docs) == 3


def test_set_doc_level(nlp_en): # noqa F811
def test_set_doc_level(nlp_en: Language): # F811
texts = ["Augmenty is a wonderful tool for augmentation."]

aug = augmenty.load("upper_case_v1", level=1)
Expand Down
15 changes: 7 additions & 8 deletions tests/test_character.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import augmenty
import spacy
from spacy.language import Language
from spacy.tokens import Doc

from .fixtures import nlp_da, nlp_en


def test_create_random_casing_augmenter(nlp_en): # noqa F811
def test_create_random_casing_augmenter(nlp_en: Language):
text = (
"some of the cases here should not be lowercased."
+ " there is naturally a chance that it might not end up that way,"
Expand All @@ -19,7 +18,7 @@ def test_create_random_casing_augmenter(nlp_en): # noqa F811
assert next(docs).text != text # type: ignore


def test_create_char_replace_random_augmenter(nlp_en): # noqa F811
def test_create_char_replace_random_augmenter(nlp_en: Language):
text = "The augmented version of this should not be the same"

aug = spacy.registry.augmenters.get("char_replace_random_v1")(level=1) # type: ignore
Expand All @@ -29,7 +28,7 @@ def test_create_char_replace_random_augmenter(nlp_en): # noqa F811
assert next(docs).text != text # type: ignore


def test_create_char_replace_augmenter(nlp_en): # noqa F811
def test_create_char_replace_augmenter(nlp_en: Language):
aug = spacy.registry.augmenters.get("char_replace_v1")( # type: ignore
level=1,
replace={"b": ["p"], "q": ["a", "b"]},
Expand All @@ -46,7 +45,7 @@ def test_create_char_replace_augmenter(nlp_en): # noqa F811
assert doc[1].text == "w" # type: ignore


def test_create_keystroke_error_augmenter(nlp_da): # noqa F811
def test_create_keystroke_error_augmenter(nlp_da: Language):
text = "q"

aug = spacy.registry.augmenters.get("keystroke_error_v1")( # type: ignore
Expand All @@ -60,15 +59,15 @@ def test_create_keystroke_error_augmenter(nlp_da): # noqa F811
assert aug_doc.text in "12wsa"


def test_create_char_swap_augmenter(nlp_en): # noqa F811
def test_create_char_swap_augmenter(nlp_en: Language):
aug = spacy.registry.augmenters.get("char_swap_v1")(level=1) # type: ignore
doc = nlp_en("qw")
docs = augmenty.docs([doc], augmenter=aug, nlp=nlp_en)
aug_doc: Doc = next(docs) # type: ignore
assert aug_doc.text == "wq"


def test_create_spacing_augmenter(nlp_en): # noqa F811
def test_create_spacing_augmenter(nlp_en: Language):
aug = spacy.registry.augmenters.get("remove_spacing_v1")(level=1) # type: ignore
doc = nlp_en("a sentence.")
docs = augmenty.docs([doc], augmenter=aug, nlp=nlp_en)
Expand Down
Loading

0 comments on commit f4ffc6e

Please sign in to comment.