# Import PyTorch modules

In [None]:
!pip install -q 'git+https://github.com/facebookresearch/detectron2.git'
import torch
import torchvision
from torchvision.datasets import VOCDetection
import os

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", DEVICE)

# Load datasets

In [None]:
import os

from pathlib import Path

from google.colab import drive

drive.mount('/content/drive')

SHARED_PATH = Path("drive/MyDrive/Colab Notebooks/Shared")
HRSC_PATH = SHARED_PATH / "HRSC2016_Final_Splits"
DOTA_PATH = SHARED_PATH / "DOTA_Final_Splits"
NWPU_PATH = SHARED_PATH / "NWPU_VHR-10_Final_Splits"

print("HRSC subfolders:", [f"{subfolder.name}/{path.name}" for subfolder in HRSC_PATH.iterdir() for path in subfolder.iterdir()])
print("DOTA subfolders:", [f"{subfolder.name}/{path.name}" for subfolder in DOTA_PATH.iterdir() for path in subfolder.iterdir()])
print("NWPU subfolders:", [f"{subfolder.name}/{path.name}" for subfolder in NWPU_PATH.iterdir() for path in subfolder.iterdir()])

HRSC_TRAIN_IMAGES = HRSC_PATH/ "train/images"
HRSC_TRAIN_ANNOTATIONS = HRSC_PATH / "train/annotations"
HRSC_VAL_IMAGES = HRSC_PATH / "val/images"
HRSC_VAL_ANNOTATIONS = HRSC_PATH / "val/annotations"
HRSC_TEST_IMAGES = HRSC_PATH / "test/images"
HRSC_TEST_ANNOTATIONS = HRSC_PATH / "test/annotations"

DOTA_TRAIN_IMAGES = DOTA_PATH / "train/images"
DOTA_TRAIN_ANNOTATIONS = DOTA_PATH / "train/hbb"
DOTA_VAL_IMAGES = DOTA_PATH / "val/images"
DOTA_VAL_ANNOTATIONS = DOTA_PATH / "val/hbb"
DOTA_TEST_IMAGES = DOTA_PATH / "test/images"
DOTA_TEST_ANNOTATIONS = DOTA_PATH / "test/hbb"

NWPU_TRAIN_IMAGES = NWPU_PATH / "train/images"
NWPU_TRAIN_ANNOTATIONS = NWPU_PATH / "train/annotations"
NWPU_VAL_IMAGES = NWPU_PATH / "val/images"
NWPU_VAL_ANNOTATIONS = NWPU_PATH / "val/annotations"
NWPU_TEST_IMAGES = NWPU_PATH / "test/images"
NWPU_TEST_ANNOTATIONS = NWPU_PATH / "test/annotations"

# Explore datasets

In [None]:
!pip install colorama
from colorama import Fore, Style

def explore_header(dataset: str, subfolder: str):
  print("Exploring", Fore.GREEN + dataset + Style.RESET_ALL, Fore.MAGENTA + subfolder + Style.RESET_ALL, "folder...")

def explore(images_folder: Path, ext: str):
  files = list(images_folder.glob(f"*.{ext}"))
  print(f"Number of {ext.upper()} files:", len(files))
  print(f"{ext.upper()} sample files:", [path.name for path in files[:3]])

explore_header("HRSC", "train")
explore(HRSC_TRAIN_IMAGES, "bmp")
explore(HRSC_TRAIN_ANNOTATIONS, "xml")

print()

explore_header("HRSC", "val")
explore(HRSC_VAL_IMAGES, "bmp")
explore(HRSC_VAL_ANNOTATIONS, "xml")

print()

explore_header("HRSC", "test")
explore(HRSC_TEST_IMAGES, "bmp")
explore(HRSC_TEST_ANNOTATIONS, "xml")

print()

explore_header("DOTA", "train")
explore(DOTA_TRAIN_IMAGES, "png")
explore(DOTA_TRAIN_ANNOTATIONS, "txt")

print()

explore_header("DOTA", "val")
explore(DOTA_VAL_IMAGES, "png")
explore(DOTA_VAL_ANNOTATIONS, "txt")

print()

explore_header("DOTA", "test")
explore(DOTA_TEST_IMAGES, "png")
explore(DOTA_TEST_ANNOTATIONS, "txt")

print()

explore_header("NWPU", "train")
explore(NWPU_TRAIN_IMAGES, "jpg")
explore(NWPU_TRAIN_ANNOTATIONS, "json")

print()

explore_header("NWPU", "val")
explore(NWPU_VAL_IMAGES, "jpg")
explore(NWPU_VAL_ANNOTATIONS, "json")

print()

explore_header("NWPU", "test")
explore(NWPU_TEST_IMAGES, "jpg")
explore(NWPU_TEST_ANNOTATIONS, "json")

# Define label parsers

In [None]:
import itertools
import xml.etree.ElementTree as ET
import numpy as np
import cv2

HRSC_CLASSES = {}


