In [1]:
%load_ext autoreload
%autoreload now

# Download and convert data

In [2]:
from tqdm import tqdm
from document_segmentation.pagexml.generale_missiven import GeneraleMissiven
from document_segmentation.settings import (
    GENERALE_MISSIVEN_DOCUMENT_DIR,
    GENERALE_MISSIVEN_SHEET,
)

N = None

GENERALE_MISSIVEN_DOCUMENT_DIR.mkdir(parents=True, exist_ok=True)

sheet = GeneraleMissiven(GENERALE_MISSIVEN_SHEET)

existing_docs = {
    path.stem
    for path in GENERALE_MISSIVEN_DOCUMENT_DIR.glob("*.json")
    if path.is_file()
}

for document in tqdm(
    sheet.to_documents(n=N, skip_ids=existing_docs),
    total=(N or len(sheet)) - len(existing_docs),
    desc="Writing documents",
    unit="doc",
):
    document_file = GENERALE_MISSIVEN_DOCUMENT_DIR / f"{document.id}.json"

    with document_file.open("xt") as f:
        f.write(document.model_dump_json())
        f.write("\n")

Writing documents:   0%|          | 0/5 [00:00<?, ?doc/s]

Skipping row with inventory number 1171 due to status message: 'Niet gedigitaliseerd.'
Skipping row with inventory number 2770 due to status message: 'Niet gedigitaliseerd.'
Skipping row with inventory number 2770 due to status message: 'Niet gedigitaliseerd.'
Skipping row with inventory number 2770 due to status message: 'Niet gedigitaliseerd.'
Skipping row with inventory number 2911 due to status message: 'Niet gedigitaliseerd.'





# Load Data from Disk

In [3]:
from document_segmentation.model.dataset import PageDataset

dataset = PageDataset.from_dir(GENERALE_MISSIVEN_DOCUMENT_DIR)
len(dataset)

Reading JSON files: 100%|██████████| 909/909 [01:08<00:00, 13.37file/s]


191146

In [4]:
dataset[5000]

