In [1]:
%load_ext autoreload
%autoreload now

# Read Data

In [2]:
TRAINING_SIZE = 0.8
TEST_SIZE = 0.2

In [3]:
from document_segmentation.model.dataset import PageDataset
from document_segmentation.settings import GENERALE_MISSIVEN_DOCUMENT_DIR

pages = PageDataset.from_dir(GENERALE_MISSIVEN_DOCUMENT_DIR)

Reading JSON files: 100%|██████████| 909/909 [01:49<00:00,  8.29file/s]


In [120]:
from document_segmentation.model.dataset import RegionDataset
from document_segmentation.settings import MIN_REGION_TEXT_LENGTH

regions = (
    RegionDataset.from_page_dataset(pages)
    .remove_empty(MIN_REGION_TEXT_LENGTH)
    .shuffle()
)
len(regions)

580439

In [121]:
regions._class_counts()

Counter({<Label.IN: 2>: 573225, <Label.END: 3>: 3752, <Label.BEGIN: 1>: 3462})

In [122]:
regions.class_weights()

[167.6116084319954, 1.0125831696398977, 154.6600053290701, 580439.0]

In [123]:
split = int(len(regions) * TRAINING_SIZE)

training_data = regions[:split]
test_data = regions[split:]

In [124]:
training_data._class_counts()

Counter({<Label.IN: 2>: 458584, <Label.END: 3>: 3023, <Label.BEGIN: 1>: 2744})

In [125]:
test_data._class_counts()

Counter({<Label.IN: 2>: 114641, <Label.END: 3>: 729, <Label.BEGIN: 1>: 718})

# Train Model

In [126]:
%autoreload now

In [168]:
BATCH_SIZE = 256
EPOCHS = 100

DEVICE = "mps"

In [169]:
from document_segmentation.model.region_classifier import RegionClassifier

model = RegionClassifier(device=DEVICE)
model

RegionClassifier(
  (_region_embedding): RegionEmbeddingSentenceTransformer(
    (_transformer_model): SentenceTransformer(
      (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: RobertaModel 
      (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False})
    )
    (_region_type): Embedding(9, 16)
    (_linear): Linear(in_features=784, out_features=4, bias=True)
  )
  (_softmax): Softmax(dim=1)
)

In [170]:
from document_segmentation.pagexml.datamodel.label import Label

model.train_(
    training_data.balance(training_data._class_counts()[Label.END]).shuffle(),
    EPOCHS,
    BATCH_SIZE,
    # regions.class_weights(),
)

35batch [00:44,  1.27s/batch]                             


[Loss:	3.269]


35batch [00:00, 80.81batch/s]                             


[Loss:	3.114]


35batch [00:00, 89.39batch/s]                             


[Loss:	3.019]


35batch [00:00, 83.06batch/s]                             


[Loss:	2.956]


35batch [00:00, 92.36batch/s]                             


[Loss:	2.912]


35batch [00:00, 91.76batch/s]                             


[Loss:	2.880]


35batch [00:00, 92.35batch/s]                             


[Loss:	2.855]


35batch [00:00, 91.78batch/s]                             


[Loss:	2.836]


35batch [00:00, 93.24batch/s]                             


[Loss:	2.820]


35batch [00:00, 93.20batch/s]                             


[Loss:	2.806]


35batch [00:00, 90.33batch/s]                             


[Loss:	2.794]


35batch [00:00, 83.43batch/s]                             


[Loss:	2.783]


35batch [00:00, 92.03batch/s]                             


[Loss:	2.772]


35batch [00:00, 93.29batch/s]                             


[Loss:	2.763]


35batch [00:00, 91.85batch/s]                             


[Loss:	2.754]


35batch [00:00, 92.83batch/s]                             


[Loss:	2.746]


35batch [00:00, 89.22batch/s]                             


[Loss:	2.739]


35batch [00:00, 89.28batch/s]                             


[Loss:	2.732]


35batch [00:00, 86.61batch/s]                             


[Loss:	2.725]


35batch [00:00, 92.92batch/s]                             


[Loss:	2.718]


35batch [00:00, 93.35batch/s]                             


[Loss:	2.712]


35batch [00:00, 92.65batch/s]                             


[Loss:	2.705]


35batch [00:00, 93.55batch/s]                             


[Loss:	2.700]


35batch [00:00, 92.86batch/s]                             


[Loss:	2.695]


35batch [00:00, 92.25batch/s]                             


[Loss:	2.691]


35batch [00:00, 91.91batch/s]                             


[Loss:	2.688]


35batch [00:00, 85.20batch/s]                             


[Loss:	2.685]


35batch [00:00, 93.84batch/s]                             


[Loss:	2.682]


35batch [00:00, 91.82batch/s]                             


