In [None]:
import xarray as xr
import numpy as np
import rioxarray
from pathlib import Path
from matplotlib import pyplot as plt
from geocoded_object_extractor.utils import hash_classname

In [None]:
# Set a mat for class code and name of the class
# uids of each label is generated by hash_classname function
# Here we freeze these value since hash does not provide a random seed
description_labels = {0: 'label142377591163_murumuru.zarr', 
                      1: 'label244751236943_tucuma.zarr',
                      2: 'label174675723264_banana.zarr',
                      3: 'label999240878592_cacao.zarr',
                      4: 'label370414265344_fruit.zarr',}
# uids of each label 
description_values = {0: 142377591163, 
                      1: 244751236943,
                      2: 174675723264,
                      3: 999240878592,
                      4: 370414265344,}

## netflora cuouts

In [None]:
dir_netflora = Path('./output_cutouts_netflora')

In [None]:
def get_padsize(taget_size, img_size_x, img_size_y):
    xleft = (taget_size - img_size_x)//2
    xright = taget_size - xleft - img_size_x
    ytop = (taget_size - img_size_y)//2
    ybottom = taget_size - ytop - img_size_y

    return xleft, xright, ytop, ybottom

In [None]:
# load the imgs as xarray, padding zeros to make them square, then save each class as a zarr file
target_size = 128

for label_dir, label_idx in zip(['label_0', 'label_6'], range(2)):
    dir_class= dir_netflora / label_dir
    list_imgs = list(dir_class.glob('*.jpg'))

    # load the first image as a template
    imgs = rioxarray.open_rasterio(list_imgs[0])

    # pad the first image to make it square
    xleft, xright, ytop, ybottom = get_padsize(target_size, imgs.sizes['x'], imgs.sizes['y'])
    imgs = imgs.pad(pad_width={'x': (xleft,xright), 'y': (ytop,ybottom)}, mode='constant', constant_values=0)
    imgs = imgs.expand_dims('sample', axis=0)
    imgs.drop_indexes(['band', 'x', 'y'])

    # loop and pad the rest of the images
    for f_img in list_imgs[1:]:
        img = rioxarray.open_rasterio(f_img)

        # if the image is larger than target size, skip it
        if img.sizes['x']>target_size or imgs.sizes['y']>target_size:
            continue

        xleft, xright, ytop, ybottom = get_padsize(target_size, img.sizes['x'], img.sizes['y'])
        img_pad = img.pad(pad_width={'x': (xleft,xright), 'y': (ytop,ybottom)}, mode='constant', constant_values=0)
        img_pad = img_pad.expand_dims('sample', axis=0)
        img_pad = img_pad.drop_indexes(['band', 'x', 'y'])
        imgs = xr.concat([imgs, img_pad], dim='sample')

    # make a dataset with both images and labels
    ds = xr.Dataset({'X': imgs, 'Y': xr.DataArray(description_values[label_idx]*len(list_imgs), dims='sample', name='label')})
    ds = ds.drop_vars(['band', 'x', 'y', 'spatial_ref'])
    ds = ds.transpose('sample', 'x', 'y', 'band')
    ds = ds.rename({'band': 'channel'})
    ds = ds.chunk('auto')

    ds.to_zarr(f'./all_cutouts/{description_labels[label_idx]}', mode='w')

## reforesttree cutouts

In [None]:
# Load the reforestree data with larger than 384x384 pixels
f_reforestree =  Path('/home/oku/Developments/XAI4GEO/data/reforestree/processed/larger_than_384/foresttree_largerthan_384.zarr')
data_reforestree = xr.open_zarr(f_reforestree)
data_reforestree

In [None]:
# Down sample the reforestree data to the same resolution as the UAV
data_reforestree_downsampled = data_reforestree.coarsen(x=4, y=4, boundary='trim').mean()
data_reforestree_downsampled

In [None]:
target_size = 128
pad_size = (target_size - data_reforestree_downsampled.sizes['x'])//2
data_reforestree_downsampled_padded = data_reforestree_downsampled.pad(mode='constant', pad_width={'x': (pad_size,pad_size), 'y': (pad_size,pad_size)}, constant_values=0)
data_reforestree_downsampled_padded

In [None]:
labels = data_reforestree_downsampled["Y"].values

In [None]:
np.unique(labels)

In [None]:
# Bananas
ds_lebel1 = data_reforestree_downsampled_padded.where(data_reforestree_downsampled_padded["Y"].compute()==1, drop=True)
ds_lebel1

In [None]:
# Cacao
ds_lebel2 = data_reforestree_downsampled_padded.where(data_reforestree_downsampled_padded["Y"].compute()==2, drop=True)
ds_lebel2

In [None]:
# Fruit
ds_lebel4 = data_reforestree_downsampled_padded.where(data_reforestree_downsampled_padded["Y"].compute()==4, drop=True)
ds_lebel4

In [None]:
for ds, label_idx in zip([ds_lebel1, ds_lebel2, ds_lebel4], range(2,5)):
    ds['Y'] = xr.DataArray(np.array(description_values[label_idx]*ds.sizes["sample"]), dims='sample')
    
    ds = ds.chunk({'sample': 50, 'x': 128, 'y': 128, 'channel': 3})
    ds.to_zarr(f'./{description_labels[label_idx]}', mode='w')

In [None]:
# Randomly select 50 samples and plot
rng = np.random.default_rng()
ds_plot = ds_lebel1.isel(sample=rng.choice(ds_lebel1.sizes['sample'], 50, replace=False))
ds_plot['X'].plot.imshow(col='sample', col_wrap=5)

