In [1]:
import xarray as xr
import numpy as np
import rioxarray
from pathlib import Path
from matplotlib import pyplot as plt

In [None]:
# Set a mat for class code and name of the class
description_labels = {0: 'netflora_murumuru_embrapa00',
                      1: 'netflora_buriti_emprapa00', 
                      2: 'netflora_tucuma_emprapa00',
                      3: 'reforestree_banana',
                      4: 'reforestree_cacao',
                      5: 'reforestree_fruit',}

## Uniform the netflora cuout sizes

In [None]:
dir_netflora = Path('./output_cutouts')

In [None]:
def get_padsize(taget_size, img_size_x, img_size_y):
    xleft = (taget_size - img_size_x)//2
    xright = taget_size - xleft - img_size_x
    ytop = (taget_size - img_size_y)//2
    ybottom = taget_size - ytop - img_size_y

    return xleft, xright, ytop, ybottom

In [None]:
# load the imgs as xarray, padding zeros to make them square, then save each class as a zarr file
target_size = 128

# for label_dir, label_idx in zip(['label_0', 'label_2', 'label_6'], range(3)):
for label_dir, label_idx in zip(['label_0'], range(1)):
    dir_class= dir_netflora / label_dir
    list_imgs = list(dir_class.glob('*.jpg'))

    # load the first image as a template
    imgs = rioxarray.open_rasterio(list_imgs[0])

    # pad the first image to make it square
    xleft, xright, ytop, ybottom = get_padsize(target_size, imgs.sizes['x'], imgs.sizes['y'])
    imgs = imgs.pad(pad_width={'x': (xleft,xright), 'y': (ytop,ybottom)}, mode='constant', constant_values=0)
    imgs = imgs.expand_dims('sample', axis=0)
    imgs.drop_indexes(['band', 'x', 'y'])

    # loop and pad the rest of the images
    for f_img in list_imgs[1:]:
        img = rioxarray.open_rasterio(f_img)

        # if the image is larger than target size, skip it
        if img.sizes['x']>target_size or imgs.sizes['y']>target_size:
            continue

        xleft, xright, ytop, ybottom = get_padsize(target_size, img.sizes['x'], img.sizes['y'])
        img_pad = img.pad(pad_width={'x': (xleft,xright), 'y': (ytop,ybottom)}, mode='constant', constant_values=0)
        img_pad = img_pad.expand_dims('sample', axis=0)
        img_pad = img_pad.drop_indexes(['band', 'x', 'y'])
        imgs = xr.concat([imgs, img_pad], dim='sample')

    # make a dataset with both images and labels
    ds = xr.Dataset({'X': imgs, 'Y': xr.DataArray([label_idx]*len(list_imgs), dims='sample', name='label')})
    ds = ds.drop_vars(['band', 'x', 'y', 'spatial_ref'])
    ds = ds.transpose('sample', 'x', 'y', 'band')
    ds = ds.rename({'band': 'channel'})
    ds = ds.chunk('auto')

    ds.to_zarr(f'./label{label_idx}_{description_labels[label_idx]}.zarr', mode='w')

In [None]:
ds_label1 = xr.open_zarr('./label0_murumuru_embrapa00.zarr')
np.unique(ds_label1['Y']) # check the labels, should be 1

## Down sampling reforesttree 
Down sampling reforesttree dataset from 1cm to 6cm 

In [None]:
# Inspect the resolution of the reforestree data
data = rioxarray.open_rasterio('/home/oku/Developments/XAI4GEO/data/reforestree/wwf_ecuador/RGB Orthomosaics/Carlos Vera Arteaga RGB.tif')
data_subset = data.isel(x=range(8000, 9000), y=range(8000, 9000)) # Take a slice to reproject, since the reprojection is computationally expensive
data_subset_meter = data_subset.rio.reproject('EPSG:32717') # Reproject to UTM Zone 17S
data_subset_meter.rio.resolution() # This gives the resolution of ~ 1cm

In [None]:
# Inspect the resolution of the brazil data
data = rioxarray.open_rasterio('/home/oku/Developments/XAI4GEO/data/brazil_data/original_data/PNM/PROCESSADOS/Map1_Orthomosaic_export_SatJun10172428194829.tif')
data.rio.resolution() # This gives the resolution of ~ 6cm

In [None]:
# Load the reforestree data with larger than 200x200 pixels
f_reforestree =  Path('/home/oku/Developments/XAI4GEO/data/reforestree/processed/foresttree_largerthan_200.zarr')

In [None]:
data_reforestree = xr.open_zarr(f_reforestree)
data_reforestree

In [None]:
# Down sample the reforestree data by a (6, 6) window
data_reforestree_downsampled = data_reforestree.coarsen(x=6, y=6, boundary='trim').mean()
data_reforestree_downsampled

## Resizing the image

In [None]:
pad_size = (target_size - data_reforestree_downsampled.sizes['x'])//2
pad_size

In [None]:
pad_size = (target_size - data_reforestree_downsampled.sizes['x'])//2
data_reforestree_downsampled_padded = data_reforestree_downsampled.pad(mode='constant', pad_width={'x': (pad_size,pad_size), 'y': (pad_size,pad_size)}, constant_values=0)
data_reforestree_downsampled_padded

In [None]:
# Save the downsampled data to separated zarr files per class
labels = data_reforestree_downsampled_padded["Y"].compute()
for original_label, label_idx in zip([1,2,4],range(3, 6)):
    ds = data_reforestree_downsampled_padded.where(labels==original_label,drop=True)
    ds['Y'] = label_idx
    ds = ds.chunk('auto')
    ds.to_zarr(f'./label{label_idx}_{description_labels[label_idx]}.zarr', mode='w')

In [None]:
# Plot some random samples per class
n_samples_plot = 4
classes = [1, 2, 4]
rng = np.random.default_rng()
fig, axs = plt.subplots(len(classes), n_samples_plot, figsize=(15, 15))
ax_row=0
labels = data_reforestree_downsampled_padded["Y"].compute()
for exmaple_class in classes: # Loop through classes
    nsamples = np.sum((labels == exmaple_class).values)
    idx = rng.integers(0, nsamples, size=n_samples_plot)
    images = (
        data_reforestree_downsampled_padded["X"]
        .where(labels == exmaple_class, drop=True)
        .isel(sample=idx)
        .dropna(dim="sample")
        .compute()
    )
    for example_i in range(n_samples_plot):
        images.isel(sample=example_i).astype('int').plot.imshow(ax=axs[ax_row, example_i])
    ax_row += 1