# KVP10k Key-Value Relation Extraction

Denne Colab-notebooken viser hvordan du kan kjøre LayoutLMv3-basert nøkkel-verdi-relasjons­ekstraksjon på KVP10k-datasettet, og visualisere resultatet med piler mellom `KEY`- og `VALUE`-spans.


In [None]:
# Installer nødvendige pakker
!pip install transformers torch pillow

In [None]:
# Koble til Google Drive (for å lese data fra Drive)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import json
import math
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import LayoutLMv3Processor, LayoutLMv3Model

# Hjelpefunksjoner
def normalize_bbox(bbox, width, height):
    left, top, right, bottom = bbox
    return [
        int(1000 * left / width),
        int(1000 * top / height),
        int(1000 * right / width),
        int(1000 * bottom / height),
    ]

def unnormalize_box(box, width, height):
    left = int(box[0] * width / 1000)
    top = int(box[1] * height / 1000)
    right = int(box[2] * width / 1000)
    bottom = int(box[3] * height / 1000)
    return [left, top, right, bottom]

def extract_spans(labels, mask, kind):
    """
    Ekstraherer spans av typen B-<kind>/I-<kind>, ignorerer pad-tokens.
    """
    spans, cur = [], []
    for idx, (lab, m) in enumerate(zip(labels, mask)):
        if m == 0:
            if cur:
                spans.append(cur)
                cur = []
            continue
        if lab == f"B-{kind}":
            if cur:
                spans.append(cur)
            cur = [idx]
        elif lab == f"I-{kind}" and cur:
            cur.append(idx)
        else:
            if cur:
                spans.append(cur)
                cur = []
    if cur:
        spans.append(cur)
    return spans

In [None]:
# Modellklasse
class LayoutWithRelationModel(torch.nn.Module):
    """
    LayoutLMv3-basert modell med token-klassifisering + relasjonslag.
    """
    def __init__(self, model_name='microsoft/layoutlmv3-base',
                 hidden_size=768, num_labels=5):
        super().__init__()
        self.layout = LayoutLMv3Model.from_pretrained(model_name)
        self.dropout = torch.nn.Dropout(0.1)
        self.token_classifier = torch.nn.Linear(hidden_size, num_labels)
        rel_in = 4 * hidden_size + 1 + 2 * num_labels
        self.rel_fc = torch.nn.Sequential(
            torch.nn.Linear(rel_in, hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_size, 1)
        )

    def forward(self, input_ids, bbox, pixel_values,
                attention_mask, entity_pairs=None):
        out = self.layout(
            input_ids=input_ids,
            bbox=bbox,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            return_dict=True
        )
        hidden = self.dropout(out.last_hidden_state)
        token_logits = self.token_classifier(hidden)
        token_probs = torch.softmax(token_logits, dim=-1)

        rel_logits = None
        if entity_pairs is not None:
            batch_scores = []
            for b_idx, pairs in enumerate(entity_pairs):
                h_b = hidden[b_idx]
                p_b = token_probs[b_idx]
                scores = []
                for k_idx, v_idx in pairs:
                    h_i = h_b[k_idx].mean(0)
                    h_j = h_b[v_idx].mean(0)
                    h_mul = h_i * h_j
                    h_diff = h_i - h_j
                    h_dot = torch.sum(h_i * h_j).unsqueeze(0)
                    p_i = p_b[k_idx].mean(0)
                    p_j = p_b[v_idx].mean(0)
                    feats = torch.cat([h_i, h_j, h_mul, h_diff, h_dot, p_i, p_j], dim=0)
                    scores.append(self.rel_fc(feats).squeeze())
                batch_scores.append(torch.stack(scores))
            max_len = max(s.size(0) for s in batch_scores)
            padded = []
            for s in batch_scores:
                if s.size(0) < max_len:
                    pad = s.new_zeros(max_len - s.size(0))
                    s = torch.cat([s, pad], dim=0)
                padded.append(s)
            rel_logits = torch.stack(padded)
        return {'token_logits': token_logits, 'rel_logits': rel_logits}