In [None]:
# Randomly select 50 samples and plot
rng = np.random.default_rng()
ds_plot = ds_lebel2.isel(sample=rng.choice(ds_lebel2.sizes['sample'], 50, replace=False))
ds_plot['X'].plot.imshow(col='sample', col_wrap=5)

In [None]:
# Randomly select 50 samples and plot
rng = np.random.default_rng()
ds_plot = ds_lebel4.isel(sample=rng.choice(ds_lebel4.sizes['sample'], 50, replace=False))
ds_plot['X'].plot.imshow(col='sample', col_wrap=5)

## Label data cutouts

Loop through all unique species in the `ds` dataset and select the cutouts for each species:
1. Select all cutouts from `ds` with number of samples greater than 5.
2. Select the center 128x128 pixels (orgininally it was 256x256)
3. Take the part of cutout the size without zero padding
4. If BOTH sides are less than 64 pixels, interpolate the non-zero part to 64x64 pixels
5. Pad the interpolated image to 128x128 pixels with zeros
6. Save the cutout of each species to a separate zarr file

In [None]:
data_path = Path('/home/oku/Developments/XAI4GEO/data/brazil_data')
ds_path = data_path / 'Tree_labels_merged' / 'tree_labels_merged.zarr'
ds = xr.open_zarr(ds_path)
ds = ds.compute() # Load data into memory since it's small
ds

In [None]:
# Loop through all unique species in the `ds` dataset and select the cutouts for each species:
# 1. Select all cutouts from `ds` with number of samples greater than 5.
# 2. Select the center 128x128 pixels (orgininally it was 256x256)
# 3. Take the part of cutout the size without zero padding
# 4. If BOTH sides are less than 64 pixels, interpolate the non-zero part to 64x64 pixels
# 5. Pad the interpolated image to 128x128 pixels with zeros
# 6. Save the cutout of each species to a separate zarr file

unique_species_ids = np.unique(ds['Y'].values)
unique_species_names = [ds.attrs[str(id)]['ESPECIE'] for id in unique_species_ids]

for id, name in zip(unique_species_ids, unique_species_names):
    ds_species = ds.where(ds['Y']==id,drop=True)
    if ds_species.sizes['sample'] > 5:
        ds_species = ds_species.sel(x=slice(64,192),y=slice(64,192))

        for i in range(ds_species.sizes['sample']):
            ds_i = ds_species.isel(sample=i)
            cutout = ds_i['X']

            # Get the size of non zero part
            coutout_nonzero = cutout.values
            coutout_nonzero = coutout_nonzero[~(coutout_nonzero == 0).all(axis=(0, 2))]
            idx = np.nonzero(~((coutout_nonzero == 0).all(axis=(0,2))))
            coutout_nonzero = coutout_nonzero[:, idx[0], :]
            x_size = coutout_nonzero.shape[0]
            y_size = coutout_nonzero.shape[1]

            if x_size < 64 and y_size < 64:
                # Select the non zero part in cutout
                cutout = cutout.isel(x=range(64-int(x_size/2),64+int(x_size/2)),y=range(64-int(y_size/2),64+int(y_size/2)))

                # Interpolate the non zero part to 64x64 pixels
                cutout = cutout.interp(x=np.linspace(cutout.x.min(), cutout.x.max(), 64),y=np.linspace(cutout.y.min(), cutout.y.max(), 64))

                # Reset the x and y coordinates to 0-64
                cutout['x'] = range(32,96)
                cutout['y'] = range(32,96)

                # Pad the interpolated image to 128x128 pixels with zeros
                cutout = cutout.interp(x=range(0,128),y=range(0,128), kwargs={'fill_value':0})

                # update ds_i
                ds_i['X'] = cutout
            
            if i==0:
                ds_output = ds_i
            else:
                ds_output = xr.concat([ds_output, ds_i], dim='sample')
        
        name = name.replace(' ', '_')
        ds_output.to_zarr(data_path / 'Tree_labels_merged' / f'{id}_{name}.zarr', mode='w')

In [None]:
data_path = Path('/home/oku/Developments/XAI4GEO/data/brazil_data/Tree_labels_merged')
zarr_file_list = list(data_path.glob('*.zarr'))
zarr_file_list.sort()
zarr_file_list

In [None]:
# Plot 5 samples per dataset
# make each dataset a row in subplot

fig, axs = plt.subplots(len(zarr_file_list), 5, figsize=(15, 20))
for i, zarr_file in enumerate(zarr_file_list):
    ds = xr.open_zarr(zarr_file)
    species_name = ds.attrs[str(ds['Y'].isel(sample=0).values.astype(int))]['ESPECIE']

    # andomly plot 5 images
    rng = np.random.default_rng(seed=42)
    indices = rng.permutation(ds.sizes['sample'])
    for j in range(5):
        cutout = ds['X'].isel(sample=indices[j])/255.
        cutout.plot.imshow(ax=axs[i, j])
        axs[i, j].set_title(f'{species_name}')

In [None]:
# Make a horizontal bar plot of the number of samples per dataset
ds_list = [xr.open_zarr(zarr_file) for zarr_file in zarr_file_list]
n_samples = [ds.sizes['sample'] for ds in ds_list]
labels = [ds.attrs[str(ds['Y'].isel(sample=0).values.astype(int))]['ESPECIE'] for ds in ds_list]

fig, ax = plt.subplots()
ax.barh(labels, n_samples)
# add number of samples to the right of the bars
for i, v in enumerate(n_samples):
    ax.text(v + 0.25, i, str(v), color='black', va='center')