In [None]:
!pip install -q 'git+https://github.com/facebookresearch/detectron2.git'
!pip install -q torch torchvision
!pip install -q opencv-python
!pip install -q colorama
!pip install -q torchmetrics
!pip install -q shapely
!pip install -q tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.datasets import VOCDetection
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.image_list import ImageList
from torchvision.ops import boxes as box_ops
from torchvision import transforms as T
from torchvision.transforms import functional as F_transforms
from torchvision.models import ResNet50_Weights
from detectron2.layers import roi_align_rotated
import numpy as np
import cv2
import math
import itertools
import json
import re
import time
import pprint
import xml.etree.ElementTree as ET
from collections import defaultdict
from pathlib import Path
from PIL import Image
from shapely.geometry import Polygon
from concurrent.futures import ThreadPoolExecutor, as_completed
from colorama import Fore, Style
from tqdm import tqdm
import gc

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {DEVICE}")
print(f"CUDA Available: {torch.cuda.is_available()}")

from google.colab import drive
drive.mount('/content/drive')

In [None]:
SHARED_PATH = Path("drive/MyDrive/Colab Notebooks/Shared")
HRSC_PATH = SHARED_PATH / "HRSC2016_Final_Splits"
DOTA_PATH = SHARED_PATH / "DOTA_Final_Splits"
NWPU_PATH = SHARED_PATH / "NWPU_VHR-10_Final_Splits"

HRSC_TRAIN_IMAGES = HRSC_PATH / "train/images"
HRSC_TRAIN_ANNOTATIONS = HRSC_PATH / "train/annotations"
HRSC_VAL_IMAGES = HRSC_PATH / "val/images"
HRSC_VAL_ANNOTATIONS = HRSC_PATH / "val/annotations"
HRSC_TEST_IMAGES = HRSC_PATH / "test/images"
HRSC_TEST_ANNOTATIONS = HRSC_PATH / "test/annotations"

DOTA_TRAIN_IMAGES = DOTA_PATH / "train/images"
DOTA_TRAIN_ANNOTATIONS = DOTA_PATH / "train/hbb"
DOTA_VAL_IMAGES = DOTA_PATH / "val/images"
DOTA_VAL_ANNOTATIONS = DOTA_PATH / "val/hbb"
DOTA_TEST_IMAGES = DOTA_PATH / "test/images"
DOTA_TEST_ANNOTATIONS = DOTA_PATH / "test/hbb"

NWPU_TRAIN_IMAGES = NWPU_PATH / "train/images"
NWPU_TRAIN_ANNOTATIONS = NWPU_PATH / "train/annotations"
NWPU_VAL_IMAGES = NWPU_PATH / "val/images"
NWPU_VAL_ANNOTATIONS = NWPU_PATH / "val/annotations"
NWPU_TEST_IMAGES = NWPU_PATH / "test/images"
NWPU_TEST_ANNOTATIONS = NWPU_PATH / "test/annotations"

In [None]:
HRSC_CLASSES = {}
DOTA_CLASSES = {}
NWPU_CLASSES = {}

In [None]:
def parse_hrsc_labels(label_path: Path):
    boxes = []
    labels = []
    tree = ET.parse(label_path.as_posix())
    root = tree.getroot()
    objects = root.findall(".//HRSC_Object")
    for obj in objects:
        try:
            center_x = float(obj.find('mbox_cx').text)
            center_y = float(obj.find('mbox_cy').text)
            width = float(obj.find('mbox_w').text)
            height = float(obj.find('mbox_h').text)
            angle = float(obj.find('mbox_ang').text)
            class_id = int(obj.find('Class_ID').text)
            if class_id not in HRSC_CLASSES:
                HRSC_CLASSES[class_id] = len(HRSC_CLASSES)
            class_id_index = HRSC_CLASSES[class_id]
            boxes.append([center_x, center_y, width, height, angle])
            labels.append(class_id_index)
        except Exception as e:
            continue
    return boxes, labels

def hrsc_image_rescale(label_path: Path, image_width, image_height):
    tree = ET.parse(label_path.as_posix())
    root = tree.getroot()
    width = int(root.find('.//Img_SizeWidth').text)
    height = int(root.find('.//Img_SizeHeight').text)
    return image_width / width, image_height / height

def parse_dota_labels(label_path: Path):
    boxes = []
    labels = []
    for line in itertools.islice(label_path.read_text().splitlines(), 2, None):
        try:
            obj = line.strip().split()
            x1, y1, x2, y2, x3, y3, x4, y4, category, difficulty = (*map(float, obj[:8]), *obj[8:])
            points = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
            pts_np = np.array(points, dtype=np.float32).reshape(-1, 1, 2)
            (cx, cy), (w, h), angle_deg = cv2.minAreaRect(pts_np)
            angle_rad = math.radians(angle_deg)
            if category not in DOTA_CLASSES:
                DOTA_CLASSES[category] = len(DOTA_CLASSES)
            class_id_index = DOTA_CLASSES[category]
            boxes.append([cx, cy, w, h, angle_rad])
            labels.append(class_id_index)
        except Exception as e:
            continue
    return boxes, labels

def dota_image_rescale(label_path: Path, image_width, image_height):
    return 1.0, 1.0

