<a href="https://colab.research.google.com/github/667029/KVP10k/blob/main/LayoutMVL3_BinaryRelation_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Dokumentforståelse med LayoutLMv3 på KVP10k-datasettet

Denne notebooken demonstrerer hvordan vi henter inn en preprossessert og tilpasset versjon av KVP10k-datasettet i Hugging Face-format, til å utføre **Key-Value Pair Extraction (KVP)** på  dokumentbilder.

- Datasettet består av over 10k forretningsdokumenter, og inneholder blant annet dokumentbilder og tilhørende nøkkel-verdi-par, som brukes av denne fine-tuned modellen som utvikles her.

- Sluttmålet er å utvikle og trene en ny modell til dokumentforståelse, ved å forstå **visuell layout**, **tekstlig innhold**, og **relasjoner mellom nøkler og verdier** i dokumentene.
  - KVP-Extraction modellen som utvikles i denne notebooken er tenkt å brukes grunnmur i sluttmodellen, for å med stor sannsynlighet beherske å linke mellom nøkkel-verdi-par i ulike dokumenter.

LayoutLMv3 er en multi-modal modell designet for å kombinere tekst, layout og annen bilde-informasjon

---

###**Notbooken dekker følgende steg**:

1. Installasjon av de nødendige biblioteker
2. Lasting av forhåndsprosessert datasett
3. Tokenisering av tekst og input-formatering med Layout sin Processor
- 3.1 Logikk for å anngi predikerte BIO-labels til dokumentets bbox'es
4. Trening av modell for token-klassifisering
5. Evaluering og lagring av modell i Drive
6. Visualisering av modell under inferense

---

###**LayoutLMv3Processor - gjør følgende**:
1. Tekst-tokenisering: Tekst fra dokumentet tokeniseres.
2. Token-connection: Hvert token kobles til en bounding box (bbox) på dokumentet, gjennom *boxes*-parmeteret som inneholder (x0,y0,x1,y1)-kordinater til hvert token.
3. Image-embedding: Dokumentbildet skaleres og legges og blir input til modellen
4. Label-alignment: Hvert token får en BIO-label, som brukes under modellens token-klassifisering

Tokeniseringen handler om å forvandle dokumentet til tokens med alle nødvendige modaliteter (tekst, layout og bilde) slik at modellen lærer sammenhengen mellom dem gjennom trening.

**Det brukes BIO-tagger, og dette er hva det står for:**
 - B --> Begin: første token i en entitet.
 - I --> Inside: inne i en entitet.
 - O --> Outside: tokenen er ikke en del av noen entitet

f.eks.
  - Tokens:  ["Name", "of", "buyer", ":", "Ole", "Martin", "Lystadmoen"]
  - Labels:  ["B-KEY", "I-KEY", "I-KEY", "O", "B-VALUE", "I-VALUE", "I-VALUE"]

# Dataset - forståelse

**Innhold i train/-mappen i KVP10k:**
_____
  - *images*/ --> .png bilder av hvert dokument. Visuell input for modellen.
    - Det modellen "ser".
_____

  - *ocrs*/ --> JSON-filer med **words** og **bboxes** for hvert dokument. Gir tekst og posisjoner fra OCR, og brukes sammen med images.
    - Det modellen "leser" (tokens og posisjonene deres).

_____

  - *gts*/ --> JSON-filer med KVPs og tilhørende bboxes. Inneholder hvilke keys og values som hører sammen.
    - Det som lærer modellen hvilke tokens som er nøkler, verdier, og hvilket som er koblet sammen.
_____

  - *items*/ --> JSON-filer med annotasjoner og layout-objs (rektangler, linker, etiketter)
    - tilleggsinformasjon
    - ikke viktig i for EE
    - Helt nødvendig i RE-delen av dette prosjektet
_____

In [30]:
#transformers: Hugging Face bibliotek som gir tilgang til LayoutLMv3
#datasets: For håndtering av dataset i Huggig Face-format
#seqeval: evalueringsbibliotek for sekvensmerking, brukes for måle metrikker for i dette tilfelle BIO-tagging
!pip install -q transformers datasets seqeval

In [31]:
#Håndterer ulike metrikker inkl. integrasjon med seqeval
!pip install -q evaluate

In [32]:
import os              #navigere mapper og filer, hente filbaner
from PIL import Image  #åpne, vise og manipulere bilder
import json            #lese/skrive til JSON-filer
from transformers import LayoutLMv3Processor
import torch           #modellens input-format for data
from google.colab import drive

In [33]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [34]:
base_path = "/content/drive/MyDrive/KVP10k-dataset/kvp10k/"
print(os.listdir(base_path))

['train', 'test']


In [35]:
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) # <-- Viktig fordi vi allerede har utført OCR på bildet og har tekst og bboxes

