In [85]:
%load_ext autoreload
%autoreload now

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [86]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "MIG-08137aa2-e69b-5e74-8390-7997329b1336"
# os.environ["WORLD_SIZE"] = "1"

# Download and convert data

In [87]:
from tqdm import tqdm

from document_segmentation.pagexml.annotations.renate_analysis import RenateAnalysis
from document_segmentation.settings import RENATE_ANALYSIS_DIR

N = None

RENATE_ANALYSIS_DIR.mkdir(parents=True, exist_ok=True)

sheet = RenateAnalysis()


existing_docs = {
    path.stem for path in RENATE_ANALYSIS_DIR.glob("Globdoc_*.json") if path.is_file()
}

for document in tqdm(
    sheet.to_documents(n=N, skip_ids=existing_docs),
    total=(N or len(sheet)) - len(existing_docs),
    desc="Writing documents",
    unit="doc",
):
    document_file = RENATE_ANALYSIS_DIR / f"{document.id}.json"

    with document_file.open("xt") as f:
        f.write(document.model_dump_json())
        f.write("\n")

Writing documents: 0doc [00:00, ?doc/s]


In [88]:
import logging

from tqdm import tqdm

from document_segmentation.pagexml.annotations.renate_analysis import RenateAnalysisInv
from document_segmentation.settings import RENATE_ANALYSIS_DIR, RENATE_ANALYSIS_SHEETS

N = None


sheet = RenateAnalysisInv(RENATE_ANALYSIS_SHEETS[0])  # TODO: use both sheets

for document in tqdm(
    sheet.to_documents(n=N), desc="Writing documents", unit="doc", total=26
):
    document_file = RENATE_ANALYSIS_DIR / f"{document.id}.json"

    if document_file.exists():
        logging.info(f"Document {document.id} already exists, skipping")
    else:
        with document_file.open("xt") as f:
            f.write(document.model_dump_json())
            f.write("\n")

Writing documents:   0%|          | 0/26 [00:00<?, ?doc/s]

Writing documents: 100%|██████████| 26/26 [00:10<00:00,  2.37doc/s]


# Load Data

In [89]:
%autoreload now

In [90]:
TRAINING_DATA = 0.8

In [91]:
from document_segmentation.model.dataset import DocumentDataset

dataset: DocumentDataset = DocumentDataset.from_dir(RENATE_ANALYSIS_DIR)
dataset.shuffle()

len(dataset)

Reading JSON files: 100%|██████████| 104/104 [00:00<00:00, 164.60file/s]


104

In [92]:
dataset._class_counts()

Counter({<Label.IN: 1>: 1907,
         <Label.BEGIN: 0>: 104,
         <Label.END: 2>: 100,
         <Label.OUT: 3>: 73})

In [93]:
dataset.class_weights()

[0.9904761904761905,
 0.05450733752620545,
 1.0297029702970297,
 1.4054054054054055]

In [94]:
training_data, test_data = dataset.split(TRAINING_DATA)

In [95]:
training_data._class_counts()

Counter({<Label.IN: 1>: 1690,
         <Label.BEGIN: 0>: 83,
         <Label.END: 2>: 81,
         <Label.OUT: 3>: 72})

In [96]:
test_data._class_counts()

Counter({<Label.IN: 1>: 217,
         <Label.BEGIN: 0>: 21,
         <Label.END: 2>: 19,
         <Label.OUT: 3>: 1})

# Train Model

In [None]:
import torch

BATCH_SIZE = 64
EPOCHS = 5
WEIGHTS = torch.Tensor(dataset.class_weights())  # For an imbalanced dataset

In [None]:
%autoreload now

In [None]:
from document_segmentation.model.page_sequence_tagger import PageSequenceTagger

tagger = PageSequenceTagger()

In [None]:
tagger._device

'mps'

In [None]:
tagger