def parse_nwpu_labels(label_path: Path):
    boxes = []
    labels = []

    data = json.loads(label_path.read_text())
    for category in data['categories']:
        NWPU_CLASSES[category['name']] = category['id'] - 1  # categories are 1-indexed -> convert to 0-indexed

    for annotation in data['annotations']:
        try:
            # Bounds
            segmentation = annotation['segmentation'][0]
            pts_np = np.array(segmentation, dtype=np.float32).reshape(-1, 2)
            (cx, cy), (w, h), angle_deg = cv2.minAreaRect(pts_np)
            angle_rad = math.radians(angle_deg)

            # Category
            class_id_index = annotation['category_id'] - 1

            boxes.append([cx, cy, w, h, angle_rad])
            labels.append(class_id_index)

        except Exception as e:
            print(Fore.RED + "Warning" + Style.RESET_ALL + f": Could not parse object in {label_path.as_posix()}: {e}")
            continue

    return boxes, labels

def nwpu_image_rescale(label_path: Path, image_width, image_height):
    data = json.loads(label_path.read_text())
    image = data['images'][0]
    width = image['width']
    height = image['height']
    return image_width / width, image_height / height

def dbbox2delta(proposals, gt, means=[0, 0, 0, 0, 0], stds=[1, 1, 1, 1, 1]):
    proposals = proposals.float()
    gt = gt.float()
    gt_widths = gt[..., 2]
    gt_heights = gt[..., 3]
    gt_angle = gt[..., 4]
    proposals_widths = proposals[..., 2]
    proposals_heights = proposals[..., 3]
    proposals_angle = proposals[..., 4]
    coord = gt[..., 0:2] - proposals[..., 0:2]
    dx = (torch.cos(proposals[..., 4]) * coord[..., 0] + torch.sin(proposals[..., 4]) * coord[..., 1]) / proposals_widths
    dy = (-torch.sin(proposals[..., 4]) * coord[..., 0] + torch.cos(proposals[..., 4]) * coord[..., 1]) / proposals_heights
    dw = torch.log(gt_widths / proposals_widths)
    dh = torch.log(gt_heights / proposals_heights)
    dangle = (gt_angle - proposals_angle) % (2 * math.pi) / (2 * math.pi)
    deltas = torch.stack((dx, dy, dw, dh, dangle), -1)
    means = deltas.new_tensor(means).unsqueeze(0)
    stds = deltas.new_tensor(stds).unsqueeze(0)
    deltas = deltas.sub_(means).div_(stds)
    return deltas

def delta2dbbox(Rrois, deltas, means=[0, 0, 0, 0, 0], stds=[1, 1, 1, 1, 1], max_shape=None, wh_ratio_clip=16/1000):
    means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 5)
    stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 5)
    denorm_deltas = deltas * stds + means
    dx = denorm_deltas[:, 0::5]
    dy = denorm_deltas[:, 1::5]
    dw = denorm_deltas[:, 2::5]
    dh = denorm_deltas[:, 3::5]
    dangle = denorm_deltas[:, 4::5]
    max_ratio = np.abs(np.log(wh_ratio_clip))
    dw = dw.clamp(min=-max_ratio, max=max_ratio)
    dh = dh.clamp(min=-max_ratio, max=max_ratio)
    Rroi_x = (Rrois[:, 0]).unsqueeze(1).expand_as(dx)
    Rroi_y = (Rrois[:, 1]).unsqueeze(1).expand_as(dy)
    Rroi_w = (Rrois[:, 2]).unsqueeze(1).expand_as(dw)
    Rroi_h = (Rrois[:, 3]).unsqueeze(1).expand_as(dh)
    Rroi_angle = (Rrois[:, 4]).unsqueeze(1).expand_as(dangle)
    gx = dx * Rroi_w * torch.cos(Rroi_angle) - dy * Rroi_h * torch.sin(Rroi_angle) + Rroi_x
    gy = dx * Rroi_w * torch.sin(Rroi_angle) + dy * Rroi_h * torch.cos(Rroi_angle) + Rroi_y
    gw = Rroi_w * dw.exp()
    gh = Rroi_h * dh.exp()
    gangle = (2 * np.pi) * dangle + Rroi_angle
    gangle = gangle % (2 * np.pi)
    bboxes = torch.stack([gx, gy, gw, gh, gangle], dim=-1).view_as(deltas)
    return bboxes

def hbb2obb_v2(boxes):
    num_boxes = boxes.size(0)
    ex_heights = boxes[..., 2] - boxes[..., 0] + 1.0
    ex_widths = boxes[..., 3] - boxes[..., 1] + 1.0
    ex_ctr_x = boxes[..., 0] + 0.5 * (ex_heights - 1.0)
    ex_ctr_y = boxes[..., 1] + 0.5 * (ex_widths - 1.0)
    c_bboxes = torch.cat((ex_ctr_x.unsqueeze(1), ex_ctr_y.unsqueeze(1), ex_widths.unsqueeze(1), ex_heights.unsqueeze(1)), 1)
    initial_angles = -c_bboxes.new_ones((num_boxes, 1)) * np.pi / 2
    dbboxes = torch.cat((c_bboxes, initial_angles), 1)
    return dbboxes