In [36]:
# Mapping fra tekstlige BIO-labels til tall som modellen bruker
label_map = {
    "O": 0,
    "B-KEY": 1,
    "I-KEY": 2,
    "B-VALUE": 3,
    "I-VALUE": 4,
}

# Funksjon for å skalere bounding boxes til 0-1000 (som LayoutLMv3 krever)
def normalize_bbox(bbox, width, height):
  return [
      int(1000 * (bbox[0] /width)),
      int(1000 * (bbox[1] / height)),
      int(1000 * (bbox[2] / width)),
      int(1000 * (bbox[3] / height))
  ]


def assign_label_for_box(box, boxes, label_type):
  """Returnerer liste med (index, label) for tokens som overlapper box"""
  overlaps = []
  for i, token_box in enumerate(boxes):
    if box_overlap(box, token_box) > 0:
      overlaps.append(i)

  overlaps = sorted(overlaps)

  labeled = []
  for j, idx in enumerate(overlaps):
    tag = f"B-{label_type}" if j == 0 else f"I-{label_type}"
    labeled.append((idx, tag))

  return labeled


#Sjekker om OCR-boksen overlapper med GTS(key/value)-boksen.
#Ved overlapp hører de til hverandre.
def box_overlap(box1, box2):
  x0 = max(box1[0], box2[0])
  y0 = max(box1[1], box2[1])
  x1 = min(box1[2], box2[2])
  y1 = min(box1[3], box2[3])
  return max(0, x1 - x0) * max(0, y1 - y0)


# Funksjon for å generere BIO-labels fra gts (ground truth).
# Lager en BIO-label for hvert token basert på om det overlapper med en key- eller value-boks fra GTS.
# Matcher hvert token fra OCR (word + bbox) mot key/value-bbokser fra gts:
# --> Token overlapper en nøkkelboks: B-KEY eller I-KEY
# --> Token overlapper en verdiboks: B-VALUE eller I-VALUE
# --> Ellers: O
def iob_from_kvps(words, boxes, kvps):
  labels = ["O"] * len(words)

  #Gå igjennom alle key-value-pairs
  for kvp in kvps:
    if "key" in kvp and "bbox" in kvp["key"]:
      key_bbox = kvp["key"]["bbox"]
      for idx, tag in assign_label_for_box(key_bbox, boxes, "KEY"):
        labels[idx] = tag

    if "value" in kvp and "bbox" in kvp["value"]:
      value_box = kvp["value"]["bbox"]
      for idx, tag in assign_label_for_box(value_box, boxes, "VALUE"):
        labels[idx] = tag

  return labels

In [37]:
!ls -lh /content/drive/MyDrive/KVP10k_processed_ready


lrw------- 1 root root 0 Apr 20 20:52 /content/drive/MyDrive/KVP10k_processed_ready -> /content/drive/.shortcut-targets-by-id/1NbM9cwuCpZGK4W3yzn5hIDmqyp3uULqw/KVP10k_processed_ready


#Innlasting av et allerede pre-prossesert KVP10k-dataset spesielt utviklet for LayoutLMv3 (KVP-extraction).
##Ikke kjør denne!

NB: Dette prosessen gjøres i en annen notebook, vi henter inn resultatet her for å spare notebooken for plass og ryddighet.

Datasettet som lastes inn er på omlag 8600 dokumenter da det er antall dokumenter med ground-truths (gts).

In [38]:
from google.colab import drive
from datasets import load_from_disk

# Monter Drive (hvis du ikke har gjort det)
drive.mount("/content/drive", force_remount=True)

# Last inn dataset fra riktig path
dataset = load_from_disk("/content/drive/MyDrive/KVP10k_processed_ready/dataset_all_gts")

# Hent splits
train_dataset = dataset["train"]
eval_dataset = dataset["eval"]
test_dataset = dataset["test"]

Mounted at /content/drive


In [93]:
from evaluate import load
metric = load("seqeval")

import numpy as np

label_list = ["O", "B-KEY", "I-KEY", "B-VALUE", "I-VALUE"]
label2id = {label: i for i, label in enumerate(label_list)}
id2label = {i: label for i, label in enumerate(label_list)}

