# Render a satellite image (required before prediction or training)

In [None]:
from functions import render_raster

render_raster(
    input_tif="data/test_raster_raw.tif",
    output_tif="data/test_raster_rendered.tif",
    rgb_bands=[1,2,3], # bands are in the correct order for RGB (e.g., 1=Red, 2=Green, 3=Blue)
    min_percentile=2, # clip the lower 2% of pixel values to enhance contrast
    max_percentile=99.9, # clip the upper 0.01% of pixel values to enhance contrast
)

# Slice a satellite image into squared tiles

In [7]:
from functions import slice_raster

slice_raster(
    raster_in_path="data/test_raster_rendered.tif",
    raster_out_dir="data/tiles",
    tile_size=512, # square tiles of 512x512 pixels
    skip_empty_tiles=True, # skip tiles that are empty (all values are the same, e.g. all zeros)
)

# YOLO OBB inference

### Make YOLO OBB georeferenced predictions over a satellite image using SAHI for slicing

In [None]:
from sahi import AutoDetectionModel
from functions import yolo_obb_predict

# Load model
detection_model = AutoDetectionModel.from_pretrained(
    model_type="ultralytics", # Make sure ultralytics is installed: pip install ultralytics
    model_path="yolo11m-obb.pt", # This model is trained on DOTAv1 dataset (will be automatically downloaded). Use an OBB model
    confidence_threshold=0.2,
    device="cuda:0", # use GPU if available, otherwise use "cpu"
)

# Run prediction
yolo_obb_predict(
    image_file="data/test_raster_rendered.tif", # input georeferenced raster file (e.g. GeoTIFF)
    labels_file="data/test_raster_rendered.geojson", # output geojson file with georeferenced bounding boxes
    detection_model=detection_model, 
    tile_size=512, # squared tiles only, size in pixels
    overlap_ratio=0.1, # overlap between tiles, value between 0 and 1
    classes_to_keep=[1], # Keep only class ID 1 (ship class in DOTAv1 dataset)
)

### OR make predictions on a set of tiles

In [None]:
from ultralytics import YOLO

# Load model
model = YOLO('yolo11m-obb.pt') # This model is trained on DOTAv1 dataset (will be automatically downloaded)

# Process each tile in the directory
results = model.predict( # Output in the runs directory by default, e.g. runs/detect/predict/labels
    source="data/tiles",
    conf=0.01,
    save=True, # Set to True to save plots with bounding boxes
    save_txt=True,
    save_conf=True,
    device=0
)

# SAM3 inference

### Make SAM3 predictions on a set of tiles (SAM3 doesn't work with SAHI yet)

In [None]:
import os
from glob import glob
from ultralytics.models.sam import SAM3SemanticPredictor

DIR_TILES = "data/tiles"
PROMPTS = ["motor boat", "sailboat", "ship"] # Example text prompts for SAM3

# Initialize predictor with configuration
overrides = dict(
    conf=0.3,
    task="segment",
    mode="predict",
    model="/home/luka/Desktop/weights/sam3.pt", # Instructions for downloading https://docs.ultralytics.com/models/sam-3/#installation
    half=True,  # Use FP16 for faster inference
    save=False,  # Disable default saving
    save_txt=True,  # Save results automatically in run directory
)
predictor = SAM3SemanticPredictor(overrides=overrides)

# Process each tile in the directory
list_images = sorted(glob(os.path.join(DIR_TILES, "*.tif")))
for image_path in list_images:
    predictor.set_image(image_path) # Set image for prediction
    results = predictor(text=PROMPTS) # Query with a text prompt

### Visualize predictions and save plots

In [None]:
import matplotlib.pyplot as plt
from glob import glob
import cv2
import os

DIR_TILES = "data/tiles"
DIR_LABELS = "runs/segment/predict/labels"
COLORS = ["yellow", "magenta", "cyan", "red", "blue", "black"]

out_dir = DIR_LABELS.replace("labels", "plots") # Default output directory for plots
os.makedirs(out_dir, exist_ok=True)

list_label_files = glob(os.path.join(DIR_LABELS, "*.txt"))
for label_file in list_label_files:
    image_file = os.path.join(DIR_TILES, os.path.basename(label_file).replace(".txt", ".tif"))
    print(f"Image: {image_file}, Label: {label_file}")

    image = cv2.imread(image_file)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    height, width, _ = image.shape
    plt.imshow(image)

    with open(label_file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            parts = line.strip().split()
            cls = int(parts[0])
            coords = list(map(float, parts[1:]))
            x_coords = coords[0::2]
            y_coords = coords[1::2]

            x_pixels = [x * width for x in x_coords]
            y_pixels = [y * height for y in y_coords]

            plt.plot(x_pixels + [x_pixels[0]], y_pixels + [y_pixels[0]], color=COLORS[cls % len(COLORS)], linewidth=1)
    
    plt.savefig(os.path.join(out_dir, os.path.basename(label_file).replace(".txt", ".png")))
    plt.show()