def rotated_box_to_polygon(box):
    if isinstance(box, torch.Tensor):
        box = box.cpu().numpy()
    if len(box) == 4:
        x1, y1, x2, y2 = box
        cx = (x1 + x2) / 2
        cy = (y1 + y2) / 2
        w = x2 - x1
        h = y2 - y1
        angle = 0.0
    elif len(box) == 5:
        cx, cy, w, h, angle = box
    else:
        raise ValueError(f"Box must have 4 or 5 elements, got {len(box)}")
    hw = w / 2
    hh = h / 2
    corners = np.array([[-hw, -hh], [hw, -hh], [hw, hh], [-hw, hh]], dtype=np.float32)
    cosA = np.cos(angle)
    sinA = np.sin(angle)
    R = np.array([[cosA, -sinA], [sinA, cosA]])
    rotated = corners @ R.T
    rotated[:, 0] += cx
    rotated[:, 1] += cy
    return Polygon(rotated.tolist())

def box_iou_rotated(boxes1, boxes2):
    N = boxes1.shape[0]
    M = boxes2.shape[0]
    ious = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device)
    for i in range(N):
        try:
            p1 = rotated_box_to_polygon(boxes1[i])
            area1 = p1.area
            if area1 <= 0:
                continue
            for j in range(M):
                try:
                    p2 = rotated_box_to_polygon(boxes2[j])
                    inter = p1.intersection(p2).area
                    union = area1 + p2.area - inter
                    if union > 0:
                        ious[i, j] = inter / union
                except:
                    continue
        except:
            continue
    return ious