def parse_hrsc_labels(label_path: Path):
  boxes = []
  labels = []
  tree = ET.parse(label_path.as_posix())
  root = tree.getroot()

  objects = root.findall(".//HRSC_Object")
  for obj in objects:
    try:
      # Bounds
      center_x = float(obj.find('mbox_cx').text)
      center_y = float(obj.find('mbox_cy').text)
      width = float(obj.find('mbox_w').text)
      height = float(obj.find('mbox_h').text)
      angle = float(obj.find('mbox_ang').text) * 180 / torch.pi

      # Category
      class_id = int(obj.find('Class_ID').text)
      if class_id not in HRSC_CLASSES:
        HRSC_CLASSES[class_id] = len(HRSC_CLASSES)
      class_id_index = HRSC_CLASSES[class_id]

      boxes.append([center_x, center_y, width, height, angle])
      labels.append(class_id_index)

    except Exception as e:
      print(Fore.RED + "Warning" + Style.RESET_ALL + f": Could not parse object in {label_path.as_posix()}: {e}")
      continue

  return boxes, labels


def hrsc_image_rescale(label_path: Path, image_width, image_height):
  tree = ET.parse(label_path.as_posix())
  root = tree.getroot()
  width = int(root.find('.//Img_SizeWidth').text)
  height = int(root.find('.//Img_SizeHeight').text)
  return image_width / width, image_height / height


DOTA_CLASSES = {}


def parse_dota_labels(label_path: Path):
  boxes = []
  labels = []

  # Find all objects
  for line in itertools.islice(label_path.read_text().splitlines(), 2, None):  # Start from line index 2
    try:
      obj = line.strip().split()
      x1, y1, x2, y2, x3, y3, x4, y4, category, difficulty = (*map(float, obj[:8]), *obj[8:])

      # Bounds
      points = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
      pts_np = np.array(points, dtype=np.float32).reshape(-1, 1, 2)
      (cx, cy), (w, h), angle = cv2.minAreaRect(pts_np)

      # Category
      if category not in DOTA_CLASSES:
        DOTA_CLASSES[category] = len(DOTA_CLASSES)
      class_id_index = DOTA_CLASSES[category]

      boxes.append([cx, cy, w, h, angle])
      labels.append(class_id_index)

    except Exception as e:
      print(Fore.RED + "Warning" + Style.RESET_ALL + f": Could not parse object in {label_path.as_posix()}: {e}")
      continue

  return boxes, labels


def dota_image_rescale(label_path: Path, image_width, image_height):
  return 1.0, 1.0


NWPU_CLASSES = {}


def parse_nwpu_labels(label_path: Path):
  boxes = []
  labels = []

  data = json.loads(label_path.read_text())
  for category in data['categories']:
    NWPU_CLASSES[category['name']] = category['id'] - 1  # categories are 1-indexed -> convert to 0-indexed

  for annotation in data['annotations']:
    try:
      # Bounds
      segmentation = annotation['segmentation'][0]
      pts_np = np.array(segmentation, dtype=np.float32).reshape(-1, 2)
      (cx, cy), (w, h), angle = cv2.minAreaRect(pts_np)

      # Category
      class_id_index = annotation['category_id'] - 1

      boxes.append([cx, cy, w, h, angle])
      labels.append(class_id_index)

    except Exception as e:
      print(Fore.RED + "Warning" + Style.RESET_ALL + f": Could not parse object in {label_path.as_posix()}: {e}")
      continue

  return boxes, labels

def nwpu_image_rescale(label_path: Path, image_width, image_height):
  data = json.loads(label_path.read_text())
  image = data['images'][0]
  width = image['width']
  height = image['height']
  return image_width / width, image_height / height

# PyTorch dataset structure

In [None]:
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

import torchvision.transforms as T
from torchvision.transforms import functional as F

from PIL import Image

