# Make cutout from UAV image

Data sources:

- tree_labels_merged.gpkg: [link](https://zenodo.org/records/13828591/files/tree_labels.zip?download=1)
- Map1_Orthomosaic_export_SatJun10172428194829.tif: [link](https://zenodo.org/records/13828591/files/uav_img.zip?download=1)

In [1]:
import xarray as xr
import math
from pathlib import Path
import geopandas as gpd
from matplotlib import pyplot as plt
from shapely.geometry import box
import pandas as pd
import numpy as np

from geocoded_object_extractor import ObjectExtractor
from geocoded_object_extractor.utils import write_cutouts

In [2]:
tree_labels_path = Path('../data/tree_labels/tree_labels_merged.gpkg')
image_path = Path('../data/uav_img/Map1_Orthomosaic_export_SatJun10172428194829.tif')

In [None]:
tree_labels = gpd.read_file(tree_labels_path)
tree_labels

In [None]:
# Build a dictionary which maps ID to [ESPECIE, TIPO]
id_species_mapping = tree_labels[['ID', 'ESPECIE', 'TIPO']].drop_duplicates().set_index('ID')

id_species_mapping = id_species_mapping.to_dict(orient='index')
id_species_mapping


In [None]:
geoms = tree_labels.geometry
labels = tree_labels['ID']

obj_extr = ObjectExtractor(
    images=[image_path],
    geoms=geoms,
    labels=labels,
    pixel_size=256,
    max_pixel_size=256,
    encode_labels=False
)

labels, transform_params, crs, cutouts = obj_extr.get_cutouts()

In [8]:
unique_labels = labels.unique()

In [None]:
# Visualize one samples per unique label from cutouts
# five samples per row
num_unique_labels = len(unique_labels)
num_samples_per_row = 5
num_samples_per_col = math.ceil(num_unique_labels / num_samples_per_row)
fig, axs = plt.subplots(num_samples_per_col, num_samples_per_row, figsize=(20, 20))
for i, label in enumerate(unique_labels):
    row = i // num_samples_per_row
    col = i % num_samples_per_row
    sample = cutouts[labels == label][0]
    axs[row, col].imshow(sample)
    axs[row, col].set_title(f'{id_species_mapping[label]["ESPECIE"]}')
    axs[row, col].axis('off')

In [None]:
# create a Dataset and add cutouts and labels to it
ds = xr.Dataset(
    data_vars={
        'X': (['sample', 'x', 'y', 'channel'], cutouts),
        'Y': (['sample'], labels),
    },
    attrs=id_species_mapping
)
ds = ds.isel(channel=range(3))
ds

In [None]:
# save the Dataset in Zarr format
ds_path = Path('.') / 'Tree_labels_merged' / 'tree_labels_merged.zarr'
ds.to_zarr(ds_path, mode='w')

## Statistics of the selected cutouts

In [None]:
ds = xr.open_zarr(ds_path)
ds

In [13]:
# Loop through the sample dimension and remove padded zeros
imgs = []
widths = []
heights = []
for i in range(ds['X'].sizes['sample']):
    img = ds['X'].isel(sample=i).values
    img = img[~(img == 0).all(axis=(0, 2))]
    idx = np.nonzero(~((img == 0).all(axis=(0,2))))
    img = img[:, idx[0], :]
    imgs.append(img)
    widths.append(img.shape[1])
    heights.append(img.shape[0])

In [None]:
# width histogram, group by ds['Y']
widths = np.array(widths)
heights = np.array(heights)
labels = ds['Y'].values
unique_labels = np.unique(labels)
num_unique_labels = len(unique_labels)
fig, axs = plt.subplots(num_unique_labels, 2, figsize=(10, 40))
for i, label in enumerate(unique_labels):
    mask = labels == label
    axs[i, 0].hist(widths[mask], bins=20)
    axs[i, 0].set_title(f'{id_species_mapping[label]["ESPECIE"]}')
    axs[i, 0].set_xlabel('width')
    axs[i, 0].set_ylabel('count')
    axs[i, 1].hist(heights[mask], bins=20)
    axs[i, 1].set_title(f'{id_species_mapping[label]["ESPECIE"]}')
    axs[i, 1].set_xlabel('height')
    axs[i, 1].set_ylabel('count')
# set x and y limits
for ax in axs.flatten():
    ax.set_xlim([0, 256])
    ax.set_ylim([0, 5])



In [None]:
# Count the number of samples per species and plot the histogram
species_counts = ds['Y'].to_pandas().value_counts()
# update index with species names
species_counts.index = species_counts.index.map(lambda x: id_species_mapping[x]['ESPECIE'])
species_counts = species_counts.sort_values(ascending=False)
species_counts.plot(kind='bar', figsize=(15, 10))
# Add the count to the plot
for i, count in enumerate(species_counts):
    plt.text(i, count, count, ha='center', va='bottom')

In [None]:
# Make a histogram of species
# replace the ID with the species name
species = [id_species_mapping[s]['ESPECIE'] for s in ds['Y'].values]
# species = np.array(species)
plt.hist(species, bins=np.unique(species).shape[0])
plt.xticks(rotation=90)
# center the xticks
plt.gca().set_xticks(np.arange(len(np.unique(species)))-0.5)