BASIC_TRANSFORM = T.Compose([T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

class TorchDataset(torch.utils.data.Dataset):
    def __init__(self, images_folder: Path, annotations_folder: Path, label_parser, image_rescale, transforms=BASIC_TRANSFORM, max_files=10, max_objects=100, target_size=800):
        super().__init__()
        self.label_parser = label_parser
        self.image_rescale = image_rescale
        self.transforms = transforms
        self.max_objects = max_objects
        self.target_size = target_size

        def sorting_key(p: Path):
            m = re.search(r'(\d+)', p.stem)
            if m:
                return int(m.group())
            else:
                return float('inf')

        if images_folder.exists():
            images = sorted([f for f in images_folder.iterdir() if f.suffix.lower() in (".jpg", ".png", ".bmp")], key=sorting_key)
        else:
            images = []

        if annotations_folder.exists():
            annotations = sorted([f for f in annotations_folder.iterdir() if f.suffix.lower() in (".xml", ".txt", ".json")], key=sorting_key)
        else:
            annotations = []

        if max_files > 0:
            images = images[:max_files]
            annotations = annotations[:max_files]

        image_set = set(f.stem for f in images)
        annotation_set = set(f.stem for f in annotations)
        self.ids = image_set.intersection(annotation_set)
        self.images = {f.stem: f for f in images if f.stem in self.ids}
        self.annotations = {f.stem: f for f in annotations if f.stem in self.ids}
        self.ids = list(self.ids)

    def __getitem__(self, index):
        id = self.ids[index]
        image_path = self.images[id]
        label_path = self.annotations[id]
        image = Image.open(image_path).convert("RGB")
        boxes, labels = self.label_parser(label_path)
        if self.max_objects > 0:
            boxes = boxes[:self.max_objects]
            labels = labels[:self.max_objects]
        if len(boxes) == 0:
            boxes = torch.zeros((0, 5), dtype=torch.float32)
        else:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            if boxes.dim() == 1:
                boxes = boxes.unsqueeze(0)
        if len(labels) == 0:
            labels = torch.zeros((0,), dtype=torch.int64)
        else:
            labels = torch.as_tensor(labels, dtype=torch.int64)
        labels = labels + 1
        image, boxes = self.preprocess(image, boxes)
        image = F_transforms.to_tensor(image)
        target = {"boxes": boxes, "labels": labels}
        image = self.transforms(image)
        image.filepath = image_path
        return image, target

    def preprocess(self, img, boxes):
        old_w, old_h = img.width, img.height
        scale = self.target_size / max(old_h, old_w)
        new_w = int(old_w * scale)
        new_h = int(old_h * scale)
        img = F_transforms.resize(img, (new_h, new_w))
        boxes = boxes.clone()
        if boxes.numel() > 0 and boxes.dim() == 2:
            boxes[:, 0] *= scale
            boxes[:, 1] *= scale
            boxes[:, 2] *= scale
            boxes[:, 3] *= scale
        pad_w = self.target_size - new_w
        pad_h = self.target_size - new_h
        img = F_transforms.pad(img, (0, 0, pad_w, pad_h))
        return img, boxes

    def __len__(self):
        return len(self.ids)

    def get_image_rescale(self, index):
        id = self.ids[index]
        image_path = self.images[id]
        label_path = self.annotations[id]
        image = Image.open(image_path).convert("RGB")
        return self.image_rescale(label_path, image.width, image.height)

    def compute_total_number_of_objects(self, max_workers=8):
        def count_objects(id):
            boxes, _ = self.label_parser(self.annotations[id])
            return len(boxes)
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            results = executor.map(count_objects, self.ids)
        return sum(results)

print("Preparing HRSC training dataset...")
HRSC_TRAIN_DATASET = TorchDataset(HRSC_TRAIN_IMAGES, HRSC_TRAIN_ANNOTATIONS, parse_hrsc_labels, hrsc_image_rescale)
print(f"...Dataset prepared: {HRSC_TRAIN_DATASET.compute_total_number_of_objects()} total objects")

print("Preparing HRSC validation dataset...")
HRSC_VAL_DATASET = TorchDataset(HRSC_VAL_IMAGES, HRSC_VAL_ANNOTATIONS, parse_hrsc_labels, hrsc_image_rescale)
print(f"...Dataset prepared: {HRSC_VAL_DATASET.compute_total_number_of_objects()} total objects")

print("Preparing HRSC testing dataset...")
HRSC_TEST_DATASET = TorchDataset(HRSC_TEST_IMAGES, HRSC_TEST_ANNOTATIONS, parse_hrsc_labels, hrsc_image_rescale)
print(f"...Dataset prepared: {HRSC_TEST_DATASET.compute_total_number_of_objects()} total objects")

print("Preparing DOTA training dataset...")
DOTA_TRAIN_DATASET = TorchDataset(DOTA_TRAIN_IMAGES, DOTA_TRAIN_ANNOTATIONS, parse_dota_labels, dota_image_rescale)
print(f"...Dataset prepared: {DOTA_TRAIN_DATASET.compute_total_number_of_objects()} total objects")

print("Preparing DOTA validation dataset...")
DOTA_VAL_DATASET = TorchDataset(DOTA_VAL_IMAGES, DOTA_VAL_ANNOTATIONS, parse_dota_labels, dota_image_rescale)
print(f"...Dataset prepared: {DOTA_VAL_DATASET.compute_total_number_of_objects()} total objects")

print("Preparing DOTA testing dataset...")
DOTA_TEST_DATASET = TorchDataset(DOTA_TEST_IMAGES, DOTA_TEST_ANNOTATIONS, parse_dota_labels, dota_image_rescale)
print(f"...Dataset prepared: {DOTA_TEST_DATASET.compute_total_number_of_objects()} total objects")

print("Preparing NWPU training dataset...")
NWPU_TRAIN_DATASET = TorchDataset(NWPU_TRAIN_IMAGES, NWPU_TRAIN_ANNOTATIONS, parse_nwpu_labels, nwpu_image_rescale)
print(f"...Dataset prepared: {NWPU_TRAIN_DATASET.compute_total_number_of_objects()} total objects")

print("Preparing NWPU validation dataset...")
NWPU_VAL_DATASET = TorchDataset(NWPU_VAL_IMAGES, NWPU_VAL_ANNOTATIONS, parse_nwpu_labels, nwpu_image_rescale)
print(f"...Dataset prepared: {NWPU_VAL_DATASET.compute_total_number_of_objects()} total objects")

print("Preparing NWPU testing dataset...")
NWPU_TEST_DATASET = TorchDataset(NWPU_TEST_IMAGES, NWPU_TEST_ANNOTATIONS, parse_nwpu_labels, nwpu_image_rescale)
print(f"...Dataset prepared: {NWPU_TEST_DATASET.compute_total_number_of_objects()} total objects")

class BenchmarkTracker:
    def __init__(self):
        self.data = defaultdict(list)

    def start(self, key):
        self.data[key].append({"start": time.time(), "end": None})

    def stop(self, key):
        self.data[key][-1]["end"] = time.time()

    def summary(self):
        report = {}
        for key, records in self.data.items():
            durations = [(r["end"] - r["start"]) for r in records if r["end"]]
            report[key] = {
                "total": sum(durations),
                "avg": np.mean(durations),
                "std": np.std(durations),
                "count": len(durations)
            }
        return report

class RotatedBoxPredictor(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.cls_score = nn.Linear(in_channels, num_classes)
        self.bbox_pred = nn.Linear(in_channels, num_classes * 5)

    def forward(self, x):
        if x.dim() == 4:
            x = x.flatten(start_dim=1)
        scores = self.cls_score(x)
        bbox_deltas = self.bbox_pred(x)
        return scores, bbox_deltas

class RotatedRoIHeads(RoIHeads):
    def __init__(self, box_roi_pool, box_head, box_predictor, fg_iou_thresh=0.5, bg_iou_thresh=0.5, batch_size_per_image=512, positive_fraction=0.25, bbox_reg_weights=None, score_thresh=0.05, nms_thresh=0.5, detections_per_img=100):
        super().__init__(box_roi_pool, box_head, box_predictor, fg_iou_thresh, bg_iou_thresh, batch_size_per_image, positive_fraction, bbox_reg_weights, score_thresh, nms_thresh, detections_per_img)
        self.box_predictor = box_predictor
        self.fg_iou_thresh = fg_iou_thresh
        self.bg_iou_thresh = bg_iou_thresh
        self.score_thresh = score_thresh
        self.nms_thresh = nms_thresh
        self.detections_per_img = detections_per_img

    def forward(self, features, proposals, image_shapes, targets=None):
        if self.training and targets is not None:
            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
        else:
            labels = None
            regression_targets = None
            matched_idxs = None
        proposals = [p.to(DEVICE) for p in proposals]
        box_features = self.box_roi_pool(features, proposals, image_shapes)
        box_features = self.box_head(box_features)
        rotated_proposals_list = []
        for i, props in enumerate(proposals):
            props = props.to(DEVICE)
            x1, y1, x2, y2 = props.unbind(1)
            cx = (x1 + x2) / 2
            cy = (y1 + y2) / 2
            w = x2 - x1
            h = y2 - y1
            angle = torch.zeros_like(cx)
            batch_idx = torch.full_like(cx, i, dtype=torch.float32, device=DEVICE)
            rotated = torch.stack([batch_idx, cx, cy, w, h, angle], dim=1)
            rotated_proposals_list.append(rotated)
        rotated_boxes = torch.cat(rotated_proposals_list, dim=0)
        if isinstance(features, dict):
            feature_map = list(features.values())[0]
            spatial_scale = 1.0 / 4.0
        else:
            feature_map = features
            spatial_scale = 1.0 / 4.0
        rotated_features = roi_align_rotated(feature_map, rotated_boxes, (7, 7), spatial_scale, 2)
        rotated_features = self.box_head(rotated_features)
        class_logits, box_regression = self.box_predictor(rotated_features)
        if self.training:
            return {'class_logits': class_logits, 'box_regression': box_regression, 'proposals': proposals, 'matched_idxs': matched_idxs, 'labels': labels, 'regression_targets': regression_targets}
        else:
            return self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)

    def postprocess_detections(self, class_logits, box_regression, proposals, image_shapes):
        device = DEVICE
        num_classes = class_logits.shape[-1]
        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
        pred_boxes = self.decode_boxes(box_regression, proposals)
        pred_scores = F.softmax(class_logits, -1)
        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
        pred_scores_list = pred_scores.split(boxes_per_image, 0)
        all_boxes = []
        all_scores = []
        all_labels = []
        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
            boxes = boxes.to(device)
            scores = scores.to(device)
            boxes = self.clip_boxes_to_image(boxes, image_shape)
            labels = torch.arange(num_classes, device=device)
            labels = labels.view(1, -1).expand_as(scores)
            boxes = boxes.reshape(-1, 5)
            scores = scores.reshape(-1)
            labels = labels.reshape(-1)
            inds = labels > 0
            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
            inds = scores > self.score_thresh
            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
            keep = self.rotated_nms(boxes, scores, self.nms_thresh)
            keep = keep[:self.detections_per_img]
            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
            all_boxes.append(boxes)
            all_scores.append(scores)
            all_labels.append(labels)
        return all_boxes, all_scores, all_labels

    def decode_boxes(self, box_regression, proposals):
        rotated_proposals = []
        for props in proposals:
            props = props.to(DEVICE)
            rotated_props = hbb2obb_v2(props)
            rotated_proposals.append(rotated_props)
        rotated_proposals = torch.cat(rotated_proposals, dim=0)
        num_classes = box_regression.shape[1] // 5
        box_regression = box_regression.view(-1, num_classes, 5)
        decoded_boxes = []
        for cls_idx in range(num_classes):
            deltas = box_regression[:, cls_idx, :]
            decoded = delta2dbbox(rotated_proposals, deltas)
            decoded_boxes.append(decoded)
        decoded_boxes = torch.stack(decoded_boxes, dim=1)
        return decoded_boxes

    def clip_boxes_to_image(self, boxes, image_shape):
        h, w = image_shape
        boxes = boxes.clone().to(DEVICE)
        w_tensor = torch.tensor(w, dtype=boxes.dtype, device=DEVICE)
        h_tensor = torch.tensor(h, dtype=boxes.dtype, device=DEVICE)
        boxes[:, 0] = boxes[:, 0].clamp(0, w_tensor)
        boxes[:, 1] = boxes[:, 1].clamp(0, h_tensor)
        boxes[:, 2] = boxes[:, 2].clamp(min=1)
        boxes[:, 3] = boxes[:, 3].clamp(min=1)
        return boxes

    def rotated_nms(self, boxes, scores, iou_threshold):
        if len(boxes) == 0:
            return torch.empty((0,), dtype=torch.long, device=boxes.device)
        keep = []
        order = scores.argsort(descending=True)
        while len(order) > 0:
            i = order[0]
            keep.append(i)
            if len(order) == 1:
                break
            ious = box_iou_rotated(boxes[i:i+1], boxes[order[1:]])
            ious = ious.squeeze(0)
            inds = (ious <= iou_threshold).nonzero(as_tuple=True)[0]
            order = order[inds + 1]
        return torch.tensor(keep, dtype=torch.long, device=boxes.device)

    def select_training_samples(self, proposals, targets):
        labels = []
        matched_idxs = []
        regression_targets = []
        for props, target in zip(proposals, targets):
            props = props.to(DEVICE)
            rotated_props = hbb2obb_v2(props)
            target_boxes = target['boxes'].to(DEVICE)
            target_labels = target['labels'].to(DEVICE)
            ious = box_iou_rotated(rotated_props, target_boxes)
            max_ious, matched_idx = ious.max(dim=1)
            label = torch.zeros(len(props), dtype=torch.long, device=DEVICE)
            pos_mask = max_ious > self.fg_iou_thresh
            if pos_mask.any():
                label[pos_mask] = target_labels[matched_idx[pos_mask]]
            labels.append(label)
            matched_idxs.append(matched_idx)
            if pos_mask.any():
                matched_gt = target_boxes[matched_idx[pos_mask]]
                deltas = dbbox2delta(rotated_props[pos_mask], matched_gt)
                reg_target = torch.zeros(len(props), 5, device=DEVICE)
                reg_target[pos_mask] = deltas
            else:
                reg_target = torch.zeros(len(props), 5, device=DEVICE)
            regression_targets.append(reg_target)
        return proposals, matched_idxs, labels, regression_targets

class RotatedFasterRCNNModel(FasterRCNN):
    def __init__(self, model_path: Path, train_dataset: TorchDataset, val_dataset: TorchDataset, test_dataset: TorchDataset, class_names: list, batch_size=4, shuffle_datasets=False):
        self.model_path = model_path
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.class_names = ["__background__"] + list(class_names)
        self.num_classes = len(self.class_names)
        collate_fn = lambda x: tuple(zip(*x))
        self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=batch_size, shuffle=shuffle_datasets, collate_fn=collate_fn)
        self.val_loader = torch.utils.data.DataLoader(self.val_dataset, batch_size=batch_size, shuffle=shuffle_datasets, collate_fn=collate_fn)
        if len(self.test_dataset) > 0:
            self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=batch_size, shuffle=shuffle_datasets, collate_fn=collate_fn)
        else:
            self.test_loader = None
        super().__init__(backbone=resnet_fpn_backbone(backbone_name='resnet50', weights=ResNet50_Weights.DEFAULT), num_classes=self.num_classes, rpn_anchor_generator=None)
        self.in_features = self.roi_heads.box_predictor.cls_score.in_features
        self.box_detector = RotatedBoxPredictor(self.in_features, self.num_classes)
        self.roi_heads = RotatedRoIHeads(self.roi_heads.box_roi_pool, self.roi_heads.box_head, self.box_detector)
        self.to(DEVICE)
        self.bench = BenchmarkTracker()
        self.batch_size = batch_size

    def train_model(self, num_epochs=50, learning_rate=0.0005):
        optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0005)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
        for epoch in range(num_epochs):
            print(f"Starting epoch {epoch+1}/{num_epochs}...")
            print("\tStarting training loop...")
            self.bench.start("train_epoch")
            train_loss = 0.0
            self.train()
            for images, targets in tqdm(self.train_loader, desc=f"Epoch {epoch+1}"):
                self.bench.start("train_batch")
                loss_dict = self(images, targets)
                losses = sum(loss for loss in loss_dict.values())
                optimizer.zero_grad()
                losses.backward()
                optimizer.step()
                train_loss += losses.item()
                self.bench.stop("train_batch")
            lr_scheduler.step()
            self.bench.stop("train_epoch")
            print("\t...Training loop complete.")
            print("\tStarting validation loop...")
            self.bench.start("val_epoch")
            val_loss = 0.0
            with torch.no_grad():
                self.train()
                for images, targets in self.val_loader:
                    self.bench.start("val_batch")
                    loss_dict = self(images, targets)
                    losses = sum(loss for loss in loss_dict.values())
                    val_loss += losses.item()
                    self.bench.stop("val_batch")
            self.bench.stop("val_epoch")
            print("\t...Validation loop complete.")
            print(f"...Finished epoch {epoch+1}/{num_epochs}, Training loss: {train_loss:.4f}, Validation loss: {val_loss:.4f}")
            if (epoch + 1) % 5 == 0:
                self.save_weights()
        self.save_weights()

    def forward(self, images, targets=None):
        self.training = targets is not None and self.training
        images = [img.to(DEVICE) for img in images]
        image_sizes = [img.shape[-2:] for img in images]
        images = ImageList(torch.stack(images), image_sizes)
        features = self.backbone(images.tensors)
        rpn_targets = []
        if targets is not None:
            for t in targets:
                boxes = t['boxes'].to(DEVICE)
                labels = t['labels'].to(DEVICE)
                if boxes.numel() > 0 and boxes.dim() == 2 and boxes.shape[1] == 5:
                    cx, cy, w, h, theta = boxes.unbind(1)
                    x1 = cx - w/2
                    y1 = cy - h/2
                    x2 = cx + w/2
                    y2 = cy + h/2
                    rpn_targets.append({'boxes': torch.stack([x1, y1, x2, y2], dim=1), 'labels': labels})
                else:
                    rpn_targets.append({'boxes': torch.zeros((0, 4), dtype=torch.float32, device=DEVICE), 'labels': labels})
        proposals, rpn_losses = self.rpn(images, features, rpn_targets if targets else None)
        if targets is not None:
            targets = [{'boxes': t['boxes'].to(DEVICE), 'labels': t['labels'].to(DEVICE)} for t in targets]
        roi_outputs = self.roi_heads(features, proposals, images.image_sizes, targets)
        loss_dict = {}
        if self.training and targets is not None:
            loss_dict.update(rpn_losses)
            class_logits = roi_outputs['class_logits']
            labels = torch.cat(roi_outputs['labels'], dim=0)
            loss_dict['loss_classifier'] = F.cross_entropy(class_logits, labels)
            box_regression = roi_outputs['box_regression']
            regression_targets = torch.cat(roi_outputs['regression_targets'], dim=0)
            pos_mask = labels > 0
            if pos_mask.any():
                num_classes = box_regression.shape[1] // 5
                box_regression = box_regression.view(-1, num_classes, 5)
                reg_for_labels = box_regression[pos_mask, labels[pos_mask] - 1, :]
                loss_dict['loss_box_reg'] = F.smooth_l1_loss(reg_for_labels, regression_targets[pos_mask], reduction='mean')
            else:
                loss_dict['loss_box_reg'] = torch.tensor(0.0, device=class_logits.device)
            return loss_dict
        else:
            all_boxes, all_scores, all_labels = roi_outputs
            return [{'boxes': boxes, 'scores': scores, 'labels': labels} for boxes, scores, labels in zip(all_boxes, all_scores, all_labels)]

    def save_weights(self):
        torch.save({"model_state_dict": self.state_dict(), "class_names": self.class_names}, self.model_path)

    def load_weights(self):
        checkpoint = torch.load(self.model_path, map_location=DEVICE)
        self.load_state_dict(checkpoint["model_state_dict"])
        self.class_names = checkpoint["class_names"]

    def test_results(self):
        if self.test_loader is None:
            print("Test dataset is empty. Skipping evaluation.")
            return [], [], []
        self.load_weights()
        self.bench.start("test_total")
        all_boxes, all_labels, all_scores = [], [], []
        with torch.no_grad():
            self.eval()
            for images, _ in self.test_loader:
                self.bench.start("test_batch")
                images = [img.to(DEVICE) for img in images]
                predictions = self(images)
                for prediction in predictions:
                    boxes = prediction['boxes'].detach().cpu().numpy()
                    labels = prediction['labels'].detach().cpu().numpy()
                    scores = prediction['scores'].detach().cpu().numpy()
                    if len(boxes) == 0:
                        boxes = np.zeros((0, 5), dtype=np.float32)
                    if len(labels) == 0:
                        labels = np.zeros((0,), dtype=np.int64)
                    if len(scores) == 0:
                        scores = np.zeros((0,), dtype=np.float32)
                    all_boxes.append(boxes)
                    all_labels.append(labels)
                    all_scores.append(scores)
                self.bench.stop("test_batch")
        self.bench.stop("test_total")
        return all_boxes, all_labels, all_scores

