# Drone Detection - Mask R-CNN (ViT-based)

In [8]:
import os
import re
import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
import torchvision
from torchvision import transforms
import torchvision.models as models
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from sklearn.metrics import f1_score, precision_score, recall_score

# Local nbutils.py
import nbutils
from nbutils import DroneDetectionDataset

sns.set_theme()

In [9]:
device = "cpu"

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

## Load Dataset

In [10]:
base_path = "./data/drone-detection/drone-detection-new.v5-new-train.yolov8/train"
images_path = os.path.join(base_path, "images")
labels_path = os.path.join(base_path, "labels")

df = nbutils.create_dataset(images_path, labels_path)

In [11]:
nbutils.view_df_summary(df)

Dataset Summary:
Total number of objects: 8997
Total number of unique images: 8818

Class distribution:
class
DRONE         4349
HELICOPTER    2374
AIRPLANE      2274
Name: count, dtype: int64

Image dimensions summary:
       image_width  image_height
count       8997.0        8997.0
mean         640.0         640.0
std            0.0           0.0
min          640.0         640.0
25%          640.0         640.0
50%          640.0         640.0
75%          640.0         640.0
max          640.0         640.0


## Model Training

In [12]:
# Adapter to convert original dataset output to Mask R-CNN format
class DroneDetectionAdapter:
    def __init__(self, device):
        self.device = device
        self.classes = ['AIRPLANE', 'DRONE', 'HELICOPTER']

    def __call__(self, batch):
        images, boxes_batch = zip(*batch)
        images_list = []
        targets_list = []

        for i, (image, boxes) in enumerate(zip(images, boxes_batch)):
            # Process image
            if not isinstance(image, torch.Tensor):
                image = transforms.ToTensor()(image)
            images_list.append(image.to(self.device))

            # Get image dimensions
            _, height, width = image.shape

            # Process target
            target = {}
            valid_boxes = []
            valid_labels = []
            valid_masks = []

            # Process each box
            for box in boxes:
                class_id = int(box[0].item())
                
                # Skip empty boxes (class_id = -1)
                if class_id == -1:
                    continue

                # Extract normalized coordinates
                x_center, y_center = float(box[1].item()), float(box[2].item())
                box_width, box_height = float(box[3].item()), float(box[4].item())

                # Convert to pixel coordinates and corner format
                x1 = (x_center - box_width/2) * width
                y1 = (y_center - box_height/2) * height
                x2 = (x_center + box_width/2) * width
                y2 = (y_center + box_height/2) * height

                # Clip to image boundaries
                x1, y1 = max(0, x1), max(0, y1)
                x2, y2 = min(width, x2), min(height, y2)

                # Only add if box is valid
                if x2 > x1 and y2 > y1:
                    valid_boxes.append([x1, y1, x2, y2])
                    valid_labels.append(class_id + 1)  # Add 1 because 0 is background in Mask R-CNN

                    # Create a simple binary mask
                    mask = torch.zeros((height, width), dtype=torch.uint8)
                    x1_int, y1_int = int(x1), int(y1)
                    x2_int, y2_int = int(x2), int(y2)
                    mask[y1_int:y2_int, x1_int:x2_int] = 1
                    valid_masks.append(mask)

            # Handle empty case
            if not valid_boxes:
                target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
                target["labels"] = torch.zeros(0, dtype=torch.int64)
                target["masks"] = torch.zeros((0, height, width), dtype=torch.uint8)
                target["image_id"] = torch.tensor([i])
                target["area"] = torch.zeros(0, dtype=torch.float32)
                target["iscrowd"] = torch.zeros((0,), dtype=torch.int64)
            else:
                # Convert lists to tensors
                boxes_tensor = torch.as_tensor(valid_boxes, dtype=torch.float32)
                labels_tensor = torch.as_tensor(valid_labels, dtype=torch.int64)
                masks_tensor = torch.stack(valid_masks)
                
                # Calculate areas
                areas = (boxes_tensor[:, 2] - boxes_tensor[:, 0]) * (boxes_tensor[:, 3] - boxes_tensor[:, 1])
                
                # Create final target dictionary
                target["boxes"] = boxes_tensor
                target["labels"] = labels_tensor
                target["masks"] = masks_tensor
                target["image_id"] = torch.tensor([i])
                target["area"] = areas
                target["iscrowd"] = torch.zeros((len(valid_boxes),), dtype=torch.int64)

            targets_list.append({k: v.to(self.device) for k, v in target.items()})

        return images_list, targets_list

def get_mask_rcnn_model(num_classes):
    # Load pre-trained model
    weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
    model = maskrcnn_resnet50_fpn(weights=weights)
    
    # Get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    # Replace the box predictor with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    # Get number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    
    # Replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask,
        hidden_layer,
        num_classes
    )
    
    return model

def train_one_epoch(model, optimizer, data_loader, device, adapter):
    model.train()
    total_loss = 0
    
    for images, boxes in tqdm(data_loader):
        # Convert dataset output to Mask R-CNN format
        images, targets = adapter((images, boxes))
        
        # Forward pass and calculate loss
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        
        # Backward pass and optimization
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        
        total_loss += losses.item()
    
    return total_loss / len(data_loader)