BASIC_TRANSFORM = T.Compose([T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

class TorchDataset(torch.utils.data.Dataset):
  def __init__(self, images_folder: Path, annotations_folder: Path, label_parser, image_rescale, transforms=BASIC_TRANSFORM, max_files=10, target_size=800) -> None:
    super().__init__()
    self.label_parser = label_parser
    self.image_rescale = image_rescale
    self.transforms = transforms
    self.target_size = target_size

    images = (
      sorted(
        [f for f in images_folder.iterdir() if f.suffix.lower() in (".jpg", ".png", ".bmp")],
        key=lambda p: p.stem.lower()
      )
      if images_folder.exists() else []
    )

    annotations = (
      sorted(
        [f for f in annotations_folder.iterdir() if f.suffix.lower() in (".xml", ".txt", ".json")],
        key=lambda p: p.stem.lower()
      )
      if annotations_folder.exists() else []
    )

    if max_files > 0:
      images = images[:max_files]
      annotations = annotations[:max_files]

    # Keep only images/annotations one-to-one correspondences
    image_set = set(f.stem for f in images)
    annotation_set = set(f.stem for f in annotations)
    self.ids = image_set.intersection(annotation_set)
    self.images = {f.stem: f for f in images if f.stem in self.ids}
    self.annotations = {f.stem: f for f in annotations if f.stem in self.ids}
    self.ids = list(self.ids)

  def __getitem__(self, index):
    id = self.ids[index]
    image_path = self.images[id]
    label_path = self.annotations[id]
    image = Image.open(image_path).convert("RGB")
    boxes, labels = self.label_parser(label_path)

    boxes = torch.as_tensor(boxes, dtype=torch.float32)
    labels = torch.as_tensor(labels, dtype=torch.int64) + 1  # Offset by 1 for background label

    image, boxes = self.preprocess(image, boxes)
    # Convert to tensor (normalized 0-1)
    image = F.to_tensor(image)

    target = {"boxes": boxes, "labels": labels}

    image = self.transforms(image)
    image.filepath = image_path;
    return image, target

  def preprocess(self, img, boxes):
    old_w, old_h = img.width, img.height
    scale = self.target_size / max(old_h, old_w)  # uniform scale

    new_w = int(old_w * scale)
    new_h = int(old_h * scale)

    # Resize image
    img = F.resize(img, (new_h, new_w))

    # Scale boxes
    boxes = boxes.clone()
    boxes[:, 0] *= scale  # cx
    boxes[:, 1] *= scale  # cy
    boxes[:, 2] *= scale  # w
    boxes[:, 3] *= scale  # h
    # theta stays unchanged

    # Pad to target_size x target_size (right and bottom)
    pad_w = self.target_size - new_w
    pad_h = self.target_size - new_h
    img = F.pad(img, (0, 0, pad_w, pad_h))
    return img, boxes

  def __len__(self):
    return len(self.ids)

  def get_image_rescale(self, index):
    id = self.ids[index]
    image_path = self.images[id]
    label_path = self.annotations[id]
    image = Image.open(image_path).convert("RGB")
    return self.image_rescale(label_path, image.width, image.height)

  def compute_total_number_of_objects(self, max_workers=8):
    def count_objects(id):
      boxes, _ = self.label_parser(self.annotations[id])
      return len(boxes)

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
      results = executor.map(count_objects, self.ids)
    return sum(results)

# Prepare datasets for PyTorch

In [None]:
print("Preparing HRSC training dataset...")
HRSC_TRAIN_DATASET = TorchDataset(HRSC_TRAIN_IMAGES, HRSC_TRAIN_ANNOTATIONS, parse_hrsc_labels, hrsc_image_rescale)
print(f"...Dataset prepared: {HRSC_TRAIN_DATASET.compute_total_number_of_objects()} total objects")
print("HRSC training dataset sample:", HRSC_TRAIN_DATASET.ids[:5])

print("Preparing HRSC validation dataset...")
HRSC_VAL_DATASET = TorchDataset(HRSC_VAL_IMAGES, HRSC_VAL_ANNOTATIONS, parse_hrsc_labels, hrsc_image_rescale)
print(f"...Dataset prepared: {HRSC_VAL_DATASET.compute_total_number_of_objects()} total objects")
print("HRSC validation dataset sample:", HRSC_VAL_DATASET.ids[:5])

print("Preparing HRSC testing dataset...")
HRSC_TEST_DATASET = TorchDataset(HRSC_TEST_IMAGES, HRSC_TEST_ANNOTATIONS, parse_hrsc_labels, hrsc_image_rescale)
print(f"...Dataset prepared: {HRSC_TEST_DATASET.compute_total_number_of_objects()} total objects")
print("HRSC testing dataset sample:", HRSC_TEST_DATASET.ids[:5])

print("Preparing DOTA training dataset...")
DOTA_TRAIN_DATASET = TorchDataset(DOTA_TRAIN_IMAGES, DOTA_TRAIN_ANNOTATIONS, parse_dota_labels, dota_image_rescale)
print(f"...Dataset prepared: {DOTA_TRAIN_DATASET.compute_total_number_of_objects()} total objects")
print("DOTA training dataset sample:", DOTA_TRAIN_DATASET.ids[:5])

print("Preparing DOTA validation dataset...")
DOTA_VAL_DATASET = TorchDataset(DOTA_VAL_IMAGES, DOTA_VAL_ANNOTATIONS, parse_dota_labels, dota_image_rescale)
print(f"...Dataset prepared: {DOTA_VAL_DATASET.compute_total_number_of_objects()} total objects")
print("DOTA validation dataset sample:", DOTA_VAL_DATASET.ids[:5])

print("Preparing DOTA testing dataset...")
DOTA_TEST_DATASET = TorchDataset(DOTA_TEST_IMAGES, DOTA_TEST_ANNOTATIONS, parse_dota_labels, dota_image_rescale)
print(f"...Dataset prepared: {DOTA_TEST_DATASET.compute_total_number_of_objects()} total objects")
print("DOTA testing dataset sample:", DOTA_TEST_DATASET.ids[:5])

print("Preparing NWPU training dataset...")
NWPU_TRAIN_DATASET = TorchDataset(NWPU_TRAIN_IMAGES, NWPU_TRAIN_ANNOTATIONS, parse_nwpu_labels, nwpu_image_rescale)
print(f"...Dataset prepared: {NWPU_TRAIN_DATASET.compute_total_number_of_objects()} total objects")
print("NWPU training dataset sample:", NWPU_TRAIN_DATASET.ids[:5])

print("Preparing NWPU validation dataset...")
NWPU_VAL_DATASET = TorchDataset(NWPU_VAL_IMAGES, NWPU_VAL_ANNOTATIONS, parse_nwpu_labels, nwpu_image_rescale)
print(f"...Dataset prepared: {NWPU_VAL_DATASET.compute_total_number_of_objects()} total objects")
print("NWPU validation dataset sample:", NWPU_VAL_DATASET.ids[:5])

print("Preparing NWPU testing dataset...")
NWPU_TEST_DATASET = TorchDataset(NWPU_TEST_IMAGES, NWPU_TEST_ANNOTATIONS, parse_nwpu_labels, nwpu_image_rescale)
print(f"...Dataset prepared: {NWPU_TEST_DATASET.compute_total_number_of_objects()} total objects")
print("NWPU testing dataset sample:", NWPU_TEST_DATASET.ids[:5])

# RRoI Layer

In [None]:
import torch
import torch.nn as nn
from torchvision.models.detection.roi_heads import RoIHeads
from detectron2.layers import roi_align_rotated

class RRoILearner(nn.Module):
  """Learns rotation offsets for each horizontal RoI."""
  def __init__(self, in_features):
    super().__init__()
    self.fc = nn.Linear(in_features, 5)  # predicts dx, dy, dw, dh, dtheta

  def forward(self, roi_features):
    offsets = self.fc(roi_features)
    return offsets


class RotatedBoxPredictor(nn.Module):
  def __init__(self, in_channels, num_classes):
    super().__init__()
    self.cls_score = nn.Linear(in_channels, num_classes)
    self.bbox_pred = nn.Linear(in_channels, num_classes * 5)

  def forward(self, x):
    if x.dim() == 4:
      x = x.flatten(start_dim=1)
    scores = self.cls_score(x)
    bbox_deltas = self.bbox_pred(x)
    return scores, bbox_deltas


class RotatedRoIHeads(RoIHeads):
  def __init__(self, box_roi_pool, box_head, box_predictor,
               fg_iou_thresh=0.5,
               bg_iou_thresh=0.5,
               batch_size_per_image=512,
               positive_fraction=0.25,
               bbox_reg_weights=None,
               score_thresh=0.05,
               nms_thresh=0.5,
               detections_per_img=100):
    super().__init__(
      box_roi_pool,
      box_head,
      box_predictor,
      fg_iou_thresh,
      bg_iou_thresh,
      batch_size_per_image,
      positive_fraction,
      bbox_reg_weights,
      score_thresh,
      nms_thresh,
      detections_per_img
    )
    in_features = self.box_head.fc7.out_features
    self.rroi_learner = RRoILearner(in_features)

  def forward(self, features, proposals, image_shapes, targets=None):
    # 1. Standard RoI feature extraction (horizontal RoIs)
    box_features = self.box_roi_pool(features, proposals, image_shapes)
    box_features = self.box_head(box_features)

    # Flatten features per RoI for offset prediction
    pooled_features = box_features
    offsets = self.rroi_learner(pooled_features)  # (N_total_rois, 5)

    # 2. Decode rotated boxes relative to horizontal proposals
    rotated_boxes_list = []
    start_idx = 0
    for i, props in enumerate(proposals):
      num_props = props.shape[0]
      offs = offsets[start_idx:start_idx + num_props]

      cx = (props[:, 0] + props[:, 2]) / 2
      cy = (props[:, 1] + props[:, 3]) / 2
      w = props[:, 2] - props[:, 0]
      h = props[:, 3] - props[:, 1]

      # Apply predicted offsets
      cx_rot = cx + offs[:, 0] * w
      cy_rot = cy + offs[:, 1] * h
      w_rot = w * torch.exp(offs[:, 2])
      h_rot = h * torch.exp(offs[:, 3])
      angle_rot = offs[:, 4]

      batch_idx = torch.full_like(cx, i, dtype=torch.float32)
      rotated = torch.stack([batch_idx, cx_rot, cy_rot, w_rot, h_rot, angle_rot], dim=1)
      rotated_boxes_list.append(rotated)

      start_idx += num_props

    rotated_boxes = torch.cat(rotated_boxes_list, dim=0)

    # 3. Rotated RoI Align
    # Ensure features is a dictionary (which it should be from FPN backbone) and extract the first feature map (typically P2, which is 1/4 original size).
    if isinstance(features, dict) and len(features) > 0:
      feature_map = next(iter(features.values())) # Robust way to get the first feature map
      spatial_scale = 1.0 / 4.0 # For P2 feature map (1/4 downsampling)
    else:
      raise ValueError("Backbone features are not a non-empty dictionary as expected for Rotated RoI Align.")

    rotated_features = roi_align_rotated(
      feature_map, # input
      rotated_boxes, # boxes
      (7, 7), # output_size
      spatial_scale, # spatial_scale
      2 # sampling_ratio
    )

    # 3b. Run through the box_head (this flattens 7×7 → vector)
    rotated_features = self.box_head(rotated_features)

    # 4. Classification and regression as usual
    class_logits, box_regression = self.box_predictor(rotated_features)

    return class_logits, box_regression, offsets


from shapely.geometry import Polygon

def rotated_box_to_polygon(box):
  # box: tensor [5] = (cx, cy, w, h, angle_in_radians)
  if box.numel() == 4:
    # axis-aligned box
    x1, y1, x2, y2 = box.tolist()
    cx = (x1 + x2) / 2
    cy = (y1 + y2) / 2
    w = x2 - x1
    h = y2 - y1
    angle = 0.0
  elif box.numel() == 5:
    cx, cy, w, h, angle = box.tolist()
  else:
    raise ValueError(f"Box must have 4 or 5 elements, got {box.numel()}")

  # Half sizes
  hw = w / 2
  hh = h / 2

  # Corner points before rotation
  corners = torch.tensor([
    [-hw, -hh],
    [ hw, -hh],
    [ hw,  hh],
    [-hw,  hh]
  ], dtype=torch.float32)

  # Rotation matrix
  cosA = torch.cos(torch.tensor(angle))
  sinA = torch.sin(torch.tensor(angle))
  R = torch.tensor([
    [cosA, -sinA],
    [sinA,  cosA]
  ])

  # Rotate + translate corners
  rotated = corners @ R.T
  rotated[:, 0] += cx
  rotated[:, 1] += cy

  return Polygon(rotated.tolist())


def box_iou_rotated(boxes1, boxes2):
  # boxes1: [N,5]
  # boxes2: [M,5]
  # returns IoU matrix [N, M]

  N = boxes1.shape[0]
  M = boxes2.shape[0]

  ious = torch.zeros((N, M), dtype=torch.float32)

  for i in range(N):
    p1 = rotated_box_to_polygon(boxes1[i])
    area1 = p1.area
    for j in range(M):
      p2 = rotated_box_to_polygon(boxes2[j])

      inter = p1.intersection(p2).area
      union = area1 + p2.area - inter

      if union > 0:
        ious[i, j] = inter / union
      else:
        ious[i, j] = 0.0

  return ious

# Define PyTorch Rotated Faster R-CNN Model

In [None]:
from torchvision.models import ResNet50_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.image_list import ImageList
from torchvision.ops import boxes as box_ops
from torchvision import transforms as T

class RotatedFasterRCNNModel(FasterRCNN):
  def __init__(self, model_path: Path, train_dataset: TorchDataset, val_dataset: TorchDataset, test_dataset: TorchDataset, class_names: list, batch_size=4, shuffle_datasets=False) -> None:
    self.model_path = model_path
    self.train_dataset = train_dataset
    self.val_dataset = val_dataset
    self.test_dataset = test_dataset

    self.class_names = ["__background__"] + list(class_names)
    self.num_classes = len(self.class_names)

    collate_fn = lambda x: tuple(zip(*x))
    self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=batch_size, shuffle=shuffle_datasets, collate_fn=collate_fn)
    self.val_loader = torch.utils.data.DataLoader(self.val_dataset, batch_size=batch_size, shuffle=shuffle_datasets, collate_fn=collate_fn)
    if len(self.test_dataset) > 0:
        self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=batch_size, shuffle=shuffle_datasets, collate_fn=collate_fn)
    else:
        self.test_loader = None

    super().__init__(backbone=resnet_fpn_backbone(backbone_name='resnet50', weights=ResNet50_Weights.DEFAULT), num_classes=self.num_classes, rpn_anchor_generator=None)
    self.in_features = self.roi_heads.box_predictor.cls_score.in_features
    self.box_detector = RotatedBoxPredictor(self.in_features, self.num_classes)
    self.roi_heads = RotatedRoIHeads(self.roi_heads.box_roi_pool, self.roi_heads.box_head, self.box_detector)
    self.to(DEVICE)

  def train_model(self, num_epochs=50, learning_rate=0.0005):
    # Example optimizer
    optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0005)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

    for epoch in range(num_epochs):
      print(f"Starting epoch {epoch+1}/{num_epochs}...")

      # Training loop
      print("\tStarting training loop...")
      train_loss = 0.0
      self.train()
      for images, targets in self.train_loader:
        loss_dict = self(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        train_loss += losses.item()

      lr_scheduler.step()
      print("\t...Training loop complete.")

      # Validation loop
      print("\tStarting validation loop...")
      val_loss = 0.0
      with torch.no_grad():
        self.train()
        for images, targets in self.val_loader:
          loss_dict = self(images, targets)
          losses = sum(loss for loss in loss_dict.values())
          val_loss += losses.item()
      print("\t...Validation loop complete.")

      print(f"...Finished epoch {epoch+1}/{num_epochs}, Training loss: {train_loss:.4f}, Validation loss: {val_loss:.4f}")

    self.save_weights()

  def forward(self, images, targets=None):
    """
    images: list of tensors [C,H,W]
    targets: list of dicts {'boxes': [N,5], 'labels': [N]}
    """
    self.training = targets is not None and self.training

    # 1. Move images to device
    images = [img.to(DEVICE) for img in images]
    image_sizes = [img.shape[-2:] for img in images]  # list of (H,W)
    images = ImageList(torch.stack(images), image_sizes)

    # 2. Backbone
    features = self.backbone(images.tensors)

    # 3. RPN proposals
    rpn_targets = []
    for t in targets:
      cx, cy, w, h, theta = t['boxes'].unbind(1)
      x1 = cx - w/2
      y1 = cy - h/2
      x2 = cx + w/2
      y2 = cy + h/2
      rpn_targets.append({
        'boxes': torch.stack([x1, y1, x2, y2], dim=1),
        'labels': t['labels']
      })

    proposals, rpn_losses = self.rpn(images, features, rpn_targets)

    # 4. Rotated RoI heads
    class_logits, rotated_boxes, pred_offsets = self.roi_heads(features, proposals, images.image_sizes, targets)

    # 5. Compute losses
    loss_dict = {}
    if self.training and targets is not None:
      # RPN losses
      loss_dict.update(rpn_losses)
      # Rotated box loss
      loss_dict['loss_rbox'] = self.rotated_box_loss(pred_offsets, proposals, targets)
      # Classification loss
      proposal_boxes = self.xyxy_to_cxcywh_angle(torch.cat(proposals, dim=0))  # [N_rois,5]
      roi_labels = self.assign_labels_to_rois(proposal_boxes, targets)
      loss_dict['loss_classifier'] = nn.CrossEntropyLoss()(class_logits, roi_labels)

    if self.training:
      return loss_dict
    else:
      # inference
      return class_logits, rotated_boxes

  @staticmethod
  def xyxy_to_cxcywh_angle(boxes):
    # boxes: [N,4] = x1,y1,x2,y2 or [N,5]
    if boxes.shape[1] == 4:
      x1, y1, x2, y2 = boxes.unbind(dim=1)
      cx = (x1 + x2) / 2
      cy = (y1 + y2) / 2
      w = x2 - x1
      h = y2 - y1
      angle = torch.zeros_like(cx)
      return torch.stack([cx, cy, w, h, angle], dim=1)
    elif boxes.shape[1] == 5:
      return boxes
    else:
      raise ValueError(f"Boxes must be 4 or 5 dims, got {boxes.shape[1]}")

  @staticmethod
  def assign_labels_to_rois(proposals, targets, iou_threshold=0.5):
    """
    proposals: [N_total_rois, 5]  (cx,cy,w,h,θ)
    targets: list of dicts {'boxes':[M_i,5], 'labels':[M_i]}
    returns:
      roi_labels: [N_total_rois]
    """
    device = proposals.device
    target_boxes = torch.cat([t['boxes'] for t in targets], dim=0).to(device)
    target_labels = torch.cat([t['labels'] for t in targets], dim=0).to(device)

    ious = box_iou_rotated(proposals, target_boxes)  # [N_rois, N_gt]
    max_iou, max_idx = ious.max(dim=1)

    roi_labels = torch.zeros(proposals.shape[0], dtype=torch.long, device=device)  # background=0
    positive_mask = max_iou > iou_threshold
    roi_labels[positive_mask] = target_labels[max_idx[positive_mask]]

    return roi_labels

  @staticmethod
  def rotated_box_loss(pred_offsets, proposals, targets):
    # Flatten proposals into a single tensor [num_rois, 5 or 4]
    proposal_boxes = torch.cat(proposals, dim=0).to(DEVICE)  # [N_total, 4]
    proposal_boxes = RotatedFasterRCNNModel.xyxy_to_cxcywh_angle(proposal_boxes)    # convert to [N_total, 5]

    target_boxes = torch.cat([t["boxes"] for t in targets], dim=0).to(DEVICE)  # [M,5]

    # Match proposals to the target boxes (IoU-based)
    # Using rotated IoU if available
    with torch.no_grad():
      # Compute IoU matrix
      ious = box_iou_rotated(proposal_boxes, target_boxes)

      # For each RoI, take best match
      matched_vals, matched_idx = ious.max(dim=1)
      matched_gt = target_boxes[matched_idx]

      # Compute rotated regression targets (5-d)
      px, py, pw, ph, pa = proposal_boxes.unbind(dim=1)
      gx, gy, gw, gh, ga = matched_gt.unbind(dim=1)

      # Eq(1) from the paper
      dx = ((gx - px) * torch.cos(pa) + (gy - py) * torch.sin(pa)) / pw
      dy = ((gy - py) * torch.cos(pa) - (gx - px) * torch.sin(pa)) / ph
      dw = torch.log(gw / pw)
      dh = torch.log(gh / ph)
      dtheta = ((ga - pa) % (2 * torch.pi)) / (2 * torch.pi)
      target_offsets = torch.stack([dx, dy, dw, dh, dtheta], dim=1)

    loss_fn = nn.SmoothL1Loss()
    return loss_fn(pred_offsets, target_offsets)

  def test_results(self):
    if self.test_loader is None:
        print("Test dataset is empty. Skipping evaluation.")
        return [], [], []

    self.load_weights()
    all_boxes, all_labels, all_scores = [], [], []
    with torch.no_grad():
      self.eval()
      for images, _ in self.test_loader:
        images = [img.to(DEVICE) for img in images]
        predictions = self(images)
        for prediction in predictions:
          boxes, labels, scores = prediction['boxes'], prediction['labels'], prediction['scores']
          all_boxes.append(boxes.detach().cpu().numpy())
          all_labels.append(labels.detach().cpu().numpy())
          all_scores.append(scores.detach().cpu().numpy())
    return all_boxes, all_labels, all_scores

  def save_weights(self):
    torch.save({
        "model_state_dict": self.state_dict(),
        "class_names": self.class_names
    }, self.model_path)

  def load_weights(self):
    checkpoint = torch.load(self.model_path, map_location=DEVICE)
    self.load_state_dict(checkpoint["model_state_dict"])
    self.class_names = checkpoint["class_names"]


# TODO Replace standard SmoothL1 loss with rotated IoU or 5D regression loss
# TODO use nms_rotated() for matching and inference?

# Prepare dataset models

In [None]:
HRSC_MODEL = RotatedFasterRCNNModel(Path("hrsc_faster_rcnn_model.pth"), HRSC_TRAIN_DATASET, HRSC_VAL_DATASET, HRSC_TEST_DATASET, [id for id, _ in sorted(HRSC_CLASSES.items(), key=lambda item: item[1])])
DOTA_MODEL = RotatedFasterRCNNModel(Path("dota_faster_rcnn_model.pth"), DOTA_TRAIN_DATASET, DOTA_VAL_DATASET, DOTA_TEST_DATASET, [id for id, _ in sorted(DOTA_CLASSES.items(), key=lambda item: item[1])])
NWPU_MODEL = RotatedFasterRCNNModel(Path("nwpu_faster_rcnn_model.pth"), NWPU_TRAIN_DATASET, NWPU_VAL_DATASET, NWPU_TEST_DATASET, [id for id, _ in sorted(NWPU_CLASSES.items(), key=lambda item: item[1])])

# Train dataset models

In [None]:
if HRSC_MODEL.model_path.exists():
  print("Loading HRSC weights...")
  HRSC_MODEL.load_weights()
  print("...HRSC weights loaded.")
else:
  print("Training HRSC model...")
  HRSC_MODEL.train_model(3)
  print("...HRSC model trained.")

if DOTA_MODEL.model_path.exists():
  print("Loading DOTA weights...")
  DOTA_MODEL.load_weights()
  print("...DOTA weights loaded.")
else:
  print("Training DOTA model...")
  DOTA_MODEL.train_model(3)
  print("...DOTA model trained.")

if NWPU_MODEL.model_path.exists():
  print("Loading NWPU weights...")
  NWPU_MODEL.load_weights()
  print("...NWPU weights loaded.")
else:
  print("Training NWPU model...")
  NWPU_MODEL.train(3)
  print("...NWPU model trained.")

# Evaluate dataset models

In [None]:
HRSC_MODEL.roi_heads.score_thresh = 1.0e-6
HRSC_PRED_BOXES, HRSC_PRED_LABELS, HRSC_PRED_SCORES = HRSC_MODEL.test_results()

DOTA_MODEL.roi_heads.score_thresh = 0.012
DOTA_PRED_BOXES, DOTA_PRED_LABELS, DOTA_PRED_SCORES = DOTA_MODEL.test_results()

NWPU_MODEL.model.roi_heads.score_thresh = 0.001
NWPU_PRED_BOXES, NWPU_PRED_LABELS, NWPU_PRED_SCORES = NWPU_MODEL.test_results()

# Visualization

In [None]:
!pip install opencv-python

from concurrent.futures import ThreadPoolExecutor, as_completed

import cv2
import numpy as np

class Visualizer:
  def __init__(self, test_dataset: TorchDataset, class_names: list, boxes: list, labels: list, scores: list, results_folder: Path):
    self.test_dataset = test_dataset
    self.class_names = class_names
    self.boxes = boxes
    self.labels = labels
    self.scores = scores
    self.results_folder = results_folder
    self.results_folder.mkdir(parents=True, exist_ok=True)

  def overlay_rotated_box(self, output, box, wmult, hmult, color, label, score):
    center_x, center_y, width, height, theta = box

    category = self.class_names[label]
    if score is None:
      text = f"{category}"
    else:
      text = f"{category} - score={score:.3g}"

    center_x = float(center_x) * wmult
    center_y = float(center_y) * hmult
    width = float(width) * wmult
    height = float(height) * hmult
    # TODO scale theta?
    box_points = cv2.boxPoints(((center_x, center_y), (width, height), theta)).astype(np.int32)

    cv2.drawContours(output, [box_points], 0, color, 1)
    text_pos = tuple(box_points[1])
    cv2.putText(output, text, text_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)

  def visualize(self, index):
    image, target = self.test_dataset[index]
    wmult, hmult = self.test_dataset.get_image_rescale(index)

    # Convert tensor image (C,H,W) → numpy (H,W,C)
    if isinstance(image, np.ndarray):
      output = image.copy()
    else:
      output = image.permute(1, 2, 0).cpu().numpy()
      output = (output * 255).astype(np.uint8).copy()

    # Ground truth
    true_boxes = target["boxes"]
    true_labels = target["labels"]

    # Predictions
    predicted_boxes = self.boxes[index]
    predicted_labels = self.labels[index]
    predicted_scores = self.scores[index]

    # Draw boxes
    for box, label in zip(true_boxes, true_labels):
      self.overlay_rotated_box(output, box, wmult, hmult, (0, 255, 0), label, None)

    for box, label, score in zip(predicted_boxes, predicted_labels, predicted_scores):
      self.overlay_rotated_box(output, box, 1.0, 1.0, (255, 0, 0), label, score)

    # Save output
    output_path = self.results_folder / f"{self.test_dataset.ids[index]}.png"
    cv2.imwrite(output_path.as_posix(), cv2.cvtColor(output, cv2.COLOR_RGB2BGR))
    print(f"Saved visualization to {output_path.as_posix()}")

  def visualize_multiple(self, count = 100, start_index = 0, max_workers=4):
    end_index = min(start_index + count, len(self.test_dataset)) if count > 0 else len(self.test_dataset)
    indices = list(range(start_index, end_index))

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
      futures = {executor.submit(self.visualize, i): i for i in indices}
      for future in as_completed(futures):
        idx = futures[future]
        try:
          future.result()
        except Exception as e:
          print(Fore.RED + f"Visualization failed for index {idx}:" + Style.RESET_ALL, e)

# Visualize results

In [None]:
RESULTS_PARENT_FOLDER = Path("drive/MyDrive/Colab Notebooks/Results")
HRSC_VISUALIZER = Visualizer(HRSC_TEST_DATASET, HRSC_MODEL.class_names, HRSC_PRED_BOXES, HRSC_PRED_LABELS, HRSC_PRED_SCORES, RESULTS_PARENT_FOLDER / "HRSC")
DOTA_VISUALIZER = Visualizer(DOTA_TEST_DATASET, DOTA_MODEL.class_names, DOTA_PRED_BOXES, DOTA_PRED_LABELS, DOTA_PRED_SCORES, RESULTS_PARENT_FOLDER / "DOTA")
NWPU_VISUALIZER = Visualizer(NWPU_TEST_DATASET, NWPU_MODEL.class_names, NWPU_PRED_BOXES, NWPU_PRED_LABELS, NWPU_PRED_SCORES, RESULTS_PARENT_FOLDER / "unrotated" / "NWPU")

HRSC_VISUALIZER.visualize_multiple()
DOTA_VISUALIZER.visualize_multiple()
NWPU_VISUALIZER.visualize_multiple()

# Statistics Computation

In [None]:
!pip install torchmetrics
from torchmetrics.detection.mean_ap import MeanAveragePrecision

class Statistician:
  def __init__(self, test_dataset: TorchDataset, predicted_boxes: list, predicted_labels: list, predicted_scores: list):
    self.metric = MeanAveragePrecision()
    self.targets = [target for image, target in test_dataset]
    self.predictions = [
        {
            "boxes": torch.as_tensor(boxes, dtype=torch.float32),
            "labels": torch.as_tensor(labels, dtype=torch.int64),
            "scores": torch.as_tensor(scores, dtype=torch.float32)
        }
        for boxes, labels, scores in zip(predicted_boxes, predicted_labels, predicted_scores)
    ]
    self.metric.update(self.predictions, self.targets)
    self.result = self.metric.compute()

  def get_map(self):
    return self.result["map"]

  def get_map_percent(self):
    return self.get_map().detach().cpu().numpy() * 100

  def get_map_50(self):
    return self.result["map_50"]

  def get_map_75(self):
    return self.result["map_75"]

  def get_map_small(self):
    return self.result["map_small"]

  def get_map_medium(self):
    return self.result["map_medium"]

  def get_map_large(self):
    return self.result["map_large"]

  def get_mar_1(self):
    return self.result["mar_1"]

  def get_mar_10(self):
    return self.result["mar_10"]

  def get_mar_100(self):
    return self.result["mar_100"]

  def get_mar_small(self):
    return self.result["mar_small"]

  def get_mar_medium(self):
    return self.result["mar_medium"]

  def get_mar_large(self):
    return self.result["mar_large"]

  def get_map_per_class(self):
    return self.result["map_per_class"]

  def get_mar_100_per_class(self):
    return self.result["mar_100_per_class"]

  def get_classes(self):
    return self.result["classes"]

# Print Statistics

In [None]:
HRSC_PRED_STATS = Statistician(HRSC_TEST_DATASET, HRSC_PRED_BOXES, HRSC_PRED_LABELS, HRSC_PRED_SCORES)
DOTA_PRED_STATS = Statistician(DOTA_TEST_DATASET, DOTA_PRED_BOXES, DOTA_PRED_LABELS, DOTA_PRED_SCORES)
NWPU_PRED_STATS = Statistician(NWPU_TEST_DATASET, NWPU_PRED_BOXES, NWPU_PRED_LABELS, NWPU_PRED_SCORES)

print("HRSC prediction statistics:")
print(f"{HRSC_PRED_STATS.get_map_percent()}%")
print("DOTA prediction statistics:")
print(f"{DOTA_PRED_STATS.get_map_percent()}%")
print("NWPU prediction statistics:")
print(f"{NWPU_PRED_STATS.get_map_percent()}%")