In [73]:
%load_ext autoreload
%autoreload now

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Read Data

In [74]:
TRAINING_SIZE = 0.8
TEST_SIZE = 0.2

MIN_REGION_LENGTH = 20

In [75]:
from document_segmentation.model.dataset import PageDataset
from document_segmentation.settings import GENERALE_MISSIVEN_DOCUMENT_DIR

pages = PageDataset.from_dir(GENERALE_MISSIVEN_DOCUMENT_DIR)

Reading JSON files:   0%|          | 0/909 [00:00<?, ?file/s]

Reading JSON files: 100%|██████████| 909/909 [04:11<00:00,  3.61file/s]


In [76]:
from document_segmentation.model.dataset import RegionDataset

all_regions = RegionDataset.from_page_dataset(pages).remove_empty(MIN_REGION_LENGTH)
len(all_regions)

580439

In [77]:
all_regions._class_counts()

Counter({<Label.IN: 2>: 573225, <Label.END: 3>: 3752, <Label.BEGIN: 1>: 3462})

In [78]:
from document_segmentation.pagexml.datamodel.label import Label

sample_size = all_regions._class_counts()[Label.END]
regions = all_regions.balance(sample_size).shuffle()

len(regions)

10966

In [79]:
regions._class_counts()

Counter({<Label.END: 3>: 3752, <Label.IN: 2>: 3752, <Label.BEGIN: 1>: 3462})

In [80]:
regions.class_weights()

[3.166618538839157, 2.9219291233679723, 2.9219291233679723, 10966.0]

In [81]:
split = int(len(regions) * TRAINING_SIZE)

training_data = regions[:split]
test_data = regions[split:]

In [82]:
training_data._class_counts()

Counter({<Label.END: 3>: 3007, <Label.IN: 2>: 2988, <Label.BEGIN: 1>: 2777})

In [83]:
test_data._class_counts()

Counter({<Label.IN: 2>: 764, <Label.END: 3>: 745, <Label.BEGIN: 1>: 685})

# Train Model

In [92]:
%autoreload now

In [93]:
BATCH_SIZE = 16
EPOCHS = 50

DEVICE = "mps"

In [94]:
from document_segmentation.model.region_classifier import (
    RegionClassifierSentenceTransformer,
)

model = RegionClassifierSentenceTransformer(device=DEVICE)
model

RegionClassifierSentenceTransformer(
  (_transformer_model): SentenceTransformer(
    (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: RobertaModel 
    (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False})
  )
  (_region_type): Embedding(9, 16)
  (_linear): Linear(in_features=784, out_features=4, bias=True)
  (_softmax): Softmax(dim=1)
)

In [95]:
model.train_(training_data, EPOCHS, BATCH_SIZE, regions.class_weights())

  0%|          | 0/548.25 [00:00<?, ?batch/s]

549batch [01:21,  6.76batch/s]                         


[Loss:	2.417]


549batch [00:02, 235.31batch/s]                         


[Loss:	2.313]


549batch [00:02, 238.58batch/s]                         


[Loss:	2.290]


549batch [00:02, 218.69batch/s]                         


[Loss:	2.281]


549batch [00:02, 235.20batch/s]                         


[Loss:	2.277]


549batch [00:02, 219.35batch/s]                         


[Loss:	2.274]


549batch [00:02, 205.61batch/s]                         


[Loss:	2.272]


549batch [00:02, 241.21batch/s]                         


[Loss:	2.271]


549batch [00:02, 240.18batch/s]                         


[Loss:	2.270]


549batch [00:02, 236.14batch/s]                         


[Loss:	2.269]


549batch [00:02, 244.59batch/s]                         


[Loss:	2.268]


549batch [00:02, 239.52batch/s]                         


[Loss:	2.268]


549batch [00:02, 241.77batch/s]                         


[Loss:	2.267]


549batch [00:02, 240.86batch/s]                         


[Loss:	2.267]


549batch [00:02, 240.42batch/s]                         


[Loss:	2.266]


549batch [00:02, 229.30batch/s]                         


[Loss:	2.266]


549batch [00:02, 236.73batch/s]                         


[Loss:	2.266]


549batch [00:02, 235.54batch/s]                         


[Loss:	2.266]


549batch [00:02, 241.98batch/s]                         


[Loss:	2.266]


549batch [00:02, 237.24batch/s]                         


[Loss:	2.265]


549batch [00:02, 246.51batch/s]                         


[Loss:	2.265]


549batch [00:02, 245.46batch/s]                         


[Loss:	2.265]


549batch [00:02, 238.43batch/s]                         


[Loss:	2.265]


549batch [00:02, 236.04batch/s]                         


[Loss:	2.265]


549batch [00:02, 227.92batch/s]                         


[Loss:	2.265]


549batch [00:02, 237.54batch/s]                         


[Loss:	2.265]


549batch [00:02, 235.38batch/s]                         


[Loss:	2.265]


549batch [00:02, 236.96batch/s]                         


[Loss:	2.265]


549batch [00:02, 239.68batch/s]                         


[Loss:	2.265]


549batch [00:02, 239.01batch/s]                         