PageSequenceTagger(
  (_page_embedding): PageEmbedding(
    (_region_model): RegionEmbeddingSentenceTransformer(
      (_transformer_model): SentenceTransformer(
        (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: RobertaModel 
        (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False})
      )
      (_region_type): Embedding(9, 16)
      (_linear): Linear(in_features=784, out_features=512, bias=True)
    )
    (_rnn): LSTM(512, 256, num_layers=2, batch_first=True, dropout=0.1, bidirectional=True)
    (_linear): Linear(in_features=512, out_features=256, bias=True)
  )
  (_rnn): LSTM(256, 256, num_layers=2, batch_first=True, dropout=0.1, bidirectional=True)
  (_linear): Linear(in_features=512, out_features=4, bias=True)
  (_soft

In [None]:
tagger.train_(training_data, EPOCHS, BATCH_SIZE, WEIGHTS.to(tagger._device))

100%|██████████| 92/92 [01:10<00:00,  1.30batch/s]


[Loss:	0.520]


100%|██████████| 92/92 [00:05<00:00, 18.09batch/s]


[Loss:	0.515]


100%|██████████| 92/92 [00:04<00:00, 18.47batch/s]


[Loss:	0.515]


100%|██████████| 92/92 [00:04<00:00, 18.64batch/s]


[Loss:	0.515]


100%|██████████| 92/92 [00:04<00:00, 18.62batch/s]


[Loss:	0.515]


100%|██████████| 92/92 [00:05<00:00, 18.18batch/s]


[Loss:	0.515]


100%|██████████| 92/92 [00:05<00:00, 18.19batch/s]


[Loss:	0.515]


100%|██████████| 92/92 [00:04<00:00, 18.48batch/s]


[Loss:	0.515]


100%|██████████| 92/92 [00:04<00:00, 19.07batch/s]


[Loss:	0.514]


100%|██████████| 92/92 [00:05<00:00, 18.26batch/s]

[Loss:	0.514]





# Evaluate Model

In [None]:
import csv
import sys

from torcheval.metrics import (
    MulticlassAccuracy,
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)
from tqdm import tqdm

from document_segmentation.pagexml.datamodel.label import Label

writer = csv.DictWriter(
    sys.stdout,
    fieldnames=("Predicted", "Actual", "Page ID", "Text", "Scores"),
    delimiter="\t",
)

writer.writeheader()

accuracy = MulticlassAccuracy(num_classes=len(Label))
precision = MulticlassPrecision(average=None, num_classes=len(Label))
recall = MulticlassRecall(average=None, num_classes=len(Label))
f1_score = MulticlassF1Score(average=None, num_classes=len(Label))

for batch in tqdm(
    test_data.batches(BATCH_SIZE), total=test_data.n_batches(BATCH_SIZE), unit="batch"
):
    predicted = tagger(batch)
    labels = batch.labels()

    _labels = torch.Tensor([label.value for label in labels]).to(int)
    accuracy.update(predicted, _labels)
    precision.update(predicted, _labels)
    recall.update(predicted, _labels)
    f1_score.update(predicted, _labels)

    for page, pred, label in zip(batch.pages, predicted, labels):
        pred_label = Label(pred.argmax().item())
        writer.writerow(
            {
                "Predicted": pred_label.name,
                "Actual": label.name,
                "Page ID": page.doc_id,
                "Text": page.text(delimiter="; ")[:50],
                "Scores": str(pred.tolist()),
            }
        )

Predicted	Actual	Page ID	Text	Scores


  4%|▍         | 1/24 [00:00<00:04,  5.53batch/s]

OUT	OUT	NL-HaNA_1.04.02_1547_0430.jpg		[0.00023680477170273662, 0.0006503228796645999, 0.00011872695904457942, 0.9989941716194153]
BEGIN	BEGIN	NL-HaNA_1.04.02_1547_0431.jpg	Mondelingh Berigt ge„; „daen door den E: Cop=n; 1;	[0.9998762607574463, 6.241842493182048e-05, 3.1839925213716924e-05, 2.9442173399729654e-05]
IN	IN	NL-HaNA_1.04.02_1547_0432.jpg	En motiven wat uijt d' verweijderingen tusschen; d	[0.002645435044541955, 0.9959866404533386, 0.0009351801709271967, 0.0004327444767113775]
IN	IN	NL-HaNA_1.04.02_1547_0433.jpg	het geheele gebret van het Land van Elledetta„; „s	[0.0004950308357365429, 0.9990591406822205, 0.0003787148161791265, 6.719260272802785e-05]
IN	IN	NL-HaNA_1.04.02_1547_0434.jpg	vant perisolisehe Rijck, zijn h:r den ragia; van c	[0.0002472178020980209, 0.9994496703147888, 0.0002673076814971864, 3.573816502466798e-05]
IN	IN	NL-HaNA_1.04.02_1547_0435.jpg	Betuijginge: bij haer hoogh=t wierde, versogt te; 	[0.00018776272190734744, 0.999562680721283, 0.0002247396914754063, 

 21%|██        | 5/24 [00:00<00:02,  6.80batch/s]

OUT	OUT	NL-HaNA_1.04.02_1547_0614.jpg		[0.0001189231697935611, 0.00040132092544808984, 4.494664608500898e-05, 0.9994348883628845]
OUT	BEGIN	NL-HaNA_1.04.02_1547_0615.jpg		[4.286131661501713e-05, 8.152954978868365e-05, 2.2358128262567334e-05, 0.999853253364563]
BEGIN	IN	NL-HaNA_1.04.02_1547_0616.jpg	1: de weduwe van Cornelis verdonck zal=r - - - - -	[0.9996368885040283, 7.999297667993233e-05, 5.104193769511767e-05, 0.00023214492830447853]
IN	IN	NL-HaNA_1.04.02_1547_0617.jpg	geruwde Persoonen met haare familjen,; anur Adriaa	[0.011218822561204433, 0.9818529486656189, 0.002009269082918763, 0.004918962717056274]
IN	IN	NL-HaNA_1.04.02_1547_0618.jpg	26. huijs gesinnen Naamen toenamen - - - - - - - -	[0.0015154351713135839, 0.9976119995117188, 0.0006167812971398234, 0.0002557664120104164]
IN	IN	NL-HaNA_1.04.02_1547_0619.jpg	P„r Transport Copper : 348. —; 1: - - - - - - - - 	[0.00031454244162887335, 0.9993447661399841, 0.00028837512945756316, 5.231601244304329e-05]
IN	IN	NL-HaNA_1.04.02_1547_0

 42%|████▏     | 10/24 [00:01<00:01, 13.70batch/s]

IN	IN	NL-HaNA_1.04.02_1547_0253.jpg	van d' Ed=le H„r Directeur; en Raad tot Zouratta z	[0.0001891237625386566, 0.999565064907074, 0.00022197455109562725, 2.3785076336935163e-05]
IN	IN	NL-HaNA_1.04.02_1547_0254.jpg	twoorenstaande formelier, gisteren voor de; middag	[0.00015430280473083258, 0.9995915293693542, 0.00023246802447829396, 2.1642721549142152e-05]
IN	IN	NL-HaNA_1.04.02_1547_0255.jpg	No/a, ook wel 140. maar tegens 135. gereek. t,; â 	[0.0001369184465147555, 0.9996145963668823, 0.0002278776082675904, 2.0600311472662725e-05]
IN	IN	NL-HaNA_1.04.02_1547_0256.jpg	meugt hoedanigh ick volgens mijn schuldige; pligt;	[0.00015061532030813396, 0.9996059536933899, 0.00022092173458077013, 2.2558742784895003e-05]
IN	IN	NL-HaNA_1.04.02_1547_0257.jpg	uE: twee aangenaame brieven gedat. t; 22. en en 25	[0.00014808033301960677, 0.9995786547660828, 0.0002490175829734653, 2.4197321181418374e-05]
IN	IN	NL-HaNA_1.04.02_1547_0258.jpg	vorsten mondelijk raatspleeginge gehouden; hebbend	[0.000136817223392

 62%|██████▎   | 15/24 [00:01<00:00, 17.50batch/s]

BEGIN	BEGIN	NL-HaNA_1.04.02_1631_0289.jpg	Twee Copie translaat briefien; door de hofs-groten	[0.9999105930328369, 6.585857045138255e-05, 2.121354191331193e-05, 2.356615141252405e-06]
IN	IN	NL-HaNA_1.04.02_1631_0290.jpg	wederseijdse gesonth:t &:a bestaande :/ gevoerd; z	[0.001078862464055419, 0.9983059167861938, 0.0005363075761124492, 7.894612645031884e-05]
IN	IN	NL-HaNA_1.04.02_1631_0291.jpg	Luijden sig onthielden, en ondervondt; dat die dri	[0.00024444443988613784, 0.9994757771492004, 0.00025385277695022523, 2.5926610760507174e-05]
IN	IN	NL-HaNA_1.04.02_1631_0292.jpg	daarom niet en konden gaan, derhalve; gelieve uEd:	[0.00022657186491414905, 0.999492883682251, 0.0002580515865702182, 2.2533637093147263e-05]
IN	IN	NL-HaNA_1.04.02_1631_0293.jpg	als 't haar maar In de zin quam; baldadigh, en sto	[0.00019673268252518028, 0.9995738863945007, 0.00021093316900078207, 1.8427284885547124e-05]
IN	IN	NL-HaNA_1.04.02_1631_0294.jpg	hebben verstaan, en begrepen: waar—; na nogh eenig	[0.0001670555939

 75%|███████▌  | 18/24 [00:01<00:00, 17.28batch/s]

BEGIN	IN	NL-HaNA_1.04.02_3577_0713.jpg	intermen van klem vermaand in deezen 't zijne met 	[0.9999111890792847, 6.942303298274055e-05, 1.7352729628328234e-05, 1.954658046088298e-06]
IN	IN	NL-HaNA_1.04.02_3577_0714.jpg	oudstens te neemen, om niet door zulke vijandelijk	[0.001180727151222527, 0.9981112480163574, 0.0006211033323779702, 8.702724153408781e-05]
IN	IN	NL-HaNA_1.04.02_3577_0715.jpg	hier toe dover eene Ernstige aanspraak aangespoord	[0.0002669957175385207, 0.9994702935218811, 0.00023714298731647432, 2.5595105398679152e-05]
IN	IN	NL-HaNA_1.04.02_3577_0716.jpg	volbragt en in het Alphoirs gebergte van waijsamoe	[0.0002059456892311573, 0.9994897842407227, 0.0002802750386763364, 2.396959280304145e-05]
IN	IN	NL-HaNA_1.04.02_3577_0717.jpg	na de Negorij Tiehoelale welkers Orangkanj en scho	[0.00016071864229161292, 0.9995928406715393, 0.00022814009571447968, 1.8367962184129283e-05]
IN	IN	NL-HaNA_1.04.02_3577_0718.jpg	vervolgens aan het Huijs van den Orangkaij gekoome	[0.00016177234647329

100%|██████████| 24/24 [00:01<00:00, 14.57batch/s]

BEGIN	BEGIN	NL-HaNA_1.04.02_1108_0653.jpg	d' E. d: heeren Bewindehebberen; mpult. febr. 1633	[0.9999197721481323, 6.24519307166338e-05, 1.5677125702495687e-05, 1.9859278381773038e-06]
IN	IN	NL-HaNA_1.04.02_1108_0654.jpg	syn gecompareert, die terstont liet vast binden, e	[0.0010502388468012214, 0.9982824325561523, 0.0005819913931190968, 8.536814129911363e-05]
IN	IN	NL-HaNA_1.04.02_1108_0655.jpg	3a di dito smorgens sagen seecker vaertuyg onder't	[0.00026778728351928294, 0.9994713664054871, 0.0002345546381548047, 2.6388303012936376e-05]
IN	IN	NL-HaNA_1.04.02_1108_0656.jpg	voor de Manipo, vandewelcke voor desen geschreven 	[0.00020679035515058786, 0.9994906187057495, 0.0002777914342004806, 2.475599649187643e-05]
IN	IN	NL-HaNA_1.04.02_1108_0657.jpg	tegenwoordigh inde Negrij geen nagelen waren: dien	[0.0001687015756033361, 0.9995935559272766, 0.00021830877813044935, 1.936741500685457e-05]
IN	IN	NL-HaNA_1.04.02_1108_0658.jpg	die haer aenhouden, te straffen, so veel als in on	[0.00016586948186




In [None]:
writer = csv.DictWriter(
    sys.stdout,
    fieldnames=["Metric"] + [label.name for label in Label],
    delimiter="\t",
)
writer.writeheader()

for metric in (precision, recall, f1_score):
    scores = {
        label.name: f"{score:.4f}"
        for label, score in zip(Label, metric.compute().tolist())
    }
    writer.writerow({"Metric": metric.__class__.__name__} | scores)

print(f"Accuracy ({accuracy.average} average):\t{accuracy.compute().item():.4f}")

Metric	BEGIN	IN	END	OUT
MulticlassPrecision	0.7308	1.0000	0.8696	0.7368
MulticlassRecall	0.9048	0.9748	0.9524	1.0000
MulticlassF1Score	0.8085	0.9872	0.9091	0.8485
Accuracy (micro average):	0.9718
