In [1]:
%load_ext autoreload
%autoreload now

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "MIG-08137aa2-e69b-5e74-8390-7997329b1336"
# os.environ["WORLD_SIZE"] = "1"

# Download and convert data

In [4]:
from tqdm import tqdm

from document_segmentation.pagexml.annotations.renate_analysis import RenateAnalysis
from document_segmentation.settings import RENATE_ANALYSIS_DIR

N = None

RENATE_ANALYSIS_DIR.mkdir(parents=True, exist_ok=True)

sheet = RenateAnalysis()


existing_docs = {
    path.stem for path in RENATE_ANALYSIS_DIR.glob("Globdoc_*.json") if path.is_file()
}

for document in tqdm(
    sheet.to_documents(n=N, skip_ids=existing_docs),
    total=(N or len(sheet)) - len(existing_docs),
    desc="Writing documents",
    unit="doc",
):
    document_file = RENATE_ANALYSIS_DIR / f"{document.id}.json"

    with document_file.open("xt") as f:
        f.write(document.model_dump_json())
        f.write("\n")

Writing documents: 100%|██████████| 78/78 [03:15<00:00,  2.51s/doc]


In [5]:
from tqdm import tqdm

from document_segmentation.pagexml.annotations.renate_analysis import RenateAnalysisInv
from document_segmentation.settings import RENATE_ANALYSIS_DIR, RENATE_ANALYSIS_SHEETS

N = None


sheet = RenateAnalysisInv(RENATE_ANALYSIS_SHEETS[0])  # TODO: use both sheets

existing_docs = {
    path.stem
    for path in RENATE_ANALYSIS_DIR.glob("NL-HaNA_1.04.02_*.json")
    if path.is_file()
}

for document in tqdm(
    sheet.to_documents(n=N, skip_ids=existing_docs),
    total=(N or len(sheet)) - len(existing_docs),  # FIXME: len(sheet) is wrong
    desc="Writing documents",
    unit="doc",
):
    document_file = RENATE_ANALYSIS_DIR / f"{document.id}.json"

    with document_file.open("xt") as f:
        f.write(document.model_dump_json())
        f.write("\n")

Writing documents:   0%|          | 0/690 [00:00<?, ?doc/s]

Writing documents:   4%|▍         | 26/690 [00:11<04:46,  2.32doc/s]


# Load Data

In [6]:
%autoreload now

In [7]:
TRAINING_DATA = 0.8

In [8]:
from document_segmentation.model.dataset import PageDataset
from document_segmentation.settings import MIN_REGION_TEXT_LENGTH

dataset = PageDataset.from_dir(RENATE_ANALYSIS_DIR).remove_short_regions(
    MIN_REGION_TEXT_LENGTH
)
len(dataset)

Reading JSON files: 100%|██████████| 104/104 [00:01<00:00, 86.56file/s] 


2184

In [9]:
dataset._class_counts()

Counter({<Label.IN: 2>: 1907,
         <Label.BEGIN: 1>: 104,
         <Label.END: 3>: 100,
         <Label.OUT: 4>: 73})

In [10]:
dataset.class_weights()

[20.8, 1.1446540880503144, 21.623762376237625, 29.513513513513512]

In [12]:
split = int(len(dataset) * TRAINING_DATA)

training_data = dataset[:split]
training_data._class_counts()

Counter({<Label.IN: 2>: 1533,
         <Label.BEGIN: 1>: 77,
         <Label.END: 3>: 75,
         <Label.OUT: 4>: 62})

In [13]:
test_data = dataset[split:]
test_data._class_counts()

Counter({<Label.IN: 2>: 374,
         <Label.BEGIN: 1>: 27,
         <Label.END: 3>: 25,
         <Label.OUT: 4>: 11})

# Train Model

In [14]:
import torch

BATCH_SIZE = 32
EPOCHS = 3
WEIGHTS = torch.Tensor(dataset.class_weights())  # For an imbalanced dataset

In [15]:
%autoreload now

In [16]:
from document_segmentation.model.page_sequence_tagger import PageSequenceTagger

tagger = PageSequenceTagger()

In [17]:
tagger._device

'mps'