MODEL_SAVE_PATH = Path("drive/MyDrive/Colab Notebooks/Models")
MODEL_SAVE_PATH.mkdir(parents=True, exist_ok=True)

HRSC_MODEL = RotatedFasterRCNNModel(MODEL_SAVE_PATH / "hrsc_faster_rcnn_model.pth", HRSC_TRAIN_DATASET, HRSC_VAL_DATASET, HRSC_TEST_DATASET, [id for id, _ in sorted(HRSC_CLASSES.items(), key=lambda item: item[1])], batch_size=2)
DOTA_MODEL = RotatedFasterRCNNModel(MODEL_SAVE_PATH / "dota_faster_rcnn_model.pth", DOTA_TRAIN_DATASET, DOTA_VAL_DATASET, DOTA_TEST_DATASET, [id for id, _ in sorted(DOTA_CLASSES.items(), key=lambda item: item[1])], batch_size=2)
NWPU_MODEL = RotatedFasterRCNNModel(MODEL_SAVE_PATH / "nwpu_faster_rcnn_model.pth", NWPU_TRAIN_DATASET, NWPU_VAL_DATASET, NWPU_TEST_DATASET, [id for id, _ in sorted(NWPU_CLASSES.items(), key=lambda item: item[1])], batch_size=2)

if HRSC_MODEL.model_path.exists():
    print("Loading HRSC weights...")
    HRSC_MODEL.load_weights()
    print("...HRSC weights loaded.")