[Loss:	2.680]


35batch [00:00, 92.69batch/s]                             


[Loss:	2.678]


35batch [00:00, 92.99batch/s]                             


[Loss:	2.676]


35batch [00:00, 93.15batch/s]                             


[Loss:	2.674]


35batch [00:00, 92.59batch/s]                             


[Loss:	2.672]


35batch [00:00, 90.88batch/s]                             


[Loss:	2.671]


35batch [00:00, 87.39batch/s]                             


[Loss:	2.669]


35batch [00:00, 92.87batch/s]                             


[Loss:	2.668]


35batch [00:00, 91.58batch/s]                             


[Loss:	2.666]


35batch [00:00, 91.65batch/s]                             


[Loss:	2.665]


35batch [00:00, 92.00batch/s]                             


[Loss:	2.664]


35batch [00:00, 93.32batch/s]                             


[Loss:	2.663]


35batch [00:00, 93.47batch/s]                             


[Loss:	2.662]


35batch [00:00, 92.93batch/s]                             


[Loss:	2.660]


35batch [00:00, 88.32batch/s]                             


[Loss:	2.659]


35batch [00:00, 92.08batch/s]                             


[Loss:	2.658]


35batch [00:00, 92.19batch/s]                             


[Loss:	2.657]


35batch [00:00, 92.91batch/s]                             


[Loss:	2.656]


35batch [00:00, 91.81batch/s]                             


[Loss:	2.655]


35batch [00:00, 92.24batch/s]                             


[Loss:	2.654]


35batch [00:00, 91.82batch/s]                             


[Loss:	2.653]


35batch [00:00, 92.55batch/s]                             


[Loss:	2.652]


35batch [00:00, 86.78batch/s]                             


[Loss:	2.651]


35batch [00:00, 92.32batch/s]                             


[Loss:	2.650]


35batch [00:00, 92.76batch/s]                             


[Loss:	2.648]


35batch [00:00, 92.75batch/s]                             


[Loss:	2.647]


35batch [00:00, 92.91batch/s]                             


[Loss:	2.646]


35batch [00:00, 92.27batch/s]                             


[Loss:	2.645]


35batch [00:00, 92.80batch/s]                             


[Loss:	2.644]


35batch [00:00, 93.14batch/s]                             


[Loss:	2.643]


35batch [00:00, 86.85batch/s]                             


[Loss:	2.642]


35batch [00:00, 92.73batch/s]                             


[Loss:	2.641]


35batch [00:00, 91.80batch/s]                             


[Loss:	2.640]


35batch [00:00, 92.96batch/s]                             


[Loss:	2.639]


35batch [00:00, 93.76batch/s]                             


[Loss:	2.638]


35batch [00:00, 93.55batch/s]                             


[Loss:	2.637]


35batch [00:00, 92.19batch/s]                             


[Loss:	2.635]


35batch [00:00, 92.06batch/s]                             


[Loss:	2.634]


35batch [00:00, 87.86batch/s]                             


[Loss:	2.633]


35batch [00:00, 93.08batch/s]                             


[Loss:	2.632]


35batch [00:00, 93.21batch/s]                             


[Loss:	2.631]


35batch [00:00, 93.00batch/s]                             


[Loss:	2.630]


35batch [00:00, 93.22batch/s]                             


[Loss:	2.629]


35batch [00:00, 93.00batch/s]                             


[Loss:	2.628]


35batch [00:00, 92.38batch/s]                             


[Loss:	2.627]


35batch [00:00, 93.34batch/s]                             


[Loss:	2.626]


35batch [00:00, 87.30batch/s]                             


[Loss:	2.625]


35batch [00:00, 92.44batch/s]                             


[Loss:	2.624]


35batch [00:00, 91.47batch/s]                             


[Loss:	2.623]


35batch [00:00, 92.77batch/s]                             


[Loss:	2.623]


35batch [00:00, 93.57batch/s]                             


[Loss:	2.622]


35batch [00:00, 93.10batch/s]                             


[Loss:	2.621]


35batch [00:00, 92.60batch/s]                             


[Loss:	2.620]


35batch [00:00, 91.78batch/s]                             


[Loss:	2.619]


35batch [00:00, 88.33batch/s]                             


[Loss:	2.619]


35batch [00:00, 93.56batch/s]                             


[Loss:	2.618]


35batch [00:00, 93.16batch/s]                             


[Loss:	2.617]


35batch [00:00, 93.03batch/s]                             


[Loss:	2.617]


35batch [00:00, 94.45batch/s]                             


[Loss:	2.616]


35batch [00:00, 93.54batch/s]                             


[Loss:	2.615]


35batch [00:00, 92.86batch/s]                             


[Loss:	2.615]


