In [80]:
import torch
import torch.nn as nn
from transformers import AutoModel, BartForConditionalGeneration, BartConfig
from transformers.modeling_outputs import BaseModelOutput

class LayoutLMv3BART(nn.Module):
    def __init__(self, encoder_name="microsoft/layoutlmv3-base", num_labels=256):
        super().__init__()

        # ----- Encoder -----
        self.encoder = AutoModel.from_pretrained(encoder_name)
        hidden_size = self.encoder.config.hidden_size  # 768

        # ----- BART Decoder -----
        config = BartConfig.from_pretrained("facebook/bart-base")
        config.encoder_layers = 0                     # remove BART encoder
        config.d_model = hidden_size                  # match LayoutLMv3 dim
        config.vocab_size = num_labels                # segment/NER label count
        config.max_position_embeddings = 1024         # safe upper bound

        self.decoder = BartForConditionalGeneration(config)

    def forward(self,
                input_ids,
                bbox,
                pixel_values,
                attention_mask,
                labels=None):

        # ----- Encode -----
        enc_out = self.encoder(
            input_ids=input_ids,
            bbox=bbox,
            pixel_values=pixel_values,
            attention_mask=attention_mask
        ).last_hidden_state   # shape: [B, 509, 768]

        # BART expects encoder_outputs as a tuple
        encoder_outputs = (enc_out,)

        # ----- Decode -----
        out = self.decoder(
            encoder_outputs=encoder_outputs,
            labels=labels,                # BART handles shifting internally
            return_dict=True
        )

        return out

    @torch.no_grad()
    def generate(self, input_ids, bbox, pixel_values, attention_mask, max_length=128):
        # Encode like LayoutLMv3
        enc_out = self.encoder(
            input_ids=input_ids,
            bbox=bbox,
            pixel_values=pixel_values,
            attention_mask=attention_mask
        ).last_hidden_state
    
        # WRAP in correct HF object
        encoder_outputs = BaseModelOutput(last_hidden_state=enc_out)
    
        # Now decoder.generate works
        pred = self.decoder.generate(
            encoder_outputs=encoder_outputs,
            max_length=max_length
        )
        
        return pred


processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)

train_ds = load_dataset("parquet", data_files="../data/funsd-v3/funsd/train-00000-of-00001.parquet", split="train")
example = train_ds[5]
image = example["image"]
words = example["tokens"]
boxes = example["bboxes"]
labels = example["ner_tags"]

device = 'cuda'

encoding = processor(
    image,
    words,
    boxes=boxes,
    word_labels=labels,
    return_tensors="pt"
)

word_labels = encoding.pop("labels")

model = LayoutLMv3BART(num_labels=256)
model = model.to(device)
out = model(
    input_ids=encoding["input_ids"].to(device),
    bbox=encoding["bbox"].to(device),
    pixel_values=encoding["pixel_values"].to(device),
    attention_mask=encoding["attention_mask"].to(device),
    labels=word_labels.to(device)
)

generated = model.generate(
    input_ids=encoding["input_ids"].to(device),
    bbox=encoding["bbox"].to(device),
    pixel_values=encoding["pixel_values"].to(device),
    attention_mask=encoding["attention_mask"].to(device),
)

print(out.loss, out.logits.shape, generated)

tensor(3.2976, device='cuda:0', grad_fn=<NllLossBackward0>) torch.Size([1, 312, 256]) tensor([[  2,   0,  95,  95,  95,  77,  77,  77,  40,  40,  40,  14,  14,  14,
         106, 106, 106, 172, 172, 172, 166, 166, 166,  49,  49,  49,  23,  23,
          23, 101, 101, 101, 247, 247, 247,  78,  78,  78, 213, 213, 213,  64,
          64,  64,  69,  69,  69,  38,  38,  38,  82,  82,  82, 190, 190, 190,
          76,  76,  76,  31,  31,  94,  94,  94,  66,  66,  66,  69,  69, 192,
         192, 192, 184, 184, 184,  41,  41,  41, 120, 120, 189, 189, 189, 201,
         201, 201, 120, 120, 120, 165, 165, 165,  55,  55,  55, 130, 130, 130,
         161, 161, 161,  83,  83,  83, 255, 255, 255,  97,  97,  97, 208, 208,
         208, 179, 179, 179,  93,  93,  93, 246, 246, 246, 153, 153, 153,  41,
          41,   2]], device='cuda:0')


In [81]:
from PIL import Image

def preprocess_batch(batch):
    images = [img.convert("RGB") for img in batch["image"]]
    words = batch["tokens"]
    boxes = batch["bboxes"]
    word_labels = batch["ner_tags"]

    enc = processor(
        images,
        words,
        boxes=boxes,
        word_labels=word_labels,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )

    return enc

train_ds = load_dataset("parquet", data_files="../data/funsd-v3/funsd/train-00000-of-00001.parquet", split="train")
train_ds = train_ds.map(
    preprocess_batch,
    batched=True,
    remove_columns=train_ds.column_names
)
train_ds.set_format('pytorch')

test_ds = load_dataset("parquet", data_files="../data/funsd-v3/funsd/test-00000-of-00001.parquet", split="train")
test_ds = test_ds.map(
    preprocess_batch,
    batched=True,
    remove_columns=test_ds.column_names
)
test_ds.set_format('pytorch')

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