else:
    print("Training HRSC model...")
    torch.cuda.empty_cache()
    gc.collect()
    HRSC_MODEL.train_model(3)
    print("...HRSC model trained.")

if DOTA_MODEL.model_path.exists():
    print("Loading DOTA weights...")
    DOTA_MODEL.load_weights()
    print("...DOTA weights loaded.")
else:
    print("Training DOTA model...")
    torch.cuda.empty_cache()
    gc.collect()
    DOTA_MODEL.train_model(3)
    print("...DOTA model trained.")

if NWPU_MODEL.model_path.exists():
    print("Loading NWPU weights...")
    NWPU_MODEL.load_weights()
    print("...NWPU weights loaded.")
else:
    print("Training NWPU model...")
    torch.cuda.empty_cache()
    gc.collect()
    NWPU_MODEL.train_model(3)
    print("...NWPU model trained.")

In [None]:
HRSC_MODEL.roi_heads.score_thresh = 1.0e-6
HRSC_PRED_BOXES, HRSC_PRED_LABELS, HRSC_PRED_SCORES = HRSC_MODEL.test_results()

DOTA_MODEL.roi_heads.score_thresh = 0.012
DOTA_PRED_BOXES, DOTA_PRED_LABELS, DOTA_PRED_SCORES = DOTA_MODEL.test_results()

