In [4]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoProcessor, BartForConditionalGeneration, BartConfig
from transformers.modeling_outputs import BaseModelOutput
from datasets import load_dataset

class LayoutLMv3BART(nn.Module):
    def __init__(self, encoder_name="microsoft/layoutlmv3-base", num_labels=256):
        super().__init__()

        # ----- Encoder -----
        self.encoder = AutoModel.from_pretrained(encoder_name)
        hidden_size = self.encoder.config.hidden_size  # 768

        # ----- BART Decoder -----
        config = BartConfig.from_pretrained("facebook/bart-base")
        config.encoder_layers = 0                     # remove BART encoder
        config.d_model = hidden_size                  # match LayoutLMv3 dim
        config.vocab_size = num_labels                # segment/NER label count
        config.max_position_embeddings = 1024         # safe upper bound
        config.num_hidden_layers = 1
        print(config)
        self.decoder = BartForConditionalGeneration(config)

    def forward(self,
                input_ids,
                bbox,
                pixel_values,
                attention_mask,
                labels=None):

        # ----- Encode -----
        enc_out = self.encoder(
            input_ids=input_ids,
            bbox=bbox,
            pixel_values=pixel_values,
            attention_mask=attention_mask
        ).last_hidden_state   # shape: [B, 509, 768]

        # BART expects encoder_outputs as a tuple
        encoder_outputs = (enc_out,)

        # ----- Decode -----
        out = self.decoder(
            encoder_outputs=encoder_outputs,
            labels=labels,                # BART handles shifting internally
            return_dict=True
        )

        return out

    @torch.no_grad()
    def generate(self, input_ids, bbox, pixel_values, attention_mask, max_length=128):
        # Encode like LayoutLMv3
        enc_out = self.encoder(
            input_ids=input_ids,
            bbox=bbox,
            pixel_values=pixel_values,
            attention_mask=attention_mask
        ).last_hidden_state
    
        # WRAP in correct HF object
        encoder_outputs = BaseModelOutput(last_hidden_state=enc_out)
    
        # Now decoder.generate works
        pred = self.decoder.generate(
            encoder_outputs=encoder_outputs,
            max_length=max_length
        )
        
        return pred


processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)

train_ds = load_dataset("parquet", data_files="../data/funsd-v3/funsd/train-00000-of-00001.parquet", split="train")
example = train_ds[5]
image = example["image"]
words = example["tokens"]
boxes = example["bboxes"]
labels = example["ner_tags"]

device = 'cuda'

encoding = processor(
    image,
    words,
    boxes=boxes,
    word_labels=labels,
    return_tensors="pt"
)

word_labels = encoding.pop("labels")

model = LayoutLMv3BART(num_labels=256)
model = model.to(device)
out = model(
    input_ids=encoding["input_ids"].to(device),
    bbox=encoding["bbox"].to(device),
    pixel_values=encoding["pixel_values"].to(device),
    attention_mask=encoding["attention_mask"].to(device),
    labels=word_labels.to(device)
)

generated = model.generate(
    input_ids=encoding["input_ids"].to(device),
    bbox=encoding["bbox"].to(device),
    pixel_values=encoding["pixel_values"].to(device),
    attention_mask=encoding["attention_mask"].to(device),
)

print(out.loss, out.logits.shape, generated)