def add_relation_labels(example):
    rel_h = []
    rel_t = []
    rel_labels = []

    words = example.get("words", [])  # Get words, or empty list if not found
    word_labels = example.get("ner_tags", [])  # Get ner_tags if "labels" not found
    kvps = example.get("kvps_list", [])

    key_token_map = {}
    value_token_map = {}

    for idx, label_id in enumerate(word_labels):
        label = id2label.get(label_id, "O")
        if label.startswith("B-KEY"):
            key_token_map.setdefault(len(key_token_map), idx)
        elif label.startswith("B-VALUE"):
            value_token_map.setdefault(len(value_token_map), idx)

    entity_pairs = []  # Create entity_pairs for this example
    for rel_id, kvp in enumerate(kvps):
        key_idx = key_token_map.get(rel_id)
        val_idx = value_token_map.get(rel_id)

        # For relation extraction, we need to find spans of key and value tokens
        key_idxs = []
        val_idxs = []

        if key_idx is not None:
            key_idxs.append(key_idx)
            # Try to find consecutive "I-KEY" tokens
            i = key_idx + 1
            while i < len(word_labels) and id2label.get(word_labels[i], "O").startswith("I-KEY"):
                key_idxs.append(i)
                i += 1

        if val_idx is not None:
            val_idxs.append(val_idx)
            # Try to find consecutive "I-VALUE" tokens
            i = val_idx + 1
            while i < len(word_labels) and id2label.get(word_labels[i], "O").startswith("I-VALUE"):
                val_idxs.append(i)
                i += 1

        if key_idxs and val_idxs:  # Only add pairs if both key and value were found
            entity_pairs.append((key_idxs, val_idxs))
            rel_h.append(key_idx)
            rel_t.append(val_idx)
            rel_labels.append(1)  # 1 = real relation
    #---Endre fra rel_h[0] siden list comprehension har blitt lagt til på linje 23
    example["rel_h"] = rel_h
    example["rel_t"] = rel_t
    example["rel_labels"] = rel_labels

    # Store entity_pairs in the example
    example["entity_pairs"] = entity_pairs

    # Handle cases where no relations were found:
    # For training, it's helpful to have at least one pair, even if it's a "dummy" relation
    if not entity_pairs:
        #example["entity_pairs"] = [([0], [1])]  # Dummy pair of [CLS] and the next token (usually a separator)
        #example["rel_labels"] = [0]  # Label this dummy pair as "no relation" (0)
        example["rel_h"] = []
        example["rel_t"] = []
        example["entity_pairs"] = [] #Dummy pair of [CLS] and the next token (usually a separator)
        example["rel_labels"] = [] # Label this dummy pair as "no relation" (0)

    return example


# Assuming train_dataset, eval_dataset, and test_dataset are already loaded
train_dataset = train_dataset.map(add_relation_labels)
eval_dataset = eval_dataset.map(add_relation_labels)
test_dataset = test_dataset.map(add_relation_labels)

Map:   0%|          | 0/6273 [00:00<?, ? examples/s]

Map:   0%|          | 0/1569 [00:00<?, ? examples/s]

Map:   0%|          | 0/828 [00:00<?, ? examples/s]

In [79]:
train_dataset.features

{'pixel_values': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None),
 'input_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'bbox': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'rel_h': Sequence(feature=Value(dtype='null', id=None), length=-1, id=None),
 'rel_t': Sequence(feature=Value(dtype='null', id=None), length=-1, id=None),
 'rel_labels': Sequence(feature=Value(dtype='null', id=None), length=-1, id=None),
 'entity_pairs': Sequence(feature=Value(dtype='null', id=None), length=-1, id=None)}

In [81]:
example = train_dataset[0]
for k,v in example.items():
    try:
        print(k, v.shape)
    except AttributeError:
        print(k, len(v))  # Print length of list instead of shape

pixel_values torch.Size([3, 224, 224])
input_ids torch.Size([512])
attention_mask torch.Size([512])
bbox torch.Size([512, 4])
labels torch.Size([512])
rel_h 0
rel_t 0
rel_labels 0
entity_pairs 0


In [82]:
processor.tokenizer.decode(train_dataset[0]["input_ids"])

'<s> Date Net Zone/Syscode Time Spot Name Len Line Rate Flag 10-722 HIST L Middlesx Smrst Cty Zone 1302/1302 4:44PM CLFTVNJO702H 30 48 $73.00 10-722 HIST L Middlesx Smrst Cty Zone 1302/1302 7:45PM CLFTVNJO702H 30 487 $90.00 10-722 HIST L Middlesx Smrst Cty Zone 1302/1302 11:46PM CLFTVNJO702H 30 487 $90.00 10-722 AEN Morris County Zone 1305/1305 2:13PM CLFTVNJO702H 30 488 $21.00 10-722 AEN Morris County Zone 1305/1305 5:40PM CLFTVNJO702H 30 489 $49.00 10-722 AEN Morris County Zone 1305/1305 7:44PM CLFTVNJO702H 30 490 $59.00 10-722 AEN Morris County Zone 1305/1305 10:48PM CLFTVNJO702H 30 490 $59.00 10-722 CNN Morris County Zone 1305/1305 8:29AM CLFTVNJO702H 30 491 $73.00 10-722 CNN Morris County Zone 1305/1305 3:30PM CLFTVNJO702H 30 492 $65.00 10-722 CNN Morris County Zone 1305/1305 4:26PM CLFTVNJO702H 30 493 $121.00 10-722 CNN Morris County Zone 1305/1305 9:27PM CLFTVNJO702H 30 494 $164.00 10-722 DISC Morris County Zone 1305/1305 9:25AM CLFTVNJO702H 30 495 $21.00 10-722 DISC Morris Coun

