In [1]:
%load_ext autoreload
%autoreload now

# Read Data

In [2]:
TRAINING_SIZE = 0.8
TEST_SIZE = 0.2

MIN_REGION_LENGTH = 20

In [3]:
from document_segmentation.model.dataset import PageDataset
from document_segmentation.settings import GENERALE_MISSIVEN_DOCUMENT_DIR

pages = PageDataset.from_dir(GENERALE_MISSIVEN_DOCUMENT_DIR)

Reading JSON files: 100%|██████████| 909/909 [01:09<00:00, 13.16file/s]


In [4]:
from document_segmentation.model.dataset import RegionDataset

all_regions = RegionDataset.from_page_dataset(pages).remove_empty(MIN_REGION_LENGTH)
len(all_regions)

580439

In [5]:
all_regions._class_counts()

Counter({<Label.IN: 2>: 573225, <Label.END: 3>: 3752, <Label.BEGIN: 1>: 3462})

In [6]:
from document_segmentation.pagexml.datamodel.label import Label

sample_size = all_regions._class_counts()[Label.END]
regions = all_regions.balance(sample_size).shuffle()

len(regions)

10966

In [7]:
regions._class_counts()

Counter({<Label.END: 3>: 3752, <Label.IN: 2>: 3752, <Label.BEGIN: 1>: 3462})

In [8]:
regions.class_weights()

[3.166618538839157, 2.9219291233679723, 2.9219291233679723, 10966.0]

In [9]:
split = int(len(regions) * TRAINING_SIZE)

training_data = regions[:split]
test_data = regions[split:]

In [10]:
training_data._class_counts()

Counter({<Label.END: 3>: 3013, <Label.IN: 2>: 2963, <Label.BEGIN: 1>: 2796})

In [11]:
test_data._class_counts()

Counter({<Label.IN: 2>: 789, <Label.END: 3>: 739, <Label.BEGIN: 1>: 666})

# Train Model

In [19]:
%autoreload now

In [20]:
BATCH_SIZE = 16
EPOCHS = 25

DEVICE = "mps"

In [21]:
from document_segmentation.model.region_classifier import RegionClassifier

model = RegionClassifier(device=DEVICE)
model

RegionClassifier(
  (_region_embedding): RegionEmbeddingSentenceTransformer(
    (_transformer_model): SentenceTransformer(
      (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: RobertaModel 
      (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False})
    )
    (_region_type): Embedding(9, 16)
    (_linear): Linear(in_features=784, out_features=4, bias=True)
  )
  (_softmax): Softmax(dim=1)
)

In [22]:
model.train_(training_data, EPOCHS, BATCH_SIZE, regions.class_weights())

549batch [01:21,  6.76batch/s]                         


[Loss:	2.424]


549batch [00:02, 244.93batch/s]                         


[Loss:	2.318]


549batch [00:02, 251.10batch/s]                         


[Loss:	2.294]


549batch [00:02, 250.70batch/s]                         


[Loss:	2.282]


549batch [00:02, 254.27batch/s]                         


[Loss:	2.276]


549batch [00:02, 237.54batch/s]                         


[Loss:	2.272]


549batch [00:02, 227.33batch/s]                         


[Loss:	2.270]


549batch [00:02, 244.47batch/s]                         


[Loss:	2.268]


549batch [00:02, 240.26batch/s]                         


[Loss:	2.267]


549batch [00:02, 237.84batch/s]                         


[Loss:	2.267]


549batch [00:02, 237.87batch/s]                         


[Loss:	2.266]


549batch [00:02, 238.39batch/s]                         


[Loss:	2.266]


549batch [00:02, 245.90batch/s]                         


[Loss:	2.265]


549batch [00:02, 232.17batch/s]                         


[Loss:	2.265]


549batch [00:02, 242.73batch/s]                         


[Loss:	2.265]


549batch [00:02, 239.89batch/s]                         


[Loss:	2.265]


549batch [00:02, 234.38batch/s]                         


[Loss:	2.265]


549batch [00:02, 244.53batch/s]                         


[Loss:	2.265]


549batch [00:02, 245.03batch/s]                         


[Loss:	2.265]


549batch [00:02, 239.47batch/s]                         


[Loss:	2.264]


549batch [00:02, 240.51batch/s]                         


[Loss:	2.264]


549batch [00:02, 243.34batch/s]                         


[Loss:	2.264]


549batch [00:02, 248.75batch/s]                         


[Loss:	2.264]


549batch [00:02, 239.98batch/s]                         


[Loss:	2.264]


549batch [00:02, 239.48batch/s]                         


[Loss:	2.264]


549batch [00:02, 244.91batch/s]                         


[Loss:	2.264]


549batch [00:02, 244.73batch/s]                         


[Loss:	2.264]


549batch [00:02, 243.92batch/s]                         


[Loss:	2.264]


549batch [00:02, 250.43batch/s]                         


[Loss:	2.264]


549batch [00:02, 246.62batch/s]                         


[Loss:	2.264]


549batch [00:02, 246.03batch/s]                         


[Loss:	2.264]