NWPU_MODEL.roi_heads.score_thresh = 0.001
NWPU_PRED_BOXES, NWPU_PRED_LABELS, NWPU_PRED_SCORES = NWPU_MODEL.test_results()

In [None]:
class Visualizer:
    def __init__(self, test_dataset: TorchDataset, class_names: list,
                 boxes: list, labels: list, scores: list, results_folder: Path,
                 normalize_mean=(0.485, 0.456, 0.406),
                 normalize_std=(0.229, 0.224, 0.225)):

        self.test_dataset = test_dataset
        self.class_names = class_names
        self.boxes = boxes
        self.labels = labels
        self.scores = scores
        self.results_folder = results_folder
        self.results_folder.mkdir(parents=True, exist_ok=True)

        # normalization undo params
        self.mean = np.array(normalize_mean)
        self.std = np.array(normalize_std)

    # HELPERS

    def unnormalize(self, img):
        """
        img: float32 numpy array in [C,H,W] or [H,W,C] with ImageNet normalization.
        Returns image in [0,1].
        """
        return img * self.std + self.mean

    def tensor_to_uint8(self, tensor):
        """Converts CHW torch tensor to uint8 HWC image (PIL-like RGB)."""
        img = tensor.permute(1, 2, 0).cpu().numpy()  # HWC float32
        # Undo normalization (ImageNet)
        img = self.unnormalize(img)  # still HWC float32 in [0,1]
        # Clip and convert
        img = np.clip(img * 255.0, 0, 255).astype(np.uint8)
        # Ensure contiguous memory layout for OpenCV
        img = np.ascontiguousarray(img)
        return img

    # BOX DRAWING

    def overlay_rotated_box(self, output, box, wmult, hmult, color, label, score):
        center_x, center_y, width, height, theta = box
        angle_deg = math.degrees(theta)

        category = self.class_names[label]
        text = f"{category}" if score is None else f"{category} - score={score:.3g}"

        # image rescaling
        center_x *= wmult
        center_y *= hmult
        width *= wmult
        height *= hmult

        # draw box + label
        box_points = cv2.boxPoints(((center_x, center_y), (width, height), angle_deg)).astype(np.int32)
        cv2.drawContours(output, [box_points], 0, color, 2)
        text_pos = tuple(box_points[1])
        cv2.putText(output, text, text_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)

    # VISUALIZATION

    def visualize(self, index):
        image, target = self.test_dataset[index]

        if isinstance(image, np.ndarray):
            output = image.copy()
        else:
            output = self.tensor_to_uint8(image)

        wmult, hmult = self.test_dataset.get_image_rescale(index)

        # ground truth
        true_boxes = target["boxes"]
        true_labels = target["labels"]

        # predictions
        predicted_boxes = self.boxes[index] if index < len(self.boxes) else np.zeros((0, 5))
        predicted_labels = self.labels[index] if index < len(self.labels) else np.zeros((0,), dtype=np.int64)
        predicted_scores = self.scores[index] if index < len(self.scores) else np.zeros((0,), dtype=np.float32)

        # convert tensors to numpy
        if isinstance(true_boxes, torch.Tensor):
            true_boxes = true_boxes.cpu().numpy()
        if isinstance(true_labels, torch.Tensor):
            true_labels = true_labels.cpu().numpy()

        # draw ground truths
        if len(true_boxes) > 0 and true_boxes.shape[1] == 5:
            for box, label in zip(true_boxes, true_labels):
                self.overlay_rotated_box(output, box, wmult, hmult,
                                         (0, 255, 0), int(label), None)

        # draw predictions
        if len(predicted_boxes) > 0 and predicted_boxes.shape[1] == 5:
            for box, label, score in zip(predicted_boxes, predicted_labels, predicted_scores):
                self.overlay_rotated_box(output, box, 1.0, 1.0,
                                         (255, 0, 0), int(label), float(score))

        # save files
        output_path = self.results_folder / f"{self.test_dataset.ids[index]}.png"
        cv2.imwrite(str(output_path), cv2.cvtColor(output, cv2.COLOR_RGB2BGR))

        print(f"Saved visualization to {output_path}")

    # BATCH VISUALIZATION

    def visualize_multiple(self, count=100, start_index=0, max_workers=4):
        end_index = min(start_index + count, len(self.test_dataset)) if count > 0 else len(self.test_dataset)
        indices = list(range(start_index, end_index))

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = {executor.submit(self.visualize, i): i for i in indices}
            for future in as_completed(futures):
                idx = futures[future]
                try:
                    future.result()
                except Exception as e:
                    print(f"Visualization failed for index {idx}: {e}")