In [18]:
tagger

PageSequenceTagger(
  (_page_embedding): PageEmbedding(
    (_region_model): RegionEmbeddingSentenceTransformer(
      (_transformer_model): SentenceTransformer(
        (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: RobertaModel 
        (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False})
      )
      (_region_type): Embedding(9, 16)
      (_linear): Linear(in_features=784, out_features=512, bias=True)
    )
    (_rnn): LSTM(512, 256, num_layers=2, batch_first=True, dropout=0.1, bidirectional=True)
    (_linear): Linear(in_features=512, out_features=256, bias=True)
  )
  (_rnn): LSTM(256, 256, num_layers=2, batch_first=True, dropout=0.1, bidirectional=True)
  (_linear): Linear(in_features=512, out_features=4, bias=True)
  (_soft

In [19]:
tagger.train_(training_data, EPOCHS, BATCH_SIZE, WEIGHTS.to(tagger._device))

  full_bar = Bar(frac,
101%|██████████| 55/54.59375 [01:31<00:00,  1.66s/batch]


Current allocated memory (MPS): 1170 MB
Driver allocated memory (MPS): 2887 MB
[Loss:	1.163]


101%|██████████| 55/54.59375 [00:03<00:00, 13.83batch/s]


Current allocated memory (MPS): 1162 MB
Driver allocated memory (MPS): 2967 MB
[Loss:	1.165]


101%|██████████| 55/54.59375 [00:03<00:00, 13.88batch/s]

Current allocated memory (MPS): 1162 MB
Driver allocated memory (MPS): 2851 MB
[Loss:	1.095]





# Evaluate Model

In [20]:
import csv
import sys

from torcheval.metrics import (
    MulticlassAccuracy,
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)
from tqdm import tqdm

from document_segmentation.pagexml.datamodel.label import Label

writer = csv.DictWriter(
    sys.stdout,
    fieldnames=("Predicted", "Actual", "Page ID", "Text", "Scores"),
    delimiter="\t",
)

writer.writeheader()

accuracy = MulticlassAccuracy(num_classes=len(Label))
precision = MulticlassPrecision(average=None, num_classes=len(Label))
recall = MulticlassRecall(average=None, num_classes=len(Label))
f1_score = MulticlassF1Score(average=None, num_classes=len(Label))

for batch in tqdm(
    test_data.batches(BATCH_SIZE), total=len(test_data) / BATCH_SIZE, unit="batch"
):
    predicted = tagger(batch)
    labels = batch.labels()

    _labels = torch.Tensor([label.value - 1 for label in labels]).to(int)
    accuracy.update(predicted, _labels)
    precision.update(predicted, _labels)
    recall.update(predicted, _labels)
    f1_score.update(predicted, _labels)

    for page, pred, label in zip(batch.pages, predicted, labels):
        pred_label = Label(pred.argmax().item() + 1)
        if pred_label != Label.IN or label != Label.IN:
            writer.writerow(
                {
                    "Predicted": pred_label.name,
                    "Actual": label.name,
                    "Page ID": page.doc_id,
                    "Text": page.text(delimiter="; ")[:50],
                    "Scores": str(pred.tolist()),
                }
            )

Predicted	Actual	Page ID	Text	Scores


  7%|▋         | 1/13.65625 [00:00<00:10,  1.21batch/s]

END	IN	NL-HaNA_1.04.02_3524_0986.jpg	van 10000 kyserdaalders ter leen zullende hij alle	[0.12512490153312683, 0.30856895446777344, 0.5628307461738586, 0.00347542786039412]
END	IN	NL-HaNA_1.04.02_3524_0987.jpg	Hier meede deesen bekortende maatigen wij ons de; 	[0.06166611984372139, 0.08441950380802155, 0.851412832736969, 0.0025015235878527164]
END	END	NL-HaNA_1.04.02_3524_0988.jpg	Notitie van Zodaanige Coopmansz:; etc: als Sariepr	[0.003870869055390358, 0.005092552397400141, 0.9894911050796509, 0.0015453965170308948]
OUT	OUT	NL-HaNA_1.04.02_1547_0481.jpg		[0.0005603920435532928, 0.0002981974685098976, 0.003644691314548254, 0.9954967498779297]
OUT	OUT	NL-HaNA_1.04.02_1547_0482.jpg		[0.0008527957834303379, 0.00024216008023358881, 0.0014517775271087885, 0.9974531531333923]
BEGIN	BEGIN	NL-HaNA_1.04.02_1547_0483.jpg	Saterdagh den 4=en xber; Present den Commandeur; P	[0.9136616587638855, 0.04032803326845169, 0.03998707979917526, 0.006023257505148649]
BEGIN	IN	NL-HaNA_1.04.02_1547_0484.jpg	na 

 15%|█▍        | 2/13.65625 [00:02<00:18,  1.58s/batch]

IN	END	NL-HaNA_1.04.02_1547_0467.jpg	Waer seven daegen zoo aen mijn lighaem te verschoo	[0.11907272785902023, 0.7566057443618774, 0.12283354252576828, 0.001488027162849903]
IN	BEGIN	NL-HaNA_1.04.02_1268_1095.jpg	Is nae aen roepingh van godes Heijsige naeme; vers	[0.11873476952314377, 0.7482216954231262, 0.13157106935977936, 0.0014724008506163955]
IN	END	NL-HaNA_1.04.02_1268_1109.jpg	g'admitteert en tot godes kerke aengenomen werden 	[0.11895566433668137, 0.7560752034187317, 0.12349596619606018, 0.0014732044655829668]
IN	BEGIN	NL-HaNA_1.04.02_3446_0595.jpg	Aan De Edele Hoog Agtb„e Heeren; Bewindhebberen va	[0.11902810633182526, 0.7544757723808289, 0.12500500679016113, 0.001491079805418849]
IN	END	NL-HaNA_1.04.02_3446_0603.jpg	Batavia; in 't Casteel den; 25:e october; waar mee	[0.12001077830791473, 0.661048412322998, 0.2169722318649292, 0.001968554686754942]
IN	BEGIN	NL-HaNA_1.04.02_2542_0114.jpg	Van Mallabaar onder dato; 9. stux te weeten; 6. ma	[0.13269509375095367, 0.5253822207450867,

 22%|██▏       | 3/13.65625 [00:04<00:17,  1.61s/batch]

IN	BEGIN	NL-HaNA_1.04.02_1647_0632.jpg	Van Macassar Anno 1701—; Vervolgens quam by den He	[0.16252586245536804, 0.6899905204772949, 0.14497889578342438, 0.0025046654045581818]
IN	END	NL-HaNA_1.04.02_1647_0642.jpg	Van Macassar A„o 1701; Van Macassar Anno 1701; had	[0.12201889604330063, 0.7082138061523438, 0.16794952750205994, 0.0018177941674366593]
IN	BEGIN	NL-HaNA_1.04.02_1060_0435.jpg	Alsoo het schip der Goes, als t'Jacht Cleijn Enckh	[0.12901876866817474, 0.7075014114379883, 0.161713108420372, 0.001766662928275764]
END	IN	NL-HaNA_1.04.02_1060_0446.jpg	ge; Alsoo 'T schip Esterre, tot zijnne te doene vo	[0.13430988788604736, 0.41885554790496826, 0.4433428645133972, 0.0034917532466351986]
END	IN	NL-HaNA_1.04.02_1060_0447.jpg	Maendach den xiiije Decemb. @ 1615 &; Alsoo bij de	[0.13999177515506744, 0.24487237632274628, 0.6115007996559143, 0.0036350605078041553]
END	END	NL-HaNA_1.04.02_1060_0448.jpg	Sivert Sipkens sargeant die de voors compe. tot nu	[0.04871416836977005, 0.0617064647376537

 29%|██▉       | 4/13.65625 [00:08<00:23,  2.47s/batch]

IN	BEGIN	NL-HaNA_1.04.02_1506_1034.jpg	-; d; E; 7.; decken; 8; 2; van de; 5; ƒ; 3; E; E; 	[0.18040326237678528, 0.5605207681655884, 0.2548779249191284, 0.004198049660772085]
END	IN	NL-HaNA_1.04.02_1506_1036.jpg	30 vrs; o; Janor; 6; 5o; e; x; 116; E; 6.; 1.; :; 	[0.12800917029380798, 0.25582608580589294, 0.6122598648071289, 0.0039049226325005293]
END	END	NL-HaNA_1.04.02_1506_1037.jpg	k; „noortvelt; rogons; „1; rsame; uijt; eruijt; ƒ;	[0.04888710007071495, 0.05560712888836861, 0.8927388191223145, 0.002766897203400731]
END	BEGIN	NL-HaNA_1.04.02_3060_0043.jpg	Na dat de Leeden deeser Vergaadering bij een geroe	[0.0022203424014151096, 0.0032034567557275295, 0.9931526184082031, 0.0014235622948035598]
OUT	IN	NL-HaNA_1.04.02_3060_0044.jpg		[0.0003995909064542502, 0.0002084755542455241, 0.00244927522726357, 0.9969426989555359]
OUT	IN	NL-HaNA_1.04.02_3060_0045.jpg		[0.00030325588886626065, 0.00010475625458639115, 0.0008695845608599484, 0.9987223744392395]
OUT	IN	NL-HaNA_1.04.02_3060_0046.jpg		[0.

 37%|███▋      | 5/13.65625 [00:09<00:18,  2.12s/batch]

END	IN	NL-HaNA_1.04.02_1088_0519.jpg	te mogen werden, met cruijt ende Loot, ende soo he	[0.1329692304134369, 0.357964426279068, 0.5050461292266846, 0.004020269960165024]
END	IN	NL-HaNA_1.04.02_1088_0520.jpg	Den 9=en d=o smorgens, quamen voor Oerien, ofte no	[0.10437708348035812, 0.19867351651191711, 0.6932567954063416, 0.003692629048600793]
END	IN	NL-HaNA_1.04.02_1088_0521.jpg	mede die van Cabau, den Sergeant aldaer op Hatuha 	[0.01955864578485489, 0.021936854347586632, 0.9564468264579773, 0.0020575968082994223]
END	IN	NL-HaNA_1.04.02_1088_0522.jpg	Sijn voorts doorgepangaijt naer Oma, de corcoiren 	[0.002283766632899642, 0.003132769837975502, 0.9928218722343445, 0.0017616436816751957]
OUT	IN	NL-HaNA_1.04.02_1088_0523.jpg		[0.0006294280756264925, 0.00033357689972035587, 0.00575008150190115, 0.9932869076728821]
OUT	IN	NL-HaNA_1.04.02_1088_0524.jpg		[0.00022018681920599192, 8.433008042629808e-05, 0.0006946068024262786, 0.9990008473396301]
OUT	IN	NL-HaNA_1.04.02_1088_0525.jpg		[0.000205579

 44%|████▍     | 6/13.65625 [00:11<00:14,  1.86s/batch]

IN	END	NL-HaNA_1.04.02_1088_0570.jpg	ende geen meer en haddent, schreeff Helenij aen on	[0.12697729468345642, 0.7508246302604675, 0.12078149616718292, 0.0014165047323331237]
IN	BEGIN	NL-HaNA_1.04.02_8099_0205.jpg	Van Ternaten onder dato 11:' 7ber: 1732; van alle 	[0.11800751090049744, 0.7573683857917786, 0.12315021455287933, 0.001473847427405417]


 51%|█████▏    | 7/13.65625 [00:12<00:11,  1.72s/batch]

END	END	NL-HaNA_1.04.02_8099_0217.jpg	Van Ternaten onder dato 11:' Septemb: 1732; s geli	[0.12939570844173431, 0.41732364892959595, 0.45073649287223816, 0.002544065937399864]
END	BEGIN	NL-HaNA_1.04.02_1070_0199.jpg	2 saeckers Elck van 3000 lb; 2 halve dittos elck v	[0.032241739332675934, 0.05895461514592171, 0.906846821308136, 0.0019568265415728092]
OUT	IN	NL-HaNA_1.04.02_1070_0200.jpg		[0.002627877052873373, 0.0011011192109435797, 0.005175719037652016, 0.9910953044891357]
BEGIN	END	NL-HaNA_1.04.02_1070_0201.jpg	Adriaen gerritsz van utrecht sergiant; marijn Ding	[0.6208497881889343, 0.29111674427986145, 0.08381491154432297, 0.004218594171106815]
IN	BEGIN	NL-HaNA_1.04.02_1509_1538.jpg	Monsterolle van alle sComp:s Loontreckende; Monste	[0.31582915782928467, 0.5470865964889526, 0.13468828797340393, 0.002395995892584324]


 59%|█████▊    | 8/13.65625 [00:16<00:13,  2.34s/batch]

IN	END	NL-HaNA_1.04.02_1509_1567.jpg	103 en 84. persoonen p:r transport; 1. toatongoe v	[0.11780481785535812, 0.7591096758842468, 0.12157411873340607, 0.0015114103443920612]
IN	BEGIN	NL-HaNA_1.04.02_1490_0583.jpg	Copije Secrete Resolutien; genomen bij de Ho: Rege	[0.119136281311512, 0.7563238143920898, 0.12305018305778503, 0.001489725662395358]


 66%|██████▌   | 9/13.65625 [00:18<00:10,  2.30s/batch]

IN	END	NL-HaNA_1.04.02_1490_0631.jpg	Bo0; Secrets.; wee; De bovenstaende secrete resolu	[0.11935272067785263, 0.7553587555885315, 0.12381473928689957, 0.0014738241443410516]
IN	BEGIN	NL-HaNA_1.04.02_1547_0110.jpg	gemerkt &E.; Waarmeede; Edele hoog agtbaare gebied	[0.12711282074451447, 0.7507976293563843, 0.12067340314388275, 0.001416230108588934]
IN	END	NL-HaNA_1.04.02_1547_0112.jpg	l  eene heben em maede slekt e ondergeende kopmans	[0.1190159022808075, 0.7594931721687317, 0.12002599984407425, 0.0014648709911853075]
IN	BEGIN	NL-HaNA_1.04.02_8820_0069.jpg	Van Cormandel deder 24: November ao 1702.; Jck ond	[0.12250299751758575, 0.7552060484886169, 0.12091794610023499, 0.0013730255886912346]


 73%|███████▎  | 10/13.65625 [00:26<00:14,  4.02s/batch]

IN	END	NL-HaNA_1.04.02_8820_0077.jpg	Van Cormandel onder 24: November 1702.; gevolgen, 	[0.12185464799404144, 0.7601072788238525, 0.11653806269168854, 0.0015000335406512022]
IN	BEGIN	NL-HaNA_1.04.02_2682_0249.jpg	Met het ondergenoemde schip; vertrekken over China	[0.11959156394004822, 0.7603518962860107, 0.11862767487764359, 0.0014288848033174872]
IN	BEGIN	NL-HaNA_1.04.02_3095_0015.jpg	Register der Papieren; werdende versonden per het 	[0.11868148297071457, 0.7593616843223572, 0.12050323933362961, 0.0014535108348354697]
IN	END	NL-HaNA_1.04.02_3095_0037.jpg	N:o 57: Sommarium van het geladene in tien; Retoúr	[0.1216193437576294, 0.7538580298423767, 0.1231299340724945, 0.0013927026884630322]
IN	BEGIN	NL-HaNA_1.04.02_8260_0061.jpg	S morgens te agt uuren nog geen bevoeging in het B	[0.120484858751297, 0.7481884360313416, 0.12981632351875305, 0.0015103251207619905]


 81%|████████  | 11/13.65625 [00:28<00:08,  3.33s/batch]

IN	END	NL-HaNA_1.04.02_8260_0075.jpg	Aan den Lieutenant militair Alexander LeCerf; Comm	[0.12889909744262695, 0.6869235038757324, 0.18219859898090363, 0.001978822285309434]
IN	BEGIN	NL-HaNA_1.04.02_3248_0877.jpg	Op Huijden den 5: October a„o 1761: voor mij Wolfe	[0.12941108644008636, 0.6866259574890137, 0.18196438252925873, 0.0019985877443104982]
END	IN	NL-HaNA_1.04.02_3248_0885.jpg	anders te zeggen, als 't geen in 't Casteel orange	[0.1377173364162445, 0.40787774324417114, 0.4511462450027466, 0.0032586543820798397]
END	IN	NL-HaNA_1.04.02_3248_0886.jpg	hij relatant in dien zoo wel onder Ambon als Terna	[0.11374401301145554, 0.20430836081504822, 0.678893506526947, 0.003054120345041156]
END	IN	NL-HaNA_1.04.02_3248_0887.jpg	Gevende voor redenen van wetenschap als in den tex	[0.013966867700219154, 0.017833277583122253, 0.9664645791053772, 0.001735275611281395]
OUT	IN	NL-HaNA_1.04.02_3248_0888.jpg		[0.0017112784553319216, 0.0006986105581745505, 0.007705260533839464, 0.9898848533630371]
BEGI

 88%|████████▊ | 12/13.65625 [00:30<00:04,  2.90s/batch]

END	IN	NL-HaNA_1.04.02_3248_0899.jpg	Waar mede den relatant dit zijne gegevene relaas c	[0.018082819879055023, 0.019184289500117302, 0.9540315866470337, 0.008701292797923088]
OUT	IN	NL-HaNA_1.04.02_3248_0900.jpg		[0.0009222699445672333, 0.00037951141712255776, 0.0023496318608522415, 0.9963486194610596]
BEGIN	IN	NL-HaNA_1.04.02_3248_0901.jpg	Op heden den 8:e Februarij 1768: Compa„; „reerde v	[0.7902325987815857, 0.09338448196649551, 0.10774803906679153, 0.008634885773062706]
END	END	NL-HaNA_1.04.02_3248_0902.jpg	waar mede den relatant zijn gegevene relaas quam t	[0.38019809126853943, 0.18494176864624023, 0.4292825162410736, 0.005577538628131151]
END	BEGIN	NL-HaNA_1.04.02_1547_0107.jpg	Extracten uijt de Daagelijkse Aanteeckeningen, Con	[0.10026136785745621, 0.0484062023460865, 0.8476727604866028, 0.0036596853751689196]
END	END	NL-HaNA_1.04.02_1547_0108.jpg	op desen holvn en dass: matthijs in ve totenende a	[0.003603915683925152, 0.0031431664247065783, 0.9913488030433655, 0.00190408644266

 95%|█████████▌| 13/13.65625 [00:32<00:01,  2.82s/batch]

IN	END	NL-HaNA_1.04.02_1083_0048.jpg	Des avondts voor het aronsteten is den predicanten	[0.11937545984983444, 0.7591578960418701, 0.12000031024217606, 0.0014663924230262637]
IN	BEGIN	NL-HaNA_1.04.02_8696_0051.jpg	Van Siam onder dato 20: april 1737; Dag-register, 	[0.12174913287162781, 0.7546426057815552, 0.122219018638134, 0.0013892798451706767]


103%|██████████| 14/13.65625 [00:33<00:00,  2.41s/batch]

IN	END	NL-HaNA_1.04.02_8696_0077.jpg	Van Siam onder dato: 20: april 1737; Translaet Sia	[0.13300307095050812, 0.5275892615318298, 0.3366990089416504, 0.0027086276095360518]





In [21]:
writer = csv.DictWriter(
    sys.stdout,
    fieldnames=["Metric", "Average"] + [label.name for label in Label],
    delimiter="\t",
)
writer.writeheader()

for metric in (precision, recall, f1_score):
    scores = {
        label.name: f"{score:.4f}"
        for label, score in zip(Label, metric.compute().tolist())
    }
    writer.writerow(
        {"Metric": metric.__class__.__name__, "Average": str(metric.average)} | scores
    )

print(f"Accuracy ({accuracy.average} average):\t{accuracy.compute().item():.4f}")

Metric	Average	BEGIN	IN	END	OUT
MulticlassPrecision	None	0.1739	0.9045	0.2051	0.2750
MulticlassRecall	None	0.1481	0.8102	0.3200	1.0000
MulticlassF1Score	None	0.1600	0.8547	0.2500	0.4314
Accuracy (micro average):	0.7460
