In [16]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Library


In [17]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.init as init
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.models.resnet import (
  resnet50,
  ResNet50_Weights,
  resnet101,
  ResNet101_Weights,
)

# RetinaNet definition


### FPN Backbone


In [18]:
def resnet50_backbone():
  resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
  modules = list(resnet.children())[:-2]
  backbone = nn.Sequential(*modules)
  return backbone


def resnet101_backbone():
  resnet = resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)
  modules = list(resnet.children())[:-2]
  backbone = nn.Sequential(*modules)
  return backbone


class FPN(nn.Module):
  def __init__(self, backbone):
    super().__init__()
    self.backbone = backbone

    # Lateral layers
    self.latLayer1 = nn.Conv2d(2048, 256, kernel_size=1)
    self.latLayer2 = nn.Conv2d(1024, 256, kernel_size=1)
    self.latLayer3 = nn.Conv2d(512, 256, kernel_size=1)

    # Final layers
    self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
    self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

    # Additional layers
    self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2)
    self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2)

  def forward(self, x):
    c2, c3, c4, c5 = self.backbone(x)

    # Top-down pathway
    p5 = self.latLayer1(c5)
    p4 = self.latLayer2(c4) + F.interpolate(p5, scale_factor=2)
    p3 = self.latLayer3(c3) + F.interpolate(p4, scale_factor=2)

    # Final convolutions
    p5 = self.conv1(p5)
    p4 = self.conv2(p4)
    p3 = self.conv3(p3)

    # Additional layers
    p6 = self.conv4(c5)
    p7 = self.conv5(F.relu(p6))

    return p3, p4, p5, p6, p7

### Classification and Box Regression Subnetwork


In [19]:
class ClassificationSubnet(nn.Module):
  def __init__(self, num_classes):
    super().__init__()
    self.num_classes = num_classes
    self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
    self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
    self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
    self.act1 = nn.ReLU()
    self.conv5 = nn.Conv2d(256, num_classes, kernel_size=3, padding=1)

  def forward(self, x):
    out = self.conv1(x)
    out = self.act1(out)
    out = self.conv2(out)
    out = self.act1(out)
    out = self.conv3(out)
    out = self.act1(out)
    out = self.conv4(out)
    out = self.act1(out)
    out = self.conv5(out)
    return out


class RegressionSubnet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
    self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
    self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
    self.act1 = nn.ReLU()
    self.conv5 = nn.Conv2d(256, 4, kernel_size=3, padding=1)

  def forward(self, x):
    out = self.conv1(x)
    out = self.act1(out)
    out = self.conv2(out)
    out = self.act1(out)
    out = self.conv3(out)
    out = self.act1(out)
    out = self.conv4(out)
    out = self.act1(out)
    out = self.conv5(out)
    return out

### RetinaNet


In [20]:
class RetinaNet(nn.Module):
  def __init__(self, backbone, num_classes):
    super().__init__()
    self.backbone = backbone
    self.fpn = FPN(backbone)
    self.classification_subnet = ClassificationSubnet(num_classes)
    self.regression_subnet = RegressionSubnet()

  def forward(self, x):
    p3, p4, p5, p6, p7 = self.fpn(x)
    cls_p3 = self.classification_subnet(p3)
    cls_p4 = self.classification_subnet(p4)
    cls_p5 = self.classification_subnet(p5)
    cls_p6 = self.classification_subnet(p6)
    cls_p7 = self.classification_subnet(p7)

    reg_p3 = self.regression_subnet(p3)
    reg_p4 = self.regression_subnet(p4)
    reg_p5 = self.regression_subnet(p5)
    reg_p6 = self.regression_subnet(p6)
    reg_p7 = self.regression_subnet(p7)

    return [cls_p3, cls_p4, cls_p5, cls_p6, cls_p7], [
      reg_p3,
      reg_p4,
      reg_p5,
      reg_p6,
      reg_p7,
    ]

### Focal Loss


In [21]:
class FocalLoss(nn.Module):
  def __init__(self, alpha=0.25, gamma=2.0):
    super().__init__()
    self.alpha = alpha
    self.gamma = gamma

  def forward(self, inputs, targets):
    bce = nn.BCEWithLogitsLoss(reduction="none")(inputs, targets)

    pt = torch.exp(-bce)

    loss = self.alpha * (1 - pt) ** self.gamma * bce

    return torch.mean(loss)

# DocBank Dataset Class

