In [1]:
%load_ext autoreload
%autoreload now

# Read Data

In [2]:
TRAINING_SIZE = 0.8
TEST_SIZE = 0.2

MIN_REGION_LENGTH = 20

In [3]:
from document_segmentation.model.dataset import PageDataset
from document_segmentation.settings import GENERALE_MISSIVEN_DOCUMENT_DIR

pages = PageDataset.from_dir(GENERALE_MISSIVEN_DOCUMENT_DIR)

Reading JSON files: 100%|██████████| 909/909 [01:09<00:00, 13.13file/s]


In [4]:
from document_segmentation.model.dataset import RegionDataset


all_regions = RegionDataset.from_page_dataset(pages).remove_empty(MIN_REGION_LENGTH)
len(all_regions)

580439

In [5]:
all_regions._class_counts()

Counter({<Label.IN: 2>: 573225, <Label.END: 3>: 3752, <Label.BEGIN: 1>: 3462})

In [6]:
from document_segmentation.pagexml.datamodel.label import Label

sample_size = all_regions._class_counts()[Label.END]
regions = all_regions.balance(sample_size).shuffle()

len(regions)

10966

In [7]:
regions._class_counts()

Counter({<Label.IN: 2>: 3752, <Label.END: 3>: 3752, <Label.BEGIN: 1>: 3462})

In [8]:
regions.class_weights()

[3.166618538839157, 2.9219291233679723, 2.9219291233679723, 10966.0]

In [9]:
split = int(len(regions) * TRAINING_SIZE)

training_data = regions[:split]
test_data = regions[split:]

In [10]:
training_data._class_counts()

Counter({<Label.END: 3>: 3026, <Label.IN: 2>: 3013, <Label.BEGIN: 1>: 2733})

In [11]:
test_data._class_counts()

Counter({<Label.IN: 2>: 739, <Label.BEGIN: 1>: 729, <Label.END: 3>: 726})

# Train Model

In [12]:
BATCH_SIZE = 16
EPOCHS = 50

DEVICE = "mps"

In [13]:
from document_segmentation.model.region_classifier import RegionClassifier

model = RegionClassifier(device=DEVICE)
model

RegionClassifier(
  (_transformer_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30500, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12

In [18]:
model.train_(training_data, EPOCHS, BATCH_SIZE, regions.class_weights())

549batch [00:10, 51.22batch/s]                         


[Loss:	2.401]


549batch [00:07, 75.41batch/s]                          


[Loss:	2.388]


549batch [00:05, 94.44batch/s]                          


[Loss:	2.378]


549batch [00:08, 63.01batch/s]                          


[Loss:	2.371]


549batch [00:05, 96.49batch/s]                          


[Loss:	2.363]


549batch [00:03, 148.86batch/s]                         


[Loss:	2.357]


549batch [00:04, 134.42batch/s]                         


[Loss:	2.350]


549batch [00:10, 53.25batch/s]                         


[Loss:	2.344]


549batch [00:08, 63.35batch/s]                          


[Loss:	2.339]


549batch [00:09, 58.38batch/s]                          


[Loss:	2.334]


549batch [00:08, 64.77batch/s]                         


[Loss:	2.330]


549batch [00:08, 62.90batch/s]                         


[Loss:	2.327]


549batch [00:07, 76.46batch/s]                          


[Loss:	2.325]


549batch [00:07, 69.52batch/s]                         


[Loss:	2.323]


549batch [00:08, 62.65batch/s]                         


[Loss:	2.320]


549batch [00:09, 55.06batch/s]                         


[Loss:	2.318]


549batch [00:06, 80.45batch/s]                          


[Loss:	2.316]


549batch [00:03, 158.14batch/s]                         


[Loss:	2.313]


549batch [00:05, 92.03batch/s]                          


[Loss:	2.311]


549batch [00:08, 62.32batch/s]                          

[Loss:	2.309]





# Evaluation

In [25]:
import csv
import random
import sys

import torch
from torcheval.metrics import MulticlassF1Score, MulticlassPrecision, MulticlassRecall

from document_segmentation.pagexml.datamodel.label import Label

writer = csv.DictWriter(
    sys.stdout,
    fieldnames=("Predicted", "Actual", "Lines", "Scores", "Types"),
    delimiter="\t",
)

writer.writeheader()

precision_metric = MulticlassPrecision(average=None, num_classes=len(Label))
recall_metric = MulticlassRecall(average=None, num_classes=len(Label))
f1_metric = MulticlassF1Score(average=None, num_classes=len(Label))

for batch in test_data.batches(BATCH_SIZE):
    predicted = model(batch)
    labels = batch.labels()
    types = [region.types for region in batch.regions()]

    _labels = torch.Tensor([label.value - 1 for label in labels]).to(int)
    precision_metric.update(predicted, _labels)
    recall_metric.update(predicted, _labels)
    f1_metric.update(predicted, _labels)

    # print every nth line:
    if random.random() > 0.9:
        for region, pred, label, types in zip(
            batch.regions(), predicted, labels, types
        ):
            pred_label = Label(pred.argmax().item() + 1)
            writer.writerow(
                {
                    "Predicted": pred_label.name,
                    "Actual": label.name,
                    "Lines": region.lines,
                    "Scores": str(pred.tolist()),
                    "Types": str(types),
                }
            )

Predicted	Actual	Lines	Scores	Types
IN	IN	"(""§ 29 Het versoek van 't Opperhoofd van Este, om"", ""de versogte voordragt van 't op na Expiratie van zijn onderkoopmans verband,"", 'uwe welEd: Hoog Agtb: tot de Qualiteid van', 'koopman te worden voorgedragen, is door', 'ons als strijdende met het reglement op de', 'bevordering der', 'Dienaaren zo wel ontzegd, als zijne teffens', 'gedaane instantie, om middelerwijle met dien', 'titul en rang te worden gebeneficeerd. —')"	[0.14257393777370453, 0.846371591091156, 0.011054456233978271, 1.006368165884508e-12]	(<RegionType.PARAGRAPH: 'paragraph'>, <RegionType.PHYSICAL_STRUCTURE_DOC: 'physical_structure_doc'>, <RegionType.TEXT_REGION: 'text_region'>, <RegionType.PAGEXML_DOC: 'pagexml_doc'>)
IN	IN	('zodanig gebruik gemaakt hebben, als', 'wij na de gesteldheid van tyden en', 'e', 'plaatsen, dienstig hebben geoordeeld.', 'Bij voorschreeven ons secreet schrij„', 'ren van den 2 Febr: ll: sullen uw', 'wel Edele Hoog Achtb: so wij vertrou„', '„wen voo

In [26]:
writer = csv.DictWriter(
    sys.stdout, fieldnames=["Metric"] + [label.name for label in Label], delimiter="\t"
)
writer.writeheader()

for metric in (precision_metric, recall_metric, f1_metric):
    scores = {
        label.name: f"{score:.4f}"
        for label, score in zip(Label, metric.compute().tolist())
    }
    writer.writerow({"Metric": metric.__class__.__name__} | scores)



Metric	BEGIN	IN	END	OUT
MulticlassPrecision	0.7871	0.7390	0.7719	0.0000
MulticlassRecall	0.6845	0.8850	0.7176	0.0000
MulticlassF1Score	0.7322	0.8054	0.7438	0.0000
