In [None]:
import geopandas as gpd
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from os import listdir
from os.path import isfile, join
from shapely.geometry import Polygon
from geocoded_object_extractor import ObjectExtractor
from geocoded_object_extractor.utils import hash_classname
import pandas as pd
from pathlib import Path

# supress rioxarray warning
import warnings
warnings.filterwarnings("ignore")

In [None]:
# Set data paths
data_dir = Path('/home/oku/Developments/XAI4GEO/data')
regions = ['Carlos Vera Guevara', 'Carlos Vera Arteaga', 'Flora Pluas', 'Leonor Aspiazu', 'Manuel Macias', 'Nestor Macias']
dataclasses = ["other", "banana", "cacao", "citrus", "fruit", "timber"]
classes_path = data_dir/ 'reforestree/mapping/final_dataset.csv'
annot = pd.read_csv(classes_path) 

In [None]:
# Loop over regions
# For each region, extract image cutouts and save to a separate zarr
for region in regions:
    print(region)
    root = data_dir / f"reforestree/tiles/{region} RGB/"
    rgb_filenames = [f.as_posix() for f in root.rglob("*.png")]
    rgb_filenames = rgb_filenames[:3] # DEBUG

    # Find rows in annot where img_path is in rgb_filenames
    annot_selected = annot[annot["img_path"].isin([f.name for f in root.rglob("*.png")])]

    # If annot_selected is empty, skip the region
    if annot_selected.empty:
        print(f"Skipping {region}")
        continue


    # Hash the class names
    annot_selected = annot_selected.rename(columns={"group": "ESPECIE"})
    labels = annot_selected["ESPECIE"].apply(hash_classname)
    annot_selected["ID"] = labels


    # Get the bounding boxes
    # Create a geometry column from the coordinates of the bounding boxes
    geoms = [
        Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
        for xmin, xmax, ymin, ymax in zip(
            annot_selected["xmin"],
            annot_selected["xmax"],
            annot_selected["ymin"],
            annot_selected["ymax"],
        )
    ]

    # Extract image
    # Final target size is 128x128 with padding
    # At least 64x64 without padding
    # Images will be downsampled by 6x6
    obj_extr = ObjectExtractor(
        images=rgb_filenames,
        geoms=geoms,
        labels=annot_selected["ESPECIE"],
        pixel_size=768,
        min_pixel_size=384,
        max_pixel_size=768,
        encode_labels=True,
    )

    # extract the cutouts
    labels, transform_params, crs, cutouts = obj_extr.get_cutouts()

    id_species_mapping = annot_selected[['ESPECIE', 'ID']].drop_duplicates().set_index('ID')
    id_species_mapping = id_species_mapping.to_dict(orient='index')
    ds = xr.Dataset(
    data_vars={
        'X': (['sample', 'x', 'y', 'channel'], cutouts),
        'Y': (['sample'], labels),
        },
        attrs=id_species_mapping
    )
    
    ds = ds.chunk({'sample': 10, 'x': 768, 'y': 768, 'channel': 3})
    ds = ds.isel(channel=range(3))
    ds['Y'] = ds['Y'].astype(int)
    ds.to_zarr(f"{region}.zarr")

## Merge Zarr files from all regions

In [None]:
data_dir = Path('/home/oku/Developments/XAI4GEO/data/reforestree/processed/larger_than_384/extracted_files')
zarr_files = [f for f in data_dir.rglob("*.zarr")]
zarr_files

In [None]:
# loop over zarr files and merge them
ds = None
for file in zarr_files:
    if ds is None:
        ds = xr.open_zarr(file)
    else:
        ds = xr.concat([ds, xr.open_zarr(file)], dim='sample')
ds

In [None]:
ds = ds.chunk({'sample': 10, 'x': 768, 'y': 768, 'channel': 3})
ds.to_zarr('/home/oku/Developments/XAI4GEO/data/reforestree/processed/larger_than_384/foresttree_largerthan_384.zarr', mode='w')

In [None]:
# Randomly select 50 samples and plot
rng = np.random.default_rng()
ds_plot = ds.isel(sample=rng.choice(ds.sizes['sample'], 50, replace=False))
ds_plot['X'].plot.imshow(col='sample', col_wrap=5)