# Getting Started: Your First Prediction

This notebook provides a concise, end-to-end walkthrough to get you from an orthomosaic to a final crown prediction map using **detectree2**.

The key steps are:
1. Preparing data (tiling)
2. Training a model
3. Making landscape-level predictions

For the full tutorial, see the [documentation](https://patball1.github.io/detectree2/tutorials/01_getting_started.html).

Example data is available on [Zenodo](https://zenodo.org/records/8136161).

## Setup

In [None]:
!pip install torch torchvision torchaudio
!pip install 'git+https://github.com/facebookresearch/detectron2.git'
!pip install detectree2

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 1. Preparing Data

First, we tile our large orthomosaic and crown data into smaller images suitable for training.

You will need:
- An orthomosaic (`.tif`)
- Corresponding tree crown polygons (`.gpkg` or `.shp`)

For best results, manual crowns should be supplied as dense clusters rather than sparsely scattered across the landscape.

In [None]:
from detectree2.preprocessing.tiling import tile_data, to_traintest_folders
import geopandas as gpd
import rasterio

In [None]:
# Set up input paths
site_path = "./Paracou"  # Example path
img_path = site_path + "/rgb/Paracou_RGB_2016_10cm.tif"
crown_path = site_path + "/crowns/UpdatedCrowns8.gpkg"

# Read in crowns and match CRS to the image
data = rasterio.open(img_path)
crowns = gpd.read_file(crown_path)
crowns = crowns.to_crs(data.crs.data)

In [None]:
# Set tiling parameters
buffer = 30
tile_width = 40
tile_height = 40
threshold = 0.6
out_dir = site_path + "/tiles/"

# Tile the data for training
tile_data(img_path, out_dir, buffer, tile_width, tile_height, crowns, threshold, mode="rgb")

In [None]:
# Create train/test folders
to_traintest_folders(out_dir, out_dir, test_frac=0.15)

## 2. Training a Model

Register the training data, configure the model, and train.

In [None]:
from detectree2.models.train import register_train_data, MyTrainer, setup_cfg

train_location = out_dir + "/train/"
register_train_data(train_location, 'Paracou', val_fold=5)

In [None]:
# Set the base (pre-trained) model from the detectron2 model_zoo
base_model = "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"

trains = ("Paracou_train",)  # Registered train data
tests = ("Paracou_val",)    # Registered validation data

model_output_dir = "./train_outputs"

cfg = setup_cfg(base_model, trains, tests, workers=4, eval_period=100, max_iter=3000, out_dir=model_output_dir)

In [None]:
trainer = MyTrainer(cfg, patience=5)
trainer.resume_or_load(resume=False)
trainer.train()

## 3. Making Landscape-Level Predictions

Tile the full orthomosaic, run predictions, then project back to geographic coordinates.

In [None]:
from detectree2.models.predict import predict_on_data
from detectree2.models.outputs import project_to_geojson, stitch_crowns, clean_crowns
from detectron2.engine import DefaultPredictor

# Path to the full orthomosaic
img_path = site_path + "/rgb/Paracou_RGB_2016_10cm.tif"
pred_tiles_path = site_path + "/tiles_pred/"

# Specify tiling parameters (should be similar to training)
buffer = 30
tile_width = 40
tile_height = 40
tile_data(img_path, pred_tiles_path, buffer, tile_width, tile_height)

In [None]:
# You can use your own trained model or download a pre-trained one
# !wget https://zenodo.org/records/15863800/files/250312_flexi.pth

trained_model = "./230103_randresize_full.pth"
cfg = setup_cfg(update_model=trained_model)
predictor = DefaultPredictor(cfg)
predict_on_data(pred_tiles_path, predictor)

In [None]:
# Project tile predictions to geo-referenced crowns
project_to_geojson(pred_tiles_path, pred_tiles_path + "predictions/", pred_tiles_path + "predictions_geo/")

# Stitch and clean crowns
crowns = stitch_crowns(pred_tiles_path + "predictions_geo/")
clean = clean_crowns(crowns, 0.6, confidence=0.5)  # Filter low-confidence and overlapping crowns

## 4. Saving and Visualizing

Save the cleaned crown map. You can view the output in QGIS or ArcGIS.

In [None]:
# Simplify geometries for easier editing in GIS software
clean = clean.set_geometry(clean.simplify(0.3))

# Save to file
clean.to_file(site_path + "/crowns_out.gpkg", driver="GPKG")