549batch [00:02, 214.20batch/s]                         


[Loss:	2.264]


549batch [00:02, 235.15batch/s]                         


[Loss:	2.264]


549batch [00:02, 236.65batch/s]                         


[Loss:	2.264]


549batch [00:02, 249.11batch/s]                         


[Loss:	2.264]


549batch [00:02, 247.63batch/s]                         


[Loss:	2.264]


549batch [00:02, 249.84batch/s]                         


[Loss:	2.264]


549batch [00:02, 237.35batch/s]                         


[Loss:	2.264]


549batch [00:02, 239.06batch/s]                         


[Loss:	2.264]


549batch [00:02, 238.84batch/s]                         


[Loss:	2.264]


549batch [00:02, 236.49batch/s]                         


[Loss:	2.264]


549batch [00:02, 249.69batch/s]                         


[Loss:	2.264]


549batch [00:02, 244.64batch/s]                         


[Loss:	2.264]


549batch [00:02, 219.79batch/s]                         


[Loss:	2.264]


549batch [00:02, 228.25batch/s]                         


[Loss:	2.264]


549batch [00:02, 201.53batch/s]                         


[Loss:	2.264]


549batch [00:02, 206.15batch/s]                         


[Loss:	2.264]


549batch [00:02, 243.81batch/s]                         


[Loss:	2.264]


549batch [00:02, 250.38batch/s]                         


[Loss:	2.264]


549batch [00:02, 250.87batch/s]                         

[Loss:	2.264]





# Evaluation

In [23]:
import csv
import random
import sys

import torch
from torcheval.metrics import MulticlassF1Score, MulticlassPrecision, MulticlassRecall

from document_segmentation.pagexml.datamodel.label import Label

writer = csv.DictWriter(
    sys.stdout,
    fieldnames=("Predicted", "Actual", "Lines", "Scores", "Types"),
    delimiter="\t",
)

writer.writeheader()

precision_metric = MulticlassPrecision(average=None, num_classes=len(Label))
recall_metric = MulticlassRecall(average=None, num_classes=len(Label))
f1_metric = MulticlassF1Score(average=None, num_classes=len(Label))

for batch in test_data.batches(BATCH_SIZE):
    predicted = model(batch)
    labels = batch.labels()
    types = [region.types for region in batch.regions()]

    _labels = torch.Tensor([label.value - 1 for label in labels]).to(int)
    precision_metric.update(predicted, _labels)
    recall_metric.update(predicted, _labels)
    f1_metric.update(predicted, _labels)

    # print every nth line:
    if random.random() > 0.9:
        for region, pred, label, types in zip(
            batch.regions(), predicted, labels, types
        ):
            pred_label = Label(pred.argmax().item() + 1)
            writer.writerow(
                {
                    "Predicted": pred_label.name,
                    "Actual": label.name,
                    "Lines": region.lines,
                    "Scores": str(pred.tolist()),
                    "Types": str(types),
                }
            )

Predicted	Actual	Lines	Scores	Types
END	END	('H: Moens.', 'Smith:', 'G', 'Lagber', 'J Viegerman', 'F')	[2.6800581198749285e-10, 6.606832594034662e-13, 1.0, 4.1578657006451563e-17]	(<RegionType.PHYSICAL_STRUCTURE_DOC: 'physical_structure_doc'>, <RegionType.TEXT_REGION: 'text_region'>, <RegionType.PAGEXML_DOC: 'pagexml_doc'>, <RegionType.SIGNATURE_MARK: 'signature-mark'>)
END	IN	('sach gaerne', 'wij daer vaste')	[0.17536962032318115, 0.07071225345134735, 0.7539181113243103, 3.4276377309377937e-12]	(<RegionType.MARGINALIA: 'marginalia'>, <RegionType.PHYSICAL_STRUCTURE_DOC: 'physical_structure_doc'>, <RegionType.TEXT_REGION: 'text_region'>, <RegionType.PAGEXML_DOC: 'pagexml_doc'>)
IN	IN	"(""'t afgaan van ons Eerbiedig schrij„"", '„ven aan VEdele hoog agtb: de dato', ""29:' maart pass:o hebben wij"", 'van dit Comptoir ontfangen en', 'derwaarts gesonden verscheijde', 'brieven waar van geen singu„', '„liere specificatie zullen doen', 'om dat alle deselve in de Indiase', 'afgaande en aankomend

In [24]:
writer = csv.DictWriter(
    sys.stdout, fieldnames=["Metric"] + [label.name for label in Label], delimiter="\t"
)
writer.writeheader()

for metric in (precision_metric, recall_metric, f1_metric):
    scores = {
        label.name: f"{score:.4f}"
        for label, score in zip(Label, metric.compute().tolist())
    }
    writer.writerow({"Metric": metric.__class__.__name__} | scores)



Metric	BEGIN	IN	END	OUT
MulticlassPrecision	0.7827	0.7841	0.8238	0.0000
MulticlassRecall	0.8003	0.8606	0.7212	0.0000
MulticlassF1Score	0.7914	0.8205	0.7691	0.0000
