In [1]:
pip install torch torchvision pycocotools matplotlib

Collecting torch
  Downloading torch-2.7.0-cp312-none-macosx_11_0_arm64.whl.metadata (29 kB)
Collecting torchvision
  Downloading torchvision-0.22.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.1 kB)
Collecting pycocotools
  Downloading pycocotools-2.0.8-cp312-cp312-macosx_10_9_universal2.whl.metadata (1.1 kB)
Collecting filelock (from torch)
  Using cached filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting typing-extensions>=4.10.0 (from torch)
  Using cached typing_extensions-4.13.2-py3-none-any.whl.metadata (3.0 kB)
Collecting setuptools (from torch)
  Downloading setuptools-80.4.0-py3-none-any.whl.metadata (6.5 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Using cached networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2025.3.2-py3-none

In [10]:
import os
import torch
from PIL import Image
from torch.utils.data import Dataset
from pycocotools.coco import COCO
import torchvision.transforms as T

class DocLayNetDataset(Dataset):
    def __init__(self, root, annotation, transforms=None):
        self.root = root
        self.coco = COCO(annotation)
        self.ids = list(sorted(self.coco.imgs.keys()))
        self.transforms = transforms

    def __getitem__(self, index):
        img_id = self.ids[index]
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)
        img_info = self.coco.loadImgs(img_id)[0]
        path = img_info['file_name']

        img = Image.open(os.path.join(self.root, path)).convert("RGB")

        boxes = []
        labels = []
        for ann in anns:
            bbox = ann['bbox']
            if bbox[2] <= 0 or bbox[3] <= 0:
                continue  # Skip invalid boxes
            boxes.append([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]])
            labels.append(ann['category_id'])

        # Skip images with no valid boxes
        if len(boxes) == 0:
            return self.__getitem__((index + 1) % len(self))  # move to next image


        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([img_id])

        if self.transforms:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.ids)

In [11]:
def get_transform():
    return T.Compose([
        T.ToTensor()
    ])

In [12]:
from torch.utils.data import DataLoader

# Paths to your dataset
image_dir = '../Dataset/DocLayNet/DocLayNet_core/PNG'
annotation_file = '../Dataset/DocLayNet/DocLayNet_core/COCO/train.json'

dataset = DocLayNetDataset(root=image_dir, annotation=annotation_file, transforms=get_transform())
data_loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

loading annotations into memory...
Done (t=3.76s)
creating index...
index created!


In [14]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# Load a model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# Replace the classifier with a new one, that has num_classes which is user-defined
num_classes = len(dataset.coco.getCatIds()) + 1  # +1 for background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

In [None]:
import torch

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    i = 0
    for images, targets in data_loader:
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if i % 10 == 0:
            print(f"Epoch [{epoch}], Step [{i}], Loss: {losses.item():.4f}")
        i += 1

Epoch [0], Step [0], Loss: 10.5104
Epoch [0], Step [10], Loss: 1.7429
Epoch [0], Step [20], Loss: 1.5296
Epoch [0], Step [30], Loss: 1.3245
Epoch [0], Step [40], Loss: 1.3574
Epoch [0], Step [50], Loss: 1.3224
Epoch [0], Step [60], Loss: 1.1687
Epoch [0], Step [70], Loss: 1.1190
Epoch [0], Step [80], Loss: 0.7337
Epoch [0], Step [90], Loss: 1.2066
Epoch [0], Step [100], Loss: 1.4907
Epoch [0], Step [110], Loss: 0.9181
Epoch [0], Step [120], Loss: 0.9526
Epoch [0], Step [130], Loss: 1.3742
Epoch [0], Step [140], Loss: 0.7967
Epoch [0], Step [150], Loss: 1.1172
Epoch [0], Step [160], Loss: 1.3187
Epoch [0], Step [170], Loss: 2.0671
Epoch [0], Step [180], Loss: 1.1406
Epoch [0], Step [190], Loss: 1.0542
Epoch [0], Step [200], Loss: 1.2457
Epoch [0], Step [210], Loss: 1.3627
Epoch [0], Step [220], Loss: 1.0903
Epoch [0], Step [230], Loss: 0.9638
Epoch [0], Step [240], Loss: 0.9786
Epoch [0], Step [250], Loss: 1.0625
Epoch [0], Step [260], Loss: 1.2226
Epoch [0], Step [270], Loss: 1.1220
Ep

In [None]:
torch.save(model.state_dict(), "fasterrcnn_doclaynet.pth")