def bbox_iou(box1, box2):
    """
    Calculate IoU between box1 and box2
    box1, box2: tensors of shape (4,) with coordinates [x1, y1, x2, y2]
    Returns: scalar tensor with IoU value
    """
    # Get coordinates of intersection
    x1 = torch.max(box1[0], box2[0])
    y1 = torch.max(box1[1], box2[1])
    x2 = torch.min(box1[2], box2[2])
    y2 = torch.min(box1[3], box2[3])
    
    # Calculate area of intersection and union
    width = torch.clamp(x2 - x1, min=0)
    height = torch.clamp(y2 - y1, min=0)
    intersection = width * height
    
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union = box1_area + box2_area - intersection
    
    return intersection / union

def evaluate(model, data_loader, device, adapter):
    model.eval()
    all_predictions = []
    all_targets = []
    all_ious = []
    
    with torch.no_grad():
        for images, boxes in tqdm(data_loader):
            # Convert dataset output to Mask R-CNN format
            images, targets = adapter((images, boxes))
            
            outputs = model(images)
            
            for i, (output, target) in enumerate(zip(outputs, targets)):
                # Skip if no ground truth boxes
                if len(target['boxes']) == 0:
                    continue
                
                # Filter predictions by confidence
                score_threshold = 0.5
                keep_idxs = output['scores'] > score_threshold
                pred_boxes = output['boxes'][keep_idxs]
                pred_labels = output['labels'][keep_idxs]
                pred_scores = output['scores'][keep_idxs]
                
                if len(pred_boxes) == 0:
                    continue
                
                # For each ground truth box, find best matching prediction
                gt_boxes = target['boxes']
                gt_labels = target['labels']
                
                for gt_idx, (gt_box, gt_label) in enumerate(zip(gt_boxes, gt_labels)):
                    best_iou = 0
                    best_pred_idx = -1
                    
                    # Find prediction with highest IoU
                    for pred_idx, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)):
                        iou = bbox_iou(gt_box, pred_box)
                        if iou > best_iou:
                            best_iou = iou
                            best_pred_idx = pred_idx
                    
                    # If we found a matching prediction
                    if best_pred_idx >= 0:
                        all_ious.append(best_iou.cpu().item())
                        all_predictions.append(pred_labels[best_pred_idx].cpu().item())
                        all_targets.append(gt_label.cpu().item())
    
    metrics = {}

    if all_ious:
        metrics['mean_iou'] = sum(all_ious) / len(all_ious)
    
    if all_predictions and all_targets:
        metrics['f1'] = f1_score(all_targets, all_predictions, average='macro')
        metrics['precision'] = precision_score(all_targets, all_predictions, average='macro', zero_division=0)
        metrics['recall'] = recall_score(all_targets, all_predictions, average='macro', zero_division=0)
    
    return metrics

def calculate_map(model, data_loader, device, adapter, iou_threshold=0.5):
    model.eval()
    ap_per_class = {}
    
    with torch.no_grad():
        # Collect all predictions and ground truths
        all_detections = {1: [], 2: [], 3: []}  # class_id -> list of [confidence, TP/FP]
        num_gt_per_class = {1: 0, 2: 0, 3: 0}
        
        for images, boxes in tqdm(data_loader):
            # Convert dataset output to Mask R-CNN format
            images, targets = adapter((images, boxes))
            
            # Get predictions
            outputs = model(images)
            
            # Process each image
            for output, target in zip(outputs, targets):
                # Count ground truths per class
                for gt_label in target['labels']:
                    class_id = gt_label.item()
                    num_gt_per_class[class_id] = num_gt_per_class.get(class_id, 0) + 1
                
                # Skip empty predictions or targets
                if len(output['boxes']) == 0 or len(target['boxes']) == 0:
                    continue
                
                # For each predicted box
                for pred_idx, (pred_box, pred_score, pred_label) in enumerate(zip(output['boxes'], output['scores'], output['labels'])):
                    class_id = pred_label.item()
                    
                    # Find best matching ground truth
                    best_iou = 0
                    best_gt_idx = -1
                    
                    for gt_idx, (gt_box, gt_label) in enumerate(zip(target['boxes'], target['labels'])):
                        # Skip if not the same class
                        if gt_label.item() != class_id:
                            continue
                            
                        # Calculate IoU
                        iou = bbox_iou(pred_box, gt_box)
                        if iou > best_iou:
                            best_iou = iou
                            best_gt_idx = gt_idx
                    
                    # Determine if detection is TP or FP
                    if best_gt_idx >= 0 and best_iou >= iou_threshold:
                        all_detections[class_id].append([pred_score.item(), 1])  # TP
                    else:
                        all_detections[class_id].append([pred_score.item(), 0])  # FP
        
        # Calculate AP for each class
        for class_id in all_detections:
            if not all_detections[class_id] or num_gt_per_class.get(class_id, 0) == 0:
                ap_per_class[class_id] = 0
                continue
                
            # Sort by confidence
            detections = sorted(all_detections[class_id], key=lambda x: x[0], reverse=True)
            
            # Calculate precision and recall
            tp_cumsum = 0
            fp_cumsum = 0
            precision = []
            recall = []
            
            for i, (_, is_tp) in enumerate(detections):
                if is_tp:
                    tp_cumsum += 1
                else:
                    fp_cumsum += 1
                
                precision.append(tp_cumsum / (tp_cumsum + fp_cumsum))
                recall.append(tp_cumsum / num_gt_per_class[class_id])
            
            # Calculate AP using 11-point interpolation
            ap = 0
            for r in np.arange(0, 1.1, 0.1):
                if not recall or recall[-1] < r:
                    p = 0
                else:
                    p = max([precision[i] for i, rec in enumerate(recall) if rec >= r])
                ap += p / 11
            
            ap_per_class[class_id] = ap
    
    # Calculate mAP
    map_score = sum(ap_per_class.values()) / len(ap_per_class)
    
    return map_score, ap_per_class