In [83]:
for id, label in zip(train_dataset[0]["input_ids"], train_dataset[0]["labels"]):
  print(processor.tokenizer.decode([id]), label.item())

<s> -100
 Date 0
 Net 0
 Zone 0
/ -100
Sys -100
code -100
 Time 0
 Spot 0
 Name 0
 Len 0
 Line 0
 Rate 0
 Flag 0
 10 0
- -100
7 -100
22 -100
 H 0
IST -100
 L 0
 Middles 0
x -100
 Sm 0
r -100
st -100
 C 0
ty -100
 Zone 0
 130 0
2 -100
/ -100
130 -100
2 -100
 4 0
: -100
44 -100
PM -100
 CL 0
F -100
TV -100
NJ -100
O -100
702 -100
H -100
 30 0
 48 0
 $ 0
73 -100
. -100
00 -100
 10 0
- -100
7 -100
22 -100
 H 0
IST -100
 L 0
 Middles 0
x -100
 Sm 0
r -100
st -100
 C 0
ty -100
 Zone 0
 130 0
2 -100
/ -100
130 -100
2 -100
 7 0
: -100
45 -100
PM -100
 CL 0
F -100
TV -100
NJ -100
O -100
702 -100
H -100
 30 0
 48 0
7 -100
 $ 0
90 -100
. -100
00 -100
 10 0
- -100
7 -100
22 -100
 H 0
IST -100
 L 0
 Middles 0
x -100
 Sm 0
r -100
st -100
 C 0
ty -100
 Zone 0
 130 0
2 -100
/ -100
130 -100
2 -100
 11 0
: -100
46 -100
PM -100
 CL 0
F -100
TV -100
NJ -100
O -100
702 -100
H -100
 30 0
 48 0
7 -100
 $ 0
90 -100
. -100
00 -100
 10 0
- -100
7 -100
22 -100
 A 0
EN -100
 Morris 0
 County 0
 Zone 0
 130 0
5 -1

In [84]:
from evaluate import load
metric = load("seqeval")

In [85]:
import numpy as np

label_list = ["O", "B-KEY", "I-KEY", "B-VALUE", "I-VALUE"]
label2id = {label: i for i, label in enumerate(label_list)}
id2label = {i: label for i, label in enumerate(label_list)}


def compute_metrics(p):
    predictions, labels = p

    # === Token classification metrics (BIO tagging)
    token_preds = predictions["token_logits"]
    token_labels = labels["labels"]

    token_preds = np.argmax(token_preds, axis=2)
    true_preds = [
        [id2label[p] for (p, l) in zip(pred, lab) if l != -100]
        for pred, lab in zip(token_preds, token_labels)
    ]
    true_labels = [
        [id2label[l] for (p, l) in zip(pred, lab) if l != -100]
        for pred, lab in zip(token_preds, token_labels)
    ]

    bio_result = metric.compute(predictions=true_preds, references=true_labels)

    # === Relation classification metrics
    rel_preds = predictions["rel_logits"]
    rel_labels = labels["rel_labels"]

    if rel_preds is not None and rel_labels is not None:
        rel_preds = np.argmax(rel_preds, axis=1)
        rel_labels = np.array(rel_labels)

        rel_acc = accuracy_score(rel_labels, rel_preds)
        rel_f1 = f1_score(rel_labels, rel_preds, average="macro")
        rel_precision = precision_score(rel_labels, rel_preds, average="macro")
        rel_recall = recall_score(rel_labels, rel_preds, average="macro")
    else:
        rel_acc = rel_f1 = rel_precision = rel_recall = 0.0

    return {
        # BIO tagging
        "precision": bio_result["overall_precision"],
        "recall": bio_result["overall_recall"],
        "f1": bio_result["overall_f1"],
        "accuracy": bio_result["overall_accuracy"],

        # Relation prediction
        "rel_acc": rel_acc,
        "rel_f1": rel_f1,
        "rel_precision": rel_precision,
        "rel_recall": rel_recall
    }


#Innlasting av modell, valg av hyperparams og modell-argumenter

In [94]:
import torch
import torch.nn as nn
from transformers import LayoutLMv3Model

