In [None]:
# Install required libraries
# dotadevkit is the official package for DOTA data preparation
!pip install -q dotadevkit
!pip install -q opencv-python-headless

# Import libraries
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as T
from PIL import Image
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import shutil # For file operations
from google.colab import drive # For mounting Drive

print(f"Torch version: {torch.__version__}")
print(f"Torchvision version: {torchvision.__version__}")

# Set device
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {DEVICE}")

In [None]:
import gdown
import os
import zipfile

folder_url = 'https://drive.google.com/drive/folders/1gmeE3D7R62UAtuIFOB9j2M5cUPTwtsxK?usp=drive_link'

# Download the folder first to see structure
gdown.download_folder(folder_url, output='/content/dataset', quiet=False)

# Then unzip the image files
for file in ['part1.zip', 'part2.zip', 'part3.zip']:
    zip_path = f'/content/dataset/images/{file}'
    if os.path.exists(zip_path):
        print(f"Unzipping {file}...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall('/content/DOTA/')

# Unzip label files
labels_path = '/content/dataset/labelTxt-v1.0'
if os.path.exists(labels_path):
    for file in os.listdir(labels_path):
        if file.endswith('.zip'):
            print(f"Unzipping {file}...")
            with zipfile.ZipFile(f'{labels_path}/{file}', 'r') as zip_ref:
                zip_ref.extractall('/content/DOTA/labelTxt')

In [None]:
# This is the formatted dir the devkit expects
SOURCE_DATA_PATH = '/content/DOTA'
# This is where the sliced patches will be saved
PATCHES_PATH = '/content/DOTA_patches'
os.makedirs(f'{PATCHES_PATH}')

In [None]:
print("Starting data slicing with dotadevkit...")
# This will take a long time (15-30+ minutes)
# Command: dotadevkit split <input_dir> <output_dir> <num_processes> <tile_size> <overlap>
!dotadevkit split "{SOURCE_DATA_PATH}" "{PATCHES_PATH}" 8 1024 200

print("Slicing complete.")
print(f"Patches saved to {PATCHES_PATH}")

In [None]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as T
from PIL import Image
import os
import numpy as np # Make sure numpy is imported

# Define data augmentation transforms
def get_transform(train):
    transforms = []
    transforms.append(T.ToImage()) # Convert PIL image to tensor
    transforms.append(T.ToDtype(torch.float32, scale=True)) # Normalize to [0, 1]
    if train:
        # Simple augmentation for training
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

class DOTAPatchDataset(Dataset):
    """
    Loads DOTA patch images and their 10-part OBB annotations.
    - Filters for the class at parts[8]
    - Converts 8 OBB coordinates to 4 HBB coordinates

    *** VERBOSE DEBUG VERSION ***
    """
    def __init__(self, patches_dir, transforms):
        self.patches_dir = patches_dir
        self.transforms = transforms

        self.img_dir = os.path.join(patches_dir, "images")
        self.ann_dir = os.path.join(patches_dir, "labelTxt")

        all_img_files = [f for f in os.listdir(self.img_dir) if f.endswith('.png')]

        self.img_files = []
        print(f"--- STARTING VERBOSE FILTER ---")
        print(f"Found {len(all_img_files)} total image patches.")
        print(f"Now checking for 'small-vehicle' in corresponding annotation files...")

        # To avoid spamming, we'll only print first 10 successes and failures
        success_prints = 0
        fail_prints = 0

        for img_name in all_img_files:
            ann_name = img_name.replace('.png', '.txt')
            ann_path = os.path.join(self.ann_dir, ann_name)

            if not os.path.exists(ann_path):
                print(f"  [SKIP] Image {img_name} has no matching annotation file.")
                continue

            found = False
            try:
                with open(ann_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        # Check for 10 parts (8 coords + class + difficulty)
                        # and check parts[8] for the class name
                        if len(parts) >= 9 and parts[8] == 'small-vehicle':
                            found = True
                            break # Found it!
            except Exception as e:
                print(f"  [ERROR] Could not read {ann_name}. Error: {e}")
                continue

            if found:
                self.img_files.append(img_name)
                if success_prints < 10:
                     print(f"  [SUCCESS] Found 'small-vehicle' in {ann_name}!")
                     success_prints += 1
            else:
                if fail_prints < 10:
                     print(f"  [FAIL] No 'small-vehicle' found in {ann_name}.")
                     fail_prints += 1

        print(f"--- VERBOSE FILTER COMPLETE ---")
        print(f"Filter complete. Found {len(self.img_files)} images with 'small-vehicle'.")
        if len(self.img_files) == 0:
            print("Warning: No images with 'small-vehicle' were found.")

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        # Load image
        img_name = self.img_files[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        # Load annotations
        ann_name = img_name.replace('.png', '.txt')
        ann_path = os.path.join(self.ann_dir, ann_name)

        boxes = []
        labels = []

        with open(ann_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                # Check for 10 parts and class at parts[8]
                if len(parts) >= 9 and parts[8] == 'small-vehicle':
                    try:
                        # 1. Get the 8 OBB coordinates
                        obb_coords = np.array([float(p) for p in parts[:8]]).reshape(4, 2)

                        # 2. Convert OBB to HBB (Horizontal Bounding Box)
                        xmin = np.min(obb_coords[:, 0])
                        ymin = np.min(obb_coords[:, 1])
                        xmax = np.max(obb_coords[:, 0])
                        ymax = np.max(obb_coords[:, 1])

                        # 3. Add to list
                        if xmax > xmin and ymax > ymin:
                            boxes.append([xmin, ymin, xmax, ymax])
                            labels.append(1) # '1' for small-vehicle
                    except ValueError:
                        print(f"Skipping malformed line in {ann_name}: {line}")
                        continue

        if len(boxes) == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros(0, dtype=torch.int64)
            area = torch.zeros(0, dtype=torch.float32)
        else:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])
        target["area"] = area
        target["iscrowd"] = torch.zeros((boxes.shape[0],), dtype=torch.int64)

        if self.transforms:
            image, target = self.transforms(image, target)

        return image, target

# Collate function for the DataLoader
def collate_fn(batch):
    return tuple(zip(*batch))

In [None]:
def get_model(num_classes=2):
    # Load a pre-trained Faster R-CNN model
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")

    # Get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features

    # Replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

In [None]:
# --- Configuration ---
NUM_EPOCHS = 10
BATCH_SIZE = 4
LEARNING_RATE = 0.005
# ---------------------

# 1. Prepare DataLoaders
full_dataset = DOTAPatchDataset(PATCHES_PATH, get_transform(train=True))

# Split dataset into training and validation (90% train, 10% val)
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# We must apply the 'train' transform to the train_dataset
# and the 'test' transform to the val_dataset.
# We do this by "re-wrapping" them.
train_dataset.dataset.transforms = get_transform(train=True)
val_dataset.dataset.transforms = get_transform(train=False)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)

print(f"Training on {len(train_dataset)} images, Validating on {len(val_dataset)} images.")

# 2. Get Model
model = get_model(num_classes=2)
model.to(DEVICE)

# 3. Set up Optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=LEARNING_RATE,
    momentum=0.9,
    weight_decay=0.0005
)

# Learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

print("--- Starting Training ---")

for epoch in range(NUM_EPOCHS):

    # --- Training Phase ---
    model.train() # Set model to training mode
    total_train_loss = 0

    for i, (images, targets) in enumerate(train_loader):
        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        # Forward pass
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        # Backward pass
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        total_train_loss += losses.item()

        if (i + 1) % 100 == 0:
            print(f"  Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(train_loader)}], Loss: {losses.item():.4f}")

    avg_train_loss = total_train_loss / len(train_loader)

    # Update the learning rate
    lr_scheduler.step()

    # --- Validation Phase ---
    model.train() # Set model to evaluation mode
    total_val_loss = 0

    with torch.no_grad():
        for images, targets in val_loader:
            images = list(image.to(DEVICE) for image in images)
            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

            # During eval, model still needs targets to calculate loss
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            total_val_loss += losses.item()

    avg_val_loss = total_val_loss / len(val_loader)

    print(f"--- Epoch {epoch+1} Summary ---")
    print(f"Avg Training Loss: {avg_train_loss:.4f}")
    print(f"Avg Validation Loss: {avg_val_loss:.4f}")
    print("--------------------------")

# Save the trained model
MODEL_SAVE_PATH = '/content/fasterrcnn_dota.pth'
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"Training complete. Model saved to {MODEL_SAVE_PATH}")

In [None]:
from google.colab import files

# This is the path where you saved your model
MODEL_SAVE_PATH = '/content/fasterrcnn_dota.pth'

print(f"Downloading {MODEL_SAVE_PATH}...")

# This command triggers the download
files.download(MODEL_SAVE_PATH)

In [None]:
# Helper function to visualize
def visualize_prediction(img, pred, threshold=0.7):
    """
    Draws the image and predicted bounding boxes.
    """
    img = img.cpu().permute(1, 2, 0).numpy() # Convert from (C, H, W) to (H, W, C)

    fig, ax = plt.subplots(1, figsize=(12, 9))
    ax.imshow(img)

    for box, score, label in zip(pred['boxes'], pred['scores'], pred['labels']):
        if score > threshold:
            # DOTA boxes are [xmin, ymin, xmax, ymax]
            xmin, ymin, xmax, ymax = box
            width = xmax - xmin
            height = ymax - ymin

            # Create a Rectangle patch
            rect = patches.Rectangle(
                (xmin, ymin),
                width,
                height,
                linewidth=2,
                edgecolor='r',
                facecolor='none'
            )

            # Add the patch to the Axes
            ax.add_patch(rect)
            plt.text(
                xmin, ymin - 5,
                f'Vehicle: {score:.2f}',
                bbox=dict(facecolor='red', alpha=0.5, pad=0),
                color='white'
            )

    plt.axis('off')
    plt.show()

# --- Load Model and Run Inference ---

# 1. Load the saved model
model = get_model(num_classes=2)
model.load_state_dict(torch.load(MODEL_SAVE_PATH))
model.to(DEVICE)
model.eval() # Set to evaluation mode

# 2. Get a random image from the validation set
img, target = random.choice(val_dataset)

# 3. Run inference
with torch.no_grad():
    # Add a batch dimension and send to device
    prediction = model([img.to(DEVICE)])[0] # [0] to get first item in batch

# 4. Visualize
print("Displaying prediction on a random validation image:")
visualize_prediction(img, prediction)