[Loss:	2.265]


549batch [00:02, 240.98batch/s]                         


[Loss:	2.265]


549batch [00:02, 242.35batch/s]                         


[Loss:	2.265]


549batch [00:02, 243.50batch/s]                         


[Loss:	2.265]


549batch [00:02, 246.23batch/s]                         


[Loss:	2.265]


549batch [00:02, 246.59batch/s]                         


[Loss:	2.265]


549batch [00:02, 243.56batch/s]                         


[Loss:	2.265]


549batch [00:02, 250.54batch/s]                         


[Loss:	2.265]


549batch [00:02, 245.53batch/s]                         


[Loss:	2.265]


549batch [00:02, 246.20batch/s]                         


[Loss:	2.265]


549batch [00:02, 245.64batch/s]                         


[Loss:	2.265]


549batch [00:02, 246.67batch/s]                         


[Loss:	2.265]


549batch [00:02, 229.39batch/s]                         


[Loss:	2.265]


549batch [00:02, 242.07batch/s]                         


[Loss:	2.264]


549batch [00:02, 245.64batch/s]                         


[Loss:	2.264]


549batch [00:02, 242.68batch/s]                         


[Loss:	2.264]


549batch [00:02, 244.54batch/s]                         


[Loss:	2.264]


549batch [00:02, 247.15batch/s]                         


[Loss:	2.264]


549batch [00:02, 245.91batch/s]                         


[Loss:	2.264]


549batch [00:02, 246.24batch/s]                         


[Loss:	2.264]


549batch [00:02, 246.30batch/s]                         

[Loss:	2.264]





# Evaluation

In [96]:
import csv
import random
import sys

import torch
from torcheval.metrics import MulticlassF1Score, MulticlassPrecision, MulticlassRecall

from document_segmentation.pagexml.datamodel.label import Label

writer = csv.DictWriter(
    sys.stdout,
    fieldnames=("Predicted", "Actual", "Lines", "Scores", "Types"),
    delimiter="\t",
)

writer.writeheader()

precision_metric = MulticlassPrecision(average=None, num_classes=len(Label))
recall_metric = MulticlassRecall(average=None, num_classes=len(Label))
f1_metric = MulticlassF1Score(average=None, num_classes=len(Label))

for batch in test_data.batches(BATCH_SIZE):
    predicted = model(batch)
    labels = batch.labels()
    types = [region.types for region in batch.regions()]

    _labels = torch.Tensor([label.value - 1 for label in labels]).to(int)
    precision_metric.update(predicted, _labels)
    recall_metric.update(predicted, _labels)
    f1_metric.update(predicted, _labels)

    # print every nth line:
    if random.random() > 0.9:
        for region, pred, label, types in zip(
            batch.regions(), predicted, labels, types
        ):
            pred_label = Label(pred.argmax().item() + 1)
            writer.writerow(
                {
                    "Predicted": pred_label.name,
                    "Actual": label.name,
                    "Lines": region.lines,
                    "Scores": str(pred.tolist()),
                    "Types": str(types),
                }
            )

Predicted	Actual	Lines	Scores	Types
BEGIN	BEGIN	('Ed: Erntfeste. Achtbare. Wijse:', 'voorsienige seer discrete heeren,')	[0.8025423288345337, 5.298455107549671e-06, 0.19745247066020966, 5.164666472268209e-10]	(<RegionType.PARAGRAPH: 'paragraph'>, <RegionType.PHYSICAL_STRUCTURE_DOC: 'physical_structure_doc'>, <RegionType.TEXT_REGION: 'text_region'>, <RegionType.PAGEXML_DOC: 'pagexml_doc'>)
IN	IN	"('verdrukte goed regt te laaten weedervaaren, dog seer', 'oneijgen oordeelende, dat alle die ingehoudene', 'maandgelden ten laste van de Comp:e soude komen;', ""is zulx meede g'inpro, over zulx ook g'arresteerd, gem: Gouvern:r Abele¬"", 'hebbende gagie ensz. „ ven insgelijks op te leggen, niet alleen, aan ged:te', 'te', 'Elin te vergoeden, zijne stil gestaan hebbende Gagie„', 'kostgelden en Emolumenten, gereekent van den tijd', 'sijner afsettinge, tot desselfs overkomste alhier, die', 'men teffens heeft gereserveerd gelaaten, zijne Actie we„', '„gens geleede Injurien, smaat en schaade, maar ook

In [97]:
writer = csv.DictWriter(
    sys.stdout, fieldnames=["Metric"] + [label.name for label in Label], delimiter="\t"
)
writer.writeheader()

for metric in (precision_metric, recall_metric, f1_metric):
    scores = {
        label.name: f"{score:.4f}"
        for label, score in zip(Label, metric.compute().tolist())
    }
    writer.writerow({"Metric": metric.__class__.__name__} | scores)



Metric	BEGIN	IN	END	OUT
MulticlassPrecision	0.7842	0.7793	0.8668	0.0000
MulticlassRecall	0.8117	0.8783	0.7248	0.0000
MulticlassF1Score	0.7977	0.8258	0.7895	0.0000