In [22]:
class DocBankDataset(Dataset):
  def __init__(self, data_dir, index_file, transform=None):
    self.data_dir = data_dir
    self.transform = transform
    self.image_paths = []
    self.annotations = []

    # Read index file
    with open(index_file, 'r') as index_f:
      annotation_files = [line.strip() for line in index_f]

    # Process each annotation file
    for annotation_file in annotation_files:
      annotation_path = os.path.join(data_dir, annotation_file)
      self.process_annotation_file(annotation_path)

  def process_annotation_file(self, annotation_path):
    with open(annotation_path, 'r') as f:
      for line in f:
        content = line.strip().split('\t')
        token, x0, y0, x1, y1, R, G, B, font, label = content

        if label == 'figure':
          bbox = [int(x0), int(y0), int(x1), int(y1)]
          if (abs(x1-x0) <= 10 or abs(y0-y1) <= 10):
            continue
          image_file = f"{token}.jpg"
          self.image_paths.append(image_file)

          self.annotations.append((bbox, 1))  # Label = 1 for 'figure'

  def __len__(self):
    return len(self.image_paths)

  def __getitem__(self, idx):
    image_path = self.image_paths[idx]
    bbox, label = self.annotations[idx]

    image = Image.open(os.path.join(self.data_dir, image_path)).convert('RGB')

    if self.transform:
        image = self.transform(image)

    # Convert to Tensor
    bbox = torch.tensor(bbox, dtype=torch.float32)
    label = torch.tensor(label, dtype=torch.int64)

    return image, bbox, label

# Training


### Initialization


In [23]:
def initialize_weights(module):
  if isinstance(module, nn.Conv2d):
    init.normal_(module.weight, mean=0.0, std=0.01)
    if module.bias is not None:
      init.constant_(module.bias, 0)


def initialize_classification_final_layer(layer, pi=0.01):
  bias_value = -torch.log(torch.tensor((1 - pi) / pi))
  init.constant_(layer.bias, bias_value)

### Model declaration


In [24]:
NUM_CLASSES = 2

backbone = resnet50_backbone()
# backbone = resnet101_backbone()

model = RetinaNet(backbone, num_classes=NUM_CLASSES)
model.apply(initialize_weights)
initialize_classification_final_layer(model.classification_subnet.conv5)

### Optimizer, Scheduler, Loss Function


In [25]:
LR = 0.01
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0001

optimizer = torch.optim.SGD(
  model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY
)

lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
  optimizer, milestones=[60000, 80000], gamma=0.1
)

cls_loss_fn = FocalLoss(alpha=0.25, gamma=2.0)
reg_loss_fn = torch.nn.SmoothL1Loss()

### Data loader


In [27]:
BATCH_SIZE = 8

train_docbank = DocBankDataset('/content/drive/MyDrive/dataset/DocBank_500K_txt',
                               '/content/drive/MyDrive/dataset/indexed_files/500K_train.txt')
val_docbank = DocBankDataset('/content/drive/MyDrive/dataset/DocBank_500K_txt',
                             '/content/drive/MyDrive/dataset/indexed_files/500K_val.txt')
test_docbank = DocBankDataset('/content/drive/MyDrive/dataset/DocBank_500K_txt',
                              '/content/drive/MyDrive/dataset/indexed_files/500K_test.txt')

train_dataloader = DataLoader(train_docbank, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_docbank, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_docbank, batch_size=BATCH_SIZE, shuffle=False)

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/dataset/DocBank_500K_txt/194.tar_1807.02545.gz_lexicography_hand_gestures_final_version_0.txt'

### Training


In [None]:
def train(model, optimizer, cls_loss_fn, reg_loss_fn, train_loader, val_loader, num_epochs):
  for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for images, bboxes, labels in train_loader:
      optimizer.zero_grad()
      cls_preds, reg_preds = model(images)
      cls_loss = 0.0
      reg_loss = 0.0
      for cls_pred, reg_pred, bbox, label in zip(cls_preds, reg_preds, bboxes, labels):
        cls_loss += cls_loss_fn(cls_pred, label)
        reg_loss += reg_loss_fn(reg_pred, bbox)
        loss = cls_loss + reg_loss
      train_loss += loss.item()
      loss.backward()
      optimizer.step()

    lr_scheduler.step()

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
      for images, bboxes, labels in val_loader:
        cls_preds, reg_preds = model(images)
        cls_loss = 0.0
        reg_loss = 0.0
        for cls_pred, reg_pred, bbox, label in zip(cls_preds, reg_preds, bboxes, labels):
          cls_loss += cls_loss_fn(cls_pred, label)
          reg_loss += reg_loss_fn(reg_pred, bbox)
          loss = cls_loss + reg_loss
        val_loss += loss.item()

    train_loss /= len(train_loader)
    val_loss /= len(val_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")


In [None]:
NUM_EPOCHS = 1

train(model, optimizer, cls_loss_fn, reg_loss_fn, train_dataloader, val_dataloader, NUM_EPOCHS)