Page(label=<Label.IN: 2>, regions=[Region(id='region_c62b09b5-3b73-455f-bb44-2c07ece8fe82_3', types=(<RegionType.PHYSICAL_STRUCTURE_DOC: 'physical_structure_doc'>, <RegionType.TEXT_REGION: 'text_region'>, <RegionType.PAGEXML_DOC: 'pagexml_doc'>, <RegionType.PAGE_NUMBER: 'page-number'>), coordinates=((66, 671), (63, 674), (66, 677), (70, 677), (73, 674), (70, 671)), lines=()), Region(id='region_72e9d1bd-256c-4b08-a65a-bafa26c4d572_4', types=(<RegionType.PHYSICAL_STRUCTURE_DOC: 'physical_structure_doc'>, <RegionType.TEXT_REGION: 'text_region'>, <RegionType.PAGEXML_DOC: 'pagexml_doc'>, <RegionType.PAGE_NUMBER: 'page-number'>), coordinates=((2550, 244), (2544, 237), (2544, 234), (2541, 234), (2534, 228), (2531, 228), (2528, 225), (2493, 225), (2490, 228), (2477, 228), (2474, 231), (2462, 231), (2458, 234), (2455, 231), (2389, 231), (2386, 234), (2364, 234), (2360, 237), (2357, 237), (2338, 256), (2338, 259), (2335, 263), (2335, 266), (2332, 269), (2332, 272), (2329, 275), (2329, 285), (232

In [19]:
training_dataset = dataset[:1000]

In [6]:
test_dataset = dataset[1000:1100]

# Train Model

In [7]:
import logging

logging.basicConfig(level=logging.INFO)

In [8]:
from document_segmentation.model.page_sequence_tagger import PageSequenceTagger

tagger = PageSequenceTagger(device="mps")

  return self.fget.__get__(instance, owner)()
INFO:root:Using device: mps
INFO:root:Moving module 'RegionEmbedding._linear' to device 'mps'
INFO:root:Moving module 'RegionEmbedding._region_embedding' to device 'mps'
INFO:root:Moving module 'RegionEmbedding._transformer_model' to device 'mps'
INFO:root:Using device: mps
INFO:root:Moving module 'PageEmbedding._linear' to device 'mps'
INFO:root:Using device: mps
INFO:root:Moving module 'RegionEmbedding._linear' to device 'mps'
INFO:root:Moving module 'RegionEmbedding._region_embedding' to device 'mps'
INFO:root:Moving module 'RegionEmbedding._transformer_model' to device 'mps'
INFO:root:Moving sub-modules of 'PageEmbedding' to device 'mps'
INFO:root:Moving module 'PageEmbedding._region_model' to device 'mps'
INFO:root:Moving module 'PageEmbedding._rnn' to device 'mps'
INFO:root:Moving module 'PageEmbedding.rnn' to device 'mps'
INFO:root:Using device: mps
INFO:root:Moving module 'PageSequenceTagger._linear' to device 'mps'
INFO:root:Using 

In [9]:
tagger._device

'mps'

In [10]:
tagger

PageSequenceTagger(
  (_page_embedding): PageEmbedding(
    (_region_model): RegionEmbedding(
      (_transformer_model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30500, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-11): 12 x BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (out

In [21]:
from tqdm.autonotebook import tqdm

tagger.train_(training_dataset, epochs=3, weights=dataset.class_weights())

  0%|          | 0/3 [00:00<?, ?epoch/s]

# Evaluate Model

In [12]:
tagger.precision(test_dataset)

MulticlassPrecision:  32%|███▏      | 1/3.125 [00:08<00:17,  8.37s/batch]

[MulticlassPrecision: {'BEGIN': 0.0, 'IN': 1.0, 'END': 0.0}]


MulticlassPrecision:  64%|██████▍   | 2/3.125 [00:16<00:09,  8.04s/batch]

[MulticlassPrecision: {'BEGIN': 0.0, 'IN': 1.0, 'END': 0.0}]


MulticlassPrecision:  96%|█████████▌| 3/3.125 [00:26<00:01,  9.04s/batch]

[MulticlassPrecision: {'BEGIN': 0.0, 'IN': 1.0, 'END': 0.0}]


MulticlassPrecision: 4batch [00:27,  6.90s/batch]                        


[MulticlassPrecision: {'BEGIN': 0.0, 'IN': 1.0, 'END': 0.0}]


{'BEGIN': 0.0, 'IN': 1.0, 'END': 0.0}

In [13]:
tagger.recall(test_dataset)

MulticlassRecall: 4batch [00:00, 36.67batch/s]                        


[MulticlassRecall: {'BEGIN': 0.0, 'IN': 0.9375, 'END': 0.0}]
[MulticlassRecall: {'BEGIN': 0.0, 'IN': 0.875, 'END': 0.0}]
[MulticlassRecall: {'BEGIN': 0.0, 'IN': 0.8645833134651184, 'END': 0.0}]
[MulticlassRecall: {'BEGIN': 0.0, 'IN': 0.8600000143051147, 'END': 0.0}]


{'BEGIN': 0.0, 'IN': 0.8600000143051147, 'END': 0.0}

In [14]:
tagger.f1_score(test_dataset)

MulticlassF1Score:  64%|██████▍   | 2/3.125 [00:00<00:00, 17.05batch/s]

[MulticlassF1Score: {'BEGIN': 0.0, 'IN': 0.9677419066429138, 'END': 0.0}]
[MulticlassF1Score: {'BEGIN': 0.0, 'IN': 0.9333333373069763, 'END': 0.0}]


MulticlassF1Score: 4batch [00:00, 25.71batch/s]                        


[MulticlassF1Score: {'BEGIN': 0.0, 'IN': 0.9273743033409119, 'END': 0.0}]
[MulticlassF1Score: {'BEGIN': 0.0, 'IN': 0.9247311949729919, 'END': 0.0}]


{'BEGIN': 0.0, 'IN': 0.9247311949729919, 'END': 0.0}

In [15]:
tagger.accuracy(test_dataset)

  num_correct = mask.new_zeros(num_classes).scatter_(0, target, mask, reduce="add")
MulticlassAccuracy: 4batch [00:00, 35.81batch/s]                        


[MulticlassAccuracy: {'BEGIN': nan, 'IN': 0.9375, 'END': nan}]
[MulticlassAccuracy: {'BEGIN': nan, 'IN': 0.875, 'END': nan}]
[MulticlassAccuracy: {'BEGIN': nan, 'IN': 0.8645833134651184, 'END': nan}]
[MulticlassAccuracy: {'BEGIN': nan, 'IN': 0.8600000143051147, 'END': nan}]


tensor(0.8600)

In [16]:
from document_segmentation.pagexml.datamodel.page import Label


preds = tagger(test_dataset)

print(
    "\t".join(
        ("Page ID", "True Label", "Predicted Label", "Correct?", "Predicted Scores")
    )
)
for page_id, true_label, pred, pred_label in zip(
    test_dataset.doc_ids(),
    test_dataset.labels(),
    preds,
    preds.argmax(dim=1),
    strict=True,
):
    print(
        "\t".join(
            (
                str(page_id),
                true_label.name,
                Label(pred_label.item() + 1).name,
                str(Label(pred_label.item() + 1) == true_label),
                str(pred.tolist()),
            )
        )
    )

Page ID	True Label	Predicted Label	Correct?	Predicted Scores
NL-HaNA_1.04.02_7536_0374.jpg	IN	BEGIN	False	[0.37676846981048584, 0.3392247259616852, 0.2840067744255066]
NL-HaNA_1.04.02_7536_0375.jpg	IN	BEGIN	False	[0.37663406133651733, 0.348863422870636, 0.2745024859905243]
NL-HaNA_1.04.02_7536_0376.jpg	IN	BEGIN	False	[0.37430357933044434, 0.355951726436615, 0.2697446942329407]
NL-HaNA_1.04.02_7536_0377.jpg	IN	BEGIN	False	[0.37243595719337463, 0.36142265796661377, 0.266141414642334]
NL-HaNA_1.04.02_7536_0378.jpg	IN	BEGIN	False	[0.37060606479644775, 0.3651258945465088, 0.26426807045936584]
NL-HaNA_1.04.02_7536_0379.jpg	IN	BEGIN	False	[0.3693685829639435, 0.36769065260887146, 0.26294076442718506]
NL-HaNA_1.04.02_7536_0380.jpg	IN	IN	True	[0.3685202896595001, 0.36943700909614563, 0.2620426416397095]
NL-HaNA_1.04.02_7536_0381.jpg	IN	IN	True	[0.3679761588573456, 0.37064996361732483, 0.2613738477230072]
NL-HaNA_1.04.02_7536_0382.jpg	IN	IN	True	[0.36756062507629395, 0.37146779894828796, 0.26097

In [17]:
from torchview import draw_graph

model_graph = draw_graph(tagger)
print(model_graph.visual_graph)

RuntimeError: Only one of (input_data, input_size) should be specified.