35batch [00:00, 94.08batch/s]                             


[Loss:	2.614]


35batch [00:00, 88.20batch/s]                             


[Loss:	2.614]


35batch [00:00, 92.53batch/s]                             


[Loss:	2.613]


35batch [00:00, 90.57batch/s]                             


[Loss:	2.613]


35batch [00:00, 93.24batch/s]                             


[Loss:	2.612]


35batch [00:00, 92.39batch/s]                             


[Loss:	2.612]


35batch [00:00, 93.18batch/s]                             


[Loss:	2.611]


35batch [00:00, 93.23batch/s]                             


[Loss:	2.610]


35batch [00:00, 91.51batch/s]                             


[Loss:	2.610]


35batch [00:00, 87.59batch/s]                             


[Loss:	2.609]


35batch [00:00, 93.30batch/s]                             

[Loss:	2.609]





# Evaluation

In [171]:
import csv
import sys

import torch
from torcheval.metrics import (
    MulticlassAccuracy,
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)
from tqdm import tqdm

from document_segmentation.pagexml.datamodel.label import Label

writer = csv.DictWriter(
    sys.stdout,
    fieldnames=("Predicted", "Actual", "Lines", "Scores", "Types"),
    delimiter="\t",
)

writer.writeheader()

accuracy = MulticlassAccuracy(num_classes=len(Label))
precision = MulticlassPrecision(average=None, num_classes=len(Label))
recall = MulticlassRecall(average=None, num_classes=len(Label))
f1_score = MulticlassF1Score(average=None, num_classes=len(Label))

for batch in tqdm(
    test_data[:1000].batches(BATCH_SIZE),
    total=len(test_data) / BATCH_SIZE,
    unit="batch",
):
    predicted = model(batch)
    labels = batch.labels()
    types = [region.types for region in batch.regions()]

    _labels = torch.Tensor([label.value - 1 for label in labels]).to(int)
    accuracy.update(predicted, _labels)
    precision.update(predicted, _labels)
    recall.update(predicted, _labels)
    f1_score.update(predicted, _labels)

    for region, pred, label, types in zip(batch.regions(), predicted, labels, types):
        pred_label = Label(pred.argmax().item() + 1)
        if pred_label != Label.IN or label != Label.IN:
            writer.writerow(
                {
                    "Predicted": pred_label.name,
                    "Actual": label.name,
                    "Lines": region.lines,
                    "Scores": str(pred.tolist()),
                    "Types": str(types),
                }
            )

Predicted	Actual	Lines	Scores	Types


  0%|          | 1/453.46875 [00:01<12:19,  1.64s/batch]

BEGIN	IN	('de sagoe boschens zijn',)	[0.6851197481155396, 0.3146204948425293, 4.4579286623047665e-05, 0.00021517790446523577]	(<RegionType.MARGINALIA: 'marginalia'>, <RegionType.PHYSICAL_STRUCTURE_DOC: 'physical_structure_doc'>, <RegionType.TEXT_REGION: 'text_region'>, <RegionType.PAGEXML_DOC: 'pagexml_doc'>)
END	IN	('Van sijn Edelheijd', 'den hoog Edelen heere', 'henric Zwaerdecroon', 'gouverneur Generaal,', 'en de verdere Edele heeren', 'Raeden van Nederlands', 'India.')	[0.12724101543426514, 2.1550866222241893e-05, 0.8727362155914307, 1.1965151998083456e-06]	(<RegionType.PARAGRAPH: 'paragraph'>, <RegionType.PHYSICAL_STRUCTURE_DOC: 'physical_structure_doc'>, <RegionType.TEXT_REGION: 'text_region'>, <RegionType.PAGEXML_DOC: 'pagexml_doc'>)
BEGIN	IN	"('Het geene gezegt werd uijt de Pa¬', 'poese Eylanden tot Zoeloe te', 'vallen, daer af hebben wij uijt', 'Ternaten UEd. e Hoog Agtbare', 'een half Aem van toegezonden', 'gehad benevens een Salf Pot', ""van't Súijverste afgeklopt van"", 'zi

  0%|          | 2/453.46875 [00:03<11:26,  1.52s/batch]

