# Matching algorithm
In this notebook, we present the OneForest matching algorithm and use it to match tree detections to field data measurements

In [1]:
import os

import numpy as np
import pandas as pd
import ot

from src import (
    MatchedFieldData,
    ImagesRegistry,
    FieldDataRegistry,
    OrthomosaicGps,
    SiteShape,
    DeepForestDetectionRegistry,
    calculate_ot_map,
    match_detections_to_field_data
)

DEEPFOREST_DETECTIONS_DIR = "../data/processed/predicted_bbox/"
FIELD_DATA_FILEPATH = "../data/raw/field_data.csv"
ANNOTATIONS_FILEPATH = "../data/raw/annotations/all_annotations.csv"
GPS_DATA = "../data/raw/ortho_data.csv"
SITE_SHAPEFILE = "../data/raw/Merged_final_plots/Merged_final_plots.shp"
OUTPUT_DIR: str = "../data/processed/mappings"

Load data and metadata registries

In [2]:
field_data_registry = FieldDataRegistry(FIELD_DATA_FILEPATH)
deepforest_detection_registry = DeepForestDetectionRegistry(DEEPFOREST_DETECTIONS_DIR)  # To use hand annotated bounding boxes, use src.tree_detection.HandAnnotatedDetectionRegistry instead
orthomosaic_names = ImagesRegistry(ANNOTATIONS_FILEPATH).get_orthomosaic_names()
ortho_gps = OrthomosaicGps(GPS_DATA)
site_shape = SiteShape(SITE_SHAPEFILE)

## Define otimal transport algorithm

In [3]:
ot_func = ot.bregman.sinkhorn

In [4]:
def load_field_data_and_deepforest_detections(orthomosaic_name):
    field_data = field_data_registry.get_field_data_for_image(orthomosaic_name)
    deepforest_detections = deepforest_detection_registry.load_detections_for_image(orthomosaic_name)
    return field_data, deepforest_detections


In [6]:
def filter_out_of_site_detections(
    field_data, deepforest_detections_unfiltered,
    ortho_gps, orthomosaic_name,
    dilation_step = 0.0001
):
    ortho_gps.set_orthomosaic_name(orthomosaic_name)
    num_field_data = len(field_data)
    num_detections, dilation = 0, 0
    while num_detections < num_field_data:
        filtered_detections = []
        site_shape.set_dilation(dilation)
        for detection in deepforest_detections_unfiltered:
            detection_x = (detection.xmax + detection.xmin) / 2
            detection_y = (detection.ymax + detection.ymin) / 2
            detection_lon, detection_lat = ortho_gps.calculate_x_y_to_lon_lat(detection_x, detection_y, True)
            if site_shape.is_in_site(orthomosaic_name, detection_lon, detection_lat):
                filtered_detections.append(detection)
        num_detections = len(filtered_detections)
        dilation += dilation_step
        print(f"Dilation = {dilation}")
        print("====================")
        print(f"Number of detections = {num_detections}")
        print(f"Number of field data = {num_field_data}")
    return filtered_detections

In [7]:
def match_field_data_and_detections(detections, field_data, mu, greedy):
    matched_detections = []
    num_detections = len(detections)
    num_field_data = len(field_data)
    # Extract tree detection data for OT
    detection_coord = np.zeros((num_detections, 2))
    detection_proba = np.zeros((num_detections, 1))
    for i, detection in enumerate(detections):
        detection_x = (detection.xmax + detection.xmin) / 2
        detection_y = (detection.ymax + detection.ymin) / 2
        detection_coord[i, 0], detection_coord[i, 1] = detection_x, detection_y
        detection_proba[i] = detection.score
    # Extract field data for OT
    field_data_coord = np.zeros((num_field_data, 2))
    field_data_proba = np.ones((num_field_data, 1))
    for i, row in enumerate(field_data):
        lon, lat = row.lon, row.lat
        field_data_x, field_data_y = ortho_gps.calculate_lon_lat_to_x_y(lon, lat, True)
        field_data_coord[i, 0], field_data_coord[i, 1] = field_data_x, field_data_y
    # Optimal transport
    ot_plan = calculate_ot_map(ot_func, detection_coord, field_data_coord, detection_proba, field_data_proba, mu)
    matched_detections.extend(match_detections_to_field_data(
        detections,
        field_data, ot_plan, greedy
    ))
    return matched_detections

## Match detections to field measurements

In [9]:
use_out_of_site_filter = True
mu = 0.5
greedy = True

matched_detections = []
for orthomosaic_name in orthomosaic_names:
    print(f"Matching on: {orthomosaic_name}")
    cur_field_data, cur_detections = load_field_data_and_deepforest_detections(orthomosaic_name)
    if use_out_of_site_filter:
        cur_detections = filter_out_of_site_detections(cur_field_data, cur_detections, ortho_gps, orthomosaic_name)
    matched_detections.extend(match_field_data_and_detections(cur_detections, cur_field_data, mu, greedy))


Matching on: Nestor Macias RGB
Dilation = 0.0001
Number of detections = 394
Number of field data = 872
Dilation = 0.0002
Number of detections = 675
Number of field data = 872
Dilation = 0.00030000000000000003
Number of detections = 982
Number of field data = 872
Number of detection: 872
Number of field data measurements: 872
Number of unique matches: 872
Matching on: Leonor Aspiazu RGB
Dilation = 0.0001
Number of detections = 322
Number of field data = 789
Dilation = 0.0002
Number of detections = 544
Number of field data = 789
Dilation = 0.00030000000000000003
Number of detections = 798
Number of field data = 789
Number of detection: 789
Number of field data measurements: 789
Number of unique matches: 789
Matching on: Carlos Vera Arteaga RGB
Dilation = 0.0001
Number of detections = 250
Number of field data = 743
Dilation = 0.0002
Number of detections = 395
Number of field data = 743
Dilation = 0.00030000000000000003
Number of detections = 607
Number of field data = 743
Dilation = 0.000

Save matches

In [12]:
matched_detections_df = pd.DataFrame.from_dict([MatchedFieldData.to_dict(d) for d in matched_detections])
csv_filename = "reproduced_final_matching.csv"
csv_filepath = os.path.join(OUTPUT_DIR, csv_filename)
matched_detections_df.to_csv(csv_filepath)