BartConfig {
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "dtype": "float32",
  "early_stopping": true,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 0,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_position_embeddings": 1024,
  "model_type": "bart",
  "no_



tensor(3.2852, device='cuda:0', grad_fn=<NllLossBackward0>) torch.Size([1, 312, 256]) tensor([[  2,   0,   0,   0, 220, 220, 220,  79,  79,  79, 134, 134, 134,  41,
          41,  41, 222, 222, 222, 253, 253, 253, 119, 119, 119,  81,  81,  81,
           5,   5,   5, 236, 236, 236, 207, 207, 207, 137, 137, 137, 161, 161,
         161, 232, 232, 232, 215, 215, 215,  14,  14,  14, 147, 147, 147, 226,
         226, 226,   4,   4,   4,  75,  75,  75, 247, 247, 247, 219, 219, 219,
          46,  46,  46,  24,  24,  24,   9,   9,   9, 146, 146, 146,  86,  86,
          86, 165, 165, 165,  99,  99,  99, 239, 239, 239, 195, 195, 195, 248,
         248, 248, 120, 120, 120,  12,  12,  12,  71,  71,  71, 169, 169, 169,
         161, 161, 201, 201, 201, 212, 212, 212,  46,  46, 113, 113, 113,  46,
          46,   2]], device='cuda:0')


In [5]:
from PIL import Image
from torch.nn.utils.rnn import pad_sequence


def preprocess_batch(batch):
    max_len = 512
    images = [img.convert("RGB") for img in batch["image"]]
    words = batch["tokens"]
    boxes = batch["bboxes"]
    word_labels = [[t + 3 for t in tags][:max_len - 1] +  [2] + [-100] * max(max_len - len(tags) - 1, 0) for tags in batch["ner_tags"]]
    enc = processor(
        images,
        words,
        boxes=boxes,
        word_labels=word_labels,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
        max_length=max_len
    )
    enc['decoder_labels'] = word_labels

    return enc

train_ds = load_dataset("parquet", data_files="../data/funsd-v3/funsd/train-00000-of-00001.parquet", split="train")
train_ds = train_ds.map(
    preprocess_batch,
    batched=True,
    remove_columns=train_ds.column_names
)
train_ds.set_format('pytorch')

test_ds = load_dataset("parquet", data_files="../data/funsd-v3/funsd/test-00000-of-00001.parquet", split="train")
test_ds = test_ds.map(
    preprocess_batch,
    batched=True,
    remove_columns=test_ds.column_names
)
test_ds.set_format('pytorch')

train_ds[0]['input_ids'].shape, train_ds[0]['decoder_labels'].shape, train_ds[0]['decoder_labels']

Map:   0%|          | 0/149 [00:00<?, ? examples/s]

(torch.Size([512]),
 torch.Size([512]),
 tensor([   3,    6,    6,    6,    8,    6,    6,    3,    4,    5,    5,    5,
            5,    5,    6,    7,    7,    7,    7,    8,    9,    9,    9,    9,
            9,    8,    9,    9,    6,    7,    8,    9,    9,    6,    7,    7,
            8,    9,    9,    9,    9,    9,    9,    9,    9,    9,    9,    9,
            9,    9,    9,    9,    9,    9,    9,    9,    9,    9,    9,    9,
            9,    9,    9,    9,    9,    9,    9,    9,    9,    9,    9,    9,
            9,    9,    9,    9,    9,    9,    9,    9,    9,    9,    9,    9,
            9,    9,    9,    9,    9,    9,    6,    7,    7,    7,    8,    9,
            9,    9,    9,    9,    9,    9,    9,    9,    9,    9,    9,    9,
            9,    9,    9,    9,    9,    9,    9,    9,    9,    9,    9,    9,
            9,    9,    9,    4,    5,    5,    5,    5,    5,    6,    7,    8,
            9,    9,    9,    9,    9,    8,    9,    9,    9,    9, 

In [15]:
import torch
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"

model = LayoutLMv3BART(num_labels=256).to(device)


train_dataloader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=32)

optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)

num_epochs = 20
num_training_steps = len(train_dataloader) * num_epochs
num_warmup_steps = int(0.1 * num_training_steps)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)

for epoch in range(num_epochs):
    model.train()

    loop = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch in loop:

        # Move all encoder inputs
        input_ids = batch["input_ids"].to(device)
        bbox = batch["bbox"].to(device)
        pixel_values = batch["pixel_values"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # Decoder labels (word-level labels)
        labels = batch["labels"].to(device)

        # Forward
        outputs = model(
            input_ids=input_ids,
            bbox=bbox,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            labels=labels
        )

        loss = outputs.loss

        # Backprop
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        loop.set_postfix(loss=loss.item())

    model.eval()
    total_val_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Validation"):
    
            input_ids = batch["input_ids"].to(device)
            bbox = batch["bbox"].to(device)
            pixel_values = batch["pixel_values"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["decoder_labels"].to(device)
    
            outputs = model(
                input_ids=input_ids,
                bbox=bbox,
                pixel_values=pixel_values,
                attention_mask=attention_mask,
                labels=labels
            )
    
            total_val_loss += outputs.loss.item()
    
    print("Validation Loss:", total_val_loss / len(test_dataloader))

    # acc = exact_match_accuracy(model, test_dataloader)
    # print("Accuracy:", acc)

torch.save(model.state_dict(), "../models/layoutlmv3_bart_segmentation.pth")

BartConfig {
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "dtype": "float32",
  "early_stopping": true,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 0,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_position_embeddings": 1024,
  "model_type": "bart",
  "no_

Epoch 1/20: 100%|██████████| 5/5 [00:09<00:00,  1.83s/it, loss=3.31]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 1.873741626739502


Epoch 2/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.72]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 1.1236109733581543


Epoch 3/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.37]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 0.9227676391601562


Epoch 4/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.31]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 0.8137649893760681


Epoch 5/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.17]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 0.7981394231319427


Epoch 6/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.11]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


Validation Loss: 0.7640697956085205


Epoch 7/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.13]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 0.7617313265800476


Epoch 8/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.18]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 0.745508998632431


Epoch 9/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.13]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


Validation Loss: 0.7329736351966858


Epoch 10/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.1] 
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 0.7095493674278259


Epoch 11/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.06]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 0.7204688489437103