class LayoutLMv3WithBinaryRelation(nn.Module):
    def __init__(self, model_name="microsoft/layoutlmv3-base", hidden_size=768, num_labels=5):
        super().__init__()
        self.layoutlmv3 = LayoutLMv3Model.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)

        # Token classification head
        self.token_classifier = nn.Linear(hidden_size, num_labels)

        # Relation classification head
        self.rel_fc = nn.Sequential(
            nn.Linear(4 * hidden_size + 1, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)
        )

    def forward(self, input_ids, bbox, attention_mask, entity_pairs=None, labels=None, rel_labels=None):
        outputs = self.layoutlmv3(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
        sequence_output = self.dropout(outputs.last_hidden_state)

        # Token classification logits
        token_logits = self.token_classifier(sequence_output)

        rel_logits = None
        rel_loss = None
        if entity_pairs is not None and entity_pairs and entity_pairs[0]:
            batch_rel_logits = []
            for batch_idx, pairs in enumerate(entity_pairs):
                logits = []
                for (head_idxs, tail_idxs) in pairs:
                    h = sequence_output[batch_idx]
                    head = h[head_idxs].mean(dim=0)
                    tail = h[tail_idxs].mean(dim=0)

                    h_mul = head * tail
                    h_diff = head - tail
                    h_dot = torch.sum(head * tail).unsqueeze(0)

                    feats = torch.cat([head, tail, h_mul, h_diff, h_dot], dim=0)
                    logit = self.rel_fc(feats).squeeze()
                    logits.append(logit)
                batch_rel_logits.append(torch.stack(logits))
            rel_logits = torch.stack(batch_rel_logits)

            if rel_labels is not None:
                loss_fn = nn.BCEWithLogitsLoss()
                rel_loss = loss_fn(rel_logits, rel_labels)
        else:
            rel_logits = torch.tensor([])

        token_loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            # reshape to (batch*seq_len, num_labels)
            token_loss = loss_fct(token_logits.view(-1, token_logits.size(-1)), labels.view(-1))

        loss = None
        if token_loss is not None and rel_loss is not None:
            loss = token_loss + rel_loss
        elif token_loss is not None:
            loss = token_loss
        elif rel_loss is not None:
            loss = rel_loss

        return {
            "loss": loss,
            "token_logits": token_logits,
            "rel_logits": rel_logits
        }


In [95]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="/content/layoutlmv3_finetuned_kvp10k",
    num_train_epochs=4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=2e-5,
    weight_decay=0.01,
    eval_strategy="steps",
    eval_steps=100,
    save_steps=200,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    logging_dir="./logs",
    logging_steps=50,
    report_to="tensorboard",
    lr_scheduler_type="cosine",
    warmup_steps=500,
    fp16=True
)


model = LayoutLMv3WithBinaryRelation()


#Trainer oppsett
Inneholder:
  - Modellen (LayoutLMv3ForTokenClassification)
  - Args (hyperparametre som: epochs, batch_size, lr, lr_scheduler,    regularisering, eval_steps, metrics)
  - Datasetsplit (train, eval)
  - Tokenizer (from processor)
  - Collator (litt usikker på denne)
  - Metrikker for modellen


In [96]:
from transformers import Trainer
import torch.nn.functional as F
from transformers.data.data_collator import default_data_collator

class CombinedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(
            input_ids=inputs["input_ids"],
            bbox=inputs["bbox"],
            attention_mask=inputs.get("attention_mask"),
            labels=inputs.get("labels"),
            rel_labels=inputs.get("rel_labels"),
            entity_pairs=inputs.get("entity_pairs")
        )

        loss = outputs["loss"]

        return (loss, outputs) if return_outputs else loss



def combined_data_collator(features):
    batch = default_data_collator(features)
    if "rel_labels" in batch:
        batch["rel_labels"] = batch["rel_labels"].float()
    batch["entity_pairs"] = [f["entity_pairs"] for f in features]

    return batch

