In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
import os
from PIL import Image
from torchvision.models.detection import (
    retinanet_resnet50_fpn,
    RetinaNet_ResNet50_FPN_Weights,
)
from tqdm import tqdm
from torchvision.ops import box_iou

# DocBankDataset Class


In [2]:
class DocBankDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        idx_file_path,
        limit,
        images_dir="/content/drive/MyDrive/dataset/images_yolo",
        labels_dir="/content/drive/MyDrive/dataset/labels_retina",
        transforms=None,
    ):
        self.images = []
        self.annotations = []
        self.transforms = transforms
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.idx_file_path = idx_file_path
        self.limit = limit
        self.load_images()
        # print("Images loaded")
        self.load_ann()
        # print("Annotations loaded")

    def __getitem__(self, idx):
        img = self.images[idx]
        target = self.annotations[idx]
        return img, target

    def __len__(self):
        return len(self.images)

    def load_images(self):
        with open(self.idx_file_path, "r") as f:
            lines = f.readlines()
            for line in lines:
                if len(self.images) == self.limit:
                    return
                img_path = os.path.join(self.images_dir, line[:-5] + "_ori.jpg")
                # print("read path {}".format(img_path))
                img = Image.open(img_path).convert("RGB")
                if self.transforms is not None:
                    img = self.transforms(img)
                self.images.append(img)

    def load_ann(self):
        with open(self.idx_file_path, "r") as f:
            files = f.readlines()
            for file in files:
                if len(self.annotations) == self.limit:
                    return
                ann_path = os.path.join(self.labels_dir, file[:-5] + ".txt")
                self.annotations.append(self.process_ann_path(ann_path))

    def process_ann_path(self, ann_path):
        # return dictionary with 2 key: boxes (FloatTensor[N, 4]) and labels (Int64Tensor[N])
        target = {}
        boxes = []
        labels = []
        with open(ann_path, "r") as f:
            lines = f.readlines()
            for line in lines:
                content = line.strip().split("\t")
                token, x0, y0, x1, y1, R, G, B, font, label = content

                # MAY NEED BOX DIMENSION VALIDATION CHECK HERE
                if (
                    (int(x0) < 0)
                    or (int(y0) < 0)
                    or (int(x1) < 0)
                    or (int(y1) < 0)
                    or (x1 <= x0 or y1 <= y0)
                ):
                    continue

                boxes.append([float(x0), float(y0), float(x1), float(y1)])
                if label == "figure":
                    labels.append(torch.tensor(1))
                else:
                    labels.append(torch.tensor(0))

        target["boxes"] = torch.FloatTensor(boxes)
        target["labels"] = torch.tensor(labels)
        # print(target)
        return target

# Dataset Loader


### Directory Definition


In [3]:
images_dir = "D:/Docbank/images"
labels_dir = "D:/Docbank/annotations"
train_idx_file_path = "D:/Docbank/indexed/500K_train.txt"
val_idx_file_path = "D:/Docbank/indexed/500K_dev.txt"
test_idx_file_path = "D:/Docbank/indexed/500K_test.txt"

### Train, Val, Test Dataset


In [4]:
TRAIN_LIMIT = 400
VAL_LIMIT = 50
TEST_LIMIT = 50

transforms = transforms.Compose(
    [transforms.Resize((1000, 1000)), transforms.ToTensor()]  # Resize to a square
)

train_dataset = DocBankDataset(
    train_idx_file_path, TRAIN_LIMIT, images_dir, labels_dir, transforms
)
val_dataset = DocBankDataset(
    val_idx_file_path, VAL_LIMIT, images_dir, labels_dir, transforms
)
test_dataset = DocBankDataset(
    test_idx_file_path, TEST_LIMIT, images_dir, labels_dir, transforms
)

### Train, Val, Test Dataloader


In [5]:
BATCH_SIZE = 4


def collate_fn(batch):
    images, targets = zip(*batch)
    return images, targets


train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn
)

# Model Definition


In [6]:
model = retinanet_resnet50_fpn(num_classes=2)


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


model.to(device)

RetinaNet(
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=1e-05)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=1e-05)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=1e-05)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=1e-05)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): FrozenBatchNorm2d(256, e

### Optimizer and LR Scheduler


In [7]:
LR = 0.001
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0001

optimizer = torch.optim.SGD(
    model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY
)

lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=[30, 80], gamma=0.1
)

### Training


In [20]:
def train(model, optimizer, lr_scheduler, train_loader, val_loader, num_epochs, device):
    best_f1_score = -float("inf")
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        print(f"Epoch {epoch+1}/{num_epochs}")

        for images, targets in tqdm(train_loader, desc="Training", leave=False):
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            optimizer.zero_grad()

            # Forward pass
            loss_dict = model(images, targets)
            # print("Output of the model:", loss_dict)
            # print("Type of output:", type(loss_dict))
            # Calculate total loss
            losses = sum(loss for loss in loss_dict.values())

            # Backward pass and optimization
            losses.backward()
            optimizer.step()

            train_loss += losses.item()

        # Step the learning rate scheduler
        lr_scheduler.step()

        # Validation loop
        model.eval()
        f1_score = 0
        with torch.no_grad():
            for images, targets in tqdm(val_loader, desc="Validating", leave=False):
                images = list(image.to(device) for image in images)
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

                # Forward pass
                predictions = model(images)

                TP, FP, FN = 0, 0, 0
                iou_threshold = 0.5

                for target, prediction in zip(targets, predictions):
                    target_boxes = target["boxes"].cpu()
                    predicted_boxes = prediction["boxes"].cpu()

                    # Skip this iteration if there are no target boxes
                    if target_boxes.shape[0] == 0:
                        print("Skipping image with no ground truth boxes.")
                        continue  # Skip to the next iteration

                    # Ensure target_boxes is 2D
                    if target_boxes.ndim == 1:  # If only one target box
                        target_boxes = target_boxes.unsqueeze(0)  # Make it [1, 4]

                    # Ensure predicted_boxes is 2D
                    if predicted_boxes.ndim == 1:  # If only one predicted box
                        predicted_boxes = predicted_boxes.unsqueeze(0)  # Make it [1, 4]

                    print(target_boxes.shape, predicted_boxes.shape)

                    iou = box_iou(predicted_boxes, target_boxes)

                    for i in range(predicted_boxes.shape[0]):
                        if (iou[i] > iou_threshold).any():
                            TP += 1  # True Positive
                        else:
                            FP += 1  # False Positive

                    # Count False Negatives
                    FN += len(target_boxes) - TP

        # DocBank Metrics
        precision = TP / (TP + FP) if (TP + FP) > 0 else 0
        recall = TP / (TP + FN) if (TP + FN) > 0 else 0
        f1_score = (
            2 * (precision * recall) / (precision + recall)
            if (precision + recall) > 0
            else 0
        )
        # Calculate average losses
        train_loss /= len(train_loader)

        # Print epoch results
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"F1 Score: {f1_score:.4f}")
        print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")

        # Save the best model
        if f1_score > best_f1_score:
            best_f1_score = f1_score
            torch.save(model.state_dict(), "best_retina.pth")
            print("Saved new best model")

    print("Training completed.")

In [21]:
NUM_EPOCHS = 5

train(model, optimizer, lr_scheduler, train_loader, val_loader, NUM_EPOCHS, device)

Epoch 1/5


Training:  68%|██████▊   | 68/100 [03:17<02:24,  4.52s/it]