# Train YOLOv8 on SpaceNet 6 (Rotterdam)

This notebook automates the download, preparation, and training of a YOLOv8 model on the SpaceNet 6 Rotterdam aerial imagery dataset.

**Runtime Requirement**: ensure you are connected to a GPU runtime (Runtime > Change runtime type > T4 GPU).

In [None]:
# 1. Install Dependencies
!pip install ultralytics rasterio geopandas boto3

In [None]:
import os
import boto3
from botocore import UNSIGNED
from botocore.config import Config
import rasterio
import geopandas as gpd
import glob
from tqdm import tqdm
from ultralytics import YOLO
import shutil
import random
import yaml
import numpy as np

# 2. Configuration
BUCKET = 'spacenet-dataset'
PREFIX_IMAGES = 'spacenet/SN6_buildings/train/AOI_11_Rotterdam/PS-RGB/'
PREFIX_LABELS = 'spacenet/SN6_buildings/train/AOI_11_Rotterdam/geojson_buildings/'
LOCAL_DIR = '/content/dataset'
IMAGE_DIR = os.path.join(LOCAL_DIR, 'images')
GEOJSON_DIR = os.path.join(LOCAL_DIR, 'geojson')
LABEL_DIR = os.path.join(LOCAL_DIR, 'labels')

os.makedirs(IMAGE_DIR, exist_ok=True)
os.makedirs(GEOJSON_DIR, exist_ok=True)
os.makedirs(LABEL_DIR, exist_ok=True)

In [None]:
# 3. Download Data
def download_s3_folder(bucket, prefix, local_dir):
    s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
    paginator = s3.get_paginator('list_objects_v2')
    print(f"Downloading from {prefix}...")
    for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
        if 'Contents' in page:
            for obj in page['Contents']:
                key = obj['Key']
                if key.endswith('/'): continue
                filename = os.path.basename(key)
                local_path = os.path.join(local_dir, filename)
                if not os.path.exists(local_path):
                    s3.download_file(bucket, key, local_path)

download_s3_folder(BUCKET, PREFIX_LABELS, GEOJSON_DIR)
download_s3_folder(BUCKET, PREFIX_IMAGES, IMAGE_DIR)

In [None]:
# 4. Convert Labels to YOLO Format
def detect_black_bars(image_array, threshold=10, min_bar_size=5, bar_threshold=0.95):
    """
    Detect black bars at the top and bottom of an image.
    
    Args:
        image_array: numpy array of shape (H, W) or (H, W, C)
        threshold: pixel value threshold below which is considered "black" (0-255)
        min_bar_size: minimum number of rows/columns to consider as a bar
        bar_threshold: fraction of pixels in a row/column that must be black to consider it a bar
    
    Returns:
        tuple: (top_crop, bottom_crop, left_crop, right_crop) in pixels
    """
    # Handle multi-channel images by converting to grayscale
    if len(image_array.shape) == 3:
        gray = np.mean(image_array, axis=2)
    else:
        gray = image_array
    
    height, width = gray.shape
    top_crop = 0
    bottom_crop = 0
    left_crop = 0
    right_crop = 0
    
    # Detect top black bar
    for i in range(height):
        row = gray[i, :]
        black_pixels = np.sum(row < threshold)
        if black_pixels / width >= bar_threshold:
            top_crop = i + 1
        else:
            break
    
    # Detect bottom black bar
    for i in range(height - 1, -1, -1):
        row = gray[i, :]
        black_pixels = np.sum(row < threshold)
        if black_pixels / width >= bar_threshold:
            bottom_crop = height - i
        else:
            break
    
    # Detect left black bar
    for j in range(width):
        col = gray[:, j]
        black_pixels = np.sum(col < threshold)
        if black_pixels / height >= bar_threshold:
            left_crop = j + 1
        else:
            break
    
    # Detect right black bar
    for j in range(width - 1, -1, -1):
        col = gray[:, j]
        black_pixels = np.sum(col < threshold)
        if black_pixels / height >= bar_threshold:
            right_crop = width - j
        else:
            break
    
    # Only return crops if they meet minimum size
    if top_crop < min_bar_size:
        top_crop = 0
    if bottom_crop < min_bar_size:
        bottom_crop = 0
    if left_crop < min_bar_size:
        left_crop = 0
    if right_crop < min_bar_size:
        right_crop = 0
    
    return top_crop, bottom_crop, left_crop, right_crop