def train_mask_rcnn(train_loader, val_loader, num_epochs=20, learning_rate=0.001):
    global device
    train_adapter = DroneDetectionAdapter(device)
    val_adapter = DroneDetectionAdapter(device)
    
    # 4 classes: background + 3 object classes (AIRPLANE, DRONE, HELICOPTER)
    num_classes = 4
    
    # Get model
    model = get_mask_rcnn_model(num_classes)
    model.to(device)
    
    # Define optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(params, lr=learning_rate, momentum=0.9, weight_decay=0.0005)
    
    # Learning rate scheduler
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='max', 
        factor=0.1, 
        patience=3,
        verbose=True
    )
    
    # Training loop
    best_f1 = 0.0
    best_map = 0.0
    
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        # Train
        train_loss = train_one_epoch(model, optimizer, train_loader, device, train_adapter)
        print(f"Train Loss: {train_loss:.4f}")
        
        # Evaluate
        metrics = evaluate(model, val_loader, device, val_adapter)
        print(f"Validation Metrics: {metrics}")
        
        # Calculate mAP
        map_score, ap_per_class = calculate_map(model, val_loader, device, val_adapter)
        print(f"mAP: {map_score:.4f}")
        print(f"AP per class: {ap_per_class}")
        
        # Update learning rate based on F1 score
        if 'f1' in metrics:
            lr_scheduler.step(metrics['f1'])
            
            # Save best model based on F1
            if metrics['f1'] > best_f1:
                best_f1 = metrics['f1']
                torch.save(model.state_dict(), 'best_f1_model.pth')
                print(f"Saved new best F1 model with F1: {best_f1:.4f}")
        
        # Save best model based on mAP
        if map_score > best_map:
            best_map = map_score
            torch.save(model.state_dict(), 'best_map_model.pth')
            print(f"Saved new best mAP model with mAP: {best_map:.4f}")
    
    return model

# Function to make predictions with the trained model
def predict(model, image_path, device, confidence_threshold=0.5):
    model.eval()
    
    # Load and transform image
    image = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((800, 800)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Get predictions
    with torch.no_grad():
        predictions = model(image_tensor)
    
    # Extract predictions
    pred_boxes = predictions[0]['boxes'].cpu().numpy()
    pred_scores = predictions[0]['scores'].cpu().numpy()
    pred_labels = predictions[0]['labels'].cpu().numpy()
    pred_masks = predictions[0]['masks'].cpu().numpy()
    
    # Filter by confidence
    high_conf_indices = pred_scores > confidence_threshold
    pred_boxes = pred_boxes[high_conf_indices]
    pred_scores = pred_scores[high_conf_indices]
    pred_labels = pred_labels[high_conf_indices]
    pred_masks = pred_masks[high_conf_indices]
    
    # Convert boxes from corner format back to center format
    height, width = image.size
    center_boxes = []
    for box in pred_boxes:
        x1, y1, x2, y2 = box
        x_center = (x1 + x2) / 2 / width
        y_center = (y1 + y2) / 2 / height
        box_width = (x2 - x1) / width
        box_height = (y2 - y1) / height
        center_boxes.append([x_center, y_center, box_width, box_height])
    
    # Create output
    class_names = ['AIRPLANE', 'DRONE', 'HELICOPTER']
    results = []
    for i in range(len(pred_labels)):
        class_id = pred_labels[i] - 1  # Convert back to 0-indexed
        results.append({
            'class': class_names[class_id] if 0 <= class_id < len(class_names) else 'UNKNOWN',
            'confidence': float(pred_scores[i]),
            'box': center_boxes[i]
        })
    
    return results

In [13]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
])

dataset = DroneDetectionDataset(df, transform=transform)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

model = train_mask_rcnn(train_loader, val_loader, num_epochs=1)



Epoch 1/1


  0%|          | 0/882 [00:00<?, ?it/s]


ValueError: too many values to unpack (expected 2)