# In-Depth Guide: Model Training & Evaluation

This notebook covers the model training process in detail, including dataset registration, configuration for single and multi-class problems, data augmentation, and post-training analysis.

For the full tutorial, see the [documentation](https://patball1.github.io/detectree2/tutorials/03_training_and_evaluation.html).

## Setup

In [None]:
!pip install torch torchvision torchaudio
!pip install 'git+https://github.com/facebookresearch/detectron2.git'
!pip install detectree2

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 1. Registering Datasets

Before training, register your tiled training and validation datasets with Detectron2.

### Single-class

In [None]:
from detectree2.models.train import register_train_data

train_location = "/path/to/Danum/tiles/train/"
register_train_data(train_location, 'Danum', val_fold=5)

### Multi-class

For multi-class datasets, provide the path to the `class_to_idx.json` file created during data preparation.

In [None]:
from detectree2.models.train import register_train_data

train_dir = "/path/to/Danum_lianas/tiles/train"
class_mapping_file = "/path/to/Danum_lianas/tiles/class_to_idx.json"
data_name = "DanumLiana"

register_train_data(train_dir, data_name, val_fold=5, class_mapping_file=class_mapping_file)

The data will be registered as `<name>_train` and `<name>_val` (e.g., `Danum_train` and `Danum_val`).

## 2. Configuring the Model

Supply a `base_model` from Detectron2's model zoo. This loads a pre-trained backbone.

### Single-class configuration

In [None]:
from detectree2.models.train import setup_cfg

# Set the base (pre-trained) model from the detectron2 model_zoo
base_model = "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"

trains = ("Danum_train", "Paracou_train")  # Registered train data
tests = ("Danum_val", "Paracou_val")       # Registered validation data

out_dir = "./train_outputs"

cfg = setup_cfg(base_model, trains, tests, workers=4, eval_period=100, max_iter=3000, out_dir=out_dir)

### Fine-tuning from a pre-trained detectree2 model

Recommended if you have limited training data.

In [None]:
# Download a pre-trained model
# !wget https://zenodo.org/records/15863800/files/250312_flexi.pth

trained_model = "./250312_flexi.pth"
cfg = setup_cfg(base_model, trains, tests, trained_model, workers=4, eval_period=100, max_iter=3000, out_dir=out_dir)

### Multi-class configuration

Pass the `class_mapping_file` to register the correct number of classes.

In [None]:
cfg = setup_cfg(
    base_model="COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml",
    trains=("DanumLiana_train",),
    tests=("DanumLiana_val",),
    max_iter=50000,
    eval_period=50,
    base_lr=0.003,
    out_dir="./liana_outputs",
    class_mapping_file=class_mapping_file
)

## 3. Running the Trainer

The trainer includes early stopping via the `patience` parameter.

In [None]:
from detectree2.models.train import MyTrainer

trainer = MyTrainer(cfg, patience=5)
trainer.resume_or_load(resume=False)
trainer.train()

## 4. Post-Training Analysis

Check that the model has converged and is not overfitting by plotting training and validation loss.

In [None]:
import json
import matplotlib.pyplot as plt
from detectree2.models.train import load_json_arr

experiment_folder = "./train_outputs"
experiment_metrics = load_json_arr(experiment_folder + '/metrics.json')

plt.plot(
    [x['iteration'] for x in experiment_metrics if 'validation_loss' in x],
    [x['validation_loss'] for x in experiment_metrics if 'validation_loss' in x],
    label='Total Validation Loss', color='red')
plt.plot(
    [x['iteration'] for x in experiment_metrics if 'total_loss' in x],
    [x['total_loss'] for x in experiment_metrics if 'total_loss' in x],
    label='Total Training Loss')

plt.legend(loc='upper right')
plt.title('Comparison of the training and validation loss of detectree2')
plt.ylabel('Total Loss')
plt.xlabel('Number of Iterations')
plt.show()

### AP50 Score Over Training

In [None]:
plt.plot(
    [x['iteration'] for x in experiment_metrics if 'bbox/AP50' in x],
    [x['bbox/AP50'] for x in experiment_metrics if 'bbox/AP50' in x],
    label='Validation AP50')

plt.legend(loc='lower right')
plt.title('Validation AP50 over training iterations')
plt.ylabel('AP50')
plt.xlabel('Number of Iterations')
plt.show()

## Performance Metrics Explained

- **IoU (Intersection over Union)**: Measures overlap between predicted and ground truth masks. Calculated as area of overlap divided by area of union.
- **AP50**: Average Precision at 50% IoU threshold. A predicted object is a true positive if its IoU with ground truth is >= 0.5.