In [97]:
from transformers import Trainer, EarlyStoppingCallback
training_args

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=processor.tokenizer,
    data_collator=combined_data_collator,
    compute_metrics=compute_metrics,
)

  trainer = Trainer(


#Trening

In [99]:
trainer.train()



Step,Training Loss,Validation Loss


Step,Training Loss,Validation Loss


TypeError: tuple indices must be integers or slices, not str

In [None]:
trainer.evaluate()

#Evaluering på test-datasettet

In [None]:
trainer.evaluate(test_dataset)


#Lagring av beste Modell (i Drive)

In [None]:
# Angi en mappe i Drive (eller lokalt hvis du vil kopiere senere)
output_dir = "/content/drive/MyDrive/layoutlmv3_kvp10k_model_full_dataset"

# Lagre modell og tokenizer
trainer.save_model(output_dir)
processor.save_pretrained(output_dir)  # dette lagrer både tokenizer + feature extracto

#INFERENCE
Laster inn beste fine-tuned modell og dens tilhørende processor fra Drive, samt tilleggsinformasjon som kreves av processoren.



In [None]:
from transformers import AutoProcessor, AutoModelForTokenClassification
from google.colab import drive
drive.mount('/content/drive')


model_path = "/content/drive/MyDrive/layoutlmv3_kvp10k_model_full_dataset"

# Last inn modellen (med dine finetunede vekter)
model = AutoModelForTokenClassification.from_pretrained(model_path)

# Last inn processor (inneholder både tokenizer og feature extractor)
processor = AutoProcessor.from_pretrained(model_path)

# Sett modellen til riktig device (valgfritt, men vanlig)
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

id2label = model.config.id2label
id2label = model.config.id2label
label_map = label2id

#Kode prediksjon og visualisering av dette
###*Tokenisering og input-prosessering med Layout sin Processor ved inference*
Processor brukes her til å gjøre om tekst, bboxes, og bilde til format modell krever. Dette inkl:
- Tokenisering
- Normalisering av bboxes tilhørende hvert token
- Skalering av bilde
- Generering av input-tensorer

NB: Denne prosessen gjøres allerede i Data_Processor notebooken som ferdigstilte datasettet for **denne** notebooken. Selve prosessen er dermed nesten indentisk

In [None]:
from PIL import Image, ImageDraw, ImageFont
from IPython.display import display

#Nødvendig for å plassere boksene på originalt-format på bilde-dokumentet
def unnormalize_box(bbox, width, height):
    return [
        width * (bbox[0] / 1000),
        height * (bbox[1] / 1000),
        width * (bbox[2] / 1000),
        height * (bbox[3] / 1000),
    ]

def predict_and_visualize(doc_id, show_gt=True):
    base_path = "/content/drive/MyDrive/DAT255/KVP10k-dataset/kvp10k/test"

    # === Last inn bilde og metadata
    image_path = f"{base_path}/images/{doc_id}.png"
    ocr_path = f"{base_path}/ocrs/{doc_id}.json"
    gt_path = f"{base_path}/gts/{doc_id}.json"

    image = Image.open(image_path).convert("RGB")
    with open(ocr_path, "r", encoding="utf-8") as f:
        ocr_data = json.load(f)
    with open(gt_path, "r", encoding="utf-8") as f:
        gt_data = json.load(f)

    # === Hent tekst og bokser
    page = ocr_data["pages"][0]
    words = [w["text"] for w in page["words"]]
    raw_boxes = [w["bbox"] for w in page["words"]]
    width, height = page["width"], page["height"]
    norm_boxes = [normalize_bbox(b, width, height) for b in raw_boxes]

    # === Lag word_labels fra GT
    string_labels = iob_from_kvps(words, raw_boxes, gt_data["kvps_list"])
    word_labels = [label_map[l] for l in string_labels]

    # === Encoding for modellen
    encoding = processor(
        image,
        words,
        boxes=norm_boxes,
        word_labels=word_labels,
        return_tensors="pt",
        truncation=True,
        padding="max_length"
    )
    inputs = {k: v.to(model.device) for k, v in encoding.items()}

    # === Modellprediksjon
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)

    input_ids = encoding["input_ids"].squeeze().tolist()
    labels = encoding["labels"].squeeze().tolist()
    bboxes = encoding["bbox"].squeeze().tolist()
    predictions = outputs.logits.argmax(-1).squeeze().tolist()

    # === Unnormalize bboxes
    unnorm_boxes = [unnormalize_box(b, width, height) for b in bboxes]
    tokens = [processor.tokenizer.decode([tid]) for tid in input_ids]

    # === Filtrer vekk padding og spesialtokens
    filtered = [
        (token, id2label[label], id2label[pred], box)
        for token, label, pred, box in zip(tokens, labels, predictions, unnorm_boxes)
        if label != -100 and token not in ["[PAD]", "[CLS]", "[SEP]"]
    ]

    # === Tegn prediksjoner
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()

    def iob_to_label(label):
        return label[2:].lower() if label.startswith(("B-", "I-")) else "other"

    label2color = {
        "key": "blue",
        "value": "green",
        "other": "gray"
    }

    for token, true, pred, box in filtered:
        if pred == "O":
            continue
        label = iob_to_label(pred)
        draw.rectangle(box, outline=label2color.get(label, "red"), width=2)
        draw.text((box[0] + 5, box[1] - 10), label, fill=label2color.get(label, "red"), font=font)

    print("📷 Modellens prediksjoner:")
    display(image)

    # === Fasit (valgfritt)
    if show_gt:
        gt_img = Image.open(image_path).convert("RGB")
        draw_gt = ImageDraw.Draw(gt_img)

        for word, box, label_id in zip(words, raw_boxes, string_labels):
            if label_id == "O":
                continue
            label_type = iob_to_label(label_id)
            draw_gt.rectangle(box, outline=label2color.get(label_type, "gray"), width=2)
            draw_gt.text((box[0] + 5, box[1] - 10), label_type, fill=label2color.get(label_type, "gray"), font=font)

        print("✅ Ground Truth:")
        display(gt_img)