def convert_labels():
    geojson_files = glob.glob(os.path.join(GEOJSON_DIR, "*.geojson"))
    print(f"Found {len(geojson_files)} label files.")
    
    for geojson_path in tqdm(geojson_files, desc="Converting"):
        filename = os.path.basename(geojson_path)
        # SN6_Train_AOI_11_Rotterdam_Buildings_... -> SN6_Train_AOI_11_Rotterdam_PS-RGB_...
        image_filename = filename.replace("Buildings", "PS-RGB").replace(".geojson", ".tif")
        image_path = os.path.join(IMAGE_DIR, image_filename)
        
        if not os.path.exists(image_path):
            continue
            
        try:
            with rasterio.open(image_path) as src:
                img_height, img_width = src.height, src.width
                src_transform = src.transform
                src_crs = src.crs
                
                # Detect black bars
                image_array = src.read()
                # Convert from (C, H, W) to (H, W, C) for detection
                if len(image_array.shape) == 3:
                    image_array = np.transpose(image_array, (1, 2, 0))
                
                top_crop, bottom_crop, left_crop, right_crop = detect_black_bars(
                    image_array, threshold=10, min_bar_size=5, bar_threshold=0.95
                )
                
                # Adjust image dimensions for cropping
                effective_height = img_height - top_crop - bottom_crop
                effective_width = img_width - left_crop - right_crop
                
            gdf = gpd.read_file(geojson_path)
            yolo_lines = []
            
            if not gdf.empty:
                if src_crs and src_crs.to_string() != "EPSG:4326":
                     try:
                        gdf = gdf.to_crs(src_crs)
                     except Exception:
                        pass

                for _, row in gdf.iterrows():
                    geom = row.geometry
                    if geom.is_empty:
                        continue
                    
                    minx, miny, maxx, maxy = geom.bounds
                    # ~src_transform * (x, y) returns (col, row) which is (x, y) in pixels.
                    c1, r1 = ~src_transform * (minx, maxy)
                    c2, r2 = ~src_transform * (maxx, miny)
                    
                    xmin_px = min(c1, c2)
                    xmax_px = max(c1, c2)
                    ymin_px = min(r1, r2)
                    ymax_px = max(r1, r2)
                    
                    # Adjust for black bar cropping
                    xmin_px = xmin_px - left_crop
                    xmax_px = xmax_px - left_crop
                    ymin_px = ymin_px - top_crop
                    ymax_px = ymax_px - top_crop
                    
                    # Clip to cropped image bounds
                    xmin_px = max(0, xmin_px)
                    ymin_px = max(0, ymin_px)
                    xmax_px = min(effective_width, xmax_px)
                    ymax_px = min(effective_height, ymax_px)
                    
                    if xmax_px <= xmin_px or ymax_px <= ymin_px:
                        continue
                    
                    # Convert to YOLO format (normalize by cropped image dimensions)
                    bbox_width = xmax_px - xmin_px
                    bbox_height = ymax_px - ymin_px
                    x_center = xmin_px + bbox_width / 2
                    y_center = ymin_px + bbox_height / 2
                    
                    x_center /= effective_width
                    y_center /= effective_height
                    bbox_width /= effective_width
                    bbox_height /= effective_height
                    
                    yolo_lines.append(f"0 {x_center:.6f} {y_center:.6f} {bbox_width:.6f} {bbox_height:.6f}")
            
            with open(os.path.join(LABEL_DIR, image_filename.replace(".tif", ".txt")), "w") as f:
                f.write("\n".join(yolo_lines))
                
        except Exception as e:
            print(f"Error processing {filename}: {e}")

convert_labels()

# 5. Split Dataset and Config
images = glob.glob(os.path.join(IMAGE_DIR, "*.tif"))
random.shuffle(images)
split = int(len(images) * 0.8)
train_imgs = images[:split]
val_imgs = images[split:]

with open('train.txt', 'w') as f:
    f.write('\n'.join(train_imgs))

with open('val.txt', 'w') as f:
    f.write('\n'.join(val_imgs))

data_yaml = f"""
names:
  0: building
path: {os.getcwd()}
train: train.txt
val: val.txt
"""

with open('data.yaml', 'w') as f:
    f.write(data_yaml)

In [None]:
# 5. Split Dataset and Config
images = glob.glob(os.path.join(IMAGE_DIR, "*.tif"))
random.shuffle(images)
split = int(len(images) * 0.8)
train_imgs = images[:split]
val_imgs = images[split:]

with open('train.txt', 'w') as f:
    f.write('\n'.join(train_imgs))

with open('val.txt', 'w') as f:
    f.write('\n'.join(val_imgs))

data_yaml = f"""
names:
  0: building
path: {os.getcwd()}
train: train.txt
val: val.txt
"""

with open('data.yaml', 'w') as f:
    f.write(data_yaml)

In [None]:
# 6. Load DIOR Model and Quantize to INT8 with Calibration
from huggingface_hub import hf_hub_download

