# Import PyTorch modules

In [1]:
import torch
import torchvision
from torchvision.datasets import VOCDetection
import os

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", DEVICE)

Device: cpu


# Load datasets

In [2]:
import os

from pathlib import Path

from google.colab import drive

drive.mount('/content/drive')

SHARED_PATH = Path("drive/MyDrive/Colab Notebooks/Shared")
HRSC_PATH = SHARED_PATH / "HRSC2016_Final_Splits"
DOTA_PATH = SHARED_PATH / "DOTA_Final_Splits"

print("HRSC subfolders:", [f"{subfolder.name}/{path.name}" for subfolder in HRSC_PATH.iterdir() for path in subfolder.iterdir()])
print("DOTA subfolders:", [f"{subfolder.name}/{path.name}" for subfolder in DOTA_PATH.iterdir() for path in subfolder.iterdir()])

HRSC_TRAIN_IMAGES = HRSC_PATH/ "train/images"
HRSC_TRAIN_ANNOTATIONS = HRSC_PATH / "train/annotations"
HRSC_VAL_IMAGES = HRSC_PATH / "val/images"
HRSC_VAL_ANNOTATIONS = HRSC_PATH / "val/annotations"
HRSC_TEST_IMAGES = HRSC_PATH / "test/images"
HRSC_TEST_ANNOTATIONS = HRSC_PATH / "test/annotations"

DOTA_TRAIN_IMAGES = DOTA_PATH / "train/images"
DOTA_TRAIN_ANNOTATIONS = DOTA_PATH / "train/hbb"
DOTA_VAL_IMAGES = DOTA_PATH / "val/images"
DOTA_VAL_ANNOTATIONS = DOTA_PATH / "val/hbb"
DOTA_TEST_IMAGES = DOTA_PATH / "test/images"
DOTA_TEST_ANNOTATIONS = DOTA_PATH / "test/hbb"

Mounted at /content/drive
HRSC subfolders: ['train/images', 'train/annotations', 'val/images', 'val/annotations', 'test/images', 'test/annotations']
DOTA subfolders: ['train/images', 'train/hbb', 'val/images', 'val/hbb', 'test/images_without_hbb', 'test/images', 'test/hbb']


# Explore datasets

In [3]:
!pip install colorama
from colorama import Fore, Style

def explore_header(dataset: str, subfolder: str):
  print("Exploring", Fore.GREEN + dataset + Style.RESET_ALL, Fore.MAGENTA + subfolder + Style.RESET_ALL, "folder...")

def explore(images_folder: Path, ext: str):
  files = list(images_folder.glob(f"*.{ext}"))
  print(f"Number of {ext.upper()} files:", len(files))
  print(f"{ext.upper()} sample files:", [path.name for path in files[:3]])

explore_header("HRSC", "train")
explore(HRSC_TRAIN_IMAGES, "bmp")
explore(HRSC_TRAIN_ANNOTATIONS, "xml")

print()

explore_header("HRSC", "val")
explore(HRSC_VAL_IMAGES, "bmp")
explore(HRSC_VAL_ANNOTATIONS, "xml")

print()

explore_header("HRSC", "test")
explore(HRSC_TEST_IMAGES, "bmp")
explore(HRSC_TEST_ANNOTATIONS, "xml")

print()

explore_header("DOTA", "train")
explore(DOTA_TRAIN_IMAGES, "png")
explore(DOTA_TRAIN_ANNOTATIONS, "txt")

print()

explore_header("DOTA", "val")
explore(DOTA_VAL_IMAGES, "png")
explore(DOTA_VAL_ANNOTATIONS, "txt")

print()

explore_header("DOTA", "test")
explore(DOTA_TEST_IMAGES, "png")
explore(DOTA_TEST_ANNOTATIONS, "txt")