def predict_relations(doc_id):
    from PIL import Image, ImageDraw, ImageFont

    base_path = "/content/drive/MyDrive/DAT255/KVP10k-dataset/kvp10k/test"
    image_path = f"{base_path}/images/{doc_id}.png"
    ocr_path = f"{base_path}/ocrs/{doc_id}.json"

    image = Image.open(image_path).convert("RGB")
    with open(ocr_path, "r", encoding="utf-8") as f:
        ocr_data = json.load(f)

    page = ocr_data["pages"][0]
    words = [w["text"] for w in page["words"]]
    raw_boxes = [w["bbox"] for w in page["words"]]
    width, height = page["width"], page["height"]
    norm_boxes = [normalize_bbox(b, width, height) for b in raw_boxes]

    encoding = processor(
        image,
        words,
        boxes=norm_boxes,
        return_tensors="pt",
        truncation=True,
        padding="max_length"
    )

    input_ids = encoding["input_ids"]
    bbox = encoding["bbox"]
    image_tensor = encoding["image"]
    attention_mask = encoding["attention_mask"]

    model.eval()
    with torch.no_grad():
        output = model(
            input_ids=input_ids.to(model.device),
            bbox=bbox.to(model.device),
            image=image_tensor.to(model.device),
            attention_mask=attention_mask.to(model.device)
        )

    # === Extract token predictions and hidden states
    logits = output["token_logits"].squeeze()
    hidden_states = output["hidden_states"].squeeze()
    predictions = torch.argmax(logits, dim=-1)

    # === Identify B-KEY and B-VALUE token indices
    keys = []
    values = []
    for idx, pred in enumerate(predictions.tolist()):
        label = id2label.get(pred, "O")
        if label == "B-KEY":
            keys.append(idx)
        elif label == "B-VALUE":
            values.append(idx)

    # === Predict relations
    relations = []
    for k in keys:
        for v in values:
            h_i = hidden_states[k].unsqueeze(0)
            h_j = hidden_states[v].unsqueeze(0)
            rel_logits = model.relation(h_i, h_j)
            pred_rel = torch.argmax(rel_logits, dim=-1).item()
            if pred_rel == 1:  # Only keep predicted "real" relations
                relations.append((k, v))

    # === Visualize arrows
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()
    for k, v in relations:
        key_box = unnormalize_box(bbox[0][k].tolist(), width, height)
        val_box = unnormalize_box(bbox[0][v].tolist(), width, height)

        # draw arrow line
        x0, y0 = (key_box[0] + key_box[2]) / 2, (key_box[1] + key_box[3]) / 2
        x1, y1 = (val_box[0] + val_box[2]) / 2, (val_box[1] + val_box[3]) / 2
        draw.line([x0, y0, x1, y1], fill="red", width=2)

        # draw tokens for context
        draw.rectangle(key_box, outline="blue", width=2)
        draw.rectangle(val_box, outline="green", width=2)

    display(image)


#Velg vilkårlig dokument fra datasettet og prediker

In [None]:
#predict_and_visualize("aaf643426f0250efd10de3d9df63b407292f3fcc2aa335e399c37aca32443ea1")
#predict_relations("aaf643426f0250efd10de3d9df63b407292f3fcc2aa335e399c37aca32443ea1")

predict_and_visualize("aaed61e79aa3edbae844f5775789ebb6aa1a94a23d9cb3468d2cfc974af304e5")
predict_relations("aaed61e79aa3edbae844f5775789ebb6aa1a94a23d9cb3468d2cfc974af304e5")

#predict_and_visualize("aa35720ba3611f946c372cc99d8cd1d78e81265b8ceb51dcdb4672d196944c2b")
#predict_relations("aa35720ba3611f946c372cc99d8cd1d78e81265b8ceb51dcdb4672d196944c2b")

#predict_and_visualize("aa7c58830d0e84f98e9fdec1bc9e131227f9b00106aa3c78bc8ea346cfb9eac0")
#predict_relations("aa7c58830d0e84f98e9fdec1bc9e131227f9b00106aa3c78bc8ea346cfb9eac0")

#predict_and_visualize("faa5d71172e2e9959b41a5aec4fd2ab700534d1b2729484d2d5f26472cd56cfa")
#predict_relations("faa5d71172e2e9959b41a5aec4fd2ab700534d1b2729484d2d5f26472cd56cfa")

#predict_and_visualize("ffe462e43b9dff12e78ea8fb69332abfb789da171a8597f5bb961853e06e6fa2")
#predict_relations("ffe462e43b9dff12e78ea8fb69332abfb789da171a8597f5bb961853e06e6fa2")

#predict_and_visualize("feb2c4b21388318c7a51cc0aaf0e7c673a07f5204a40549a281bef065bb77925")
#predict_relations("feb2c4b21388318c7a51cc0aaf0e7c673a07f5204a40549a281bef065bb77925")

#predict_and_visualize("feaf84d435bd46100db82de51f5a989ff4d39fdcdb040a7044720b943e34b7d7")
#predict_relations("feaf84d435bd46100db82de51f5a989ff4d39fdcdb040a7044720b943e34b7d7")

#predict_and_visualize("df6b0a4cf1908bb95be874e4efa59411c685095d7bb596879961563503b5c239")
#predict_relations("df6b0a4cf1908bb95be874e4efa59411c685095d7bb596879961563503b5c239")


In [None]:

import torch.nn as nn
from transformers import LayoutLMv3Model