In [83]:
import torch
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"

model = LayoutLMv3BART(num_labels=256).to(device)


train_dataloader = DataLoader(processed, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=32)

optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)

num_epochs = 5
num_training_steps = len(train_dataloader) * num_epochs
num_warmup_steps = int(0.1 * num_training_steps)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)

for epoch in range(num_epochs):
    model.train()

    loop = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch in loop:

        # Move all encoder inputs
        input_ids = batch["input_ids"].to(device)
        bbox = batch["bbox"].to(device)
        pixel_values = batch["pixel_values"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # Decoder labels (word-level labels)
        labels = batch["labels"].to(device)

        # Forward
        outputs = model(
            input_ids=input_ids,
            bbox=bbox,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            labels=labels
        )

        loss = outputs.loss

        # Backprop
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        loop.set_postfix(loss=loss.item())

    model.eval()
    total_val_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Validation"):
    
            input_ids = batch["input_ids"].to(device)
            bbox = batch["bbox"].to(device)
            pixel_values = batch["pixel_values"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
    
            outputs = model(
                input_ids=input_ids,
                bbox=bbox,
                pixel_values=pixel_values,
                attention_mask=attention_mask,
                labels=labels
            )
    
            total_val_loss += outputs.loss.item()
    
    print("Validation Loss:", total_val_loss / len(test_dataloader))

torch.save(model.state_dict(), "../models/layoutlmv3_bart_segmentation.pth")

Epoch 1/5: 100%|██████████| 5/5 [00:09<00:00,  1.83s/it, loss=1.94]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 1.4637821912765503


Epoch 2/5: 100%|██████████| 5/5 [00:09<00:00,  1.83s/it, loss=1.58]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.43it/s]


Validation Loss: 1.3286529779434204


Epoch 3/5: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.48]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 1.1660884022712708


Epoch 4/5: 100%|██████████| 5/5 [00:09<00:00,  1.83s/it, loss=1.35]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.43it/s]


Validation Loss: 1.1454375982284546


Epoch 5/5: 100%|██████████| 5/5 [00:09<00:00,  1.82s/it, loss=1.4] 
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


Validation Loss: 1.1301340460777283


In [109]:
model.eval()

item = next(next(iter(test_dataloader)))

with torch.no_grad():
    pred = model.generate(
        input_ids=item['input_ids'].to(device),
        bbox=item['bbox'].to(device),
        pixel_values=item['pixel_values'].to(device),
        attention_mask=item['attention_mask'].to(device),
        max_length=128
    )

TypeError: 'dict' object is not an iterator

In [105]:
test_ds1 = load_dataset("parquet", data_files="../data/funsd-v3/funsd/test-00000-of-00001.parquet", split="train")
test_ds2 = test_ds1.map(
    preprocess_batch,
    batched=True,
    remove_columns=test_ds1.column_names
)
test_ds2.set_format('pytorch')
len(test_ds1[0]['tokens']), len(test_ds1[0]['ner_tags']), len(test_ds2[0]['input_ids']), test_ds1

(223,
 223,
 512,
 Dataset({
     features: ['id', 'tokens', 'bboxes', 'ner_tags', 'image'],
     num_rows: 50
 }))

In [106]:
test_ds1[0]['tokens']

['TO:',
 'DATE:',
 '3',
 'Fax:',
 'NOTE:',
 '82092117',
 '614',
 '-466',
 '-5087',
 'Dec',
 '10',
 "'98",
 '17',
 ':46',
 'P.',
 '01',
 'ATT.',
 'GEN.',
 'ADMIN.',
 'OFFICE',
 'Attorney',
 'General',
 'Betty',
 'D.',
 'Montgomery',
 'CONFIDENTIAL',
 'FACSIMILE',
 'TRANSMISSION',
 'COVER',
 'SHEET',
 '(614)',
 '466-',
 '5087',
 'FAX',
 'NO.',
 'George',
 'Baroody',
 '(336)',
 '335-',
 '7392',
 'FAX',
 'NUMBER:',
 'PHONE',
 'NUMBER:',
 '(336)',
 '335-',
 '7363',
 'NUMBER',
 'OF',
 'PAGES',
 'INCLUDING',
 'COVER',
 'SHEET:',
 'June',
 'Flynn',
 'for',
 'Eric',
 'Brown/',
 '(614)',
 '466-',
 '8980',
 'SENDER',
 '/PHONE',
 'NUMBER:',
 'SPECIAL',
 'INSTRUCTIONS:',
 'IF',
 'YOU',
 'DO',
 'NOT',
 'RECEIVE',
 'ANY',
 'OF',
 'THE',
 'PAGES',
 'PROPERLY,',
 'PLEASE',
 'CONTACT',
 'SENDER',
 'AS',
 'SOON',
 'AS',
 'POSSIBLE',
 'THIS',
 'MESSAGE',
 'IS',
 'INTENDED',
 'ONLY',
 'FOR',
 'THE',
 'USE',
 'OF',
 'THE',
 'INDIVIDUAL',
 'OR',
 'ENTITY',
 'TO',
 'WHOM',
 'IT',
 'IS',
 'ADDRESSED',
 'AND',