In [None]:
# Funksjon for å kjøre hele pipeline og tegne relasjoner
def predict_relations_per_key(doc_id, threshold=0.5,
                              model=None, processor=None,
                              id2label=None, device=None,
                              base_path=None):
    # Load image og OCR
    img_p = f"{base_path}/images/{doc_id}.png"
    ocr_p = f"{base_path}/ocrs/{doc_id}.json"
    image = Image.open(img_p).convert('RGB')
    with open(ocr_p, 'r', encoding='utf-8') as f:
        ocr = json.load(f)
    page = ocr['pages'][0]
    words = [w['text'] for w in page['words']]
    raw = [w['bbox'] for w in page['words']]
    W, H = page['width'], page['height']
    norm = [normalize_bbox(b, W, H) for b in raw]

    # Encoding
    enc = processor(image, words, boxes=norm,
                    return_tensors='pt', truncation=True, padding='max_length')
    for k in enc:
        enc[k] = enc[k].to(device)
    token_boxes = enc['bbox'][0].cpu().tolist()
    mask = enc['attention_mask'][0].cpu().tolist()

    # 1) Token-prediksjoner
    model.eval()
    with torch.no_grad():
        out_t = model(input_ids=enc['input_ids'],
                      bbox=enc['bbox'],
                      pixel_values=enc['pixel_values'],
                      attention_mask=enc['attention_mask'])
    preds = out_t['token_logits'][0].argmax(-1).tolist()
    labels = [id2label[i] for i in preds]
    key_spans = extract_spans(labels, mask, 'KEY')
    value_spans = extract_spans(labels, mask, 'VALUE')

    # 2) Relasjonslogits
    pairs = [(k, v) for k in key_spans for v in value_spans]
    with torch.no_grad():
        out_r = model(input_ids=enc['input_ids'],
                      bbox=enc['bbox'],
                      pixel_values=enc['pixel_values'],
                      attention_mask=enc['attention_mask'],
                      entity_pairs=[pairs])
    rel_logits = out_r['rel_logits'][0]

    # 3) Velg beste VALUE per KEY
    best = {}
    for idx, (k, v) in enumerate(pairs):
        p = torch.sigmoid(rel_logits[idx]).item()
        kt = tuple(k)
        if kt not in best or p > best[kt][2]:
            best[kt] = (k, v, p)
    relations = [(k, v, p) for (k, v, p) in best.values() if p >= threshold]

    # 4) Tegn spans og relasjoner
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()
    arrow_len = 15
    angle = math.radians(25)

    # Tegn alle KEY og VALUE
    for k in key_spans:
        for i in k:
            draw.rectangle(unnormalize_box(token_boxes[i], W, H),
                           outline='blue', width=2)
    for v in value_spans:
        for i in v:
            draw.rectangle(unnormalize_box(token_boxes[i], W, H),
                           outline='green', width=2)

    # Tegn pil per KEY→beste VALUE
    if not relations:
        draw.text((10,10), 'Ingen relasjoner funnet',
                  fill='orange', font=font)
    else:
        for k, v, p in relations:
            kxs = [(b[0]+b[2])/2 for idx in k for b in [token_boxes[idx]]]
            kys = [(b[1]+b[3])/2 for idx in k for b in [token_boxes[idx]]]
            vxs = [(b[0]+b[2])/2 for idx in v for b in [token_boxes[idx]]]
            vys = [(b[1]+b[3])/2 for idx in v for b in [token_boxes[idx]]]
            kc = ((sum(kxs)/len(kxs))*W/1000, (sum(kys)/len(kys))*H/1000)
            vc = ((sum(vxs)/len(vxs))*W/1000, (sum(vys)/len(vys))*H/1000)

            dx, dy = vc[0]-kc[0], vc[1]-kc[1]
            dist = math.hypot(dx, dy)
            if dist > arrow_len:
                ux, uy = dx/dist, dy/dist
                tail = (kc[0], kc[1])
                hb = (vc[0]-ux*arrow_len, vc[1]-uy*arrow_len)
                draw.line([tail, hb], fill='red', width=2)
            else:
                hb = kc

            ux, uy = (vc[0]-hb[0])/arrow_len, (vc[1]-hb[1])/arrow_len
            lx = vc[0] - arrow_len*(ux*math.cos(angle) + uy*math.sin(angle))
            ly = vc[1] - arrow_len*(uy*math.cos(angle) - ux*math.sin(angle))
            rx = vc[0] - arrow_len*(ux*math.cos(angle) - uy*math.sin(angle))
            ry = vc[1] - arrow_len*(uy*math.cos(angle) + ux*math.sin(angle))
            draw.polygon([vc, (lx, ly), (rx, ry)], fill='red')
            draw.text((vc[0]+3, vc[1]-10), f"{p:.2f}",
                      fill='red', font=font)

    display(image)

In [None]:
# Demonstrasjon
# Sett opp device, processor, model, id2label
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processor = LayoutLMv3Processor.from_pretrained('microsoft/layoutlmv3-base')
model = LayoutWithRelationModel().to(device)
# NB: erstatt id2label med riktig mapping for ditt datasett
id2label = {i: lab for lab, i in processor.tokenizer.get_vocab().items()}

# Kjør prediksjon for et dokument
predict_relations_per_key(
    doc_id='f1689dd6-ac75-46ca-a32d-ae56d571dfa6',
    threshold=0.5,
    model=model,
    processor=processor,
    id2label=id2label,
    device=device,
    base_path='/content/drive/MyDrive/KVP10k-dataset/kvp10k/test'
)