class LayoutLMv3WithBinaryRelation(nn.Module):
    def __init__(self, model_name="microsoft/layoutlmv3-base", hidden_size=768):
        super().__init__()
        self.layoutlmv3 = LayoutLMv3Model.from_pretrained(model_name)
        self.relation_classifier = nn.Bilinear(hidden_size, hidden_size, 1)
        self.dropout = nn.Dropout(0.1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, bbox, attention_mask, entity_pairs):
        outputs = self.layoutlmv3(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state

        batch_relations = []
        for batch_idx, pairs in enumerate(entity_pairs):
            head_tail_preds = []
            for (head_idxs, tail_idxs) in pairs:
                head_repr = sequence_output[batch_idx, head_idxs].mean(dim=0)
                tail_repr = sequence_output[batch_idx, tail_idxs].mean(dim=0)

                rel_logit = self.relation_classifier(self.dropout(head_repr), self.dropout(tail_repr))
                rel_prob = self.sigmoid(rel_logit).squeeze()
                head_tail_preds.append(rel_prob)
            batch_relations.append(torch.stack(head_tail_preds))
        return batch_relations


In [None]:

# Example usage after BIO tagging to get entity pairs:
# entity_pairs = [[([start1, start1+1], [start2]), ([start3], [start4, start4+1])], ...]
# This should be constructed based on decoded BIO tag spans

model = LayoutLMv3WithBinaryRelation()
model.eval()

# Dummy input
batch_size, seq_len = 2, 512
input_ids = torch.randint(0, 1000, (batch_size, seq_len))
bbox = torch.randint(0, 1000, (batch_size, seq_len, 4))
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)

entity_pairs = [
    [([10, 11], [15]), ([20], [30, 31])],
    [([5], [6]), ([100], [101])]
]

with torch.no_grad():
    relation_probs = model(input_ids, bbox, attention_mask, entity_pairs)

for i, probs in enumerate(relation_probs):
    print(f"Batch {i}: {[float(p) for p in probs]}")


In [None]:

# ─────────────── Innlasting av RE‑dataset ───────────────
from datasets import load_from_disk
import torch
from torch.utils.data import Dataset, DataLoader

# Last inn train og test
train_hf = load_from_disk("/content/drive/MyDrive/RE_ready/re_dataset_train_combined")
test_hf  = load_from_disk("/content/drive/MyDrive/RE_ready/re_dataset_test_combined")

class REPartnerDataset(Dataset):
    def __init__(self, hf_dataset):
        self.ds = hf_dataset

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        ex = self.ds[idx]
        hi    = torch.tensor(ex["h_i"],    dtype=torch.float)
        hj    = torch.tensor(ex["h_j"],    dtype=torch.float)
        himul = torch.tensor(ex["h_mul"],  dtype=torch.float)
        hidiff= torch.tensor(ex["h_diff"], dtype=torch.float)
        hdot  = torch.tensor([ex["h_dot"]], dtype=torch.float)
        label = torch.tensor(ex["label"],   dtype=torch.float)
        return hi, hj, himul, hidiff, hdot, label

train_ds = REPartnerDataset(train_hf)
test_ds  = REPartnerDataset(test_hf)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=32)


In [None]:

# ─────────────── Klassifikator ───────────────
import torch.nn as nn

class RelationClassifier(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(4 * hidden_dim + 1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, hi, hj, himul, hidiff, hdot):
        x = torch.cat([hi, hj, himul, hidiff, hdot], dim=-1)
        return self.fc(x).squeeze(-1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_dim = len(train_hf[0]["h_i"])
model = RelationClassifier(hidden_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()


In [None]:

# ─────────────── Trenings‑loop ───────────────
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for hi, hj, himul, hidiff, hdot, label in train_loader:
        hi, hj, himul, hidiff, hdot, label = [t.to(device) for t in (hi, hj, himul, hidiff, hdot, label)]
        optimizer.zero_grad()
        logits = model(hi, hj, himul, hidiff, hdot)
        loss = criterion(logits, label)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * hi.size(0)
    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs} — loss: {avg_loss:.4f}")


In [None]:

# ─────────────── Evaluerings‑loop ───────────────
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for hi, hj, himul, hidiff, hdot, label in test_loader:
        hi, hj, himul, hidiff, hdot = [t.to(device) for t in (hi, hj, himul, hidiff, hdot)]
        logits = model(hi, hj, himul, hidiff, hdot)
        probs = torch.sigmoid(logits).cpu().numpy()
        all_preds.extend(probs)
        all_labels.extend(label.numpy())

bin_preds = [1 if p >= 0.5 else 0 for p in all_preds]

acc = accuracy_score(all_labels, bin_preds)
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, bin_preds, average="binary")

print("\n=== Evalueringsresultater ===")
print(f"Accuracy : {acc:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall   : {recall:.4f}")
print(f"F1-score : {f1:.4f}")
