Skip to content

Commit

Permalink
fix: remove doc extension instead of pipe component. TODO double chec…
Browse files Browse the repository at this point in the history
…k all assings are correct
  • Loading branch information
HLasse committed Jan 4, 2023
1 parent fb33e19 commit bc32d47
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
5 changes: 3 additions & 2 deletions src/textdescriptives/components/descriptive_stats.py
Expand Up @@ -154,8 +154,9 @@ def __call__(self, doc):
@Language.factory(
"textdescriptives/descriptive_stats",
assigns=[
"doc._.n_sentences",
"doc._.n_tokens",
"doc._._n_sentences",
"doc._._n_tokens",
"doc._._n_syllables",
"doc._.token_length",
"doc._.sentence_length",
"doc._.syllables",
Expand Down
23 changes: 20 additions & 3 deletions src/textdescriptives/extractors.py
Expand Up @@ -8,7 +8,7 @@
from spacy.tokens import Doc
from wasabi import msg

from textdescriptives.utils import get_valid_metrics
from textdescriptives.utils import get_assigns, get_valid_metrics


def __get_quality(doc: Doc) -> dict:
Expand Down Expand Up @@ -142,7 +142,12 @@ def extract_metrics(
metrics = get_valid_metrics()

# load spacy model if any component requires it
nlp = load_spacy_model(spacy_model, lang, metrics, spacy_model_size)
nlp = load_spacy_model(
spacy_model=spacy_model,
lang=lang,
metrics=metrics,
spacy_model_size=spacy_model_size,
)

# add pipeline components
for component in metrics:
Expand All @@ -152,7 +157,10 @@ def extract_metrics(
text = [text]
docs = nlp.pipe(text)

return extract_df(docs)
df = extract_df(docs)
_clean_doc_extensions(metrics=metrics)

return df


def load_spacy_model(
Expand Down Expand Up @@ -217,3 +225,12 @@ def download_spacy_model(lang: str, size: str) -> str:
return spacy_model
spacy.cli.download(spacy_model)
return spacy_model


def _clean_doc_extensions(metrics: Iterable[str]) -> None:
"""Remove doc extensions added by textdescriptives. This is necesarry to avoid
errors if running `extract_metrics` multiple times with different metrics"""
for metric in metrics:
assigns = get_assigns(metric)
for assigned in assigns:
Doc.remove_extension(assigned)
14 changes: 14 additions & 0 deletions tests/test_extractors.py
Expand Up @@ -140,3 +140,17 @@ def test_extract_model_not_needed():
lang="en",
)
assert "n_tokens" in df.columns


def test_extract_metrics_twice():
text = "Just a small test"
df = td.extract_metrics(
text,
metrics="coherence",
lang="en",
)
df2 = td.extract_metrics(
text,
metrics="descriptive_stats",
lang="en",
)

0 comments on commit bc32d47

Please sign in to comment.