In [None]:
RESULTS_PARENT_FOLDER = Path("drive/MyDrive/Colab Notebooks/Results")
HRSC_VISUALIZER = Visualizer(HRSC_TEST_DATASET, HRSC_MODEL.class_names, HRSC_PRED_BOXES, HRSC_PRED_LABELS, HRSC_PRED_SCORES, RESULTS_PARENT_FOLDER / "HRSC")
DOTA_VISUALIZER = Visualizer(DOTA_TEST_DATASET, DOTA_MODEL.class_names, DOTA_PRED_BOXES, DOTA_PRED_LABELS, DOTA_PRED_SCORES, RESULTS_PARENT_FOLDER / "DOTA")
NWPU_VISUALIZER = Visualizer(NWPU_TEST_DATASET, NWPU_MODEL.class_names, NWPU_PRED_BOXES, NWPU_PRED_LABELS, NWPU_PRED_SCORES, RESULTS_PARENT_FOLDER / "NWPU")

HRSC_VISUALIZER.visualize_multiple()
DOTA_VISUALIZER.visualize_multiple()
NWPU_VISUALIZER.visualize_multiple()

# Benchmark Summaries

In [None]:
print("\n=== HRSC Benchmark Summary ===\n")
print(f"Batch Size: {HRSC_MODEL.batch_size}")
pprint.pprint(HRSC_MODEL.bench.summary())

print("\n=== DOTA Benchmark Summary ===\n")
print(f"Batch Size: {DOTA_MODEL.batch_size}")
pprint.pprint(DOTA_MODEL.bench.summary())

print("\n=== NWPU Benchmark Summary ===\n")
print(f"Batch Size: {NWPU_MODEL.batch_size}")
pprint.pprint(NWPU_MODEL.bench.summary())