Epoch 12/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.1]  
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 0.6976237893104553


Epoch 13/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.08]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 0.7044256031513214


Epoch 14/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=0.963]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 0.7024900615215302


Epoch 15/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.05]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 0.6880208551883698


Epoch 16/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=0.973]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


Validation Loss: 0.6976084411144257


Epoch 17/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.03] 
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 0.6843081414699554


Epoch 18/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.08] 
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 0.6929122507572174


Epoch 19/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=0.968]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 0.6839370131492615


Epoch 20/20: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=0.999]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 0.6900151669979095


In [16]:
import torch
from tqdm import tqdm

def exact_match_accuracy(model, dataloader, device="cuda"):
    model.eval()
    total_correct = 0
    total_count = 0

    with torch.no_grad():
        for batch in tqdm(dataloader):

            input_ids = batch["input_ids"].to(device)
            bbox = batch["bbox"].to(device)
            pixel_values = batch["pixel_values"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["decoder_labels"].to(device)     # [B, T_gt]

            B, T_gt = labels.shape

            # 1. Generate predictions (can be shorter or longer)
            preds = model.generate(
                input_ids=input_ids,
                bbox=bbox,
                pixel_values=pixel_values,
                attention_mask=attention_mask,
                max_length=T_gt,
            )   # [B, T_pred <= T_gt]

            # 2. Pad predictions to match GT length
            T_pred = preds.size(1)
            if T_pred < T_gt:
                pad = torch.full(
                    (B, T_gt - T_pred),
                    fill_value=-999,   # invalid token (never matches)
                    device=device,
                    dtype=preds.dtype
                )
                preds = torch.cat([preds, pad], dim=1)
            elif T_pred > T_gt:
                preds = preds[:, :T_gt]

            # 3. Mask ignore_index in labels
            mask = labels != -100

            # 4. Count matches
            correct = ((preds == labels) & mask).sum().item()
            total = mask.sum().item()

            total_correct += correct
            total_count += total

    return total_correct / total_count if total_count > 0 else 0.0

exact_match_accuracy(model, test_dataloader)

100%|██████████| 2/2 [00:42<00:00, 21.43s/it]


0.17757222793194016

In [105]:
test_ds1 = load_dataset("parquet", data_files="../data/funsd-v3/funsd/test-00000-of-00001.parquet", split="train")
test_ds2 = test_ds1.map(
    preprocess_batch,
    batched=True,
    remove_columns=test_ds1.column_names
)
test_ds2.set_format('pytorch')
len(test_ds1[0]['tokens']), len(test_ds1[0]['ner_tags']), len(test_ds2[0]['input_ids']), test_ds1

(223,
 223,
 512,
 Dataset({
     features: ['id', 'tokens', 'bboxes', 'ner_tags', 'image'],
     num_rows: 50
 }))

In [106]:
test_ds1[0]['tokens']

['TO:',
 'DATE:',
 '3',
 'Fax:',
 'NOTE:',
 '82092117',
 '614',
 '-466',
 '-5087',
 'Dec',
 '10',
 "'98",
 '17',
 ':46',
 'P.',
 '01',
 'ATT.',
 'GEN.',
 'ADMIN.',
 'OFFICE',
 'Attorney',
 'General',
 'Betty',
 'D.',
 'Montgomery',
 'CONFIDENTIAL',
 'FACSIMILE',
 'TRANSMISSION',
 'COVER',
 'SHEET',
 '(614)',
 '466-',
 '5087',
 'FAX',
 'NO.',
 'George',
 'Baroody',
 '(336)',
 '335-',
 '7392',
 'FAX',
 'NUMBER:',
 'PHONE',
 'NUMBER:',
 '(336)',
 '335-',
 '7363',
 'NUMBER',
 'OF',
 'PAGES',
 'INCLUDING',
 'COVER',
 'SHEET:',
 'June',
 'Flynn',
 'for',
 'Eric',
 'Brown/',
 '(614)',
 '466-',
 '8980',
 'SENDER',
 '/PHONE',
 'NUMBER:',
 'SPECIAL',
 'INSTRUCTIONS:',
 'IF',
 'YOU',
 'DO',
 'NOT',
 'RECEIVE',
 'ANY',
 'OF',
 'THE',
 'PAGES',
 'PROPERLY,',
 'PLEASE',
 'CONTACT',
 'SENDER',
 'AS',
 'SOON',
 'AS',
 'POSSIBLE',
 'THIS',
 'MESSAGE',
 'IS',
 'INTENDED',
 'ONLY',
 'FOR',
 'THE',
 'USE',
 'OF',
 'THE',
 'INDIVIDUAL',
 'OR',
 'ENTITY',
 'TO',
 'WHOM',
 'IT',
 'IS',
 'ADDRESSED',
 'AND',