Collecting colorama
  Downloading colorama-0.4.6-py2.py3-none-any.whl.metadata (17 kB)
Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Installing collected packages: colorama
Successfully installed colorama-0.4.6
Exploring [32mHRSC[0m [35mtrain[0m folder...
Number of BMP files: 436
BMP sample files: ['100000001.bmp', '100000002.bmp', '100000004.bmp']
Number of XML files: 436
XML sample files: ['100000001.xml', '100000002.xml', '100000004.xml']

Exploring [32mHRSC[0m [35mval[0m folder...
Number of BMP files: 181
BMP sample files: ['100000006.bmp', '100000010.bmp', '100000622.bmp']
Number of XML files: 181
XML sample files: ['100000006.xml', '100000010.xml', '100000622.xml']

Exploring [32mHRSC[0m [35mtest[0m folder...
Number of BMP files: 453
BMP sample files: ['100000003.bmp', '100000005.bmp', '100000623.bmp']
Number of XML files: 453
XML sample files: ['100000003.xml', '100000005.xml', '100000623.xml']

Exploring [32mDOTA[0m [35mtrain[0m folder...
Number of PNG

# Define label parsers

In [4]:
import itertools
import xml.etree.ElementTree as ET

HRSC_CLASSES = {}
DOTA_CLASSES = {}


def parse_hrsc_labels(label_path: Path):
  boxes = []
  labels = []
  tree = ET.parse(label_path.as_posix())
  root = tree.getroot()

  # Get image dimensions
  width = int(root.find('.//Img_SizeWidth').text)
  height = int(root.find('.//Img_SizeHeight').text)

  objects = root.findall(".//HRSC_Object")
  for obj in objects:
    try:
      # Bounds
      xmin = float(obj.find('box_xmin').text)
      ymin = float(obj.find('box_ymin').text)
      xmax = float(obj.find('box_xmax').text)
      ymax = float(obj.find('box_ymax').text)

      # Category
      class_id = int(obj.find('Class_ID').text)
      if class_id not in HRSC_CLASSES:
        HRSC_CLASSES[class_id] = len(HRSC_CLASSES)
      class_id_index = HRSC_CLASSES[class_id]

      boxes.append([xmin, ymin, xmax, ymax])
      labels.append(class_id_index)

    except Exception as e:
      print(Fore.RED + "Warning" + Style.RESET_ALL + f": Could not parse object in {label_path.as_posix()}: {e}")
      continue

  return boxes, labels


def parse_dota_labels(label_path: Path):
  boxes = []
  labels = []

  # Find all objects
  for line in itertools.islice(label_path.read_text().splitlines(), 2, None):  # Start from line index 2
    try:
      obj = line.strip().split()
      x1, y1, x2, y2, x3, y3, x4, y4, category, difficulty = (*map(float, obj[:8]), *obj[8:])

      # Bounds
      xmin = min(x1, x2, x3, x4)
      ymin = min(y1, y2, y3, y4)
      xmax = max(x1, x2, x3, x4)
      ymax = max(y1, y2, y3, y4)

      # Category
      if category not in DOTA_CLASSES:
        DOTA_CLASSES[category] = len(DOTA_CLASSES)
      class_id_index = DOTA_CLASSES[category]

      boxes.append([xmin, ymin, xmax, ymax])
      labels.append(class_id_index)

    except Exception as e:
      print(Fore.RED + "Warning" + Style.RESET_ALL + f": Could not parse object in {label_path.as_posix()}: {e}")
      continue

  return boxes, labels

# PyTorch dataset structure

In [5]:
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

import torchvision.transforms as T

from PIL import Image

BASIC_TRANSFORM = T.Compose([T.ToTensor()])

class TorchDataset(torch.utils.data.Dataset):
  def __init__(self, images_folder: Path, annotations_folder: Path, label_parser, transforms=BASIC_TRANSFORM, max_files=10, max_objects=100) -> None:
    super().__init__()
    self.label_parser = label_parser
    self.transforms = transforms
    self.max_objects = max_objects
    images = [f for f in images_folder.iterdir() if f.suffix in (".jpg", ".png", ".bmp")] if images_folder.exists() else []
    if max_files > 0:
      images = images[:max_files]
    annotations = [f for f in annotations_folder.iterdir() if f.suffix in (".xml", ".txt")] if annotations_folder.exists() else []
    if max_files > 0:
      annotations = annotations[:max_files]

    # Keep only images/annotations one-to-one correspondences
    image_set = set(f.stem for f in images)
    annotation_set = set(f.stem for f in annotations)
    self.ids = image_set.intersection(annotation_set)
    self.images = {f.stem: f for f in images if f.stem in self.ids}
    self.annotations = {f.stem: f for f in annotations if f.stem in self.ids}
    self.ids = list(self.ids)

  def __getitem__(self, index):
    id = self.ids[index]
    image_path = self.images[id]
    label_path = self.annotations[id]
    image = Image.open(image_path).convert("RGB")
    boxes, labels = self.label_parser(label_path)
    if self.max_objects > 0:
      boxes = boxes[:self.max_objects]
      labels = labels[:self.max_objects]
    target = {
        "boxes": torch.as_tensor(boxes, dtype=torch.float32),
        "labels": torch.as_tensor(labels, dtype=torch.int64) + 1,  # Offset by 1 for background label
    }
    image = self.transforms(image)
    image.filepath = image_path;
    return image, target

  def __len__(self):
    return len(self.ids)

  def compute_total_number_of_objects(self, max_workers=8):
    def count_objects(id):
      boxes, _ = self.label_parser(self.annotations[id])
      return len(boxes)

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
      results = executor.map(count_objects, self.ids)
    return sum(results)

# Prepare datasets for PyTorch

In [6]:
print("Preparing HRSC training dataset...")
HRSC_TRAIN_DATASET = TorchDataset(HRSC_TRAIN_IMAGES, HRSC_TRAIN_ANNOTATIONS, parse_hrsc_labels)
print(f"...Dataset prepared: {HRSC_TRAIN_DATASET.compute_total_number_of_objects()} total objects")
print("HRSC training dataset sample:", HRSC_TRAIN_DATASET.ids[:5])

print("Preparing HRSC validation dataset...")
HRSC_VAL_DATASET = TorchDataset(HRSC_VAL_IMAGES, HRSC_VAL_ANNOTATIONS, parse_hrsc_labels)
print(f"...Dataset prepared: {HRSC_VAL_DATASET.compute_total_number_of_objects()} total objects")
print("HRSC validation dataset sample:", HRSC_VAL_DATASET.ids[:5])

print("Preparing HRSC testing dataset...")
HRSC_TEST_DATASET = TorchDataset(HRSC_TEST_IMAGES, HRSC_TEST_ANNOTATIONS, parse_hrsc_labels)
print(f"...Dataset prepared: {HRSC_TEST_DATASET.compute_total_number_of_objects()} total objects")
print("HRSC testing dataset sample:", HRSC_TEST_DATASET.ids[:5])

print("Preparing DOTA training dataset...")
DOTA_TRAIN_DATASET = TorchDataset(DOTA_TRAIN_IMAGES, DOTA_TRAIN_ANNOTATIONS, parse_dota_labels)
print(f"...Dataset prepared: {DOTA_TRAIN_DATASET.compute_total_number_of_objects()} total objects")
print("DOTA training dataset sample:", DOTA_TRAIN_DATASET.ids[:5])

print("Preparing DOTA validation dataset...")
DOTA_VAL_DATASET = TorchDataset(DOTA_VAL_IMAGES, DOTA_VAL_ANNOTATIONS, parse_dota_labels)
print(f"...Dataset prepared: {DOTA_VAL_DATASET.compute_total_number_of_objects()} total objects")
print("DOTA validation dataset sample:", DOTA_VAL_DATASET.ids[:5])

print("Preparing DOTA testing dataset...")
DOTA_TEST_DATASET = TorchDataset(DOTA_TEST_IMAGES, DOTA_TEST_ANNOTATIONS, parse_dota_labels)
print(f"...Dataset prepared: {DOTA_TEST_DATASET.compute_total_number_of_objects()} total objects")
print("DOTA testing dataset sample:", DOTA_TEST_DATASET.ids[:5])

Preparing HRSC training dataset...
...Dataset prepared: 13 total objects
HRSC training dataset sample: ['100000007', '100000638', '100000004', '100000001', '100000002']
Preparing HRSC validation dataset...
...Dataset prepared: 14 total objects
HRSC validation dataset sample: ['100000627', '100000010', '100000622', '100000637', '100000629']
Preparing HRSC testing dataset...
...Dataset prepared: 5 total objects
HRSC testing dataset sample: ['100000634', '100000625', '100000626', '100000628', '100000005']
Preparing DOTA training dataset...
...Dataset prepared: 523 total objects
DOTA training dataset sample: ['P0039', 'P0001', 'P0000']
Preparing DOTA validation dataset...
...Dataset prepared: 1235 total objects
DOTA validation dataset sample: ['P0007', 'P0003', 'P0019', 'P0056', 'P0027']
Preparing DOTA testing dataset...
...Dataset prepared: 1280 total objects
DOTA testing dataset sample: ['P0049', 'P0058', 'P0054', 'P0044', 'P0042']


# Define PyTorch Faster R-CNN model

In [7]:
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision import transforms as T

class FasterRCNNModel:
  def __init__(self, train_dataset: TorchDataset, val_dataset: TorchDataset, test_dataset: TorchDataset, class_names: list, batch_size=4, shuffle_datasets=False) -> None:
    self.train_dataset = train_dataset
    self.val_dataset = val_dataset
    self.test_dataset = test_dataset

    self.class_names = ["__background__"] + list(class_names)
    self.num_classes = len(self.class_names)

    collate_fn = lambda x: tuple(zip(*x))
    self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=batch_size, shuffle=shuffle_datasets, collate_fn=collate_fn)
    self.val_loader = torch.utils.data.DataLoader(self.val_dataset, batch_size=batch_size, shuffle=shuffle_datasets, collate_fn=collate_fn)
    if len(self.test_dataset) > 0:
        self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=batch_size, shuffle=shuffle_datasets, collate_fn=collate_fn)
    else:
        self.test_loader = None

    weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
    self.model = fasterrcnn_resnet50_fpn(weights=weights)
    self.in_features = self.model.roi_heads.box_predictor.cls_score.in_features
    self.model.roi_heads.box_predictor = FastRCNNPredictor(self.in_features, self.num_classes)
    self.model.to(DEVICE)

  def train(self, num_epochs=50):
    # TODO Add an "RoI Learner" after ROI pooling to predit rotated offsets (dx, dy, dw, dh, dtheta)
    # TODO Replace roi_align with torchvision.ops.roi_align_rotated
    # TODO Add rotation-aware losses

    # Example optimizer
    optimizer = torch.optim.SGD(self.model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

    for epoch in range(num_epochs):
      print(f"Starting epoch {epoch+1}/{num_epochs}...")

      # Training loop
      print("\tStarting training loop...")
      train_loss = 0.0
      self.model.train()
      for images, targets in self.train_loader:
        print(f"\t\tProcessing images: {[image.filepath.stem for image in images]}")
        images = [img.to(DEVICE) for img in images]
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        loss_dict = self.model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        train_loss += losses.item()

      lr_scheduler.step()
      print("\t...Training loop complete.")

      # Validation loop
      print("\tStarting validation loop...")
      val_loss = 0.0
      with torch.no_grad():
        self.model.train()
        for images, targets in self.val_loader:
          print(f"\t\tProcessing images: {[image.filepath.stem for image in images]}")
          images = [img.to(DEVICE) for img in images]
          targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
          loss_dict = self.model(images, targets)
          losses = sum(loss for loss in loss_dict.values())
          val_loss += losses.item()
      print("\t...Validation loop complete.")

      print(f"...Finished epoch {epoch+1}/{num_epochs}, Training loss: {train_loss:.4f}, Validation loss: {val_loss:.4f}")

    torch.save({
        "model_state_dict": self.model.state_dict(),
        "class_names": self.class_names
    }, "faster_rcnn_model.pth")

  def test_results(self):
    if self.test_loader is None:
        print("Test dataset is empty. Skipping evaluation.")
        return [], [], []

    self.model.eval()
    all_boxes, all_labels, all_scores = [], [], []
    with torch.no_grad():
      for images, _ in self.test_loader:
        print(f"\tProcessing images: {[image.filepath.stem for image in images]}")
        images = [img.to(DEVICE) for img in images]
        predictions = self.model(images)
        for prediction in predictions:
          boxes, labels, scores = prediction['boxes'], prediction['labels'], prediction['scores']
          all_boxes.append(boxes.cpu().numpy())
          all_labels.append(labels.cpu().numpy())
          all_scores.append(scores.cpu().numpy())
    return all_boxes, all_labels, all_scores

# TODO
# -------------------------------------------------------------------
# FUTURE UPGRADE: RoI TRANSFORMER INTEGRATION
# -------------------------------------------------------------------
# To integrate the RoI Transformer:
#   1. Add a custom RRoI Learner layer:
#        fc = nn.Linear(roi_feature_dim, 5)
#        -> predicts (dx, dy, dw, dh, dtheta)
#   2. Compute rotated boxes and apply torchvision.ops.roi_align_rotated
#   3. Replace standard SmoothL1 loss with rotated IoU or 5D regression loss
#   4. Use torchvision.ops.box_iou_rotated() and nms_rotated() for matching and inference
#   5. Dataset boxes should include rotation (x, y, w, h, theta)

# Prepare dataset models

In [8]:
HRSC_MODEL = FasterRCNNModel(HRSC_TRAIN_DATASET, HRSC_VAL_DATASET, HRSC_TEST_DATASET, [id for id, _ in sorted(HRSC_CLASSES.items(), key=lambda item: item[1])])
DOTA_MODEL = FasterRCNNModel(DOTA_TRAIN_DATASET, DOTA_VAL_DATASET, DOTA_TEST_DATASET, [id for id, _ in sorted(DOTA_CLASSES.items(), key=lambda item: item[1])])

Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth


100%|██████████| 160M/160M [00:03<00:00, 42.7MB/s]


# Train dataset models

In [9]:
print("Training HRSC model...")
HRSC_MODEL.train(3)
print("...HRSC model trained.")

print("Training DOTA model...")
DOTA_MODEL.train(3)
print("...DOTA model trained.")

Training HRSC model...
Starting epoch 1/3...
	Starting training loop...
		Processing images: ['100000007', '100000638', '100000004', '100000001']
		Processing images: ['100000002', '100000632', '100000631', '100000011']
		Processing images: ['100000008', '100000009']
	...Training loop complete.
	Starting validation loop...
		Processing images: ['100000627', '100000010', '100000622', '100000637']
		Processing images: ['100000629', '100000640', '100000644', '100000636']
		Processing images: ['100000006', '100000641']
	...Validation loop complete.
...Finished epoch 1/3, Training loss: 3.3431, Validation loss: 0.6569
Starting epoch 2/3...
	Starting training loop...
		Processing images: ['100000007', '100000638', '100000004', '100000001']
		Processing images: ['100000002', '100000632', '100000631', '100000011']
		Processing images: ['100000008', '100000009']
	...Training loop complete.
	Starting validation loop...
		Processing images: ['100000627', '100000010', '100000622', '100000637']
		P

# Evaluate dataset models

In [10]:
hrsc_boxes, hrsc_labels, hrsc_scores = HRSC_MODEL.test_results()
for score in hrsc_scores:
  print(score)
dota_boxes, dota_labels, dota_scores = DOTA_MODEL.test_results()
for score in dota_scores:
  print(score)

# TODO analyse results and visualize them

	Processing images: ['100000634', '100000625', '100000626', '100000628']
	Processing images: ['100000005', '100000624', '100000633', '100000630']
	Processing images: ['100000623', '100000003']
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
	Processing images: ['P0049', 'P0058', 'P0054', 'P0044']
	Processing images: ['P0042', 'P0010', 'P0012', 'P0013']
	Processing images: ['P0002']
[]
[]
[]
[]
[]
[]
[]
[]
[]
