In [27]:
%load_ext autoreload
%autoreload 2
import os
import torch
import json
import numpy as np
from tqdm import tqdm
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using device: cuda


In [28]:
class FUNSDDataset(Dataset):
    def __init__(self, data_path, processor, max_length=256):
        with open(data_path, 'r') as f:
            self.examples = json.load(f)
        
        self.processor = processor
        self.max_length = max_length
        
        all_labels = set()
        for example in self.examples:
            all_labels.update(example['labels'])
        
        self.label2id = {label: idx for idx, label in enumerate(sorted(all_labels))}
        self.id2label = {idx: label for label, idx in self.label2id.items()}
        
        print(f"Loaded {len(self.examples)} documents")
        print(f"Found {len(self.label2id)} unique labels")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        
        image = Image.open(example['image_path']).convert("RGB")
        
        words = example['words']
        boxes = example['bboxes']
        
        labels = [self.label2id[label] for label in example['labels']]

        encoding = self.processor(
            image,
            words,
            boxes=boxes,
            word_labels=labels,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        for key, val in encoding.items():
            encoding[key] = val.squeeze(0)
            
        return encoding

In [29]:
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)

train_data_path = "../data/processed/FUNSD/training.json"
train_dataset = FUNSDDataset(train_data_path, processor)

sample = train_dataset[0]
print("Sample keys:", sample.keys())
print("Input IDs shape:", sample["input_ids"].shape)
print("Bbox shape:", sample["bbox"].shape)
print("Labels shape:", sample["labels"].shape)