# Download DIOR model
print("Downloading DIOR model from HuggingFace...")
dior_model_path = hf_hub_download(
    repo_id="pauhidalgoo/yolov8-DIOR",
    filename="DIOR_yolov8n_backbone.pt"
)
print(f"Model downloaded to: {dior_model_path}")

# Load the model
model = YOLO(dior_model_path)
print("Model loaded successfully!")

# Prepare calibration dataset
# Option 1: Use validation images from training (if available)
if 'val_imgs' in globals() and len(val_imgs) > 0:
    calibration_images = val_imgs[:100]  # Use first 100 validation images
    print(f"Using {len(calibration_images)} validation images for calibration")
# Option 2: Use images from IMAGE_DIR if available
elif 'IMAGE_DIR' in globals() and os.path.exists(IMAGE_DIR):
    all_images = glob.glob(os.path.join(IMAGE_DIR, "*.tif"))
    calibration_images = all_images[:100]  # Use first 100 images
    print(f"Using {len(calibration_images)} images from {IMAGE_DIR} for calibration")
# Option 3: Manually specify image paths (uncomment and modify as needed)
# calibration_images = [
#     "/path/to/image1.tif",
#     "/path/to/image2.tif",
#     # ... add more image paths
# ]
else:
    raise ValueError("No calibration images found! Please run cells 2-5 first, or manually specify calibration_images list.")

# Export to ONNX first (required for INT8 quantization)
print("\nExporting to ONNX format...")
onnx_path = dior_model_path.replace('.pt', '.onnx')
model.export(
    format='onnx',
    imgsz=416,
    simplify=True,
    opset=12
)
print(f"ONNX model exported to: {onnx_path}")

# Now quantize to INT8 with calibration
print("\nQuantizing to INT8 with calibration...")
!pip install onnxruntime -q

from onnxruntime.quantization import quantize_static, QuantType, CalibrationDataReader
from PIL import Image
import numpy as np

class DIORCalibrationDataReader(CalibrationDataReader):
    """Calibration data reader for DIOR model."""
    def __init__(self, image_paths, image_size=416, input_name=None):
        self.image_paths = image_paths
        self.image_size = image_size
        self.current_index = 0
        self.input_name = input_name or 'images'  # Default input name
        
    def get_next(self):
        if self.current_index >= len(self.image_paths):
            return None
        
        image_path = self.image_paths[self.current_index]
        self.current_index += 1
        
        try:
            # Load and preprocess image
            import rasterio
            with rasterio.open(image_path) as src:
                img = src.read([1, 2, 3]).transpose(1, 2, 0)
                
            # Normalize (same as inference)
            p2, p98 = np.percentile(img, (2, 98))
            img = np.clip((img - p2) / (p98 - p2) * 255.0, 0, 255).astype(np.uint8)
            
            # Resize to model input size
            pil_img = Image.fromarray(img)
            pil_img = pil_img.resize((self.image_size, self.image_size))
            
            # Convert to numpy array and normalize to [0, 1]
            img_array = np.array(pil_img).astype(np.float32) / 255.0
            
            # Convert to CHW format and add batch dimension
            img_array = img_array.transpose(2, 0, 1)  # HWC -> CHW
            img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension
            
            return {self.input_name: img_array}
        except Exception as e:
            print(f"Error processing {image_path}: {e}")
            return self.get_next()  # Skip and try next image

# Get the correct input name from ONNX model
import onnx
onnx_model = onnx.load(onnx_path)
input_name = onnx_model.graph.input[0].name
print(f"ONNX model input name: {input_name}")

# Create calibration data reader with correct input name
calibration_reader = DIORCalibrationDataReader(calibration_images, image_size=416, input_name=input_name)

# Quantize with static calibration
int8_model_path = dior_model_path.replace('.pt', '_int8_calibrated.onnx')
print(f"\nQuantizing model (this may take a few minutes)...")

quantize_static(
    model_input=onnx_path,
    model_output=int8_model_path,
    calibration_data_reader=calibration_reader,
    quant_type=QuantType.QInt8,  # Use signed int8 for better accuracy
    optimize_model=True
)

print(f"\nâœ… INT8 quantized model saved to: {int8_model_path}")
print(f"Model size comparison:")
import os
original_size = os.path.getsize(onnx_path) / (1024 * 1024)
quantized_size = os.path.getsize(int8_model_path) / (1024 * 1024)
print(f"  Original ONNX: {original_size:.2f} MB")
print(f"  Quantized INT8: {quantized_size:.2f} MB")
print(f"  Size reduction: {(1 - quantized_size/original_size)*100:.1f}%")

In [None]:
# 7. Zip Results for Download
!zip -r trained_model.zip spacenet_rotterdam