END	IN	"('Batavia 3„en 6„en 10„en en 24„en october', 'bij onsen brieff van den 13 Ianuarij', 'deses Jaers hebben wij VEd hoog agtb', 'bekent gemaekt ons gevoelen over', 'de groote oegsten ende het bederff', 'der Nagulen, Noten, en foelij', 'mitsgaders dat wij genoodsaekt', 'zoude sijn retreeden tot vernie„', 'tiging van de bedorvenagulen', 'op VEd: hoog agtb: vorige ordres', 'waer op dan ook bij onse recolutien', 'van den 28:en feb:, 4„en april passado', 'hebben geresolveert 611049 lb: op', 'de vorige wijse te vernietigen, gelyk', 'bij ons besluijt van den 2„en maij', ""daer aen g'insereert staat de"")"	[0.09277204424142838, 0.013756833970546722, 0.893462598323822, 8.451814210275188e-06]	(<RegionType.PARAGRAPH: 'paragraph'>, <RegionType.PHYSICAL_STRUCTURE_DOC: 'physical_structure_doc'>, <RegionType.TEXT_REGION: 'text_region'>, <RegionType.PAGEXML_DOC: 'pagexml_doc'>)
BEGIN	IN	('item dat nogh 9 schepen in', 'weijnigh maenden stonden', 'te volgen')	[0.9162551164627075, 0.0836177095770835

  1%|          | 3/453.46875 [00:04<11:02,  1.47s/batch]

BEGIN	IN	('zijnde te samen 12. schepen', 'voorde eerste besendinge', 'deses zaijsoens')	[0.9968668818473816, 0.0031296685338020325, 3.3960714063141495e-06, 6.860248813467251e-09]	(<RegionType.MARGINALIA: 'marginalia'>, <RegionType.PHYSICAL_STRUCTURE_DOC: 'physical_structure_doc'>, <RegionType.TEXT_REGION: 'text_region'>, <RegionType.PAGEXML_DOC: 'pagexml_doc'>)
END	IN	('o 1696 1697 8: 69 5 696. dit jaar',)	[0.15360620617866516, 0.08114391565322876, 0.7651458978652954, 0.00010400544124422595]	(<RegionType.PARAGRAPH: 'paragraph'>, <RegionType.PHYSICAL_STRUCTURE_DOC: 'physical_structure_doc'>, <RegionType.TEXT_REGION: 'text_region'>, <RegionType.PAGEXML_DOC: 'pagexml_doc'>)
END	IN	('staats personen in', 'golconda.')	[0.025345075875520706, 0.014373749494552612, 0.9595934748649597, 0.0006877743289805949]	(<RegionType.MARGINALIA: 'marginalia'>, <RegionType.PHYSICAL_STRUCTURE_DOC: 'physical_structure_doc'>, <RegionType.TEXT_REGION: 'text_region'>, <RegionType.PAGEXML_DOC: 'pagexml_doc'>)
END	

  1%|          | 4/453.46875 [00:06<12:43,  1.70s/batch]

BEGIN	IN	('Omtrent de door de Ministers, bij de beantwoording', 'het douceur voor den Derecteur van het', 'Extract uit uwel Edele Hoog Achtb: missive van den 9. No„', '„vember 1789 „ geinsereerd bij opgemelde gemeene briev va„', 'den 15: December 1790. gemaakte aanmerkingen, op het')	[0.9753358364105225, 0.02357635460793972, 0.0010852472623810172, 2.5172125788230915e-06]	(<RegionType.PARAGRAPH: 'paragraph'>, <RegionType.PHYSICAL_STRUCTURE_DOC: 'physical_structure_doc'>, <RegionType.TEXT_REGION: 'text_region'>, <RegionType.PAGEXML_DOC: 'pagexml_doc'>)
BEGIN	IN	('Ordre voor den corporael Hendrick',)	[0.5170659422874451, 0.4827541410923004, 0.00017452039173804224, 5.438733296614373e-06]	(<RegionType.PHYSICAL_STRUCTURE_DOC: 'physical_structure_doc'>, <RegionType.TEXT_REGION: 'text_region'>, <RegionType.PAGEXML_DOC: 'pagexml_doc'>, <RegionType.HEADER: 'header'>)
BEGIN	IN	('Uit des ag ons toegesonden staat- Reekening, getrok„', '„ken uit de afgeslootene', 'Negotie- Boeken ne 17 9/0, ons gebl




In [179]:
writer = csv.DictWriter(
    sys.stdout,
    fieldnames=["Metric", "Average"] + [label.name for label in Label],
    delimiter="\t",
)
writer.writeheader()

for metric in (precision, recall, f1_score):
    scores = {
        label.name: f"{score:.4f}"
        for label, score in zip(Label, metric.compute().tolist())
    }
    writer.writerow(
        {"Metric": metric.__class__.__name__, "Average": str(metric.average)} | scores
    )

print(f"Accuracy ({accuracy.average} average):\t{accuracy.compute().item():.4f}")



Metric	Average	BEGIN	IN	END	OUT
MulticlassPrecision	None	0.0430	0.9977	0.0625	0.0000
MulticlassRecall	None	0.8000	0.8648	0.7500	0.0000
MulticlassF1Score	None	0.0816	0.9265	0.1154	0.0000
Accuracy (micro average):	0.8640