Loaded 149 documents
Found 4 unique labels
Sample keys: KeysView({'input_ids': tensor([    0, 15231,  3935, 39033,  6597,  6034, 39658, 39261,  6823,  5945,
        26744,  1691, 39477, 25054,  2808,  4186, 44583, 34300, 11126, 35086,
         1691,  1777,  1245,  5768,  1018, 38793,  6178,  3293,  3858, 30438,
        44731,  3243, 15421, 43784, 18578, 11350,  5382, 27560,  6034, 24566,
        30596,  5121, 15421, 43784, 42699,  3602,   248,  3293, 10760, 15823,
        37962,    35, 15823, 12613,  5121, 12901, 11088,    35,  2808,  4186,
        44583, 10786,  5168,  2492, 13471,     6, 24316, 26896,     6,  5198,
          359, 36441, 14452,     6, 18012,     4,   163,  6597, 24258,  2444,
         4516, 44335,  1862,    12,  2808,  4186, 44583, 13060,  3675,     4,
         4017,    20, 13108,  1913,     4,   132,    73,   291,   359,   155,
           73,   291,   155,    73,   883,    73,  8301,   163,  6597, 24258,
         2444, 42699, 16948,   305,  7981, 42699, 16948,  4250,

In [30]:
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = torch.utils.data.random_split(
    train_dataset, [train_size, val_size]
)

print(f"Train size: {len(train_subset)}")
print(f"Validation size: {len(val_subset)}")

batch_size = 2

train_loader = DataLoader(
    train_subset, 
    batch_size=batch_size, 
    shuffle=True,
    collate_fn=lambda batch: {
        "input_ids": torch.stack([item["input_ids"] for item in batch]),
        "attention_mask": torch.stack([item["attention_mask"] for item in batch]),
        "bbox": torch.stack([item["bbox"] for item in batch]),
        "pixel_values": torch.stack([item["pixel_values"] for item in batch]),
        "labels": torch.stack([item["labels"] for item in batch])
    }
)

val_loader = DataLoader(
    val_subset, 
    batch_size=batch_size, 
    collate_fn=lambda batch: {
        "input_ids": torch.stack([item["input_ids"] for item in batch]),
        "attention_mask": torch.stack([item["attention_mask"] for item in batch]),
        "bbox": torch.stack([item["bbox"] for item in batch]),
        "pixel_values": torch.stack([item["pixel_values"] for item in batch]),
        "labels": torch.stack([item["labels"] for item in batch])
    }
)

for batch in train_loader:
    print("Batch shapes:")
    for key, value in batch.items():
        print(f"{key}: {value.shape}")
    break

Train size: 134
Validation size: 15
Batch shapes:
input_ids: torch.Size([2, 256])
attention_mask: torch.Size([2, 256])
bbox: torch.Size([2, 256, 4])
pixel_values: torch.Size([2, 3, 224, 224])
labels: torch.Size([2, 256])


In [31]:
model = LayoutLMv3ForTokenClassification.from_pretrained(
    "microsoft/layoutlmv3-base",
    num_labels=len(train_dataset.label2id),
    id2label=train_dataset.id2label,
    label2id=train_dataset.label2id
)
model.to(device)

print("Model initialized with:")
print(f" - Number of labels: {len(train_dataset.label2id)}")
print(f" - Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Some weights of LayoutLMv3ForTokenClassification were not initialized from the model checkpoint at microsoft/layoutlmv3-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized with:
 - Number of labels: 4
 - Model parameters: 125,330,052


In [None]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR

epochs = 4
learning_rate = 5e-5

optimizer = AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_loader) * epochs

from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR

warmup_scheduler = LinearLR(
    optimizer,
    start_factor=0.01,
    end_factor=1.0,
    total_iters=int(0.1 * total_steps)
)
    
main_scheduler = CosineAnnealingLR(
    optimizer,
    T_max=total_steps - int(0.1 * total_steps)
)

from torch.optim.lr_scheduler import SequentialLR
scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup_scheduler, main_scheduler],
    milestones=[int(0.1 * total_steps)]
)

In [35]:
model.train()
for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    total_loss = 0
    
    for batch in tqdm(train_loader, desc="Training"):
        batch = {k: v.to(device) for k, v in batch.items()}
        
        outputs = model(**batch)
        loss = outputs.loss
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        total_loss += loss.item()
    
    avg_train_loss = total_loss / len(train_loader)
    print(f"Average training loss: {avg_train_loss:.4f}")
    
    model.eval()
    val_loss = 0
    for batch in tqdm(val_loader, desc="Validation"):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        val_loss += outputs.loss.item()
    
    avg_val_loss = val_loss / len(val_loader)
    print(f"Average validation loss: {avg_val_loss:.4f}")
    
    checkpoint_path = f"../models/trained/layoutlmv3-epoch-{epoch+1}"
    model.save_pretrained(checkpoint_path)
    processor.save_pretrained(checkpoint_path)
    print(f"Saved checkpoint to {checkpoint_path}")

print("Training complete!")


Epoch 1/4


Training:   1%|▏         | 1/67 [00:00<00:17,  3.74it/s]

Training: 100%|██████████| 67/67 [00:17<00:00,  3.73it/s]


Average training loss: 0.4163


Validation: 100%|██████████| 8/8 [00:00<00:00, 10.57it/s]


Average validation loss: 0.5579
Saved checkpoint to ../models/trained/layoutlmv3-epoch-1

Epoch 2/4


Training: 100%|██████████| 67/67 [00:17<00:00,  3.78it/s]


Average training loss: 0.3011


Validation: 100%|██████████| 8/8 [00:00<00:00, 10.26it/s]


Average validation loss: 0.5226
Saved checkpoint to ../models/trained/layoutlmv3-epoch-2

Epoch 3/4


Training: 100%|██████████| 67/67 [00:17<00:00,  3.79it/s]


Average training loss: 0.1451


Validation: 100%|██████████| 8/8 [00:00<00:00,  9.73it/s]


Average validation loss: 0.5237
Saved checkpoint to ../models/trained/layoutlmv3-epoch-3

Epoch 4/4


Training: 100%|██████████| 67/67 [00:17<00:00,  3.77it/s]


Average training loss: 0.0852


Validation: 100%|██████████| 8/8 [00:00<00:00, 10.43it/s]


Average validation loss: 0.5336
Saved checkpoint to ../models/trained/layoutlmv3-epoch-4
Training complete!


In [36]:
model_path = "../models/trained/layoutlmv3-funsd-final"
model.save_pretrained(model_path)
processor.save_pretrained(model_path)
print(f"Model saved to {model_path}")

Model saved to ../models/trained/layoutlmv3-funsd-final
