In [1]:
%load_ext autoreload
%autoreload now

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "MIG-08137aa2-e69b-5e74-8390-7997329b1336"
# os.environ["WORLD_SIZE"] = "1"

# Download and convert data

In [7]:
from tqdm import tqdm

from document_segmentation.pagexml.annotations.renate_analysis import RenateAnalysis
from document_segmentation.settings import RENATE_ANALYSIS_DIR

N = None

RENATE_ANALYSIS_DIR.mkdir(parents=True, exist_ok=True)

sheet = RenateAnalysis()


existing_docs = {
    path.stem for path in RENATE_ANALYSIS_DIR.glob("Globdoc_*.json") if path.is_file()
}

for document in tqdm(
    sheet.to_documents(n=N, skip_ids=existing_docs),
    total=(N or len(sheet)) - len(existing_docs),
    desc="Writing documents",
    unit="doc",
):
    document_file = RENATE_ANALYSIS_DIR / f"{document.id}.json"

    with document_file.open("xt") as f:
        f.write(document.model_dump_json())
        f.write("\n")

Writing documents: 100%|██████████| 78/78 [03:02<00:00,  2.34s/doc]


In [6]:
from tqdm import tqdm

from document_segmentation.pagexml.annotations.renate_analysis import RenateAnalysisInv
from document_segmentation.settings import RENATE_ANALYSIS_DIR, RENATE_ANALYSIS_SHEETS

N = None


sheet = RenateAnalysisInv(RENATE_ANALYSIS_SHEETS[0])  # TODO: use both sheets

for document in tqdm(sheet.to_documents(n=N), desc="Writing documents", unit="doc"):
    document_file = RENATE_ANALYSIS_DIR / f"{document.id}.json"

    with document_file.open("xt") as f:
        f.write(document.model_dump_json())
        f.write("\n")

Writing documents:   0%|          | 0/664 [00:00<?, ?doc/s]

Writing documents:   4%|▍         | 26/664 [00:20<08:26,  1.26doc/s]


# Load Data

In [None]:
%autoreload now

In [8]:
TRAINING_DATA = 0.8

In [9]:
from document_segmentation.model.dataset import PageDataset
from document_segmentation.settings import MIN_REGION_TEXT_LENGTH

dataset = PageDataset.from_dir(RENATE_ANALYSIS_DIR).remove_short_regions(
    MIN_REGION_TEXT_LENGTH
)
len(dataset)

Reading JSON files: 100%|██████████| 104/104 [00:00<00:00, 152.02file/s]


2184

In [10]:
dataset._class_counts()

Counter({<Label.IN: 1>: 1907,
         <Label.BEGIN: 0>: 104,
         <Label.END: 2>: 100,
         <Label.OUT: 3>: 73})

In [11]:
dataset.class_weights()

[20.8, 1.1446540880503144, 21.623762376237625, 29.513513513513512]

In [12]:
split = int(len(dataset) * TRAINING_DATA)

training_data = dataset[:split]
training_data._class_counts()

Counter({<Label.IN: 1>: 1533,
         <Label.BEGIN: 0>: 77,
         <Label.END: 2>: 75,
         <Label.OUT: 3>: 62})

In [13]:
test_data = dataset[split:]
test_data._class_counts()

Counter({<Label.IN: 1>: 374,
         <Label.BEGIN: 0>: 27,
         <Label.END: 2>: 25,
         <Label.OUT: 3>: 11})

# Train Model

In [14]:
import torch

BATCH_SIZE = 32
EPOCHS = 10
WEIGHTS = torch.Tensor(dataset.class_weights())  # For an imbalanced dataset

In [15]:
%autoreload now

In [16]:
from document_segmentation.model.page_sequence_tagger import PageSequenceTagger

tagger = PageSequenceTagger()

In [17]:
tagger._device

'mps'

In [18]:
tagger

