In [9]:
# --- 📦 Imports ---
import os
import json
import xml.etree.ElementTree as ET
from pathlib import Path
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt

# --- 💡 Utilities to Parse XML Annotations ---
def parse_structure_annotations(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    boxes, labels = [], []
    for obj in root.findall('object'):
        label = obj.find('name').text
        bbox = obj.find('bndbox')
        box = [
            int(float(bbox.find('xmin').text)),
            int(float(bbox.find('ymin').text)),
            int(float(bbox.find('xmax').text)),
            int(float(bbox.find('ymax').text)),
        ]
        boxes.append(box)
        labels.append(label)
    return boxes, labels


Define Dataset Class

In [10]:
from torch.utils.data import Dataset
from pathlib import Path
from PIL import Image
import torch
import torchvision.transforms.functional as F

class PubTablesTSRDataset(Dataset):
    def __init__(self, img_dir, ann_dir, label_map, transforms=None, max_samples=3500):
        self.img_dir = Path(img_dir)
        self.ann_dir = Path(ann_dir)
        self.transforms = transforms
        self.label_map = label_map

        # ✅ Keep only valid image-annotation pairs
        valid_files = sorted([
            f for f in self.img_dir.glob("*.jpg")
            if (self.ann_dir / (f.stem + ".xml")).exists()
        ])

        # ✅ Limit to first N valid samples
        self.img_files = valid_files[:max_samples]

    def __len__(self):
        return len(self.img_files)

    def random_valid_index(self, idx):
        return (idx + 1) % len(self)

    def __getitem__(self, idx, retry_count=0):
        try:
            img_path = self.img_files[idx]
            ann_path = self.ann_dir / (img_path.stem + ".xml")

            if not img_path.exists() or not ann_path.exists():
                raise FileNotFoundError(f"Missing file: {img_path} or {ann_path}")

            image = Image.open(img_path).convert("RGB")
            boxes, labels = parse_structure_annotations(ann_path)

            # Filter labels and boxes together
            filtered = [(box, lbl) for box, lbl in zip(boxes, labels) if lbl in self.label_map]
            if len(filtered) == 0:
                raise ValueError(f"No valid boxes/labels in {ann_path}")

            boxes, labels = zip(*filtered)
            boxes = torch.tensor(boxes, dtype=torch.float32)
            labels = torch.tensor([self.label_map[lbl] for lbl in labels], dtype=torch.int64)

            target = {
                "boxes": boxes,
                "labels": labels,
                "image_id": torch.tensor([idx])
            }

            if self.transforms:
                image = self.transforms(image)
            else:
                image = F.to_tensor(image)

            return image, target

        except (FileNotFoundError, ValueError, KeyError) as e:
            print(f"⚠️ Skipping idx {idx}: {e}")
            if retry_count < 3:
                return self.__getitem__(self.random_valid_index(idx), retry_count + 1)
            else:
                raise RuntimeError("Too many invalid samples in a row.")


Label Map

In [11]:
label_map = {
    "table": 0,
    "table row": 1,
    "table column": 2,
    "column header": 3,
    "projected row header": 4,
    "spanning cell": 5,
    "no cell": 6,
    "table spanning cell": 7,
    "table column header": 8,
    "table projected column header": 9,
}



Setup Transforms and DataLoaders

In [12]:
# --- 🔁 Transformations ---
transform = T.Compose([
    T.Resize((512, 512)),
    T.ToTensor(),
])

# --- 📁 Paths ---
base_path = Path("archive")  # replace this with your dataset root
train_dataset = PubTablesTSRDataset(base_path / "images", base_path / "train", label_map, transforms=transform)
val_dataset = PubTablesTSRDataset(base_path / "images", base_path / "val", label_map, transforms=transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))


Model (Using Faster R-CNN)

In [13]:
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# --- ⚙️ Load pre-trained Faster R-CNN ---
model = fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, len(label_map))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

Training Function

In [14]:
import torch.nn as nn
import torch.optim as optim

optimizer = optim.AdamW(model.parameters(), lr=1e-4)

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    for images, targets in tqdm(loader, desc="Training"):
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        total_loss += losses.item()

    return total_loss / len(loader)


Validation

In [15]:
@torch.no_grad()
def evaluate_model(model, loader, device):
    model.eval()
    total = 0
    for images, targets in tqdm(loader, desc="Evaluating"):
        images = list(img.to(device) for img in images)
        outputs = model(images)
        # you can add evaluation metrics here like IoU, AP etc
    print("✅ Evaluation completed.")


Training Loop

In [16]:
epochs = 1
best_loss = float("inf")

for epoch in range(epochs):
    print(f"\n📘 Epoch {epoch+1}/{epochs}")
    loss = train_one_epoch(model, train_loader, optimizer, device)
    print(f"📉 Training Loss: {loss:.4f}")

    # Save checkpoint every epoch
    os.makedirs("checkpoints", exist_ok=True)
    checkpoint_path = f"checkpoints/tsr_model_epoch_{epoch+1}.pth"
    torch.save(model.state_dict(), checkpoint_path)
    print(f"💾 Saved checkpoint: {checkpoint_path}")

    


📘 Epoch 1/1


Training: 100%|██████████| 875/875 [6:40:44<00:00, 27.48s/it]   


📉 Training Loss: 1.5911
💾 Saved checkpoint: checkpoints/tsr_model_epoch_1.pth


In [17]:
# Validation loss tracking (you can refine this with actual val loss or mAP)
evaluate_model(model, val_loader, device)

    # Save best model
if loss < best_loss:
        best_loss = loss
        torch.save(model.state_dict(), "best_tsr_model.pth")
        print("✅ Saved best model (best_tsr_model.pth)")

# Save final model
torch.save(model.state_dict(), "final_tsr_model.pth")
print("✅ Saved final model (final_tsr_model.pth)")


Evaluating: 100%|██████████| 1750/1750 [2:15:08<00:00,  4.63s/it] 


✅ Evaluation completed.
✅ Saved best model (best_tsr_model.pth)
✅ Saved final model (final_tsr_model.pth)
