Skip to content

Commit

Permalink
Output reason from classify_with_score().
Browse files Browse the repository at this point in the history
  • Loading branch information
carschno committed Sep 21, 2023
1 parent fcbe136 commit 0df3c98
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 13 deletions.
13 changes: 10 additions & 3 deletions scripts/classify_text_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@

logging.basicConfig(level=LOG_LEVEL)

REASON_FIELDNAME = "Reason"


class OutputRow(TypedDict):
"""Container class for the rows in the CSV output."""

filename: str
quality_class: int

Expand Down Expand Up @@ -77,7 +81,7 @@ class OutputRow(TypedDict):
parser.add_argument(
"--output-scores",
action="store_true",
help="Output scores and text statistics.",
help="Output scores and text statistics, and reason for classification.",
)
args = parser.parse_args()

Expand Down Expand Up @@ -113,7 +117,7 @@ class OutputRow(TypedDict):

fieldnames = list(OutputRow.__annotations__.keys())
if args.output_scores:
fieldnames += list(ClassifierScores.__annotations__.keys())
fieldnames += list(ClassifierScores.__annotations__.keys()) + [REASON_FIELDNAME]

writer = csv.DictWriter(args.output, fieldnames=fieldnames)
writer.writeheader()
Expand All @@ -122,10 +126,13 @@ class OutputRow(TypedDict):
(text_inputs | pagexml_inputs).items(), desc="Processing", unit="file"
):
if args.output_scores:
quality_class, classifier_scores = pipeline.classify_with_scores(page)
quality_class, classifier_scores, reason = pipeline.classify_with_scores(
page
)
row = (
OutputRow(filename=name, quality_class=quality_class)
| classifier_scores
| {REASON_FIELDNAME: reason.name}
)
else:
row = OutputRow(filename=name, quality_class=pipeline.classify(page))
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ install_requires =
scikit-learn==1.2.1
spylls==0.1.7
tqdm==4.65.0
openpyxl~=3.1.2
scripts =
scripts/classify_text_quality.py

Expand Down
15 changes: 11 additions & 4 deletions tests/classifier/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sklearn
from pagexml.model.physical_document_model import PageXMLScan
from pagexml.model.physical_document_model import PageXMLTextLine
from text_quality.classifier.pipeline import ClassifierScores
from text_quality.classifier.pipeline import ClassifierScores, Reason
from text_quality.classifier.pipeline import Pipeline
from text_quality.classifier.pipeline import default_scores_dict
from text_quality.feature.featurize import Scorers
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_classify(self, pipeline, page, expected):
assert pipeline.classify(page) == expected

@pytest.mark.parametrize(
"text, expected_class, expected_scores",
"text, expected_class, expected_scores, expected_reason",
[
(
"",
Expand All @@ -71,6 +71,7 @@ def test_classify(self, pipeline, page, expected):
n_characters=0,
n_tokens=0,
),
Reason.EMPTY,
),
(
"een Nederlands tekst",
Expand All @@ -84,6 +85,7 @@ def test_classify(self, pipeline, page, expected):
n_characters=20,
n_tokens=3,
),
Reason.CLASSIFIER,
),
(
Page(PageXMLScan(lines=[PageXMLTextLine(text="test")])),
Expand All @@ -97,6 +99,7 @@ def test_classify(self, pipeline, page, expected):
n_characters=4,
n_tokens=0,
),
Reason.SHORT_COLUMNS,
),
(
Page(PageXMLScan(lines=[PageXMLTextLine(text="een Nederlands tekst")])),
Expand All @@ -110,6 +113,7 @@ def test_classify(self, pipeline, page, expected):
n_characters=20,
n_tokens=3,
),
Reason.CLASSIFIER,
),
(
Page(PageXMLScan(lines=[PageXMLTextLine(text="test")] * 10)),
Expand All @@ -123,15 +127,18 @@ def test_classify(self, pipeline, page, expected):
n_characters=49,
n_tokens=0,
),
Reason.SHORT_COLUMNS,
),
],
)
# pylint: disable=too-many-arguments
def test_classify_with_scores(
self, pipeline, text, expected_class, expected_scores
self, pipeline, text, expected_class, expected_scores, expected_reason
):
quality, scores = pipeline.classify_with_scores(text)
quality, scores, reason = pipeline.classify_with_scores(text)
assert quality == expected_class
assert scores == pytest.approx(expected_scores)
assert reason == expected_reason


@pytest.mark.parametrize(
Expand Down
25 changes: 19 additions & 6 deletions text_quality/classifier/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Classification pipeline."""

import logging
from enum import Enum
from enum import auto
from pathlib import Path
from typing import List
from typing import TypedDict
Expand All @@ -24,6 +26,14 @@
"""Container class for the scores returned by the classifier."""


class Reason(Enum):
"""Reasons for the classification result."""

CLASSIFIER = auto()
SHORT_COLUMNS = auto()
EMPTY = auto()


def default_scores_dict(default_value, **fields) -> ClassifierScores:
"""Generate a ClassifierScores dict with default values.
Expand Down Expand Up @@ -93,17 +103,18 @@ def _classify_pagexml(self, pagexml: Page) -> int:

def classify_with_scores(
self, page: Union[Page, str]
) -> tuple[int, ClassifierScores]:
) -> tuple[int, ClassifierScores, Reason]:
"""Single instance classification with scores."""

if isinstance(page, Page):
quality, scores = self._classify_pagexml_with_scores(page)
quality, scores, reason = self._classify_pagexml_with_scores(page)
elif self._is_short(page):
logging.debug(
"Skipping short text: '%s' (%d characters).", page, len(page.strip())
)
quality = EMPTY_PAGE_OUTPUT
scores = default_scores_dict(0, confidence=1.0, n_characters=len(page))
reason = Reason.EMPTY
else:
features, tokens = self._featurizer.featurize(page)
features_df: pd.DataFrame = Featurizer.as_dataframe(features)
Expand All @@ -116,12 +127,13 @@ def classify_with_scores(
n_tokens=len(tokens),
**features,
)
reason = Reason.CLASSIFIER

return quality, scores
return quality, scores, reason

def _classify_pagexml_with_scores(
self, pagexml: Page
) -> tuple[int, ClassifierScores]:
) -> tuple[int, ClassifierScores, Reason]:
"""Classify a Page object with scores."""

if all(len(line) < SHORT_COLUMN_WIDTH for line in pagexml.lines()):
Expand All @@ -131,10 +143,11 @@ def _classify_pagexml_with_scores(
scores = default_scores_dict(
0, confidence=1.0, n_characters=len(pagexml.get_text())
)
reason = Reason.SHORT_COLUMNS
else:
quality, scores = self.classify_with_scores(pagexml.get_text())
quality, scores, reason = self.classify_with_scores(pagexml.get_text())

return quality, scores
return quality, scores, reason

@staticmethod
def _is_short(text: str):
Expand Down

0 comments on commit 0df3c98

Please sign in to comment.