PageSequenceTagger(
  (_page_embedding): PageEmbedding(
    (_region_model): RegionEmbeddingSentenceTransformer(
      (_transformer_model): SentenceTransformer(
        (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: RobertaModel 
        (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False})
      )
      (_region_type): Embedding(9, 16)
      (_linear): Linear(in_features=784, out_features=512, bias=True)
    )
    (_rnn): LSTM(512, 256, num_layers=2, batch_first=True, dropout=0.1, bidirectional=True)
    (_linear): Linear(in_features=512, out_features=256, bias=True)
  )
  (_rnn): LSTM(256, 256, num_layers=2, batch_first=True, dropout=0.1, bidirectional=True)
  (_linear): Linear(in_features=512, out_features=4, bias=True)
  (_soft

In [24]:
tagger.train_(training_data, EPOCHS, BATCH_SIZE, WEIGHTS.to(tagger._device))

  0%|          | 0/54.59375 [00:00<?, ?batch/s]

101%|██████████| 55/54.59375 [00:04<00:00, 11.73batch/s]


Current allocated memory (MPS): 1378 MB
Driver allocated memory (MPS): 3060 MB
[Loss:	1.054]


101%|██████████| 55/54.59375 [00:04<00:00, 13.37batch/s]


Current allocated memory (MPS): 1378 MB
Driver allocated memory (MPS): 3070 MB
[Loss:	1.085]


101%|██████████| 55/54.59375 [00:04<00:00, 13.61batch/s]


Current allocated memory (MPS): 1378 MB
Driver allocated memory (MPS): 3070 MB
[Loss:	1.098]


101%|██████████| 55/54.59375 [00:04<00:00, 13.21batch/s]


Current allocated memory (MPS): 1378 MB
Driver allocated memory (MPS): 3070 MB
[Loss:	1.109]


101%|██████████| 55/54.59375 [00:04<00:00, 13.74batch/s]


Current allocated memory (MPS): 1378 MB
Driver allocated memory (MPS): 3070 MB
[Loss:	1.116]


101%|██████████| 55/54.59375 [00:04<00:00, 13.54batch/s]


Current allocated memory (MPS): 1378 MB
Driver allocated memory (MPS): 3070 MB
[Loss:	1.097]


101%|██████████| 55/54.59375 [00:04<00:00, 13.56batch/s]


Current allocated memory (MPS): 1378 MB
Driver allocated memory (MPS): 3070 MB
[Loss:	1.079]


101%|██████████| 55/54.59375 [00:04<00:00, 13.56batch/s]


Current allocated memory (MPS): 1378 MB
Driver allocated memory (MPS): 3070 MB
[Loss:	1.054]


101%|██████████| 55/54.59375 [00:03<00:00, 13.83batch/s]


Current allocated memory (MPS): 1378 MB
Driver allocated memory (MPS): 3070 MB
[Loss:	1.144]


101%|██████████| 55/54.59375 [00:04<00:00, 13.45batch/s]

Current allocated memory (MPS): 1378 MB
Driver allocated memory (MPS): 3070 MB
[Loss:	1.250]





# Evaluate Model

In [27]:
import csv
import sys

from torcheval.metrics import (
    MulticlassAccuracy,
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)
from tqdm import tqdm

from document_segmentation.pagexml.datamodel.label import Label

writer = csv.DictWriter(
    sys.stdout,
    fieldnames=("Predicted", "Actual", "Page ID", "Text", "Scores"),
    delimiter="\t",
)

writer.writeheader()

accuracy = MulticlassAccuracy(num_classes=len(Label))
precision = MulticlassPrecision(average=None, num_classes=len(Label))
recall = MulticlassRecall(average=None, num_classes=len(Label))
f1_score = MulticlassF1Score(average=None, num_classes=len(Label))

for batch in tqdm(
    test_data.batches(BATCH_SIZE), total=len(test_data) / BATCH_SIZE, unit="batch"
):
    predicted = tagger(batch)
    labels = batch.labels()

    _labels = torch.Tensor([label.value for label in labels]).to(int)
    accuracy.update(predicted, _labels)
    precision.update(predicted, _labels)
    recall.update(predicted, _labels)
    f1_score.update(predicted, _labels)

    for page, pred, label in zip(batch.pages, predicted, labels):
        pred_label = Label(pred.argmax().item())
        # if pred_label != Label.IN or label != Label.IN:
        writer.writerow(
            {
                "Predicted": pred_label.name,
                "Actual": label.name,
                "Page ID": page.doc_id,
                "Text": page.text(delimiter="; ")[:50],
                "Scores": str(pred.tolist()),
            }
        )

Predicted	Actual	Page ID	Text	Scores


  7%|▋         | 1/13.65625 [00:00<00:01,  8.55batch/s]

IN	IN	NL-HaNA_1.04.02_3524_0970.jpg	Een Praurmaijang en arriveerd en den 4 te Mampau„;	[0.1231585219502449, 0.6661331057548523, 0.20988833904266357, 0.0008200352895073593]
IN	IN	NL-HaNA_1.04.02_3524_0971.jpg	met den dogter van den Panumbahan aldaar ge„; stro	[0.13595041632652283, 0.6782530546188354, 0.18533653020858765, 0.0004599464882630855]
IN	IN	NL-HaNA_1.04.02_3524_0972.jpg	Raadsaam was dat de Pangerang zijn schoonzoon; zel	[0.14104604721069336, 0.6879023313522339, 0.17063266038894653, 0.000418944691773504]
IN	IN	NL-HaNA_1.04.02_3524_0973.jpg	dat hem dit zeerlieff was dat de Landakers affalli	[0.1505967080593109, 0.6758900284767151, 0.1730610728263855, 0.00045220437459647655]
IN	IN	NL-HaNA_1.04.02_3524_0974.jpg	Raadsaam was dat de Pangerang zijn schoonzoon; zel	[0.15059442818164825, 0.6761804819107056, 0.1727742999792099, 0.0004507791600190103]
IN	IN	NL-HaNA_1.04.02_3524_0975.jpg	dat hem dit zeerlieff was dat de Landakers affalli	[0.15530365705490112, 0.6721768975257874, 0.17207029

 22%|██▏       | 3/13.65625 [00:00<00:00, 13.72batch/s]

IN	BEGIN	NL-HaNA_1.04.02_1647_0632.jpg	Van Macassar Anno 1701—; Vervolgens quam by den He	[0.12395356595516205, 0.6638548970222473, 0.21137934923171997, 0.0008122037397697568]
IN	IN	NL-HaNA_1.04.02_1647_0633.jpg	Van Macassar Anno 1701; Van Macassar Anno 1701.; a	[0.13473038375377655, 0.6803565621376038, 0.1844596564769745, 0.00045336579205468297]
IN	IN	NL-HaNA_1.04.02_1647_0634.jpg	Van Macassar Anno 1701; Van Macassar Anno 1701; va	[0.140029177069664, 0.6896322965621948, 0.16992247104644775, 0.00041610473999753594]
IN	IN	NL-HaNA_1.04.02_1647_0635.jpg	Van Macassar Anno 1701; Van Macassar Anno 1701; ve	[0.1512736678123474, 0.6765015125274658, 0.1717754304409027, 0.0004493932065088302]
IN	IN	NL-HaNA_1.04.02_1647_0636.jpg	Van Macassar Anno 1701; Van Macassar A„o 1701; het	[0.1493915617465973, 0.6775469183921814, 0.172615647315979, 0.00044592225458472967]
IN	IN	NL-HaNA_1.04.02_1647_0637.jpg	Van Macassar Anno 1701; Van Macassar A„o 1701.; mo	[0.1560990810394287, 0.6709672808647156, 0.1724813

 44%|████▍     | 6/13.65625 [00:00<00:00, 16.61batch/s]

IN	IN	NL-HaNA_1.04.02_1088_0515.jpg	ons voor antwoort, dat wel waar was, dat sijn Vade	[0.12300992757081985, 0.6676484942436218, 0.2085222452878952, 0.0008192991372197866]
IN	IN	NL-HaNA_1.04.02_1088_0516.jpg	noch al wel toeginck, ende mijn hier van noch vrij	[0.1371178925037384, 0.6777298450469971, 0.18468590080738068, 0.0004663012514356524]
IN	IN	NL-HaNA_1.04.02_1088_0517.jpg	blijcken wat parthij hij hielt van dese uijr affso	[0.14017949998378754, 0.6881232261657715, 0.17127247154712677, 0.0004248098412062973]
IN	IN	NL-HaNA_1.04.02_1088_0518.jpg	dus lange was gedaen, hebben den raet dit in Beden	[0.15155616402626038, 0.673597514629364, 0.17435970902442932, 0.00048659005551598966]
IN	IN	NL-HaNA_1.04.02_1088_0519.jpg	te mogen werden, met cruijt ende Loot, ende soo he	[0.14337728917598724, 0.664137065410614, 0.19196680188179016, 0.0005187909700907767]
IN	IN	NL-HaNA_1.04.02_1088_0520.jpg	Den 9=en d=o smorgens, quamen voor Oerien, ofte no	[0.15297247469425201, 0.6291505098342896, 0.2173487

 59%|█████▊    | 8/13.65625 [00:00<00:00, 14.59batch/s]

IN	IN	NL-HaNA_1.04.02_1509_1562.jpg	464. persoonen p:r transport; Namen Toenamen en Ge	[0.12717807292938232, 0.6669479608535767, 0.20509423315525055, 0.0007797127473168075]
IN	IN	NL-HaNA_1.04.02_1509_1563.jpg	p:r coopm: en opperh:t van; tegenep: en p:r novo: 	[0.20447637140750885, 0.6038725972175598, 0.19107909500598907, 0.0005719168693758547]
IN	IN	NL-HaNA_1.04.02_1509_1564.jpg	504. persoonen p=r Transport; Namen, Toenamen, en 	[0.1593000441789627, 0.6723264455795288, 0.16791807115077972, 0.0004554121114779264]
IN	IN	NL-HaNA_1.04.02_1509_1565.jpg	w:t p:r met wat schip in; presente qualitijt maent	[0.16074635088443756, 0.6690497398376465, 0.16975799202919006, 0.00044589844765141606]
IN	IN	NL-HaNA_1.04.02_1509_1566.jpg	Op Nagapatnam zijn nogh bescheijden d'volgende; Bo	[0.16036903858184814, 0.6660071015357971, 0.17316178977489471, 0.0004620686231646687]
IN	END	NL-HaNA_1.04.02_1509_1567.jpg	103 en 84. persoonen p:r transport; 1. toatongoe v	[0.1625780314207077, 0.6644097566604614, 0.1725

 73%|███████▎  | 10/13.65625 [00:00<00:00, 15.65batch/s]

IN	IN	NL-HaNA_1.04.02_8820_0075.jpg	Van Cormandel onder 24: November 1702.; ten dienst	[0.12715326249599457, 0.6582581996917725, 0.21381030976772308, 0.0007782831089571118]
IN	IN	NL-HaNA_1.04.02_8820_0076.jpg	Van Cormandel onder 24: November 1702.; de allermi	[0.13509301841259003, 0.6820201873779297, 0.18245910108089447, 0.0004277309635654092]
IN	END	NL-HaNA_1.04.02_8820_0077.jpg	Van Cormandel onder 24: November 1702.; gevolgen, 	[0.13931158185005188, 0.6907104253768921, 0.1695808470249176, 0.0003971432743128389]
IN	BEGIN	NL-HaNA_1.04.02_2682_0249.jpg	Met het ondergenoemde schip; vertrekken over China	[0.1499280035495758, 0.6791777014732361, 0.17046895623207092, 0.0004253980587236583]
IN	BEGIN	NL-HaNA_1.04.02_3095_0015.jpg	Register der Papieren; werdende versonden per het 	[0.1480630487203598, 0.6783959865570068, 0.17311644554138184, 0.0004244910378474742]
IN	IN	NL-HaNA_1.04.02_3095_0016.jpg	4.; orig: in genaagt, a:o p„o; d'Edele Groot Agtba	[0.15643933415412903, 0.6712308526039124, 0.

 88%|████████▊ | 12/13.65625 [00:00<00:00, 16.48batch/s]

IN	IN	NL-HaNA_1.04.02_3248_0899.jpg	Waar mede den relatant dit zijne gegevene relaas c	[0.16347582638263702, 0.5671754479408264, 0.26833581924438477, 0.0010129263391718268]
BEGIN	IN	NL-HaNA_1.04.02_3248_0900.jpg		[0.404065877199173, 0.30468520522117615, 0.29047891497612, 0.0007700325222685933]
IN	IN	NL-HaNA_1.04.02_3248_0901.jpg	Op heden den 8:e Februarij 1768: Compa„; „reerde v	[0.2528277635574341, 0.5486882328987122, 0.19793754816055298, 0.0005464744754135609]
IN	END	NL-HaNA_1.04.02_3248_0902.jpg	waar mede den relatant zijn gegevene relaas quam t	[0.24869173765182495, 0.5353215932846069, 0.21540138125419617, 0.0005852750036865473]
IN	BEGIN	NL-HaNA_1.04.02_1547_0107.jpg	Extracten uijt de Daagelijkse Aanteeckeningen, Con	[0.24311599135398865, 0.5059977769851685, 0.2502225935459137, 0.0006636455073021352]
IN	END	NL-HaNA_1.04.02_1547_0108.jpg	op desen holvn en dass: matthijs in ve totenende a	[0.2662496566772461, 0.4181278347969055, 0.3150073289871216, 0.0006151841371320188]
BEGIN	OUT	NL

103%|██████████| 14/13.65625 [00:00<00:00, 15.93batch/s]

IN	IN	NL-HaNA_1.04.02_8696_0057.jpg	Van Siam onder dato 20: april 1737; vergadert was 	[0.1308359056711197, 0.6738881468772888, 0.1945958137512207, 0.0006801239214837551]
IN	IN	NL-HaNA_1.04.02_8696_0058.jpg	Van Siam onder dato: 20: april 1737; Maart voogd v	[0.12245719879865646, 0.6935535669326782, 0.18355098366737366, 0.0004382397746667266]
IN	IN	NL-HaNA_1.04.02_8696_0059.jpg	Van Siam onder dato: 20: april 1737; houden dit go	[0.1412268877029419, 0.6839978098869324, 0.17438267171382904, 0.0003925894561689347]
IN	IN	NL-HaNA_1.04.02_8696_0060.jpg	Siam onder dato: 20: april 1737; an; kittelde en h	[0.1528184711933136, 0.6743114590644836, 0.17245063185691833, 0.00041942019015550613]
IN	IN	NL-HaNA_1.04.02_8696_0061.jpg	Van Siam onder dato: 20:' april 1737; Excuseert, m	[0.15676839649677277, 0.6717154383659363, 0.1710730344057083, 0.0004431256093084812]
IN	IN	NL-HaNA_1.04.02_8696_0062.jpg	an Siam onder dato: 20: april 1737; kittelde en he	[0.15608811378479004, 0.6699219346046448, 0.17355549




In [26]:
writer = csv.DictWriter(
    sys.stdout,
    fieldnames=["Metric"] + [label.name for label in Label],
    delimiter="\t",
)
writer.writeheader()

for metric in (precision, recall, f1_score):
    scores = {
        label.name: f"{score:.4f}"
        for label, score in zip(Label, metric.compute().tolist())
    }
    writer.writerow({"Metric": metric.__class__.__name__} | scores)

print(f"Accuracy ({accuracy.average} average):\t{accuracy.compute().item():.4f}")



Metric	BEGIN	IN	END	OUT
MulticlassPrecision	0.0000	0.8747	0.0789	0.0000
MulticlassRecall	0.0000	0.8770	0.1200	0.0000
MulticlassF1Score	0.0000	0.8758	0.0952	0.0000
Accuracy (micro average):	0.7574
